1use crate::proc::overloads::any_overload_set::AnyOverloadSet;
4use crate::proc::overloads::list::List;
5use crate::proc::overloads::regular::regular;
6use crate::proc::overloads::utils::{
7 concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
8 scalar_or_vecn, triples, vector_sizes,
9};
10use crate::proc::overloads::OverloadSet;
11
12use crate::ir;
13
14impl ir::MathFunction {
15 pub fn overloads(self) -> impl OverloadSet {
16 use ir::MathFunction as Mf;
17
18 let set: AnyOverloadSet = match self {
19 Mf::Abs | Mf::Sign => regular!(1, SCALAR|VECN of NUMERIC).into(),
21
22 Mf::Min | Mf::Max => regular!(2, SCALAR|VECN of NUMERIC).into(),
24
25 Mf::Clamp => regular!(3, SCALAR|VECN of NUMERIC).into(),
27
28 Mf::Sin
30 | Mf::Cos
31 | Mf::Tan
32 | Mf::Asin
33 | Mf::Acos
34 | Mf::Atan
35 | Mf::Sinh
36 | Mf::Cosh
37 | Mf::Tanh
38 | Mf::Asinh
39 | Mf::Acosh
40 | Mf::Atanh
41 | Mf::Saturate
42 | Mf::Radians
43 | Mf::Degrees
44 | Mf::Ceil
45 | Mf::Floor
46 | Mf::Round
47 | Mf::Fract
48 | Mf::Trunc
49 | Mf::Exp
50 | Mf::Exp2
51 | Mf::Log
52 | Mf::Log2
53 | Mf::Sqrt
54 | Mf::InverseSqrt => regular!(1, SCALAR|VECN of FLOAT).into(),
55
56 Mf::Atan2 | Mf::Pow | Mf::Step => regular!(2, SCALAR|VECN of FLOAT).into(),
58
59 Mf::Fma | Mf::SmoothStep => regular!(3, SCALAR|VECN of FLOAT).into(),
61
62 Mf::CountTrailingZeros
64 | Mf::CountLeadingZeros
65 | Mf::CountOneBits
66 | Mf::ReverseBits
67 | Mf::FirstTrailingBit
68 | Mf::FirstLeadingBit => regular!(1, SCALAR|VECN of CONCRETE_INTEGER).into(),
69
70 Mf::Pack4x8snorm | Mf::Pack4x8unorm => regular!(1, VEC4 of F32 -> U32).into(),
72 Mf::Pack2x16snorm | Mf::Pack2x16unorm | Mf::Pack2x16float => {
73 regular!(1, VEC2 of F32 -> U32).into()
74 }
75 Mf::Pack4xI8 => regular!(1, VEC4 of I32 -> U32).into(),
76 Mf::Pack4xU8 => regular!(1, VEC4 of U32 -> U32).into(),
77
78 Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => regular!(1, SCALAR of U32 -> Vec4F).into(),
80 Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack2x16float => {
81 regular!(1, SCALAR of U32 -> Vec2F).into()
82 }
83 Mf::Unpack4xI8 => regular!(1, SCALAR of U32 -> Vec4I).into(),
84 Mf::Unpack4xU8 => regular!(1, SCALAR of U32 -> Vec4U).into(),
85
86 Mf::Dot => regular!(2, VECN of NUMERIC -> Scalar).into(),
88 Mf::Modf => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Modf).into(),
89 Mf::Frexp => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Frexp).into(),
90 Mf::Ldexp => ldexp().into(),
91 Mf::Outer => outer().into(),
92 Mf::Cross => regular!(2, VEC3 of FLOAT).into(),
93 Mf::Distance => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
94 Mf::Length => regular!(1, SCALAR|VECN of FLOAT_ABSTRACT_UNIMPLEMENTED -> Scalar).into(),
95 Mf::Normalize => regular!(1, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
96 Mf::FaceForward => regular!(3, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
97 Mf::Reflect => regular!(2, VECN of FLOAT_ABSTRACT_UNIMPLEMENTED).into(),
98 Mf::Refract => refract().into(),
99 Mf::Mix => mix().into(),
100 Mf::Inverse => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT).into(),
101 Mf::Transpose => transpose().into(),
102 Mf::Determinant => regular!(1, MAT2X2|MAT3X3|MAT4X4 of FLOAT -> Scalar).into(),
103 Mf::QuantizeToF16 => regular!(1, SCALAR|VECN of F32).into(),
104 Mf::ExtractBits => extract_bits().into(),
105 Mf::InsertBits => insert_bits().into(),
106 };
107
108 set
109 }
110}
111
112fn ldexp() -> List {
113 fn exponent_from_mantissa(mantissa: ir::Scalar) -> ir::Scalar {
115 match mantissa.kind {
116 ir::ScalarKind::AbstractFloat => ir::Scalar::ABSTRACT_INT,
117 ir::ScalarKind::Float => ir::Scalar::I32,
118 _ => unreachable!("not a float scalar"),
119 }
120 }
121
122 list(
123 float_scalars_unimplemented_abstract().flat_map(|mantissa_scalar| {
125 let exponent_scalar = exponent_from_mantissa(mantissa_scalar);
127 scalar_or_vecn(mantissa_scalar)
129 .zip(scalar_or_vecn(exponent_scalar))
130 .map(move |(mantissa, exponent)| {
131 let result = mantissa.clone();
132 rule([mantissa, exponent], result)
133 })
134 }),
135 )
136}
137
138fn outer() -> List {
139 list(
140 triples(
141 vector_sizes(),
142 vector_sizes(),
143 float_scalars_unimplemented_abstract(),
144 )
145 .map(|(cols, rows, scalar)| {
146 let left = ir::TypeInner::Vector { size: cols, scalar };
147 let right = ir::TypeInner::Vector { size: rows, scalar };
148 let result = ir::TypeInner::Matrix {
149 columns: cols,
150 rows,
151 scalar,
152 };
153 rule([left, right], result)
154 }),
155 )
156}
157
158fn refract() -> List {
159 list(
160 pairs(vector_sizes(), float_scalars_unimplemented_abstract()).map(|(size, scalar)| {
161 let incident = ir::TypeInner::Vector { size, scalar };
162 let normal = incident.clone();
163 let ratio = ir::TypeInner::Scalar(scalar);
164 let result = incident.clone();
165 rule([incident, normal, ratio], result)
166 }),
167 )
168}
169
170fn transpose() -> List {
171 list(
172 triples(vector_sizes(), vector_sizes(), float_scalars()).map(|(a, b, scalar)| {
173 let input = ir::TypeInner::Matrix {
174 columns: a,
175 rows: b,
176 scalar,
177 };
178 let output = ir::TypeInner::Matrix {
179 columns: b,
180 rows: a,
181 scalar,
182 };
183 rule([input], output)
184 }),
185 )
186}
187
188fn extract_bits() -> List {
189 list(concrete_int_scalars().flat_map(|scalar| {
190 scalar_or_vecn(scalar).map(|input| {
191 let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
192 let count = ir::TypeInner::Scalar(ir::Scalar::U32);
193 let output = input.clone();
194 rule([input, offset, count], output)
195 })
196 }))
197}
198
199fn insert_bits() -> List {
200 list(concrete_int_scalars().flat_map(|scalar| {
201 scalar_or_vecn(scalar).map(|input| {
202 let newbits = input.clone();
203 let offset = ir::TypeInner::Scalar(ir::Scalar::U32);
204 let count = ir::TypeInner::Scalar(ir::Scalar::U32);
205 let output = input.clone();
206 rule([input, newbits, offset, count], output)
207 })
208 }))
209}
210
211fn mix() -> List {
212 list(float_scalars().flat_map(|scalar| {
213 scalar_or_vecn(scalar).flat_map(move |input| {
214 let scalar_ratio = ir::TypeInner::Scalar(scalar);
215 [
216 rule([input.clone(), input.clone(), input.clone()], input.clone()),
217 rule([input.clone(), input.clone(), scalar_ratio], input),
218 ]
219 })
220 }))
221}