naga/proc/overloads/
mathfunction.rs

1//! Overload sets for [`ir::MathFunction`].
2
3use crate::proc::overloads::any_overload_set::AnyOverloadSet;
4use crate::proc::overloads::list::List;
5use crate::proc::overloads::regular::regular;
6use crate::proc::overloads::utils::{
7    concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
8    scalar_or_vecn, triples, vector_sizes,
9};
10use crate::proc::overloads::OverloadSet;
11
12use crate::ir;
13
14impl ir::MathFunction {
15    pub fn overloads(self) -> impl OverloadSet {
16        use ir::MathFunction as Mf;
17
18        let set: AnyOverloadSet = match self {
19            // Component-wise unary numeric operations
20            Mf::Abs | Mf::Sign => regular!(1, SCALAR|VECN of NUMERIC).into(),
21
22            // Component-wise binary numeric operations
23            Mf::Min | Mf::Max => regular!(2, SCALAR|VECN of NUMERIC).into(),
24
25            // Component-wise ternary numeric operations
26            Mf::Clamp => regular!(3, SCALAR|VECN of NUMERIC).into(),
27
28            // Component-wise unary floating-point operations
29            Mf::Sin
30            | Mf::Cos
31            | Mf::Tan
32            | Mf::Asin
33            | Mf::Acos
34            | Mf::Atan
35            | Mf::Sinh
36            | Mf::Cosh
37            | Mf::Tanh
38            | Mf::Asinh
39            | Mf::Acosh
40            | Mf::Atanh
41            | Mf::Saturate
42            | Mf::Radians
43            | Mf::Degrees
44            | Mf::Ceil
45            | Mf::Floor
46            | Mf::Round
47            | Mf::Fract
48            | Mf::Trunc
49            | Mf::Exp
50            | Mf::Exp2
51            | Mf::Log
52            | Mf::Log2
53            | Mf::Sqrt
54            | Mf::InverseSqrt => regular!(1, SCALAR|VECN of FLOAT).into(),
55
56            // Component-wise binary floating-point operations
57            Mf::Atan2 | Mf::Pow | Mf::Step => regular!(2, SCALAR|VECN of FLOAT).into(),
58
59            // Component-wise ternary floating-point operations
60            Mf::Fma | Mf::SmoothStep => regular!(3, SCALAR|VECN of FLOAT).into(),
61
62            // Component-wise unary concrete integer operations
63            Mf::CountTrailingZeros
64            | Mf::CountLeadingZeros
65            | Mf::CountOneBits
66            | Mf::ReverseBits
67            | Mf::FirstTrailingBit
68            | Mf::FirstLeadingBit => regular!(1, SCALAR|VECN of CONCRETE_INTEGER).into(),
69
70            // Packing functions
71            Mf::Pack4x8snorm | Mf::Pack4x8unorm => regular!(1, VEC4 of F32 -> U32).into(),
72            Mf::Pack2x16snorm | Mf::Pack2x16unorm | Mf::Pack2x16float => {
73                regular!(1, VEC2 of F32 -> U32).into()
74            }
75            Mf::Pack4xI8 => regular!(1, VEC4 of I32 -> U32).into(),
76            Mf::Pack4xU8 => regular!(1, VEC4 of U32 -> U32).into(),
77
78            // Unpacking functions
79            Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => regular!(1, SCALAR of U32 -> Vec4F).into(),
80            Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
81                regular!(1, SCALAR of U32 -> Vec2F).into()
82            }
83            Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
84            Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
85
86            // One-off operations
87            Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),
88            Mf::Modf => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Modf).into(),
89            Mf::Frexp => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Frexp).into(),
90            Mf::Ldexp => ldexp().into(),
91            Mf::Outer => outer().into(),
92            Mf::Cross => regular!(2, VEC3 of FLOAT).into(),
93            Mf::Distance => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
94            Mf::Length => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
95            Mf::Normalize => regular!(1, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
96            Mf::FaceForward => regular!(3, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
97            Mf::Reflect => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
98            Mf::Refract => refract().into(),
99            Mf::Mix => mix().into(),
100            Mf::Inverse => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT).into(),
101            Mf::Transpose => transpose().into(),
102            Mf::Determinant => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT -> Scalar).into(),
103            Mf::QuantizeToF16 => regular!(1, SCALAR|VECN of F32).into(),
104            Mf::ExtractBits => extract_bits().into(),
105            Mf::InsertBits => insert_bits().into(),
106        };
107
108        set
109    }
110}
111
112fn ldexp() -> List {
113    /// Construct the exponent scalar given the mantissa's inner.
114    fn exponent_from_mantissa(mantissa: ir::Scalar) -> ir::Scalar {
115        match mantissa.kind {
116            ir::ScalarKind::AbstractFloat => ir::Scalar::ABSTRACT_INT,
117            ir::ScalarKind::Float => ir::Scalar::I32,
118            _ => unreachable!("not a float scalar"),
119        }
120    }
121
122    list(
123        // The ldexp mantissa argument can be any floating-point type.
124        float_scalars_unimplemented_abstract().flat_map(|mantissa_scalar| {
125            // The exponent type is the integer counterpart of the mantissa type.
126            let exponent_scalar = exponent_from_mantissa(mantissa_scalar);
127            // There are scalar and vector component-wise overloads.
128            scalar_or_vecn(mantissa_scalar)
129                .zip(scalar_or_vecn(exponent_scalar))
130                .map(move |(mantissa, exponent)| {
131                    let result = mantissa.clone();
132                    rule([mantissa, exponent], result)
133                })
134        }),
135    )
136}
137
138fn outer() -> List {
139    list(
140        triples(
141            vector_sizes(),
142            vector_sizes(),
143            float_scalars_unimplemented_abstract(),
144        )
145        .map(|(cols, rows, scalar)| {
146            let left = ir::TypeInner::Vector { size: cols, scalar };
147            let right = ir::TypeInner::Vector { size: rows, scalar };
148            let result = ir::TypeInner::Matrix {
149                columns: cols,
150                rows,
151                scalar,
152            };
153            rule([left, right], result)
154        }),
155    )
156}
157
158fn refract() -> List {
159    list(
160        pairs(vector_sizes(), float_scalars_unimplemented_abstract()).map(|(size, scalar)| {
161            let incident = ir::TypeInner::Vector { size, scalar };
162            let normal = incident.clone();
163            let ratio = ir::TypeInner::Scalar(scalar);
164            let result = incident.clone();
165            rule([incident, normal, ratio], result)
166        }),
167    )
168}
169
170fn transpose() -> List {
171    list(
172        triples(vector_sizes(), vector_sizes(), float_scalars()).map(|(a, b, scalar)| {
173            let input = ir::TypeInner::Matrix {
174                columns: a,
175                rows: b,
176                scalar,
177            };
178            let output = ir::TypeInner::Matrix {
179                columns: b,
180                rows: a,
181                scalar,
182            };
183            rule([input], output)
184        }),
185    )
186}
187
188fn extract_bits() -> List {
189    list(concrete_int_scalars().flat_map(|scalar| {
190        scalar_or_vecn(scalar).map(|input| {
191            let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
192            let count = ir::TypeInner::Scalar(ir::Scalar::U32);
193            let output = input.clone();
194            rule([input, offset, count], output)
195        })
196    }))
197}
198
199fn insert_bits() -> List {
200    list(concrete_int_scalars().flat_map(|scalar| {
201        scalar_or_vecn(scalar).map(|input| {
202            let newbits = input.clone();
203            let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
204            let count = ir::TypeInner::Scalar(ir::Scalar::U32);
205            let output = input.clone();
206            rule([input, newbits, offset, count], output)
207        })
208    }))
209}
210
211fn mix() -> List {
212    list(float_scalars().flat_map(|scalar| {
213        scalar_or_vecn(scalar).flat_map(move |input| {
214            let scalar_ratio = ir::TypeInner::Scalar(scalar);
215            [
216                rule([input.clone(), input.clone(), input.clone()], input.clone()),
217                rule([input.clone(), input.clone(), scalar_ratio], input),
218            ]
219        })
220    }))
221}