naga/proc/overloads/
regular.rs

1/*! A representation for highly regular overload sets common in Naga IR.
2
3Many Naga builtin functions' overload sets have a highly regular
4structure. For example, many arithmetic functions can be applied to
5any floating-point type, or any vector thereof. This module defines a
6handful of types for representing such simple overload sets that is
7simple and efficient.
8
9*/
10
11use crate::common::{DiagnosticDebug, ForDebugWithTypes};
12use crate::ir;
13use crate::proc::overloads::constructor_set::{ConstructorSet, ConstructorSize};
14use crate::proc::overloads::rule::{Conclusion, Rule};
15use crate::proc::overloads::scalar_set::ScalarSet;
16use crate::proc::overloads::OverloadSet;
17use crate::proc::{GlobalCtx, TypeResolution};
18use crate::UniqueArena;
19
20use alloc::vec::Vec;
21use core::fmt;
22
23/// Overload sets represented as sets of scalars and constructors.
24///
25/// This type represents an [`OverloadSet`] using a bitset of scalar
26/// types and a bitset of type constructors that might be applied to
27/// those scalars. The overload set contains a rule for every possible
28/// combination of scalars and constructors, essentially the cartesian
29/// product of the two sets.
30///
31/// For example, if the arity is 2, set of scalars is { AbstractFloat,
32/// `f32` }, and the set of constructors is { `vec2`, `vec3` }, then
33/// that represents the set of overloads:
34///
35/// - (`vec2<AbstractFloat>`, `vec2<AbstractFloat>`) -> `vec2<AbstractFloat>`
36/// - (`vec2<f32>`, `vec2<f32>`) -> `vec2<f32>`
37/// - (`vec3<AbstractFloat>`, `vec3<AbstractFloat>`) -> `vec3<AbstractFloat>`
38/// - (`vec3<f32>`, `vec3<f32>`) -> `vec3<f32>`
39///
40/// The `conclude` value says how to determine the return type from
41/// the argument type.
42///
43/// Restrictions:
44///
45/// - All overloads must take the same number of arguments.
46///
47/// - For any given overload, all its arguments must have the same
48///   type.
49#[derive(Clone)]
50pub(in crate::proc::overloads) struct Regular {
51    /// The number of arguments in the rules.
52    pub arity: usize,
53
54    /// The set of type constructors to apply.
55    pub constructors: ConstructorSet,
56
57    /// The set of scalars to apply them to.
58    pub scalars: ScalarSet,
59
60    /// How to determine a member rule's return type given the type of
61    /// its arguments.
62    pub conclude: ConclusionRule,
63}
64
65impl Regular {
66    pub(in crate::proc::overloads) const EMPTY: Regular = Regular {
67        arity: 0,
68        constructors: ConstructorSet::empty(),
69        scalars: ScalarSet::empty(),
70        conclude: ConclusionRule::ArgumentType,
71    };
72
73    /// Return an iterator over all the argument types allowed by `self`.
74    ///
75    /// Return an iterator that produces, for each overload in `self`, the
76    /// constructor and scalar of its argument types and return type.
77    ///
78    /// A [`Regular`] value can only represent overload sets where, in
79    /// each overload, all the arguments have the same type, and the
80    /// return type is always going to be a determined by the argument
81    /// types, so giving the constructor and scalar is sufficient to
82    /// characterize the entire rule.
83    fn members(&self) -> impl Iterator<Item = (ConstructorSize, ir::Scalar)> {
84        let scalars = self.scalars;
85        self.constructors.members().flat_map(move |constructor| {
86            let size = constructor.size();
87            // Technically, we don't need the "most general" `TypeInner` here,
88            // but since `ScalarSet::members` only produces singletons anyway,
89            // the effect is the same.
90            scalars
91                .members()
92                .map(move |singleton| (size, singleton.most_general_scalar()))
93        })
94    }
95
96    fn rules(&self) -> impl Iterator<Item = Rule> {
97        let arity = self.arity;
98        let conclude = self.conclude;
99        self.members()
100            .map(move |(size, scalar)| make_rule(arity, size, scalar, conclude))
101    }
102}
103
104impl OverloadSet for Regular {
105    fn is_empty(&self) -> bool {
106        self.constructors.is_empty() || self.scalars.is_empty()
107    }
108
109    fn min_arguments(&self) -> usize {
110        assert!(!self.is_empty());
111        self.arity
112    }
113
114    fn max_arguments(&self) -> usize {
115        assert!(!self.is_empty());
116        self.arity
117    }
118
119    fn arg(&self, i: usize, ty: &ir::TypeInner, types: &UniqueArena<ir::Type>) -> Self {
120        if i >= self.arity {
121            return Self::EMPTY;
122        }
123
124        let constructor = ConstructorSet::singleton(ty);
125
126        let scalars = match ty.scalar_for_conversions(types) {
127            Some(ty_scalar) => ScalarSet::convertible_from(ty_scalar),
128            None => ScalarSet::empty(),
129        };
130
131        Self {
132            arity: self.arity,
133
134            // Constrain all member rules' constructors to match `ty`'s.
135            constructors: self.constructors & constructor,
136
137            // Constrain all member rules' arguments to be something
138            // that `ty` can be converted to.
139            scalars: self.scalars & scalars,
140
141            conclude: self.conclude,
142        }
143    }
144
145    fn concrete_only(self, _types: &UniqueArena<ir::Type>) -> Self {
146        Self {
147            scalars: self.scalars & ScalarSet::CONCRETE,
148            ..self
149        }
150    }
151
152    fn most_preferred(&self) -> Rule {
153        assert!(!self.is_empty());
154
155        // If there is more than one constructor allowed, then we must
156        // not have had any arguments supplied at all. In any case, we
157        // don't have any unambiguously preferred candidate.
158        assert!(self.constructors.is_singleton());
159
160        let size = self.constructors.size();
161        let scalar = self.scalars.most_general_scalar();
162        make_rule(self.arity, size, scalar, self.conclude)
163    }
164
165    fn overload_list(&self, _gctx: &GlobalCtx<'_>) -> Vec<Rule> {
166        self.rules().collect()
167    }
168
169    fn allowed_args(&self, i: usize, _gctx: &GlobalCtx<'_>) -> Vec<TypeResolution> {
170        if i >= self.arity {
171            return Vec::new();
172        }
173        self.members()
174            .map(|(size, scalar)| TypeResolution::Value(size.to_inner(scalar)))
175            .collect()
176    }
177
178    fn for_debug(&self, types: &UniqueArena<ir::Type>) -> impl fmt::Debug {
179        DiagnosticDebug((self, types))
180    }
181}
182
183/// Construct a [`Regular`] member [`Rule`] for the given arity and type.
184///
185/// [`Regular`] can only represent rules where all the argument types and the
186/// return type are the same, so just knowing `arity` and `inner` is sufficient.
187///
188/// [`Rule`]: crate::proc::overloads::Rule
189fn make_rule(
190    arity: usize,
191    size: ConstructorSize,
192    scalar: ir::Scalar,
193    conclusion_rule: ConclusionRule,
194) -> Rule {
195    let inner = size.to_inner(scalar);
196    let arg = TypeResolution::Value(inner.clone());
197    Rule {
198        arguments: core::iter::repeat(arg.clone()).take(arity).collect(),
199        conclusion: conclusion_rule.conclude(size, scalar),
200    }
201}
202
203/// Conclusion-computing rules.
204#[derive(Clone, Copy, Debug)]
205#[repr(u8)]
206pub(in crate::proc::overloads) enum ConclusionRule {
207    ArgumentType,
208    Scalar,
209    Frexp,
210    Modf,
211    U32,
212    Vec2F,
213    Vec4F,
214    Vec4I,
215    Vec4U,
216}
217
218impl ConclusionRule {
219    fn conclude(self, size: ConstructorSize, scalar: ir::Scalar) -> Conclusion {
220        match self {
221            Self::ArgumentType => Conclusion::Value(size.to_inner(scalar)),
222            Self::Scalar => Conclusion::Value(ir::TypeInner::Scalar(scalar)),
223            Self::Frexp => Conclusion::for_frexp_modf(ir::MathFunction::Frexp, size, scalar),
224            Self::Modf => Conclusion::for_frexp_modf(ir::MathFunction::Modf, size, scalar),
225            Self::U32 => Conclusion::Value(ir::TypeInner::Scalar(ir::Scalar::U32)),
226            Self::Vec2F => Conclusion::Value(ir::TypeInner::Vector {
227                size: ir::VectorSize::Bi,
228                scalar: ir::Scalar::F32,
229            }),
230            Self::Vec4F => Conclusion::Value(ir::TypeInner::Vector {
231                size: ir::VectorSize::Quad,
232                scalar: ir::Scalar::F32,
233            }),
234            Self::Vec4I => Conclusion::Value(ir::TypeInner::Vector {
235                size: ir::VectorSize::Quad,
236                scalar: ir::Scalar::I32,
237            }),
238            Self::Vec4U => Conclusion::Value(ir::TypeInner::Vector {
239                size: ir::VectorSize::Quad,
240                scalar: ir::Scalar::U32,
241            }),
242        }
243    }
244}
245
246impl fmt::Debug for DiagnosticDebug<(&Regular, &UniqueArena<ir::Type>)> {
247    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248        let (regular, types) = self.0;
249        let rules: Vec<Rule> = regular.rules().collect();
250        f.debug_struct("List")
251            .field("rules", &rules.for_debug(types))
252            .field("conclude", &regular.conclude)
253            .finish()
254    }
255}
256
257impl ForDebugWithTypes for &Regular {}
258
259impl fmt::Debug for DiagnosticDebug<(&[Rule], &UniqueArena<ir::Type>)> {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        let (rules, types) = self.0;
262        f.debug_list()
263            .entries(rules.iter().map(|rule| rule.for_debug(types)))
264            .finish()
265    }
266}
267
268impl ForDebugWithTypes for &[Rule] {}
269
270/// Construct a [`Regular`] [`OverloadSet`].
271///
272/// Examples:
273///
274/// - `regular!(2, SCALAR|VECN of FLOAT)`: An overload set whose rules take two
275///   arguments of the same type: a floating-point scalar (possibly abstract) or
276///   a vector of such. The return type is the same as the argument type.
277///
278/// - `regular!(1, VECN of FLOAT -> Scalar)`: An overload set whose rules take
279///   one argument that is a vector of floats, and whose return type is the leaf
280///   scalar type of the argument type.
281///
282/// The constructor values (before the `<` angle brackets `>`) are
283/// constants from [`ConstructorSet`].
284///
285/// The scalar values (inside the `<` angle brackets `>`) are
286/// constants from [`ScalarSet`].
287///
288/// When a return type identifier is given, it is treated as a variant
289/// of the the [`ConclusionRule`] enum.
290macro_rules! regular {
291    // regular!(ARITY, CONSTRUCTOR of SCALAR)
292    ( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|*) => {
293        {
294            use $crate::proc::overloads;
295            use overloads::constructor_set::constructor_set;
296            use overloads::regular::{Regular, ConclusionRule};
297            use overloads::scalar_set::scalar_set;
298            Regular {
299                arity: $arity,
300                constructors: constructor_set!( $( $constr )|* ),
301                scalars: scalar_set!( $( $scalar )|* ),
302                conclude: ConclusionRule::ArgumentType,
303            }
304        }
305    };
306
307    // regular!(ARITY, CONSTRUCTOR of SCALAR -> CONCLUSION_RULE)
308    ( $arity:literal , $( $constr:ident )|* of $( $scalar:ident )|* -> $conclude:ident) => {
309        {
310            use $crate::proc::overloads;
311            use overloads::constructor_set::constructor_set;
312            use overloads::regular::{Regular, ConclusionRule};
313            use overloads::scalar_set::scalar_set;
314            Regular {
315                arity: $arity,
316                constructors:constructor_set!( $( $constr )|* ),
317                scalars: scalar_set!( $( $scalar )|* ),
318                conclude: ConclusionRule::$conclude,
319            }
320        }
321    };
322}
323
324pub(in crate::proc::overloads) use regular;
325
326#[cfg(test)]
327mod test {
328    use super::*;
329    use crate::ir;
330
331    const fn scalar(scalar: ir::Scalar) -> ir::TypeInner {
332        ir::TypeInner::Scalar(scalar)
333    }
334
335    const fn vec2(scalar: ir::Scalar) -> ir::TypeInner {
336        ir::TypeInner::Vector {
337            scalar,
338            size: ir::VectorSize::Bi,
339        }
340    }
341
342    const fn vec3(scalar: ir::Scalar) -> ir::TypeInner {
343        ir::TypeInner::Vector {
344            scalar,
345            size: ir::VectorSize::Tri,
346        }
347    }
348
349    /// Assert that `set` has a most preferred candidate whose type
350    /// conclusion is `expected`.
351    #[track_caller]
352    fn check_return_type(set: &Regular, expected: &ir::TypeInner, arena: &UniqueArena<ir::Type>) {
353        assert!(!set.is_empty());
354
355        let special_types = ir::SpecialTypes::default();
356
357        let preferred = set.most_preferred();
358        let conclusion = preferred.conclusion;
359        let resolution = conclusion
360            .into_resolution(&special_types)
361            .expect("special types should have been pre-registered");
362        let inner = resolution.inner_with(arena);
363
364        assert!(
365            inner.equivalent(expected, arena),
366            "Expected {:?}, got {:?}",
367            expected.for_debug(arena),
368            inner.for_debug(arena),
369        );
370    }
371
372    #[test]
373    fn unary_vec_or_scalar_numeric_scalar() {
374        let arena = UniqueArena::default();
375
376        let builtin = regular!(1, SCALAR of NUMERIC);
377
378        let ok = builtin.arg(0, &scalar(ir::Scalar::U32), &arena);
379        check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
380
381        let err = builtin.arg(0, &scalar(ir::Scalar::BOOL), &arena);
382        assert!(err.is_empty());
383    }
384
385    #[test]
386    fn unary_vec_or_scalar_numeric_vector() {
387        let arena = UniqueArena::default();
388
389        let builtin = regular!(1, VECN|SCALAR of NUMERIC);
390
391        let ok = builtin.arg(0, &vec3(ir::Scalar::F64), &arena);
392        check_return_type(&ok, &vec3(ir::Scalar::F64), &arena);
393
394        let err = builtin.arg(0, &vec3(ir::Scalar::BOOL), &arena);
395        assert!(err.is_empty());
396    }
397
398    #[test]
399    fn unary_vec_or_scalar_numeric_matrix() {
400        let arena = UniqueArena::default();
401
402        let builtin = regular!(1, VECN|SCALAR of NUMERIC);
403
404        let err = builtin.arg(
405            0,
406            &ir::TypeInner::Matrix {
407                columns: ir::VectorSize::Tri,
408                rows: ir::VectorSize::Tri,
409                scalar: ir::Scalar::F32,
410            },
411            &arena,
412        );
413        assert!(err.is_empty());
414    }
415
416    #[test]
417    #[rustfmt::skip]
418    fn binary_vec_or_scalar_numeric_scalar() {
419        let arena = UniqueArena::default();
420
421        let builtin = regular!(2, VECN|SCALAR of NUMERIC);
422
423        let ok = builtin
424            .arg(0, &scalar(ir::Scalar::F32), &arena)
425            .arg(1, &scalar(ir::Scalar::F32), &arena);
426        check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
427
428        let ok = builtin
429            .arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
430            .arg(1, &scalar(ir::Scalar::F32), &arena);
431        check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
432
433        let ok = builtin
434            .arg(0, &scalar(ir::Scalar::F32), &arena)
435            .arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
436        check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
437
438        let ok = builtin
439            .arg(0, &scalar(ir::Scalar::U32), &arena)
440            .arg(1, &scalar(ir::Scalar::U32), &arena);
441        check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
442
443        let ok = builtin
444            .arg(0, &scalar(ir::Scalar::U32), &arena)
445            .arg(1, &scalar(ir::Scalar::ABSTRACT_INT), &arena);
446        check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
447
448        let ok = builtin
449            .arg(0, &scalar(ir::Scalar::ABSTRACT_INT), &arena)
450            .arg(1, &scalar(ir::Scalar::U32), &arena);
451        check_return_type(&ok, &scalar(ir::Scalar::U32), &arena);
452
453        // Not numeric.
454        let err = builtin
455            .arg(0, &scalar(ir::Scalar::BOOL), &arena)
456            .arg(1, &scalar(ir::Scalar::BOOL), &arena);
457        assert!(err.is_empty());
458
459        // Different floating-point types.
460        let err = builtin
461            .arg(0, &scalar(ir::Scalar::F32), &arena)
462            .arg(1, &scalar(ir::Scalar::F64), &arena);
463        assert!(err.is_empty());
464
465        // Different constructor.
466        let err = builtin
467            .arg(0, &scalar(ir::Scalar::F32), &arena)
468            .arg(1, &vec2(ir::Scalar::F32), &arena);
469        assert!(err.is_empty());
470
471        // Different vector size
472        let err = builtin
473            .arg(0, &vec2(ir::Scalar::F32), &arena)
474            .arg(1, &vec3(ir::Scalar::F32), &arena);
475        assert!(err.is_empty());
476    }
477
478    #[test]
479    #[rustfmt::skip]
480    fn binary_vec_or_scalar_numeric_vector() {
481        let arena = UniqueArena::default();
482
483        let builtin = regular!(2, VECN|SCALAR of NUMERIC);
484
485        let ok = builtin
486            .arg(0, &vec3(ir::Scalar::F32), &arena)
487            .arg(1, &vec3(ir::Scalar::F32), &arena);
488        check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
489
490        // Different vector sizes.
491        let err = builtin
492            .arg(0, &vec2(ir::Scalar::F32), &arena)
493            .arg(1, &vec3(ir::Scalar::F32), &arena);
494        assert!(err.is_empty());
495
496        // Different vector scalars.
497        let err = builtin
498            .arg(0, &vec3(ir::Scalar::F32), &arena)
499            .arg(1, &vec3(ir::Scalar::F64), &arena);
500        assert!(err.is_empty());
501
502        // Mix of vectors and scalars.
503        let err = builtin
504            .arg(0, &scalar(ir::Scalar::F32), &arena)
505            .arg(1, &vec3(ir::Scalar::F32), &arena);
506        assert!(err.is_empty());
507    }
508
509    #[test]
510    #[rustfmt::skip]
511    fn binary_vec_or_scalar_numeric_vector_abstract() {
512        let arena = UniqueArena::default();
513
514        let builtin = regular!(2, VECN|SCALAR of NUMERIC);
515
516        let ok = builtin
517            .arg(0, &vec2(ir::Scalar::ABSTRACT_INT), &arena)
518            .arg(1, &vec2(ir::Scalar::U32), &arena);
519        check_return_type(&ok, &vec2(ir::Scalar::U32), &arena);
520
521        let ok = builtin
522            .arg(0, &vec3(ir::Scalar::ABSTRACT_INT), &arena)
523            .arg(1, &vec3(ir::Scalar::F32), &arena);
524        check_return_type(&ok, &vec3(ir::Scalar::F32), &arena);
525
526        let ok = builtin
527            .arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
528            .arg(1, &scalar(ir::Scalar::F32), &arena);
529        check_return_type(&ok, &scalar(ir::Scalar::F32), &arena);
530
531        let err = builtin
532            .arg(0, &scalar(ir::Scalar::ABSTRACT_FLOAT), &arena)
533            .arg(1, &scalar(ir::Scalar::U32), &arena);
534        assert!(err.is_empty());
535
536        let err = builtin
537            .arg(0, &scalar(ir::Scalar::I32), &arena)
538            .arg(1, &scalar(ir::Scalar::U32), &arena);
539        assert!(err.is_empty());
540    }
541}