1use alloc::{
7 format,
8 string::{String, ToString},
9 vec,
10 vec::Vec,
11};
12use core::iter;
13
14use arrayvec::ArrayVec;
15use half::f16;
16use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19 arena::{Arena, Handle, HandleVec, UniqueArena},
20 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
21 ScalarKind, Span, Type, TypeInner, UnaryOperator,
22};
23
24#[cfg(feature = "wgsl-in")]
25use crate::common::wgsl::TryToWgsl;
26
27macro_rules! with_dollar_sign {
33 ($($body:tt)*) => {
34 macro_rules! __with_dollar_sign { $($body)* }
35 __with_dollar_sign!($);
36 }
37}
38
39macro_rules! gen_component_wise_extractor {
40 (
41 $ident:ident -> $target:ident,
42 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
43 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
44 ) => {
45 #[derive(Debug)]
47 #[cfg_attr(test, derive(PartialEq))]
48 enum $target<const N: usize> {
49 $(
50 #[doc = concat!(
51 "Maps to [`Literal::",
52 stringify!($literal),
53 "`]",
54 )]
55 $mapping([$ty; N]),
56 )+
57 }
58
59 impl From<$target<1>> for Expression {
60 fn from(value: $target<1>) -> Self {
61 match value {
62 $(
63 $target::$mapping([value]) => {
64 Expression::Literal(Literal::$literal(value))
65 }
66 )+
67 }
68 }
69 }
70
71 #[doc = concat!(
72 "Attempts to evaluate multiple `exprs` as a combined [`",
73 stringify!($target),
74 "`] to pass to `handler`. ",
75 )]
76 fn $ident<const N: usize, const M: usize>(
83 eval: &mut ConstantEvaluator<'_>,
84 span: Span,
85 exprs: [Handle<Expression>; N],
86 handler: fn($target<N>) -> Result<$target<M>, ConstantEvaluatorError>,
87 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
88 where
89 $target<M>: Into<Expression>,
90 {
91 assert!(N > 0);
92 let err = ConstantEvaluatorError::InvalidMathArg;
93 let mut exprs = exprs.into_iter();
94
95 macro_rules! sanitize {
96 ($expr:expr) => {
97 eval.eval_zero_value_and_splat($expr, span)
98 .map(|expr| &eval.expressions[expr])
99 };
100 }
101
102 let new_expr: Result<Expression, ConstantEvaluatorError> = match sanitize!(exprs.next().unwrap())? {
103 $(
104 &Expression::Literal(Literal::$literal(x)) => {
105 let mut arr = ArrayVec::<_, N>::new();
106 arr.push(x);
107 for expr in exprs {
108 match sanitize!(expr)? {
109 &Expression::Literal(Literal::$literal(val)) => arr.push(val),
110 _ => return Err(err),
111 }
112 }
113 let comps = $target::$mapping(arr.into_inner().unwrap());
114 Ok(handler(comps)?.into())
115 },
116 )+
117 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
118 &TypeInner::Vector { size, scalar } => match scalar.kind {
119 $(ScalarKind::$scalar_kind)|* => {
120 let first_ty = ty;
121 let mut component_groups =
122 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
123 {
124 let mut inner = ArrayVec::new();
125 for item in crate::proc::flatten_compose(
126 first_ty,
127 components,
128 eval.expressions,
129 eval.types,
130 ) {
131 inner.push(item);
132 }
133 component_groups.push(inner);
134 }
135 for expr in exprs {
136 match sanitize!(expr)? {
137 &Expression::Compose { ty, ref components }
138 if &eval.types[ty].inner
139 == &eval.types[first_ty].inner =>
140 {
141 let mut inner = ArrayVec::new();
142 for item in crate::proc::flatten_compose(
143 ty,
144 components,
145 eval.expressions,
146 eval.types,
147 ) {
148 inner.push(item);
149 }
150 component_groups.push(inner);
151 }
152 _ => return Err(err),
153 }
154 }
155 let component_groups = component_groups.into_inner().unwrap();
156 let mut new_components =
157 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
158 for idx in 0..(size as u8).into() {
159 let mut group_arr = ArrayVec::<_, N>::new();
160 for cs in component_groups.iter() {
161 group_arr.push(
162 cs.get(idx).cloned().ok_or_else(|| err.clone())?,
163 );
164 }
165 let group = group_arr.into_inner().unwrap();
166 new_components.push($ident(
167 eval,
168 span,
169 group,
170 handler,
171 )?);
172 }
173 Ok(Expression::Compose {
174 ty: first_ty,
175 components: new_components.into_iter().collect(),
176 })
177 }
178 _ => return Err(err),
179 },
180 _ => return Err(err),
181 },
182 _ => return Err(err),
183 };
184 eval.register_evaluated_expr(new_expr?, span)
185 }
186
187 with_dollar_sign! {
188 ($d:tt) => {
189 #[allow(unused)]
190 #[doc = concat!(
191 "A convenience macro for using the same RHS for each [`",
192 stringify!($target),
193 "`] variant in a call to [`",
194 stringify!($ident),
195 "`].",
196 )]
197 macro_rules! $ident {
198 (
199 $eval:expr,
200 $span:expr,
201 [$d ($d expr:expr),+ $d (,)?],
202 |$d ($d arg:ident),+| $d tt:tt
203 ) => {
204 $ident($eval, $span, [$d ($d expr),+], |args| match args {
205 $(
206 $target::$mapping([$d ($d arg),+]) => {
207 let res = $d tt;
208 Result::map(res, $target::$mapping)
209 },
210 )+
211 })
212 };
213 }
214 };
215 }
216 };
217}
218
219gen_component_wise_extractor! {
220 component_wise_scalar -> Scalar,
221 literals: [
222 AbstractFloat => AbstractFloat: f64,
223 F32 => F32: f32,
224 F16 => F16: f16,
225 AbstractInt => AbstractInt: i64,
226 U32 => U32: u32,
227 I32 => I32: i32,
228 U64 => U64: u64,
229 I64 => I64: i64,
230 ],
231 scalar_kinds: [
232 Float,
233 AbstractFloat,
234 Sint,
235 Uint,
236 AbstractInt,
237 ],
238}
239
240gen_component_wise_extractor! {
241 component_wise_float -> Float,
242 literals: [
243 AbstractFloat => Abstract: f64,
244 F32 => F32: f32,
245 F16 => F16: f16,
246 ],
247 scalar_kinds: [
248 Float,
249 AbstractFloat,
250 ],
251}
252
253gen_component_wise_extractor! {
254 component_wise_concrete_int -> ConcreteInt,
255 literals: [
256 U32 => U32: u32,
257 I32 => I32: i32,
258 ],
259 scalar_kinds: [
260 Sint,
261 Uint,
262 ],
263}
264
265gen_component_wise_extractor! {
266 component_wise_signed -> Signed,
267 literals: [
268 AbstractFloat => AbstractFloat: f64,
269 AbstractInt => AbstractInt: i64,
270 F32 => F32: f32,
271 F16 => F16: f16,
272 I32 => I32: i32,
273 ],
274 scalar_kinds: [
275 Sint,
276 AbstractInt,
277 Float,
278 AbstractFloat,
279 ],
280}
281
282#[derive(Debug)]
284enum LiteralVector {
285 F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
286 F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
287 F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
288 U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
289 I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
290 U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
291 I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
292 Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
293 AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
294 AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
295}
296
297impl LiteralVector {
298 #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
299 fn len(&self) -> usize {
300 match *self {
301 LiteralVector::F64(ref v) => v.len(),
302 LiteralVector::F32(ref v) => v.len(),
303 LiteralVector::F16(ref v) => v.len(),
304 LiteralVector::U32(ref v) => v.len(),
305 LiteralVector::I32(ref v) => v.len(),
306 LiteralVector::U64(ref v) => v.len(),
307 LiteralVector::I64(ref v) => v.len(),
308 LiteralVector::Bool(ref v) => v.len(),
309 LiteralVector::AbstractInt(ref v) => v.len(),
310 LiteralVector::AbstractFloat(ref v) => v.len(),
311 }
312 }
313
314 fn from_literal(literal: Literal) -> Self {
316 fn arrayvec_of<T, const N: usize>(val: T) -> ArrayVec<T, N> {
317 let mut v = ArrayVec::new();
318 v.push(val);
319 v
320 }
321 match literal {
322 Literal::F64(e) => Self::F64(arrayvec_of(e)),
323 Literal::F32(e) => Self::F32(arrayvec_of(e)),
324 Literal::U32(e) => Self::U32(arrayvec_of(e)),
325 Literal::I32(e) => Self::I32(arrayvec_of(e)),
326 Literal::U64(e) => Self::U64(arrayvec_of(e)),
327 Literal::I64(e) => Self::I64(arrayvec_of(e)),
328 Literal::Bool(e) => Self::Bool(arrayvec_of(e)),
329 Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)),
330 Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)),
331 Literal::F16(e) => Self::F16(arrayvec_of(e)),
332 }
333 }
334
335 fn from_literal_vec(
340 components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
341 ) -> Result<Self, ConstantEvaluatorError> {
342 assert!(!components.is_empty());
343 macro_rules! compose_literals {
345 ($components:expr, $variant:ident, $self_variant:ident) => {{
346 let mut out = ArrayVec::new();
347 for l in &$components {
348 match l {
349 &Literal::$variant(v) => out.push(v),
350 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
351 }
352 }
353 Self::$self_variant(out)
354 }};
355 }
356 Ok(match components[0] {
357 Literal::I32(_) => compose_literals!(components, I32, I32),
358 Literal::U32(_) => compose_literals!(components, U32, U32),
359 Literal::I64(_) => compose_literals!(components, I64, I64),
360 Literal::U64(_) => compose_literals!(components, U64, U64),
361 Literal::F32(_) => compose_literals!(components, F32, F32),
362 Literal::F64(_) => compose_literals!(components, F64, F64),
363 Literal::Bool(_) => compose_literals!(components, Bool, Bool),
364 Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt),
365 Literal::AbstractFloat(_) => {
366 compose_literals!(components, AbstractFloat, AbstractFloat)
367 }
368 Literal::F16(_) => compose_literals!(components, F16, F16),
369 })
370 }
371
372 #[allow(dead_code)]
373 fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
375 macro_rules! decompose_literals {
376 ($v:expr, $variant:ident) => {{
377 let mut out = ArrayVec::new();
378 for e in $v {
379 out.push(Literal::$variant(*e));
380 }
381 out
382 }};
383 }
384 match *self {
385 LiteralVector::F64(ref v) => decompose_literals!(v, F64),
386 LiteralVector::F32(ref v) => decompose_literals!(v, F32),
387 LiteralVector::F16(ref v) => decompose_literals!(v, F16),
388 LiteralVector::U32(ref v) => decompose_literals!(v, U32),
389 LiteralVector::I32(ref v) => decompose_literals!(v, I32),
390 LiteralVector::U64(ref v) => decompose_literals!(v, U64),
391 LiteralVector::I64(ref v) => decompose_literals!(v, I64),
392 LiteralVector::Bool(ref v) => decompose_literals!(v, Bool),
393 LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt),
394 LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat),
395 }
396 }
397
398 #[allow(dead_code)]
399 fn register_as_evaluated_expr(
401 &self,
402 eval: &mut ConstantEvaluator<'_>,
403 span: Span,
404 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
405 let lit_vec = self.to_literal_vec();
406 assert!(!lit_vec.is_empty());
407 let expr = if lit_vec.len() == 1 {
408 Expression::Literal(lit_vec[0])
409 } else {
410 Expression::Compose {
411 ty: eval.types.insert(
412 Type {
413 name: None,
414 inner: TypeInner::Vector {
415 size: match lit_vec.len() {
416 2 => crate::VectorSize::Bi,
417 3 => crate::VectorSize::Tri,
418 4 => crate::VectorSize::Quad,
419 _ => unreachable!(),
420 },
421 scalar: lit_vec[0].scalar(),
422 },
423 },
424 Span::UNDEFINED,
425 ),
426 components: lit_vec
427 .iter()
428 .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
429 .collect::<Result<_, _>>()?,
430 }
431 };
432 eval.register_evaluated_expr(expr, span)
433 }
434}
435
436macro_rules! match_literal_vector {
461 (match $lit_vec:expr => $out:ident {
462 $(
463 $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
464 ),+
465 $(,)?
466 }) => {
467 match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
468 };
469
470 (@inner_start
471 $lit_vec:expr;
472 $out:ident;
473 [$($ty:ident),+];
474 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
475 ) => {
476 match_literal_vector!(@inner
477 $lit_vec;
478 $out;
479 [$($ty),+];
480 [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
481 )
482 };
483
484 (@inner
485 $lit_vec:expr;
486 $out:ident;
487 [$ty:ident $(, $ty1:ident)*];
488 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
489 [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
490 ) => {
491 match_literal_vector!(@inner
492 $ty;
493 $lit_vec;
494 $out;
495 [$($ty1),*];
496 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
497 [$({ $($var),+ ; $($ret)? ; $body }),+]
498 )
499 };
500 (@inner
501 Integer;
502 $lit_vec:expr;
503 $out:ident;
504 [$($ty:ident),*];
505 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
506 [
507 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
508 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
509 ]
510 ) => {
511 match_literal_vector!(@inner
512 $lit_vec;
513 $out;
514 [U32, I32, U64, I64, AbstractInt $(, $ty)*];
515 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
516 [
517 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
523 ]
524 )
525 };
526 (@inner
527 Float;
528 $lit_vec:expr;
529 $out:ident;
530 [$($ty:ident),*];
531 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
532 [
533 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
534 $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
535 ]
536 ) => {
537 match_literal_vector!(@inner
538 $lit_vec;
539 $out;
540 [F16, F32, F64, AbstractFloat $(, $ty)*];
541 [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542 [
543 { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body }, { $($var),+ ; $($ret)? ; $body } $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
548 ]
549 )
550 };
551 (@inner
552 $ty:ident;
553 $lit_vec:expr;
554 $out:ident;
555 [$ty1:ident $(,$ty2:ident)*];
556 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
557 { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
558 $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
559 ]
560 ) => {
561 match_literal_vector!(@inner
562 $ty1;
563 $lit_vec;
564 $out;
565 [$($ty2),*];
566 [
567 $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
568 { $ty; $($var),+ ; $($ret)? ; $body }
569 ] <>
570 [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
571
572 )
573 };
574 (@inner
575 $ty:ident;
576 $lit_vec:expr;
577 $out:ident;
578 [];
579 [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
580 [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
581 ) => {
582 match_literal_vector!(@inner_finish
583 $lit_vec;
584 $out;
585 [
586 $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
587 { $ty; $($var),+ ; $($ret)? ; $body }
588 ]
589 )
590 };
591 (@inner_finish
592 $lit_vec:expr;
593 $out:ident;
594 [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
595 ) => {
596 match $lit_vec {
597 $(
598 #[allow(unused_parens)]
599 ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
600 )+
601 _ => Err(ConstantEvaluatorError::InvalidMathArg),
602 }
603 };
604 (@expand_ret $out:ident; $ty:ident; $body:expr) => {
605 $out::$ty($body)
606 };
607 (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
608 $out::$ret($body)
609 };
610}
611
612#[derive(Debug)]
613enum Behavior<'a> {
614 Wgsl(WgslRestrictions<'a>),
615 Glsl(GlslRestrictions<'a>),
616}
617
618impl Behavior<'_> {
619 const fn has_runtime_restrictions(&self) -> bool {
621 matches!(
622 self,
623 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
624 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
625 )
626 }
627}
628
629#[derive(Debug)]
647pub struct ConstantEvaluator<'a> {
648 behavior: Behavior<'a>,
650
651 types: &'a mut UniqueArena<Type>,
658
659 constants: &'a Arena<Constant>,
661
662 overrides: &'a Arena<Override>,
664
665 expressions: &'a mut Arena<Expression>,
667
668 expression_kind_tracker: &'a mut ExpressionKindTracker,
670
671 layouter: &'a mut crate::proc::Layouter,
672}
673
674#[derive(Debug)]
675enum WgslRestrictions<'a> {
676 Const(Option<FunctionLocalData<'a>>),
678 Override,
681 Runtime(FunctionLocalData<'a>),
685}
686
687#[derive(Debug)]
688enum GlslRestrictions<'a> {
689 Const,
691 Runtime(FunctionLocalData<'a>),
695}
696
697#[derive(Debug)]
698struct FunctionLocalData<'a> {
699 global_expressions: &'a Arena<Expression>,
701 emitter: &'a mut super::Emitter,
702 block: &'a mut crate::Block,
703}
704
705#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
706pub enum ExpressionKind {
707 Const,
708 Override,
709 Runtime,
710}
711
712#[derive(Debug)]
713pub struct ExpressionKindTracker {
714 inner: HandleVec<Expression, ExpressionKind>,
715}
716
717impl ExpressionKindTracker {
718 pub const fn new() -> Self {
719 Self {
720 inner: HandleVec::new(),
721 }
722 }
723
724 pub fn force_non_const(&mut self, value: Handle<Expression>) {
726 self.inner[value] = ExpressionKind::Runtime;
727 }
728
729 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
730 self.inner.insert(value, expr_type);
731 }
732
733 pub fn is_const(&self, h: Handle<Expression>) -> bool {
734 matches!(self.type_of(h), ExpressionKind::Const)
735 }
736
737 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
738 matches!(
739 self.type_of(h),
740 ExpressionKind::Const | ExpressionKind::Override
741 )
742 }
743
744 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
745 self.inner[value]
746 }
747
748 pub fn from_arena(arena: &Arena<Expression>) -> Self {
749 let mut tracker = Self {
750 inner: HandleVec::with_capacity(arena.len()),
751 };
752 for (handle, expr) in arena.iter() {
753 tracker
754 .inner
755 .insert(handle, tracker.type_of_with_expr(expr));
756 }
757 tracker
758 }
759
760 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
761 match *expr {
762 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
763 ExpressionKind::Const
764 }
765 Expression::Override(_) => ExpressionKind::Override,
766 Expression::Compose { ref components, .. } => {
767 let mut expr_type = ExpressionKind::Const;
768 for component in components {
769 expr_type = expr_type.max(self.type_of(*component))
770 }
771 expr_type
772 }
773 Expression::Splat { value, .. } => self.type_of(value),
774 Expression::AccessIndex { base, .. } => self.type_of(base),
775 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
776 Expression::Swizzle { vector, .. } => self.type_of(vector),
777 Expression::Unary { expr, .. } => self.type_of(expr),
778 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
779 Expression::Math {
780 arg,
781 arg1,
782 arg2,
783 arg3,
784 ..
785 } => self
786 .type_of(arg)
787 .max(
788 arg1.map(|arg| self.type_of(arg))
789 .unwrap_or(ExpressionKind::Const),
790 )
791 .max(
792 arg2.map(|arg| self.type_of(arg))
793 .unwrap_or(ExpressionKind::Const),
794 )
795 .max(
796 arg3.map(|arg| self.type_of(arg))
797 .unwrap_or(ExpressionKind::Const),
798 ),
799 Expression::As { expr, .. } => self.type_of(expr),
800 Expression::Select {
801 condition,
802 accept,
803 reject,
804 } => self
805 .type_of(condition)
806 .max(self.type_of(accept))
807 .max(self.type_of(reject)),
808 Expression::Relational { argument, .. } => self.type_of(argument),
809 Expression::ArrayLength(expr) => self.type_of(expr),
810 _ => ExpressionKind::Runtime,
811 }
812 }
813}
814
815#[derive(Clone, Debug, thiserror::Error)]
816#[cfg_attr(test, derive(PartialEq))]
817pub enum ConstantEvaluatorError {
818 #[error("Constants cannot access function arguments")]
819 FunctionArg,
820 #[error("Constants cannot access global variables")]
821 GlobalVariable,
822 #[error("Constants cannot access local variables")]
823 LocalVariable,
824 #[error("Cannot get the array length of a non array type")]
825 InvalidArrayLengthArg,
826 #[error("Constants cannot get the array length of a dynamically sized array")]
827 ArrayLengthDynamic,
828 #[error("Cannot call arrayLength on array sized by override-expression")]
829 ArrayLengthOverridden,
830 #[error("Constants cannot call functions")]
831 Call,
832 #[error("Constants don't support workGroupUniformLoad")]
833 WorkGroupUniformLoadResult,
834 #[error("Constants don't support atomic functions")]
835 Atomic,
836 #[error("Constants don't support derivative functions")]
837 Derivative,
838 #[error("Constants don't support load expressions")]
839 Load,
840 #[error("Constants don't support image expressions")]
841 ImageExpression,
842 #[error("Constants don't support ray query expressions")]
843 RayQueryExpression,
844 #[error("Constants don't support subgroup expressions")]
845 SubgroupExpression,
846 #[error("Cannot access the type")]
847 InvalidAccessBase,
848 #[error("Cannot access at the index")]
849 InvalidAccessIndex,
850 #[error("Cannot access with index of type")]
851 InvalidAccessIndexTy,
852 #[error("Constants don't support array length expressions")]
853 ArrayLength,
854 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
855 InvalidCastArg { from: String, to: String },
856 #[error("Cannot apply the unary op to the argument")]
857 InvalidUnaryOpArg,
858 #[error("Cannot apply the binary op to the arguments")]
859 InvalidBinaryOpArgs,
860 #[error("Cannot apply math function to type")]
861 InvalidMathArg,
862 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
863 InvalidMathArgCount(crate::MathFunction, usize, usize),
864 #[error("{0} built-in function argument is out of valid range")]
865 InvalidMathArgValue(String),
866 #[error("Cannot apply relational function to type")]
867 InvalidRelationalArg(RelationalFunction),
868 #[error("value of `low` is greater than `high` for clamp built-in function")]
869 InvalidClamp,
870 #[error("Constructor expects {expected} components, found {actual}")]
871 InvalidVectorComposeLength { expected: usize, actual: usize },
872 #[error("Constructor must only contain vector or scalar arguments")]
873 InvalidVectorComposeComponent,
874 #[error("Splat is defined only on scalar values")]
875 SplatScalarOnly,
876 #[error("Can only swizzle vector constants")]
877 SwizzleVectorOnly,
878 #[error("swizzle component not present in source expression")]
879 SwizzleOutOfBounds,
880 #[error("Type is not constructible")]
881 TypeNotConstructible,
882 #[error("Subexpression(s) are not constant")]
883 SubexpressionsAreNotConstant,
884 #[error("Not implemented as constant expression: {0}")]
885 NotImplemented(String),
886 #[error("{0} operation overflowed")]
887 Overflow(String),
888 #[error(
889 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
890 )]
891 AutomaticConversionLossy {
892 value: String,
893 to_type: &'static str,
894 },
895 #[error("Division by zero")]
896 DivisionByZero,
897 #[error("Remainder by zero")]
898 RemainderByZero,
899 #[error("RHS of shift operation is greater than or equal to 32")]
900 ShiftedMoreThan32Bits,
901 #[error(transparent)]
902 Literal(#[from] crate::valid::LiteralError),
903 #[error("Can't use pipeline-overridable constants in const-expressions")]
904 Override,
905 #[error("Unexpected runtime-expression")]
906 RuntimeExpr,
907 #[error("Unexpected override-expression")]
908 OverrideExpr,
909 #[error("Expected boolean expression for condition argument of `select`, got something else")]
910 SelectScalarConditionNotABool,
911 #[error(
912 "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
913 reject,
914 accept
915 )]
916 SelectVecRejectAcceptSizeMismatch {
917 reject: crate::VectorSize,
918 accept: crate::VectorSize,
919 },
920 #[error("Expected boolean vector for condition arg., got something else")]
921 SelectConditionNotAVecBool,
922 #[error(
923 "Expected same number of vector components between condition, accept, and reject args., got something else",
924 )]
925 SelectConditionVecSizeMismatch,
926 #[error(
927 "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
928 )]
929 SelectAcceptRejectTypeMismatch,
930 #[error("Cooperative operations can't be constant")]
931 CooperativeOperation,
932}
933
934impl<'a> ConstantEvaluator<'a> {
935 pub const fn for_wgsl_module(
940 module: &'a mut crate::Module,
941 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
942 layouter: &'a mut crate::proc::Layouter,
943 in_override_ctx: bool,
944 ) -> Self {
945 Self::for_module(
946 Behavior::Wgsl(if in_override_ctx {
947 WgslRestrictions::Override
948 } else {
949 WgslRestrictions::Const(None)
950 }),
951 module,
952 global_expression_kind_tracker,
953 layouter,
954 )
955 }
956
957 pub const fn for_glsl_module(
962 module: &'a mut crate::Module,
963 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
964 layouter: &'a mut crate::proc::Layouter,
965 ) -> Self {
966 Self::for_module(
967 Behavior::Glsl(GlslRestrictions::Const),
968 module,
969 global_expression_kind_tracker,
970 layouter,
971 )
972 }
973
974 const fn for_module(
975 behavior: Behavior<'a>,
976 module: &'a mut crate::Module,
977 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
978 layouter: &'a mut crate::proc::Layouter,
979 ) -> Self {
980 Self {
981 behavior,
982 types: &mut module.types,
983 constants: &module.constants,
984 overrides: &module.overrides,
985 expressions: &mut module.global_expressions,
986 expression_kind_tracker: global_expression_kind_tracker,
987 layouter,
988 }
989 }
990
991 pub const fn for_wgsl_function(
996 module: &'a mut crate::Module,
997 expressions: &'a mut Arena<Expression>,
998 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
999 layouter: &'a mut crate::proc::Layouter,
1000 emitter: &'a mut super::Emitter,
1001 block: &'a mut crate::Block,
1002 is_const: bool,
1003 ) -> Self {
1004 let local_data = FunctionLocalData {
1005 global_expressions: &module.global_expressions,
1006 emitter,
1007 block,
1008 };
1009 Self {
1010 behavior: Behavior::Wgsl(if is_const {
1011 WgslRestrictions::Const(Some(local_data))
1012 } else {
1013 WgslRestrictions::Runtime(local_data)
1014 }),
1015 types: &mut module.types,
1016 constants: &module.constants,
1017 overrides: &module.overrides,
1018 expressions,
1019 expression_kind_tracker: local_expression_kind_tracker,
1020 layouter,
1021 }
1022 }
1023
1024 pub const fn for_glsl_function(
1029 module: &'a mut crate::Module,
1030 expressions: &'a mut Arena<Expression>,
1031 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1032 layouter: &'a mut crate::proc::Layouter,
1033 emitter: &'a mut super::Emitter,
1034 block: &'a mut crate::Block,
1035 ) -> Self {
1036 Self {
1037 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1038 global_expressions: &module.global_expressions,
1039 emitter,
1040 block,
1041 })),
1042 types: &mut module.types,
1043 constants: &module.constants,
1044 overrides: &module.overrides,
1045 expressions,
1046 expression_kind_tracker: local_expression_kind_tracker,
1047 layouter,
1048 }
1049 }
1050
1051 pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1052 crate::proc::GlobalCtx {
1053 types: self.types,
1054 constants: self.constants,
1055 overrides: self.overrides,
1056 global_expressions: match self.function_local_data() {
1057 Some(data) => data.global_expressions,
1058 None => self.expressions,
1059 },
1060 }
1061 }
1062
1063 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1064 if !self.expression_kind_tracker.is_const(expr) {
1065 log::debug!("check: SubexpressionsAreNotConstant");
1066 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1067 }
1068 Ok(())
1069 }
1070
1071 fn check_and_get(
1072 &mut self,
1073 expr: Handle<Expression>,
1074 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1075 match self.expressions[expr] {
1076 Expression::Constant(c) => {
1077 if let Some(function_local_data) = self.function_local_data() {
1080 self.copy_from(
1082 self.constants[c].init,
1083 function_local_data.global_expressions,
1084 )
1085 } else {
1086 Ok(self.constants[c].init)
1088 }
1089 }
1090 _ => {
1091 self.check(expr)?;
1092 Ok(expr)
1093 }
1094 }
1095 }
1096
1097 pub fn try_eval_and_append(
1121 &mut self,
1122 expr: Expression,
1123 span: Span,
1124 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1125 match self.expression_kind_tracker.type_of_with_expr(&expr) {
1126 ExpressionKind::Const => {
1127 let eval_result = self.try_eval_and_append_impl(&expr, span);
1128 if self.behavior.has_runtime_restrictions()
1133 && matches!(
1134 eval_result,
1135 Err(ConstantEvaluatorError::NotImplemented(_)
1136 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1137 )
1138 {
1139 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1140 } else {
1141 eval_result
1142 }
1143 }
1144 ExpressionKind::Override => match self.behavior {
1145 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1146 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1147 }
1148 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1149 Err(ConstantEvaluatorError::OverrideExpr)
1150 }
1151
1152 Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1154 Ok(self.append_expr(expr, span, ExpressionKind::Override))
1155 }
1156 Behavior::Glsl(GlslRestrictions::Const) => {
1157 Err(ConstantEvaluatorError::OverrideExpr)
1158 }
1159 },
1160 ExpressionKind::Runtime => {
1161 if self.behavior.has_runtime_restrictions() {
1162 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1163 } else {
1164 Err(ConstantEvaluatorError::RuntimeExpr)
1165 }
1166 }
1167 }
1168 }
1169
1170 const fn is_global_arena(&self) -> bool {
1172 matches!(
1173 self.behavior,
1174 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1175 | Behavior::Glsl(GlslRestrictions::Const)
1176 )
1177 }
1178
1179 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1180 match self.behavior {
1181 Behavior::Wgsl(
1182 WgslRestrictions::Runtime(ref function_local_data)
1183 | WgslRestrictions::Const(Some(ref function_local_data)),
1184 )
1185 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1186 Some(function_local_data)
1187 }
1188 _ => None,
1189 }
1190 }
1191
1192 fn try_eval_and_append_impl(
1193 &mut self,
1194 expr: &Expression,
1195 span: Span,
1196 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1197 log::trace!("try_eval_and_append: {expr:?}");
1198 match *expr {
1199 Expression::Constant(c) if self.is_global_arena() => {
1200 Ok(self.constants[c].init)
1203 }
1204 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1205 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1206 self.register_evaluated_expr(expr.clone(), span)
1207 }
1208 Expression::Compose { ty, ref components } => {
1209 let components = components
1210 .iter()
1211 .map(|component| self.check_and_get(*component))
1212 .collect::<Result<Vec<_>, _>>()?;
1213 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1214 }
1215 Expression::Splat { size, value } => {
1216 let value = self.check_and_get(value)?;
1217 self.register_evaluated_expr(Expression::Splat { size, value }, span)
1218 }
1219 Expression::AccessIndex { base, index } => {
1220 let base = self.check_and_get(base)?;
1221
1222 self.access(base, index as usize, span)
1223 }
1224 Expression::Access { base, index } => {
1225 let base = self.check_and_get(base)?;
1226 let index = self.check_and_get(index)?;
1227
1228 let index_val: u32 = self
1229 .to_ctx()
1230 .get_const_val_from(index, self.expressions)
1231 .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1232 self.access(base, index_val as usize, span)
1233 }
1234 Expression::Swizzle {
1235 size,
1236 vector,
1237 pattern,
1238 } => {
1239 let vector = self.check_and_get(vector)?;
1240
1241 self.swizzle(size, span, vector, pattern)
1242 }
1243 Expression::Unary { expr, op } => {
1244 let expr = self.check_and_get(expr)?;
1245
1246 self.unary_op(op, expr, span)
1247 }
1248 Expression::Binary { left, right, op } => {
1249 let left = self.check_and_get(left)?;
1250 let right = self.check_and_get(right)?;
1251
1252 self.binary_op(op, left, right, span)
1253 }
1254 Expression::Math {
1255 fun,
1256 arg,
1257 arg1,
1258 arg2,
1259 arg3,
1260 } => {
1261 let arg = self.check_and_get(arg)?;
1262 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1263 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1264 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1265
1266 self.math(arg, arg1, arg2, arg3, fun, span)
1267 }
1268 Expression::As {
1269 convert,
1270 expr,
1271 kind,
1272 } => {
1273 let expr = self.check_and_get(expr)?;
1274
1275 match convert {
1276 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1277 None => Err(ConstantEvaluatorError::NotImplemented(
1278 "bitcast built-in function".into(),
1279 )),
1280 }
1281 }
1282 Expression::Select {
1283 reject,
1284 accept,
1285 condition,
1286 } => {
1287 let mut arg = |expr| self.check_and_get(expr);
1288
1289 let reject = arg(reject)?;
1290 let accept = arg(accept)?;
1291 let condition = arg(condition)?;
1292
1293 self.select(reject, accept, condition, span)
1294 }
1295 Expression::Relational { fun, argument } => {
1296 let argument = self.check_and_get(argument)?;
1297 self.relational(fun, argument, span)
1298 }
1299 Expression::ArrayLength(expr) => match self.behavior {
1300 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1301 Behavior::Glsl(_) => {
1302 let expr = self.check_and_get(expr)?;
1303 self.array_length(expr, span)
1304 }
1305 },
1306 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1307 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1308 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1309 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1310 Expression::WorkGroupUniformLoadResult { .. } => {
1311 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1312 }
1313 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1314 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1315 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1316 Expression::ImageSample { .. }
1317 | Expression::ImageLoad { .. }
1318 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1319 Expression::RayQueryProceedResult
1320 | Expression::RayQueryGetIntersection { .. }
1321 | Expression::RayQueryVertexPositions { .. } => {
1322 Err(ConstantEvaluatorError::RayQueryExpression)
1323 }
1324 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1325 Expression::SubgroupOperationResult { .. } => {
1326 Err(ConstantEvaluatorError::SubgroupExpression)
1327 }
1328 Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1329 Err(ConstantEvaluatorError::CooperativeOperation)
1330 }
1331 }
1332 }
1333
1334 fn splat(
1347 &mut self,
1348 value: Handle<Expression>,
1349 size: crate::VectorSize,
1350 span: Span,
1351 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1352 match self.expressions[value] {
1353 Expression::Literal(literal) => {
1354 let scalar = literal.scalar();
1355 let ty = self.types.insert(
1356 Type {
1357 name: None,
1358 inner: TypeInner::Vector { size, scalar },
1359 },
1360 span,
1361 );
1362 let expr = Expression::Compose {
1363 ty,
1364 components: vec![value; size as usize],
1365 };
1366 self.register_evaluated_expr(expr, span)
1367 }
1368 Expression::ZeroValue(ty) => {
1369 let inner = match self.types[ty].inner {
1370 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1371 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1372 };
1373 let res_ty = self.types.insert(Type { name: None, inner }, span);
1374 let expr = Expression::ZeroValue(res_ty);
1375 self.register_evaluated_expr(expr, span)
1376 }
1377 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1378 }
1379 }
1380
1381 fn swizzle(
1382 &mut self,
1383 size: crate::VectorSize,
1384 span: Span,
1385 src_constant: Handle<Expression>,
1386 pattern: [crate::SwizzleComponent; 4],
1387 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1388 let mut get_dst_ty = |ty| match self.types[ty].inner {
1389 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1390 Type {
1391 name: None,
1392 inner: TypeInner::Vector { size, scalar },
1393 },
1394 span,
1395 )),
1396 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1397 };
1398
1399 match self.expressions[src_constant] {
1400 Expression::ZeroValue(ty) => {
1401 let dst_ty = get_dst_ty(ty)?;
1402 let expr = Expression::ZeroValue(dst_ty);
1403 self.register_evaluated_expr(expr, span)
1404 }
1405 Expression::Splat { value, .. } => {
1406 let expr = Expression::Splat { size, value };
1407 self.register_evaluated_expr(expr, span)
1408 }
1409 Expression::Compose { ty, ref components } => {
1410 let dst_ty = get_dst_ty(ty)?;
1411
1412 let mut flattened = [src_constant; 4]; let len =
1414 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1415 .zip(flattened.iter_mut())
1416 .map(|(component, elt)| *elt = component)
1417 .count();
1418 let flattened = &flattened[..len];
1419
1420 let swizzled_components = pattern[..size as usize]
1421 .iter()
1422 .map(|&sc| {
1423 let sc = sc as usize;
1424 if let Some(elt) = flattened.get(sc) {
1425 Ok(*elt)
1426 } else {
1427 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1428 }
1429 })
1430 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1431 let expr = Expression::Compose {
1432 ty: dst_ty,
1433 components: swizzled_components,
1434 };
1435 self.register_evaluated_expr(expr, span)
1436 }
1437 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1438 }
1439 }
1440
1441 fn math(
1442 &mut self,
1443 arg: Handle<Expression>,
1444 arg1: Option<Handle<Expression>>,
1445 arg2: Option<Handle<Expression>>,
1446 arg3: Option<Handle<Expression>>,
1447 fun: crate::MathFunction,
1448 span: Span,
1449 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1450 let expected = fun.argument_count();
1451 let given = Some(arg)
1452 .into_iter()
1453 .chain(arg1)
1454 .chain(arg2)
1455 .chain(arg3)
1456 .count();
1457 if expected != given {
1458 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1459 fun, expected, given,
1460 ));
1461 }
1462
1463 match fun {
1465 crate::MathFunction::Abs => {
1467 component_wise_scalar(self, span, [arg], |args| match args {
1468 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1469 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1470 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1471 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1472 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1473 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1475 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1476 })
1477 }
1478 crate::MathFunction::Min => {
1479 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1480 Ok([e1.min(e2)])
1481 })
1482 }
1483 crate::MathFunction::Max => {
1484 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1485 Ok([e1.max(e2)])
1486 })
1487 }
1488 crate::MathFunction::Clamp => {
1489 component_wise_scalar!(
1490 self,
1491 span,
1492 [arg, arg1.unwrap(), arg2.unwrap()],
1493 |e, low, high| {
1494 if low > high {
1495 Err(ConstantEvaluatorError::InvalidClamp)
1496 } else {
1497 Ok([e.clamp(low, high)])
1498 }
1499 }
1500 )
1501 }
1502 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1503 Float::F16([e]) => Ok(Float::F16(
1504 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1505 )),
1506 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1507 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1508 }),
1509
1510 crate::MathFunction::Cos => {
1512 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1513 }
1514 crate::MathFunction::Cosh => {
1515 component_wise_float!(self, span, [arg], |e| {
1516 let result = e.cosh();
1517 if result.is_finite() {
1518 Ok([result])
1519 } else {
1520 Err(ConstantEvaluatorError::Overflow("cosh".into()))
1521 }
1522 })
1523 }
1524 crate::MathFunction::Sin => {
1525 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1526 }
1527 crate::MathFunction::Sinh => {
1528 component_wise_float!(self, span, [arg], |e| {
1529 let result = e.sinh();
1530 if result.is_finite() {
1531 Ok([result])
1532 } else {
1533 Err(ConstantEvaluatorError::Overflow("sinh".into()))
1534 }
1535 })
1536 }
1537 crate::MathFunction::Tan => {
1538 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1539 }
1540 crate::MathFunction::Tanh => {
1541 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1542 }
1543 crate::MathFunction::Acos => {
1544 component_wise_float!(self, span, [arg], |e| {
1545 if e.abs() <= One::one() {
1546 Ok([e.acos()])
1547 } else {
1548 Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into()))
1549 }
1550 })
1551 }
1552 crate::MathFunction::Asin => {
1553 component_wise_float!(self, span, [arg], |e| {
1554 if e.abs() <= One::one() {
1555 Ok([e.asin()])
1556 } else {
1557 Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into()))
1558 }
1559 })
1560 }
1561 crate::MathFunction::Atan => {
1562 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1563 }
1564 crate::MathFunction::Atan2 => {
1565 component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1566 Ok([y.atan2(x)])
1567 })
1568 }
1569 crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e {
1570 Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])),
1571 Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])),
1572 Float::F16([e]) => Ok(Float::F16([e.asinh()])),
1573 }),
1574 crate::MathFunction::Acosh => {
1575 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1576 }
1577 crate::MathFunction::Atanh => {
1578 component_wise_float!(self, span, [arg], |e| {
1579 if e.abs() < One::one() {
1580 Ok([e.atanh()])
1581 } else {
1582 Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into()))
1583 }
1584 })
1585 }
1586 crate::MathFunction::Radians => {
1587 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1588 }
1589 crate::MathFunction::Degrees => {
1590 component_wise_float!(self, span, [arg], |e| {
1591 let result = e.to_degrees();
1592 if result.is_finite() {
1593 Ok([result])
1594 } else {
1595 Err(ConstantEvaluatorError::Overflow("degrees".into()))
1596 }
1597 })
1598 }
1599
1600 crate::MathFunction::Ceil => {
1602 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1603 }
1604 crate::MathFunction::Floor => {
1605 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1606 }
1607 crate::MathFunction::Round => {
1608 component_wise_float(self, span, [arg], |e| match e {
1609 Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1610 Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1611 Float::F16([e]) => {
1612 fn round_ties_even(x: f64) -> f64 {
1620 let i = x as i64;
1621 let f = (x - i as f64).abs();
1622 if f == 0.5 {
1623 if i & 1 == 1 {
1624 (x.abs() + 0.5).copysign(x)
1626 } else {
1627 (x.abs() - 0.5).copysign(x)
1628 }
1629 } else {
1630 x.round()
1631 }
1632 }
1633
1634 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1635 }
1636 })
1637 }
1638 crate::MathFunction::Fract => {
1639 component_wise_float!(self, span, [arg], |e| {
1640 Ok([e - e.floor()])
1643 })
1644 }
1645 crate::MathFunction::Trunc => {
1646 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1647 }
1648
1649 crate::MathFunction::Exp => {
1651 component_wise_float!(self, span, [arg], |e| {
1652 let result = e.exp();
1653 if result.is_finite() {
1654 Ok([result])
1655 } else {
1656 Err(ConstantEvaluatorError::Overflow("exp".into()))
1657 }
1658 })
1659 }
1660 crate::MathFunction::Exp2 => {
1661 component_wise_float!(self, span, [arg], |e| {
1662 let result = e.exp2();
1663 if result.is_finite() {
1664 Ok([result])
1665 } else {
1666 Err(ConstantEvaluatorError::Overflow("exp2".into()))
1667 }
1668 })
1669 }
1670 crate::MathFunction::Log => {
1671 component_wise_float!(self, span, [arg], |e| {
1672 if e > Zero::zero() {
1673 Ok([e.ln()])
1674 } else {
1675 Err(ConstantEvaluatorError::InvalidMathArgValue("log".into()))
1676 }
1677 })
1678 }
1679 crate::MathFunction::Log2 => {
1680 component_wise_float!(self, span, [arg], |e| {
1681 if e > Zero::zero() {
1682 Ok([e.log2()])
1683 } else {
1684 Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into()))
1685 }
1686 })
1687 }
1688 crate::MathFunction::Pow => {
1689 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1690 Ok([e1.powf(e2)])
1691 })
1692 }
1693
1694 crate::MathFunction::Sign => {
1696 component_wise_signed!(self, span, [arg], |e| {
1697 Ok([if e.is_zero() {
1698 Zero::zero()
1699 } else {
1700 e.signum()
1701 }])
1702 })
1703 }
1704 crate::MathFunction::Fma => {
1705 component_wise_float!(
1706 self,
1707 span,
1708 [arg, arg1.unwrap(), arg2.unwrap()],
1709 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1710 )
1711 }
1712 crate::MathFunction::Step => {
1713 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1714 Float::Abstract([edge, x]) => {
1715 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1716 }
1717 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1718 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1719 f16::one()
1720 } else {
1721 f16::zero()
1722 }])),
1723 })
1724 }
1725 crate::MathFunction::Sqrt => {
1726 component_wise_float!(self, span, [arg], |e| {
1727 if e >= Zero::zero() {
1728 Ok([e.sqrt()])
1729 } else {
1730 Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into()))
1731 }
1732 })
1733 }
1734 crate::MathFunction::InverseSqrt => {
1735 component_wise_float(self, span, [arg], |e| match e {
1736 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1737 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1738 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1739 })
1740 }
1741
1742 crate::MathFunction::CountTrailingZeros => {
1744 component_wise_concrete_int!(self, span, [arg], |e| {
1745 #[allow(clippy::useless_conversion)]
1746 Ok([e
1747 .trailing_zeros()
1748 .try_into()
1749 .expect("bit count overflowed 32 bits, somehow!?")])
1750 })
1751 }
1752 crate::MathFunction::CountLeadingZeros => {
1753 component_wise_concrete_int!(self, span, [arg], |e| {
1754 #[allow(clippy::useless_conversion)]
1755 Ok([e
1756 .leading_zeros()
1757 .try_into()
1758 .expect("bit count overflowed 32 bits, somehow!?")])
1759 })
1760 }
1761 crate::MathFunction::CountOneBits => {
1762 component_wise_concrete_int!(self, span, [arg], |e| {
1763 #[allow(clippy::useless_conversion)]
1764 Ok([e
1765 .count_ones()
1766 .try_into()
1767 .expect("bit count overflowed 32 bits, somehow!?")])
1768 })
1769 }
1770 crate::MathFunction::ReverseBits => {
1771 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1772 }
1773 crate::MathFunction::FirstTrailingBit => {
1774 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1775 }
1776 crate::MathFunction::FirstLeadingBit => {
1777 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1778 }
1779
1780 crate::MathFunction::Dot4I8Packed => {
1782 self.packed_dot_product(arg, arg1.unwrap(), span, true)
1783 }
1784 crate::MathFunction::Dot4U8Packed => {
1785 self.packed_dot_product(arg, arg1.unwrap(), span, false)
1786 }
1787 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1788 crate::MathFunction::Dot => {
1789 let e1 = self.extract_vec(arg, false)?;
1791 let e2 = self.extract_vec(arg1.unwrap(), false)?;
1792 if e1.len() != e2.len() {
1793 return Err(ConstantEvaluatorError::InvalidMathArg);
1794 }
1795
1796 fn float_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1797 where
1798 P: num_traits::Float,
1799 {
1800 let result = a
1801 .iter()
1802 .zip(b.iter())
1803 .map(|(&aa, &bb)| aa * bb)
1804 .fold(P::zero(), |acc, x| acc + x);
1805 if result.is_finite() {
1806 Ok(result)
1807 } else {
1808 Err(ConstantEvaluatorError::Overflow("in dot built-in".into()))
1809 }
1810 }
1811
1812 fn int_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1813 where
1814 P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1815 {
1816 a.iter()
1817 .zip(b.iter())
1818 .map(|(&aa, bb)| aa.checked_mul(bb))
1819 .try_fold(P::zero(), |acc, x| {
1820 if let Some(x) = x {
1821 acc.checked_add(&x)
1822 } else {
1823 None
1824 }
1825 })
1826 .ok_or(ConstantEvaluatorError::Overflow(
1827 "in dot built-in".to_string(),
1828 ))
1829 }
1830
1831 fn int_dot_wrapping<P>(a: &[P], b: &[P]) -> P
1832 where
1833 P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul,
1834 {
1835 a.iter()
1836 .zip(b.iter())
1837 .map(|(&aa, bb)| aa.wrapping_mul(bb))
1838 .fold(P::zero(), |acc, x| acc.wrapping_add(&x))
1839 }
1840
1841 let result = match_literal_vector!(match (e1, e2) => Literal {
1842 Float => |e1, e2| { float_dot_checked(e1, e2)? },
1843 AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? },
1844 I32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1845 U32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1846 })?;
1847 self.register_evaluated_expr(Expression::Literal(result), span)
1848 }
1849 crate::MathFunction::Length => {
1850 let e1 = self.extract_vec(arg, true)?;
1852
1853 fn float_length<F>(e: &[F]) -> F
1854 where
1855 F: core::ops::Mul<F>,
1856 F: num_traits::Float + iter::Sum,
1857 {
1858 if e.len() == 1 {
1859 e[0].abs()
1861 } else {
1862 e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1863 }
1864 }
1865
1866 let result = match_literal_vector!(match e1 => Literal {
1867 Float => |e1| { float_length(e1) },
1868 })?;
1869 self.register_evaluated_expr(Expression::Literal(result), span)
1870 }
1871 crate::MathFunction::Distance => {
1872 let e1 = self.extract_vec(arg, true)?;
1874 let e2 = self.extract_vec(arg1.unwrap(), true)?;
1875 if e1.len() != e2.len() {
1876 return Err(ConstantEvaluatorError::InvalidMathArg);
1877 }
1878
1879 fn float_distance<F>(a: &[F], b: &[F]) -> F
1880 where
1881 F: core::ops::Mul<F>,
1882 F: num_traits::Float + iter::Sum + core::ops::Sub,
1883 {
1884 if a.len() == 1 {
1885 (a[0] - b[0]).abs()
1887 } else {
1888 a.iter()
1889 .zip(b.iter())
1890 .map(|(&aa, &bb)| aa - bb)
1891 .map(|ei| ei * ei)
1892 .sum::<F>()
1893 .sqrt()
1894 }
1895 }
1896 let result = match_literal_vector!(match (e1, e2) => Literal {
1897 Float => |e1, e2| { float_distance(e1, e2) },
1898 })?;
1899 self.register_evaluated_expr(Expression::Literal(result), span)
1900 }
1901 crate::MathFunction::Normalize => {
1902 let e1 = self.extract_vec(arg, true)?;
1904
1905 fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1906 where
1907 F: core::ops::Mul<F>,
1908 F: num_traits::Float + iter::Sum,
1909 {
1910 let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1911 let mut out = ArrayVec::new();
1912 for &ei in e {
1913 out.push(ei / len);
1914 }
1915 out
1916 }
1917
1918 let result = match_literal_vector!(match e1 => LiteralVector {
1919 Float => |e1| { float_normalize(e1) },
1920 })?;
1921 result.register_as_evaluated_expr(self, span)
1922 }
1923
1924 crate::MathFunction::Modf
1926 | crate::MathFunction::Frexp
1927 | crate::MathFunction::Ldexp
1928 | crate::MathFunction::Outer
1929 | crate::MathFunction::FaceForward
1930 | crate::MathFunction::Reflect
1931 | crate::MathFunction::Refract
1932 | crate::MathFunction::Mix
1933 | crate::MathFunction::SmoothStep
1934 | crate::MathFunction::Inverse
1935 | crate::MathFunction::Transpose
1936 | crate::MathFunction::Determinant
1937 | crate::MathFunction::QuantizeToF16
1938 | crate::MathFunction::ExtractBits
1939 | crate::MathFunction::InsertBits
1940 | crate::MathFunction::Pack4x8snorm
1941 | crate::MathFunction::Pack4x8unorm
1942 | crate::MathFunction::Pack2x16snorm
1943 | crate::MathFunction::Pack2x16unorm
1944 | crate::MathFunction::Pack2x16float
1945 | crate::MathFunction::Pack4xI8
1946 | crate::MathFunction::Pack4xU8
1947 | crate::MathFunction::Pack4xI8Clamp
1948 | crate::MathFunction::Pack4xU8Clamp
1949 | crate::MathFunction::Unpack4x8snorm
1950 | crate::MathFunction::Unpack4x8unorm
1951 | crate::MathFunction::Unpack2x16snorm
1952 | crate::MathFunction::Unpack2x16unorm
1953 | crate::MathFunction::Unpack2x16float
1954 | crate::MathFunction::Unpack4xI8
1955 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1956 format!("{fun:?} built-in function"),
1957 )),
1958 }
1959 }
1960
1961 fn packed_dot_product(
1963 &mut self,
1964 a: Handle<Expression>,
1965 b: Handle<Expression>,
1966 span: Span,
1967 signed: bool,
1968 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1969 let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1970 return Err(ConstantEvaluatorError::InvalidMathArg);
1971 };
1972 let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1973 return Err(ConstantEvaluatorError::InvalidMathArg);
1974 };
1975
1976 let result = if signed {
1977 Literal::I32(
1978 (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1979 + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1980 + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1981 + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1982 )
1983 } else {
1984 Literal::U32(
1985 (a & 0xFF) * (b & 0xFF)
1986 + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1987 + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1988 + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1989 )
1990 };
1991
1992 self.register_evaluated_expr(Expression::Literal(result), span)
1993 }
1994
1995 fn cross_product(
1997 &mut self,
1998 a: Handle<Expression>,
1999 b: Handle<Expression>,
2000 span: Span,
2001 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2002 use Literal as Li;
2003
2004 let (a, ty) = self.extract_vec_with_size::<3>(a)?;
2005 let (b, _) = self.extract_vec_with_size::<3>(b)?;
2006
2007 let product = match (a, b) {
2008 (
2009 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
2010 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
2011 ) => {
2012 let p = cross_product(
2017 [a0 as f64, a1 as f64, a2 as f64],
2018 [b0 as f64, b1 as f64, b2 as f64],
2019 );
2020 [
2021 Li::AbstractFloat(p[0]),
2022 Li::AbstractFloat(p[1]),
2023 Li::AbstractFloat(p[2]),
2024 ]
2025 }
2026 (
2027 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
2028 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
2029 ) => {
2030 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2031 [
2032 Li::AbstractFloat(p[0]),
2033 Li::AbstractFloat(p[1]),
2034 Li::AbstractFloat(p[2]),
2035 ]
2036 }
2037 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
2038 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2039 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
2040 }
2041 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
2042 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2043 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
2044 }
2045 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
2046 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2047 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
2048 }
2049 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2050 };
2051
2052 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
2053 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
2054 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
2055
2056 self.register_evaluated_expr(
2057 Expression::Compose {
2058 ty,
2059 components: vec![p0, p1, p2],
2060 },
2061 span,
2062 )
2063 }
2064
2065 fn extract_vec_with_size<const N: usize>(
2073 &mut self,
2074 expr: Handle<Expression>,
2075 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2076 let span = self.expressions.get_span(expr);
2077 let expr = self.eval_zero_value_and_splat(expr, span)?;
2078 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2079 return Err(ConstantEvaluatorError::InvalidMathArg);
2080 };
2081
2082 let mut value = [Literal::Bool(false); N];
2083 for (component, elt) in
2084 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2085 .zip(value.iter_mut())
2086 {
2087 let Expression::Literal(literal) = self.expressions[component] else {
2088 return Err(ConstantEvaluatorError::InvalidMathArg);
2089 };
2090 *elt = literal;
2091 }
2092
2093 Ok((value, ty))
2094 }
2095
2096 fn extract_vec(
2104 &mut self,
2105 expr: Handle<Expression>,
2106 allow_single: bool,
2107 ) -> Result<LiteralVector, ConstantEvaluatorError> {
2108 let span = self.expressions.get_span(expr);
2109 let expr = self.eval_zero_value_and_splat(expr, span)?;
2110
2111 match self.expressions[expr] {
2112 Expression::Literal(literal) if allow_single => {
2113 Ok(LiteralVector::from_literal(literal))
2114 }
2115 Expression::Compose { ty, ref components } => {
2116 let mut components_out = ArrayVec::<Literal, { crate::VectorSize::MAX }>::new();
2117 for expr in
2118 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2119 {
2120 match self.expressions[expr] {
2121 Expression::Literal(l) => components_out.push(l),
2122 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2123 }
2124 }
2125 LiteralVector::from_literal_vec(components_out)
2126 }
2127 _ => Err(ConstantEvaluatorError::InvalidMathArg),
2128 }
2129 }
2130
2131 fn array_length(
2132 &mut self,
2133 array: Handle<Expression>,
2134 span: Span,
2135 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2136 match self.expressions[array] {
2137 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2138 match self.types[ty].inner {
2139 TypeInner::Array { size, .. } => match size {
2140 ArraySize::Constant(len) => {
2141 let expr = Expression::Literal(Literal::U32(len.get()));
2142 self.register_evaluated_expr(expr, span)
2143 }
2144 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2145 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2146 },
2147 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2148 }
2149 }
2150 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2151 }
2152 }
2153
2154 fn access(
2155 &mut self,
2156 base: Handle<Expression>,
2157 index: usize,
2158 span: Span,
2159 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2160 match self.expressions[base] {
2161 Expression::ZeroValue(ty) => {
2162 let ty_inner = &self.types[ty].inner;
2163 let components = ty_inner
2164 .components()
2165 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2166
2167 if index >= components as usize {
2168 Err(ConstantEvaluatorError::InvalidAccessBase)
2169 } else {
2170 let ty_res = ty_inner
2171 .component_type(index)
2172 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2173 let ty = match ty_res {
2174 crate::proc::TypeResolution::Handle(ty) => ty,
2175 crate::proc::TypeResolution::Value(inner) => {
2176 self.types.insert(Type { name: None, inner }, span)
2177 }
2178 };
2179 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2180 }
2181 }
2182 Expression::Splat { size, value } => {
2183 if index >= size as usize {
2184 Err(ConstantEvaluatorError::InvalidAccessBase)
2185 } else {
2186 Ok(value)
2187 }
2188 }
2189 Expression::Compose { ty, ref components } => {
2190 let _ = self.types[ty]
2191 .inner
2192 .components()
2193 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2194
2195 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2196 .nth(index)
2197 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2198 }
2199 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2200 }
2201 }
2202
2203 fn eval_zero_value_and_splat(
2210 &mut self,
2211 mut expr: Handle<Expression>,
2212 span: Span,
2213 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2214 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2217 let components = components
2218 .clone()
2219 .iter()
2220 .map(|component| self.eval_zero_value_and_splat(*component, span))
2221 .collect::<Result<_, _>>()?;
2222 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2223 }
2224
2225 if let Expression::Splat { size, value } = self.expressions[expr] {
2229 expr = self.splat(value, size, span)?;
2230 }
2231 if let Expression::ZeroValue(ty) = self.expressions[expr] {
2232 expr = self.eval_zero_value_impl(ty, span)?;
2233 }
2234 Ok(expr)
2235 }
2236
2237 fn eval_zero_value(
2243 &mut self,
2244 expr: Handle<Expression>,
2245 span: Span,
2246 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2247 match self.expressions[expr] {
2248 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2249 _ => Ok(expr),
2250 }
2251 }
2252
2253 fn eval_zero_value_impl(
2259 &mut self,
2260 ty: Handle<Type>,
2261 span: Span,
2262 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2263 match self.types[ty].inner {
2264 TypeInner::Scalar(scalar) => {
2265 let expr = Expression::Literal(
2266 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2267 );
2268 self.register_evaluated_expr(expr, span)
2269 }
2270 TypeInner::Vector { size, scalar } => {
2271 let scalar_ty = self.types.insert(
2272 Type {
2273 name: None,
2274 inner: TypeInner::Scalar(scalar),
2275 },
2276 span,
2277 );
2278 let el = self.eval_zero_value_impl(scalar_ty, span)?;
2279 let expr = Expression::Compose {
2280 ty,
2281 components: vec![el; size as usize],
2282 };
2283 self.register_evaluated_expr(expr, span)
2284 }
2285 TypeInner::Matrix {
2286 columns,
2287 rows,
2288 scalar,
2289 } => {
2290 let vec_ty = self.types.insert(
2291 Type {
2292 name: None,
2293 inner: TypeInner::Vector { size: rows, scalar },
2294 },
2295 span,
2296 );
2297 let el = self.eval_zero_value_impl(vec_ty, span)?;
2298 let expr = Expression::Compose {
2299 ty,
2300 components: vec![el; columns as usize],
2301 };
2302 self.register_evaluated_expr(expr, span)
2303 }
2304 TypeInner::Array {
2305 base,
2306 size: ArraySize::Constant(size),
2307 ..
2308 } => {
2309 let el = self.eval_zero_value_impl(base, span)?;
2310 let expr = Expression::Compose {
2311 ty,
2312 components: vec![el; size.get() as usize],
2313 };
2314 self.register_evaluated_expr(expr, span)
2315 }
2316 TypeInner::Struct { ref members, .. } => {
2317 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2318 let mut components = Vec::with_capacity(members.len());
2319 for ty in types {
2320 components.push(self.eval_zero_value_impl(ty, span)?);
2321 }
2322 let expr = Expression::Compose { ty, components };
2323 self.register_evaluated_expr(expr, span)
2324 }
2325 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2326 }
2327 }
2328
2329 pub fn cast(
2333 &mut self,
2334 expr: Handle<Expression>,
2335 target: crate::Scalar,
2336 span: Span,
2337 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2338 use crate::Scalar as Sc;
2339
2340 let expr = self.eval_zero_value(expr, span)?;
2341
2342 let make_error = || -> Result<_, ConstantEvaluatorError> {
2343 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2344
2345 #[cfg(feature = "wgsl-in")]
2346 let to = target.to_wgsl_for_diagnostics();
2347
2348 #[cfg(not(feature = "wgsl-in"))]
2349 let to = format!("{target:?}");
2350
2351 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2352 };
2353
2354 use crate::proc::type_methods::IntFloatLimits;
2355
2356 let expr = match self.expressions[expr] {
2357 Expression::Literal(literal) => {
2358 let literal = match target {
2359 Sc::I32 => Literal::I32(match literal {
2360 Literal::I32(v) => v,
2361 Literal::U32(v) => v as i32,
2362 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2363 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
2365 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2366 return make_error();
2367 }
2368 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2369 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2370 }),
2371 Sc::U32 => Literal::U32(match literal {
2372 Literal::I32(v) => v as u32,
2373 Literal::U32(v) => v,
2374 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2375 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2377 Literal::Bool(v) => v as u32,
2378 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2379 return make_error();
2380 }
2381 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2382 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2383 }),
2384 Sc::I64 => Literal::I64(match literal {
2385 Literal::I32(v) => v as i64,
2386 Literal::U32(v) => v as i64,
2387 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2388 Literal::Bool(v) => v as i64,
2389 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2390 Literal::I64(v) => v,
2391 Literal::U64(v) => v as i64,
2392 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2394 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2395 }),
2396 Sc::U64 => Literal::U64(match literal {
2397 Literal::I32(v) => v as u64,
2398 Literal::U32(v) => v as u64,
2399 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2400 Literal::Bool(v) => v as u64,
2401 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2402 Literal::I64(v) => v as u64,
2403 Literal::U64(v) => v,
2404 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2406 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2407 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2408 }),
2409 Sc::F16 => Literal::F16(match literal {
2410 Literal::F16(v) => v,
2411 Literal::F32(v) => f16::from_f32(v),
2412 Literal::F64(v) => f16::from_f64(v),
2413 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2414 Literal::I64(v) => f16::from_i64(v).unwrap(),
2415 Literal::U64(v) => f16::from_u64(v).unwrap(),
2416 Literal::I32(v) => f16::from_i32(v).unwrap(),
2417 Literal::U32(v) => f16::from_u32(v).unwrap(),
2418 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2419 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2420 }),
2421 Sc::F32 => Literal::F32(match literal {
2422 Literal::I32(v) => v as f32,
2423 Literal::U32(v) => v as f32,
2424 Literal::F32(v) => v,
2425 Literal::Bool(v) => v as u32 as f32,
2426 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2427 return make_error();
2428 }
2429 Literal::F16(v) => f16::to_f32(v),
2430 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2431 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2432 }),
2433 Sc::F64 => Literal::F64(match literal {
2434 Literal::I32(v) => v as f64,
2435 Literal::U32(v) => v as f64,
2436 Literal::F16(v) => f16::to_f64(v),
2437 Literal::F32(v) => v as f64,
2438 Literal::F64(v) => v,
2439 Literal::Bool(v) => v as u32 as f64,
2440 Literal::I64(_) | Literal::U64(_) => return make_error(),
2441 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2442 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2443 }),
2444 Sc::BOOL => Literal::Bool(match literal {
2445 Literal::I32(v) => v != 0,
2446 Literal::U32(v) => v != 0,
2447 Literal::F32(v) => v != 0.0,
2448 Literal::F16(v) => v != f16::zero(),
2449 Literal::Bool(v) => v,
2450 Literal::AbstractInt(v) => v != 0,
2451 Literal::AbstractFloat(v) => v != 0.0,
2452 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2453 return make_error();
2454 }
2455 }),
2456 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2457 Literal::AbstractInt(v) => {
2458 v as f64
2463 }
2464 Literal::AbstractFloat(v) => v,
2465 _ => return make_error(),
2466 }),
2467 Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2468 Literal::AbstractInt(v) => v,
2469 _ => return make_error(),
2470 }),
2471 _ => {
2472 log::debug!("Constant evaluator refused to convert value to {target:?}");
2473 return make_error();
2474 }
2475 };
2476 Expression::Literal(literal)
2477 }
2478 Expression::Compose {
2479 ty,
2480 components: ref src_components,
2481 } => {
2482 let ty_inner = match self.types[ty].inner {
2483 TypeInner::Vector { size, .. } => TypeInner::Vector {
2484 size,
2485 scalar: target,
2486 },
2487 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2488 columns,
2489 rows,
2490 scalar: target,
2491 },
2492 _ => return make_error(),
2493 };
2494
2495 let mut components = src_components.clone();
2496 for component in &mut components {
2497 *component = self.cast(*component, target, span)?;
2498 }
2499
2500 let ty = self.types.insert(
2501 Type {
2502 name: None,
2503 inner: ty_inner,
2504 },
2505 span,
2506 );
2507
2508 Expression::Compose { ty, components }
2509 }
2510 Expression::Splat { size, value } => {
2511 let value_span = self.expressions.get_span(value);
2512 let cast_value = self.cast(value, target, value_span)?;
2513 Expression::Splat {
2514 size,
2515 value: cast_value,
2516 }
2517 }
2518 _ => return make_error(),
2519 };
2520
2521 self.register_evaluated_expr(expr, span)
2522 }
2523
2524 pub fn cast_array(
2537 &mut self,
2538 expr: Handle<Expression>,
2539 target: crate::Scalar,
2540 span: Span,
2541 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2542 let expr = self.check_and_get(expr)?;
2543
2544 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2545 return self.cast(expr, target, span);
2546 };
2547
2548 let TypeInner::Array {
2549 base: _,
2550 size,
2551 stride: _,
2552 } = self.types[ty].inner
2553 else {
2554 return self.cast(expr, target, span);
2555 };
2556
2557 let mut components = components.clone();
2558 for component in &mut components {
2559 *component = self.cast_array(*component, target, span)?;
2560 }
2561
2562 let first = components.first().unwrap();
2563 let new_base = match self.resolve_type(*first)? {
2564 crate::proc::TypeResolution::Handle(ty) => ty,
2565 crate::proc::TypeResolution::Value(inner) => {
2566 self.types.insert(Type { name: None, inner }, span)
2567 }
2568 };
2569 let mut layouter = core::mem::take(self.layouter);
2570 layouter.update(self.to_ctx()).unwrap();
2571 *self.layouter = layouter;
2572
2573 let new_base_stride = self.layouter[new_base].to_stride();
2574 let new_array_ty = self.types.insert(
2575 Type {
2576 name: None,
2577 inner: TypeInner::Array {
2578 base: new_base,
2579 size,
2580 stride: new_base_stride,
2581 },
2582 },
2583 span,
2584 );
2585
2586 let compose = Expression::Compose {
2587 ty: new_array_ty,
2588 components,
2589 };
2590 self.register_evaluated_expr(compose, span)
2591 }
2592
2593 fn unary_op(
2594 &mut self,
2595 op: UnaryOperator,
2596 expr: Handle<Expression>,
2597 span: Span,
2598 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2599 let expr = self.eval_zero_value_and_splat(expr, span)?;
2600
2601 let expr = match self.expressions[expr] {
2602 Expression::Literal(value) => Expression::Literal(match op {
2603 UnaryOperator::Negate => match value {
2604 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2605 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2606 Literal::F32(v) => Literal::F32(-v),
2607 Literal::F16(v) => Literal::F16(-v),
2608 Literal::F64(v) => Literal::F64(-v),
2609 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2610 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2611 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2612 },
2613 UnaryOperator::LogicalNot => match value {
2614 Literal::Bool(v) => Literal::Bool(!v),
2615 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2616 },
2617 UnaryOperator::BitwiseNot => match value {
2618 Literal::I32(v) => Literal::I32(!v),
2619 Literal::I64(v) => Literal::I64(!v),
2620 Literal::U32(v) => Literal::U32(!v),
2621 Literal::U64(v) => Literal::U64(!v),
2622 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2623 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2624 },
2625 }),
2626 Expression::Compose {
2627 ty,
2628 components: ref src_components,
2629 } => {
2630 match self.types[ty].inner {
2631 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2632 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2633 }
2634
2635 let mut components = src_components.clone();
2636 for component in &mut components {
2637 *component = self.unary_op(op, *component, span)?;
2638 }
2639
2640 Expression::Compose { ty, components }
2641 }
2642 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2643 };
2644
2645 self.register_evaluated_expr(expr, span)
2646 }
2647
2648 fn binary_op(
2649 &mut self,
2650 op: BinaryOperator,
2651 left: Handle<Expression>,
2652 right: Handle<Expression>,
2653 span: Span,
2654 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2655 let left = self.eval_zero_value_and_splat(left, span)?;
2656 let right = self.eval_zero_value_and_splat(right, span)?;
2657
2658 let expr = match (&self.expressions[left], &self.expressions[right]) {
2663 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2664 if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight)
2665 && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value)
2666 {
2667 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2668 }
2669
2670 let literal = match op {
2671 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2672 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2673 BinaryOperator::Less => Literal::Bool(left_value < right_value),
2674 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2675 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2676 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2677
2678 _ => match (left_value, right_value) {
2679 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2680 BinaryOperator::Add => a.wrapping_add(b),
2681 BinaryOperator::Subtract => a.wrapping_sub(b),
2682 BinaryOperator::Multiply => a.wrapping_mul(b),
2683 BinaryOperator::Divide => {
2684 if b == 0 {
2685 return Err(ConstantEvaluatorError::DivisionByZero);
2686 } else {
2687 a.wrapping_div(b)
2688 }
2689 }
2690 BinaryOperator::Modulo => {
2691 if b == 0 {
2692 return Err(ConstantEvaluatorError::RemainderByZero);
2693 } else {
2694 a.wrapping_rem(b)
2695 }
2696 }
2697 BinaryOperator::And => a & b,
2698 BinaryOperator::ExclusiveOr => a ^ b,
2699 BinaryOperator::InclusiveOr => a | b,
2700 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2701 }),
2702 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2703 BinaryOperator::ShiftLeft => {
2704 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2705 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2706 }
2707 a.checked_shl(b)
2708 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2709 }
2710 BinaryOperator::ShiftRight => a
2711 .checked_shr(b)
2712 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2713 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2714 }),
2715 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2716 BinaryOperator::Add => a.wrapping_add(b),
2717 BinaryOperator::Subtract => a.wrapping_sub(b),
2718 BinaryOperator::Multiply => a.wrapping_mul(b),
2719 BinaryOperator::Divide => a
2720 .checked_div(b)
2721 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2722 BinaryOperator::Modulo => a
2723 .checked_rem(b)
2724 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2725 BinaryOperator::And => a & b,
2726 BinaryOperator::ExclusiveOr => a ^ b,
2727 BinaryOperator::InclusiveOr => a | b,
2728 BinaryOperator::ShiftLeft => a
2729 .checked_mul(
2730 1u32.checked_shl(b)
2731 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2732 )
2733 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2734 BinaryOperator::ShiftRight => a
2735 .checked_shr(b)
2736 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2737 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2738 }),
2739 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2740 BinaryOperator::Add => a + b,
2741 BinaryOperator::Subtract => a - b,
2742 BinaryOperator::Multiply => a * b,
2743 BinaryOperator::Divide => a / b,
2744 BinaryOperator::Modulo => a % b,
2745 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2746 }),
2747 (Literal::AbstractInt(a), Literal::U32(b)) => {
2748 Literal::AbstractInt(match op {
2749 BinaryOperator::ShiftLeft => {
2750 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2751 return Err(ConstantEvaluatorError::Overflow(
2752 "<<".to_string(),
2753 ));
2754 }
2755 a.checked_shl(b).unwrap_or(0)
2756 }
2757 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2758 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2759 })
2760 }
2761 (Literal::F16(a), Literal::F16(b)) => {
2762 let result = match op {
2763 BinaryOperator::Add => a + b,
2764 BinaryOperator::Subtract => a - b,
2765 BinaryOperator::Multiply => a * b,
2766 BinaryOperator::Divide => a / b,
2767 BinaryOperator::Modulo => a % b,
2768 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2769 };
2770 if !result.is_finite() {
2771 return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2772 }
2773 Literal::F16(result)
2774 }
2775 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2776 Literal::AbstractInt(match op {
2777 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2778 ConstantEvaluatorError::Overflow("addition".into())
2779 })?,
2780 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2781 ConstantEvaluatorError::Overflow("subtraction".into())
2782 })?,
2783 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2784 ConstantEvaluatorError::Overflow("multiplication".into())
2785 })?,
2786 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2787 if b == 0 {
2788 ConstantEvaluatorError::DivisionByZero
2789 } else {
2790 ConstantEvaluatorError::Overflow("division".into())
2791 }
2792 })?,
2793 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2794 if b == 0 {
2795 ConstantEvaluatorError::RemainderByZero
2796 } else {
2797 ConstantEvaluatorError::Overflow("remainder".into())
2798 }
2799 })?,
2800 BinaryOperator::And => a & b,
2801 BinaryOperator::ExclusiveOr => a ^ b,
2802 BinaryOperator::InclusiveOr => a | b,
2803 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2804 })
2805 }
2806 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2807 let result = match op {
2808 BinaryOperator::Add => a + b,
2809 BinaryOperator::Subtract => a - b,
2810 BinaryOperator::Multiply => a * b,
2811 BinaryOperator::Divide => a / b,
2812 BinaryOperator::Modulo => a % b,
2813 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2814 };
2815 if !result.is_finite() {
2816 return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2817 }
2818 Literal::AbstractFloat(result)
2819 }
2820 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2821 BinaryOperator::LogicalAnd => a && b,
2822 BinaryOperator::LogicalOr => a || b,
2823 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2824 }),
2825 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2826 },
2827 };
2828 Expression::Literal(literal)
2829 }
2830 (
2831 &Expression::Compose {
2832 components: ref src_components,
2833 ty,
2834 },
2835 &Expression::Literal(_),
2836 ) => {
2837 if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
2838 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2839 }
2840 let mut components = src_components.clone();
2841 for component in &mut components {
2842 *component = self.binary_op(op, *component, right, span)?;
2843 }
2844 Expression::Compose { ty, components }
2845 }
2846 (
2847 &Expression::Literal(_),
2848 &Expression::Compose {
2849 components: ref src_components,
2850 ty,
2851 },
2852 ) => {
2853 if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
2854 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2855 }
2856 let mut components = src_components.clone();
2857 for component in &mut components {
2858 *component = self.binary_op(op, left, *component, span)?;
2859 }
2860 Expression::Compose { ty, components }
2861 }
2862 (
2863 &Expression::Compose {
2864 components: ref left_components,
2865 ty: left_ty,
2866 },
2867 &Expression::Compose {
2868 components: ref right_components,
2869 ty: right_ty,
2870 },
2871 ) => {
2872 let left_flattened = crate::proc::flatten_compose(
2876 left_ty,
2877 left_components,
2878 self.expressions,
2879 self.types,
2880 )
2881 .collect::<Vec<_>>();
2882 let right_flattened = crate::proc::flatten_compose(
2883 right_ty,
2884 right_components,
2885 self.expressions,
2886 self.types,
2887 )
2888 .collect::<Vec<_>>();
2889
2890 self.binary_op_compose(
2891 op,
2892 &left_flattened,
2893 &right_flattened,
2894 left_ty,
2895 right_ty,
2896 span,
2897 )?
2898 }
2899 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2900 };
2901
2902 return self.register_evaluated_expr(expr, span);
2903
2904 fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool {
2905 let is_numeric_vec = matches!(
2906 compose_ty, TypeInner::Vector { scalar, .. }
2907 if scalar.kind != ScalarKind::Bool
2908 );
2909 let is_allowed_vec_scalar_op = matches!(
2910 op,
2911 BinaryOperator::Add
2912 | BinaryOperator::Subtract
2913 | BinaryOperator::Multiply
2914 | BinaryOperator::Divide
2915 | BinaryOperator::Modulo
2916 );
2917 let is_mat = matches!(compose_ty, TypeInner::Matrix { .. });
2918 let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply);
2919 is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op
2920 }
2921 }
2922
2923 fn binary_op_compose(
2924 &mut self,
2925 op: BinaryOperator,
2926 left_components: &[Handle<Expression>],
2927 right_components: &[Handle<Expression>],
2928 left_ty: Handle<Type>,
2929 right_ty: Handle<Type>,
2930 span: Span,
2931 ) -> Result<Expression, ConstantEvaluatorError> {
2932 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2933 (
2935 &TypeInner::Vector {
2936 size: left_size, ..
2937 },
2938 &TypeInner::Vector {
2939 size: right_size, ..
2940 },
2941 ) if left_size == right_size => self.binary_op_vector(
2942 op,
2943 left_size,
2944 left_components,
2945 right_components,
2946 left_ty,
2947 span,
2948 ),
2949 (
2951 &TypeInner::Vector { size, .. },
2952 &TypeInner::Matrix {
2953 columns,
2954 rows,
2955 scalar,
2956 },
2957 ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix(
2958 left_components,
2959 right_components,
2960 columns,
2961 scalar,
2962 span,
2963 ),
2964 (
2966 &TypeInner::Matrix {
2967 columns,
2968 rows,
2969 scalar,
2970 },
2971 &TypeInner::Vector { size, .. },
2972 ) if op == BinaryOperator::Multiply && size == columns => {
2973 self.multiply_matrix_vector(left_components, right_components, rows, scalar, span)
2974 }
2975 (
2977 &TypeInner::Matrix {
2978 columns: left_columns,
2979 rows: left_rows,
2980 scalar,
2981 },
2982 &TypeInner::Matrix {
2983 columns: right_columns,
2984 rows: right_rows,
2985 ..
2986 },
2987 ) => match op {
2988 BinaryOperator::Add | BinaryOperator::Subtract
2989 if left_columns == right_columns && left_rows == right_rows =>
2990 {
2991 let components = left_components
2992 .iter()
2993 .zip(right_components)
2994 .map(|(&left, &right)| self.binary_op(op, left, right, span))
2995 .collect::<Result<Vec<_>, _>>()?;
2996 Ok(Expression::Compose {
2997 ty: left_ty,
2998 components,
2999 })
3000 }
3001 BinaryOperator::Multiply if left_columns == right_rows => self
3002 .multiply_matrix_matrix(
3003 left_components,
3004 right_components,
3005 left_rows,
3006 right_columns,
3007 scalar,
3008 span,
3009 ),
3010 _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3011 },
3012 _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3013 }
3014 }
3015
3016 fn binary_op_vector(
3017 &mut self,
3018 op: BinaryOperator,
3019 size: crate::VectorSize,
3020 left_components: &[Handle<Expression>],
3021 right_components: &[Handle<Expression>],
3022 left_ty: Handle<Type>,
3023 span: Span,
3024 ) -> Result<Expression, ConstantEvaluatorError> {
3025 let ty = match op {
3026 BinaryOperator::Equal
3028 | BinaryOperator::NotEqual
3029 | BinaryOperator::Less
3030 | BinaryOperator::LessEqual
3031 | BinaryOperator::Greater
3032 | BinaryOperator::GreaterEqual => self.types.insert(
3033 Type {
3034 name: None,
3035 inner: TypeInner::Vector {
3036 size,
3037 scalar: crate::Scalar::BOOL,
3038 },
3039 },
3040 span,
3041 ),
3042
3043 BinaryOperator::Add
3046 | BinaryOperator::Subtract
3047 | BinaryOperator::Multiply
3048 | BinaryOperator::Divide
3049 | BinaryOperator::Modulo
3050 | BinaryOperator::And
3051 | BinaryOperator::ExclusiveOr
3052 | BinaryOperator::InclusiveOr
3053 | BinaryOperator::ShiftLeft
3054 | BinaryOperator::ShiftRight => left_ty,
3055
3056 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
3057 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3059 }
3060 };
3061
3062 let components = left_components
3063 .iter()
3064 .zip(right_components)
3065 .map(|(&left, &right)| self.binary_op(op, left, right, span))
3066 .collect::<Result<Vec<_>, _>>()?;
3067
3068 Ok(Expression::Compose { ty, components })
3069 }
3070
3071 fn multiply_vector_matrix(
3072 &mut self,
3073 vec_components: &[Handle<Expression>],
3074 mat_components: &[Handle<Expression>],
3075 mat_columns: crate::VectorSize,
3076 scalar: crate::Scalar,
3077 span: Span,
3078 ) -> Result<Expression, ConstantEvaluatorError> {
3079 let ty = self.types.insert(
3080 Type {
3081 name: None,
3082 inner: TypeInner::Vector {
3083 size: mat_columns,
3084 scalar,
3085 },
3086 },
3087 span,
3088 );
3089 let components = mat_components
3090 .iter()
3091 .map(|&column| {
3092 let Expression::Compose { ref components, .. } = self.expressions[column] else {
3093 unreachable!()
3094 };
3095 self.dot_exprs(
3096 vec_components.iter().cloned(),
3097 components.clone().into_iter(),
3098 span,
3099 )
3100 })
3101 .collect::<Result<Vec<_>, _>>()?;
3102 Ok(Expression::Compose { ty, components })
3103 }
3104
3105 fn multiply_matrix_vector(
3106 &mut self,
3107 mat_components: &[Handle<Expression>],
3108 vec_components: &[Handle<Expression>],
3109 mat_rows: crate::VectorSize,
3110 scalar: crate::Scalar,
3111 span: Span,
3112 ) -> Result<Expression, ConstantEvaluatorError> {
3113 let ty = self.types.insert(
3114 Type {
3115 name: None,
3116 inner: TypeInner::Vector {
3117 size: mat_rows,
3118 scalar,
3119 },
3120 },
3121 span,
3122 );
3123
3124 let flatten = self.flatten_matrix(mat_components);
3125 let nr = mat_rows as usize;
3126 let components = (0..nr)
3127 .map(|r| {
3128 let row = flatten.iter().skip(r).step_by(nr).cloned();
3129 self.dot_exprs(row, vec_components.iter().cloned(), span)
3130 })
3131 .collect::<Result<Vec<_>, _>>()?;
3132 Ok(Expression::Compose { ty, components })
3133 }
3134
3135 fn multiply_matrix_matrix(
3136 &mut self,
3137 left_components: &[Handle<Expression>],
3138 right_components: &[Handle<Expression>],
3139 left_rows: crate::VectorSize,
3140 right_columns: crate::VectorSize,
3141 scalar: crate::Scalar,
3142 span: Span,
3143 ) -> Result<Expression, ConstantEvaluatorError> {
3144 let left_nc = left_components.len();
3145 let left_nr = left_rows as usize;
3146 let right_nc = right_columns as usize;
3147 let right_nr = left_nc;
3148
3149 let mut result = Vec::with_capacity(right_nc);
3150 let result_ty = self.types.insert(
3151 Type {
3152 name: None,
3153 inner: TypeInner::Matrix {
3154 columns: right_columns,
3155 rows: left_rows,
3156 scalar,
3157 },
3158 },
3159 span,
3160 );
3161 let result_column_ty = self.types.insert(
3162 Type {
3163 name: None,
3164 inner: TypeInner::Vector {
3165 size: left_rows,
3166 scalar,
3167 },
3168 },
3169 span,
3170 );
3171
3172 let left_flattened = self.flatten_matrix(left_components);
3173 let right_flattened = self.flatten_matrix(right_components);
3174 for c in 0..right_nc {
3175 let result_column = (0..left_nr)
3176 .map(|r| {
3177 let row = left_flattened.iter().skip(r).step_by(left_nr);
3178 let column = right_flattened.iter().skip(c * right_nr).take(right_nr);
3179 self.dot_exprs(row.cloned(), column.cloned(), span)
3180 })
3181 .collect::<Result<Vec<_>, _>>()?;
3182 let expr = Expression::Compose {
3183 ty: result_column_ty,
3184 components: result_column,
3185 };
3186 let handle = self.register_evaluated_expr(expr, span)?;
3187 result.push(handle);
3188 }
3189 Ok(Expression::Compose {
3190 ty: result_ty,
3191 components: result,
3192 })
3193 }
3194
3195 fn flatten_matrix(&self, columns: &[Handle<Expression>]) -> ArrayVec<Handle<Expression>, 16> {
3196 let mut flattened = ArrayVec::<_, 16>::new();
3197 for &column in columns {
3198 let Expression::Compose { ref components, .. } = self.expressions[column] else {
3199 unreachable!()
3200 };
3201 flattened.extend(components.iter().cloned());
3202 }
3203 flattened
3204 }
3205
3206 fn dot_exprs(
3207 &mut self,
3208 left: impl Iterator<Item = Handle<Expression>>,
3209 right: impl Iterator<Item = Handle<Expression>>,
3210 span: Span,
3211 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3212 let mut acc = None;
3213 for (l, r) in left.zip(right) {
3214 let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?;
3215 match acc.as_mut() {
3216 Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?,
3217 None => acc = Some(result),
3218 }
3219 }
3220 Ok(acc.unwrap())
3221 }
3222
3223 fn relational(
3224 &mut self,
3225 fun: RelationalFunction,
3226 arg: Handle<Expression>,
3227 span: Span,
3228 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3229 let arg = self.eval_zero_value_and_splat(arg, span)?;
3230 match fun {
3231 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
3232 Expression::Literal(Literal::Bool(_)) => Ok(arg),
3233 Expression::Compose { ty, ref components }
3234 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
3235 {
3236 let mut bool_components = ArrayVec::<bool, { crate::VectorSize::MAX }>::new();
3237 for component in
3238 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
3239 {
3240 match self.expressions[component] {
3241 Expression::Literal(Literal::Bool(val)) => {
3242 bool_components.push(val);
3243 }
3244 _ => {
3245 return Err(ConstantEvaluatorError::InvalidRelationalArg(fun));
3246 }
3247 }
3248 }
3249 let components = bool_components;
3250 let result = match fun {
3251 RelationalFunction::All => components.iter().all(|c| *c),
3252 RelationalFunction::Any => components.iter().any(|c| *c),
3253 _ => unreachable!(),
3254 };
3255 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
3256 }
3257 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
3258 },
3259 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
3260 "{fun:?} built-in function"
3261 ))),
3262 }
3263 }
3264
3265 fn copy_from(
3273 &mut self,
3274 expr: Handle<Expression>,
3275 expressions: &Arena<Expression>,
3276 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3277 let span = expressions.get_span(expr);
3278 match expressions[expr] {
3279 ref expr @ (Expression::Literal(_)
3280 | Expression::Constant(_)
3281 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
3282 Expression::Compose { ty, ref components } => {
3283 let mut components = components.clone();
3284 for component in &mut components {
3285 *component = self.copy_from(*component, expressions)?;
3286 }
3287 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
3288 }
3289 Expression::Splat { size, value } => {
3290 let value = self.copy_from(value, expressions)?;
3291 self.register_evaluated_expr(Expression::Splat { size, value }, span)
3292 }
3293 _ => {
3294 log::debug!("copy_from: SubexpressionsAreNotConstant");
3295 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
3296 }
3297 }
3298 }
3299
3300 fn vector_compose_flattened_size(
3302 &self,
3303 components: &[Handle<Expression>],
3304 ) -> Result<usize, ConstantEvaluatorError> {
3305 components
3306 .iter()
3307 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
3308 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
3309 TypeInner::Scalar(_) => 1,
3310 TypeInner::Vector { size, .. } => size as usize,
3314 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
3315 };
3316 Ok(acc + size)
3317 })
3318 }
3319
3320 fn register_evaluated_expr(
3321 &mut self,
3322 expr: Expression,
3323 span: Span,
3324 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3325 if let Expression::Literal(literal) = expr {
3330 crate::valid::check_literal_value(literal)?;
3331 }
3332
3333 if let Expression::Compose { ty, ref components } = expr {
3337 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
3338 let expected = size as usize;
3339 let actual = self.vector_compose_flattened_size(components)?;
3340 if expected != actual {
3341 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3342 expected,
3343 actual,
3344 });
3345 }
3346 }
3347 }
3348
3349 Ok(self.append_expr(expr, span, ExpressionKind::Const))
3350 }
3351
3352 fn append_expr(
3353 &mut self,
3354 expr: Expression,
3355 span: Span,
3356 expr_type: ExpressionKind,
3357 ) -> Handle<Expression> {
3358 let h = match self.behavior {
3359 Behavior::Wgsl(
3360 WgslRestrictions::Runtime(ref mut function_local_data)
3361 | WgslRestrictions::Const(Some(ref mut function_local_data)),
3362 )
3363 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3364 let is_running = function_local_data.emitter.is_running();
3365 let needs_pre_emit = expr.needs_pre_emit();
3366 if is_running && needs_pre_emit {
3367 function_local_data
3368 .block
3369 .extend(function_local_data.emitter.finish(self.expressions));
3370 let h = self.expressions.append(expr, span);
3371 function_local_data.emitter.start(self.expressions);
3372 h
3373 } else {
3374 self.expressions.append(expr, span)
3375 }
3376 }
3377 _ => self.expressions.append(expr, span),
3378 };
3379 self.expression_kind_tracker.insert(h, expr_type);
3380 h
3381 }
3382
3383 fn resolve_type(
3388 &self,
3389 expr: Handle<Expression>,
3390 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3391 use crate::proc::TypeResolution as Tr;
3392 use crate::Expression as Ex;
3393 let resolution = match self.expressions[expr] {
3394 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3395 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3396 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3397 Ex::Splat { size, value } => {
3398 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3399 return Err(ConstantEvaluatorError::SplatScalarOnly);
3400 };
3401 Tr::Value(TypeInner::Vector { scalar, size })
3402 }
3403 _ => {
3404 log::debug!("resolve_type: SubexpressionsAreNotConstant");
3405 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3406 }
3407 };
3408
3409 Ok(resolution)
3410 }
3411
3412 fn select(
3413 &mut self,
3414 reject: Handle<Expression>,
3415 accept: Handle<Expression>,
3416 condition: Handle<Expression>,
3417 span: Span,
3418 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3419 let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3420
3421 let reject = arg(reject)?;
3422 let accept = arg(accept)?;
3423 let condition = arg(condition)?;
3424
3425 let select_single_component =
3426 |this: &mut Self, reject_scalar, reject, accept, condition| {
3427 let accept = this.cast(accept, reject_scalar, span)?;
3428 if condition {
3429 Ok(accept)
3430 } else {
3431 Ok(reject)
3432 }
3433 };
3434
3435 match (&self.expressions[reject], &self.expressions[accept]) {
3436 (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3437 let reject_scalar = reject_lit.scalar();
3438 let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3439 else {
3440 return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3441 };
3442 select_single_component(self, reject_scalar, reject, accept, condition)
3443 }
3444 (
3445 &Expression::Compose {
3446 ty: reject_ty,
3447 components: ref reject_components,
3448 },
3449 &Expression::Compose {
3450 ty: accept_ty,
3451 components: ref accept_components,
3452 },
3453 ) => {
3454 let ty_deets = |ty| {
3455 let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3456 (size.unwrap(), scalar)
3457 };
3458
3459 let expected_vec_size = {
3460 let [(reject_vec_size, _), (accept_vec_size, _)] =
3461 [reject_ty, accept_ty].map(ty_deets);
3462
3463 if reject_vec_size != accept_vec_size {
3464 return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3465 reject: reject_vec_size,
3466 accept: accept_vec_size,
3467 });
3468 }
3469 reject_vec_size
3470 };
3471
3472 let condition_components = match self.expressions[condition] {
3473 Expression::Literal(Literal::Bool(condition)) => {
3474 vec![condition; (expected_vec_size as u8).into()]
3475 }
3476 Expression::Compose {
3477 ty: condition_ty,
3478 components: ref condition_components,
3479 } => {
3480 let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3481 if condition_scalar.kind != ScalarKind::Bool {
3482 return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3483 }
3484 if condition_vec_size != expected_vec_size {
3485 return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3486 }
3487 condition_components
3488 .iter()
3489 .copied()
3490 .map(|component| match &self.expressions[component] {
3491 &Expression::Literal(Literal::Bool(condition)) => condition,
3492 _ => unreachable!(),
3493 })
3494 .collect()
3495 }
3496
3497 _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3498 };
3499
3500 let evaluated = Expression::Compose {
3501 ty: reject_ty,
3502 components: reject_components
3503 .clone()
3504 .into_iter()
3505 .zip(accept_components.clone().into_iter())
3506 .zip(condition_components.into_iter())
3507 .map(|((reject, accept), condition)| {
3508 let reject_scalar = match &self.expressions[reject] {
3509 &Expression::Literal(lit) => lit.scalar(),
3510 _ => unreachable!(),
3511 };
3512 select_single_component(self, reject_scalar, reject, accept, condition)
3513 })
3514 .collect::<Result<_, _>>()?,
3515 };
3516 self.register_evaluated_expr(evaluated, span)
3517 }
3518 _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3519 }
3520 }
3521}
3522
3523fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3524 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3528 match e {
3529 idx @ 0..=31 => idx,
3530 32 => u32::MAX,
3531 _ => unreachable!(),
3532 }
3533 };
3534 match concrete_int {
3535 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3536 ConcreteInt::I32([e]) => {
3537 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3538 }
3539 }
3540}
3541
3542#[test]
3543fn first_trailing_bit_smoke() {
3544 assert_eq!(
3545 first_trailing_bit(ConcreteInt::I32([0])),
3546 ConcreteInt::I32([-1])
3547 );
3548 assert_eq!(
3549 first_trailing_bit(ConcreteInt::I32([1])),
3550 ConcreteInt::I32([0])
3551 );
3552 assert_eq!(
3553 first_trailing_bit(ConcreteInt::I32([2])),
3554 ConcreteInt::I32([1])
3555 );
3556 assert_eq!(
3557 first_trailing_bit(ConcreteInt::I32([-1])),
3558 ConcreteInt::I32([0]),
3559 );
3560 assert_eq!(
3561 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3562 ConcreteInt::I32([31]),
3563 );
3564 assert_eq!(
3565 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3566 ConcreteInt::I32([0]),
3567 );
3568 for idx in 0..32 {
3569 assert_eq!(
3570 first_trailing_bit(ConcreteInt::I32([1 << idx])),
3571 ConcreteInt::I32([idx])
3572 )
3573 }
3574
3575 assert_eq!(
3576 first_trailing_bit(ConcreteInt::U32([0])),
3577 ConcreteInt::U32([u32::MAX])
3578 );
3579 assert_eq!(
3580 first_trailing_bit(ConcreteInt::U32([1])),
3581 ConcreteInt::U32([0])
3582 );
3583 assert_eq!(
3584 first_trailing_bit(ConcreteInt::U32([2])),
3585 ConcreteInt::U32([1])
3586 );
3587 assert_eq!(
3588 first_trailing_bit(ConcreteInt::U32([1 << 31])),
3589 ConcreteInt::U32([31]),
3590 );
3591 assert_eq!(
3592 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3593 ConcreteInt::U32([0]),
3594 );
3595 for idx in 0..32 {
3596 assert_eq!(
3597 first_trailing_bit(ConcreteInt::U32([1 << idx])),
3598 ConcreteInt::U32([idx])
3599 )
3600 }
3601}
3602
3603fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3604 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3608 match e {
3609 idx @ 0..=31 => 31 - idx,
3610 32 => u32::MAX,
3611 _ => unreachable!(),
3612 }
3613 };
3614 match concrete_int {
3615 ConcreteInt::I32([e]) => ConcreteInt::I32([{
3616 let rtl_bit_index = if e.is_negative() {
3617 e.leading_ones()
3618 } else {
3619 e.leading_zeros()
3620 };
3621 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3622 }]),
3623 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3624 }
3625}
3626
3627#[test]
3628fn first_leading_bit_smoke() {
3629 assert_eq!(
3630 first_leading_bit(ConcreteInt::I32([-1])),
3631 ConcreteInt::I32([-1])
3632 );
3633 assert_eq!(
3634 first_leading_bit(ConcreteInt::I32([0])),
3635 ConcreteInt::I32([-1])
3636 );
3637 assert_eq!(
3638 first_leading_bit(ConcreteInt::I32([1])),
3639 ConcreteInt::I32([0])
3640 );
3641 assert_eq!(
3642 first_leading_bit(ConcreteInt::I32([-2])),
3643 ConcreteInt::I32([0])
3644 );
3645 assert_eq!(
3646 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3647 ConcreteInt::I32([12])
3648 );
3649 assert_eq!(
3650 first_leading_bit(ConcreteInt::I32([i32::MAX])),
3651 ConcreteInt::I32([30])
3652 );
3653 assert_eq!(
3654 first_leading_bit(ConcreteInt::I32([i32::MIN])),
3655 ConcreteInt::I32([30])
3656 );
3657 for idx in 0..(32 - 1) {
3659 assert_eq!(
3660 first_leading_bit(ConcreteInt::I32([1 << idx])),
3661 ConcreteInt::I32([idx])
3662 );
3663 }
3664 for idx in 1..(32 - 1) {
3665 assert_eq!(
3666 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3667 ConcreteInt::I32([idx - 1])
3668 );
3669 }
3670
3671 assert_eq!(
3672 first_leading_bit(ConcreteInt::U32([0])),
3673 ConcreteInt::U32([u32::MAX])
3674 );
3675 assert_eq!(
3676 first_leading_bit(ConcreteInt::U32([1])),
3677 ConcreteInt::U32([0])
3678 );
3679 assert_eq!(
3680 first_leading_bit(ConcreteInt::U32([u32::MAX])),
3681 ConcreteInt::U32([31])
3682 );
3683 for idx in 0..32 {
3684 assert_eq!(
3685 first_leading_bit(ConcreteInt::U32([1 << idx])),
3686 ConcreteInt::U32([idx])
3687 )
3688 }
3689}
3690
3691trait TryFromAbstract<T>: Sized {
3693 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3715}
3716
3717impl TryFromAbstract<i64> for i32 {
3718 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3719 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3720 value: format!("{value:?}"),
3721 to_type: "i32",
3722 })
3723 }
3724}
3725
3726impl TryFromAbstract<i64> for u32 {
3727 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3728 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3729 value: format!("{value:?}"),
3730 to_type: "u32",
3731 })
3732 }
3733}
3734
3735impl TryFromAbstract<i64> for u64 {
3736 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3737 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3738 value: format!("{value:?}"),
3739 to_type: "u64",
3740 })
3741 }
3742}
3743
3744impl TryFromAbstract<i64> for i64 {
3745 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3746 Ok(value)
3747 }
3748}
3749
3750impl TryFromAbstract<i64> for f32 {
3751 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3752 let f = value as f32;
3753 Ok(f)
3757 }
3758}
3759
3760impl TryFromAbstract<f64> for f32 {
3761 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3762 let f = value as f32;
3763 if f.is_infinite() {
3764 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3765 value: format!("{value:?}"),
3766 to_type: "f32",
3767 });
3768 }
3769 Ok(f)
3770 }
3771}
3772
3773impl TryFromAbstract<i64> for f64 {
3774 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3775 let f = value as f64;
3776 Ok(f)
3780 }
3781}
3782
3783impl TryFromAbstract<f64> for f64 {
3784 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3785 Ok(value)
3786 }
3787}
3788
3789impl TryFromAbstract<f64> for i32 {
3790 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3791 Ok(value as i32)
3804 }
3805}
3806
3807impl TryFromAbstract<f64> for u32 {
3808 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3809 Ok(value as u32)
3812 }
3813}
3814
3815impl TryFromAbstract<f64> for i64 {
3816 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3817 use crate::proc::type_methods::IntFloatLimits;
3820 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3821 }
3822}
3823
3824impl TryFromAbstract<f64> for u64 {
3825 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3826 use crate::proc::type_methods::IntFloatLimits;
3829 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3830 }
3831}
3832
3833impl TryFromAbstract<f64> for f16 {
3834 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3835 let f = f16::from_f64(value);
3836 if f.is_infinite() {
3837 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3838 value: format!("{value:?}"),
3839 to_type: "f16",
3840 });
3841 }
3842 Ok(f)
3843 }
3844}
3845
3846impl TryFromAbstract<i64> for f16 {
3847 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3848 let f = f16::from_i64(value);
3849 if f.is_none() {
3850 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3851 value: format!("{value:?}"),
3852 to_type: "f16",
3853 });
3854 }
3855 Ok(f.unwrap())
3856 }
3857}
3858
3859fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3860where
3861 T: Copy,
3862 T: core::ops::Mul<T, Output = T>,
3863 T: core::ops::Sub<T, Output = T>,
3864{
3865 [
3866 a[1] * b[2] - a[2] * b[1],
3867 a[2] * b[0] - a[0] * b[2],
3868 a[0] * b[1] - a[1] * b[0],
3869 ]
3870}
3871
3872#[cfg(test)]
3873mod tests {
3874 use alloc::{vec, vec::Vec};
3875
3876 use crate::{
3877 Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind,
3878 Type, TypeInner, UnaryOperator, UniqueArena, VectorSize,
3879 };
3880
3881 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3882
3883 #[test]
3884 fn unary_op() {
3885 let mut types = UniqueArena::new();
3886 let mut constants = Arena::new();
3887 let overrides = Arena::new();
3888 let mut global_expressions = Arena::new();
3889
3890 let scalar_ty = types.insert(
3891 Type {
3892 name: None,
3893 inner: TypeInner::Scalar(crate::Scalar::I32),
3894 },
3895 Default::default(),
3896 );
3897
3898 let vec_ty = types.insert(
3899 Type {
3900 name: None,
3901 inner: TypeInner::Vector {
3902 size: VectorSize::Bi,
3903 scalar: crate::Scalar::I32,
3904 },
3905 },
3906 Default::default(),
3907 );
3908
3909 let h = constants.append(
3910 Constant {
3911 name: None,
3912 ty: scalar_ty,
3913 init: global_expressions
3914 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3915 },
3916 Default::default(),
3917 );
3918
3919 let h1 = constants.append(
3920 Constant {
3921 name: None,
3922 ty: scalar_ty,
3923 init: global_expressions
3924 .append(Expression::Literal(Literal::I32(8)), Default::default()),
3925 },
3926 Default::default(),
3927 );
3928
3929 let vec_h = constants.append(
3930 Constant {
3931 name: None,
3932 ty: vec_ty,
3933 init: global_expressions.append(
3934 Expression::Compose {
3935 ty: vec_ty,
3936 components: vec![constants[h].init, constants[h1].init],
3937 },
3938 Default::default(),
3939 ),
3940 },
3941 Default::default(),
3942 );
3943
3944 let expr = global_expressions.append(Expression::Constant(h), Default::default());
3945 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3946
3947 let expr2 = Expression::Unary {
3948 op: UnaryOperator::Negate,
3949 expr,
3950 };
3951
3952 let expr3 = Expression::Unary {
3953 op: UnaryOperator::BitwiseNot,
3954 expr,
3955 };
3956
3957 let expr4 = Expression::Unary {
3958 op: UnaryOperator::BitwiseNot,
3959 expr: expr1,
3960 };
3961
3962 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3963 let mut solver = ConstantEvaluator {
3964 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3965 types: &mut types,
3966 constants: &constants,
3967 overrides: &overrides,
3968 expressions: &mut global_expressions,
3969 expression_kind_tracker,
3970 layouter: &mut crate::proc::Layouter::default(),
3971 };
3972
3973 let res1 = solver
3974 .try_eval_and_append(expr2, Default::default())
3975 .unwrap();
3976 let res2 = solver
3977 .try_eval_and_append(expr3, Default::default())
3978 .unwrap();
3979 let res3 = solver
3980 .try_eval_and_append(expr4, Default::default())
3981 .unwrap();
3982
3983 assert_eq!(
3984 global_expressions[res1],
3985 Expression::Literal(Literal::I32(-4))
3986 );
3987
3988 assert_eq!(
3989 global_expressions[res2],
3990 Expression::Literal(Literal::I32(!4))
3991 );
3992
3993 let res3_inner = &global_expressions[res3];
3994
3995 match *res3_inner {
3996 Expression::Compose {
3997 ref ty,
3998 ref components,
3999 } => {
4000 assert_eq!(*ty, vec_ty);
4001 let mut components_iter = components.iter().copied();
4002 assert_eq!(
4003 global_expressions[components_iter.next().unwrap()],
4004 Expression::Literal(Literal::I32(!4))
4005 );
4006 assert_eq!(
4007 global_expressions[components_iter.next().unwrap()],
4008 Expression::Literal(Literal::I32(!8))
4009 );
4010 assert!(components_iter.next().is_none());
4011 }
4012 _ => panic!("Expected vector"),
4013 }
4014 }
4015
4016 #[test]
4017 fn matrix_op() {
4018 let mut helper = MatrixTestHelper::new();
4019
4020 for nc in 2..=4 {
4021 for nr in 2..=4 {
4022 let evaluated = helper.eval_vector_multiply_matrix(nc, nr);
4025 let expected = (0..nc)
4026 .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum())
4027 .collect::<Vec<f32>>();
4028 assert_eq!(evaluated, expected);
4029
4030 let evaluated = helper.eval_matrix_multiply_vector(nc, nr);
4033 let expected = (0..nr)
4034 .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum())
4035 .collect::<Vec<f32>>();
4036 assert_eq!(evaluated, expected);
4037
4038 for k in 2..=4 {
4039 let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k);
4042 let expected = (0..nc)
4043 .flat_map(|c| {
4044 (0..nr).map(move |r| {
4045 (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum()
4046 })
4047 })
4048 .collect::<Vec<f32>>();
4049 assert_eq!(evaluated, expected);
4050 }
4051 }
4052 }
4053 }
4054
4055 struct MatrixTestHelper {
4059 types: UniqueArena<Type>,
4060 expressions: Arena<Expression>,
4061 vec_exprs: FastHashMap<usize, Handle<Expression>>,
4063 mat_exprs: FastHashMap<(usize, usize), Handle<Expression>>,
4065 }
4066
4067 impl MatrixTestHelper {
4068 fn new() -> Self {
4069 let mut types = UniqueArena::new();
4070 let mut expressions = Arena::new();
4071 let span = crate::Span::default();
4072
4073 let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default());
4074 for c in 2..=4 {
4075 let vec_ty = types.insert(
4076 Type {
4077 name: None,
4078 inner: TypeInner::Vector {
4079 size: Self::int_to_vector_size(c),
4080 scalar: crate::Scalar::F32,
4081 },
4082 },
4083 span,
4084 );
4085 vec_tys.insert(c, vec_ty);
4086 for r in 2..=4 {
4087 let mat_ty = types.insert(
4088 Type {
4089 name: None,
4090 inner: TypeInner::Matrix {
4091 columns: Self::int_to_vector_size(c),
4092 rows: Self::int_to_vector_size(r),
4093 scalar: crate::Scalar::F32,
4094 },
4095 },
4096 span,
4097 );
4098 mat_tys.insert((c, r), mat_ty);
4099 }
4100 }
4101
4102 let mut lit_exprs = FastHashMap::default();
4103 for i in 0..16 {
4104 let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span);
4105 lit_exprs.insert(i, expr);
4106 }
4107
4108 let mut vec_exprs = FastHashMap::default();
4109 for c in 2..=4 {
4110 let expr = expressions.append(
4111 Expression::Compose {
4112 ty: *vec_tys.get(&c).unwrap(),
4113 components: (0..c)
4114 .map(|i| *lit_exprs.get(&i).unwrap())
4115 .collect::<Vec<_>>(),
4116 },
4117 span,
4118 );
4119 vec_exprs.insert(c, expr);
4120 }
4121
4122 let mut mat_exprs = FastHashMap::default();
4123 for c in 2..=4 {
4124 for r in 2..=4 {
4125 let mut columns = Vec::with_capacity(c);
4126 for cc in 0..c {
4127 let start = cc * r;
4128 let expr = expressions.append(
4129 Expression::Compose {
4130 ty: *vec_tys.get(&r).unwrap(),
4131 components: (start..start + r)
4132 .map(|i| *lit_exprs.get(&i).unwrap())
4133 .collect::<Vec<_>>(),
4134 },
4135 span,
4136 );
4137 columns.push(expr);
4138 }
4139
4140 let expr = expressions.append(
4141 Expression::Compose {
4142 ty: *mat_tys.get(&(c, r)).unwrap(),
4143 components: columns,
4144 },
4145 span,
4146 );
4147 mat_exprs.insert((c, r), expr);
4148 }
4149 }
4150
4151 Self {
4152 types,
4153 expressions,
4154 vec_exprs,
4155 mat_exprs,
4156 }
4157 }
4158
4159 fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4161 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4162 let mut solver = ConstantEvaluator {
4163 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4164 types: &mut self.types,
4165 constants: &Arena::new(),
4166 overrides: &Arena::new(),
4167 expressions: &mut self.expressions,
4168 expression_kind_tracker,
4169 layouter: &mut crate::proc::Layouter::default(),
4170 };
4171
4172 let result = solver
4173 .try_eval_and_append(
4174 Expression::Binary {
4175 op: BinaryOperator::Multiply,
4176 left: *self.vec_exprs.get(&nr).unwrap(),
4177 right: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4178 },
4179 Default::default(),
4180 )
4181 .unwrap();
4182 self.flatten(result)
4183 }
4184
4185 fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4187 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4188 let mut solver = ConstantEvaluator {
4189 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4190 types: &mut self.types,
4191 constants: &Arena::new(),
4192 overrides: &Arena::new(),
4193 expressions: &mut self.expressions,
4194 expression_kind_tracker,
4195 layouter: &mut crate::proc::Layouter::default(),
4196 };
4197
4198 let result = solver
4199 .try_eval_and_append(
4200 Expression::Binary {
4201 op: BinaryOperator::Multiply,
4202 left: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4203 right: *self.vec_exprs.get(&nc).unwrap(),
4204 },
4205 Default::default(),
4206 )
4207 .unwrap();
4208 self.flatten(result)
4209 }
4210
4211 fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec<f32> {
4214 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4215 let mut solver = ConstantEvaluator {
4216 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4217 types: &mut self.types,
4218 constants: &Arena::new(),
4219 overrides: &Arena::new(),
4220 expressions: &mut self.expressions,
4221 expression_kind_tracker,
4222 layouter: &mut crate::proc::Layouter::default(),
4223 };
4224
4225 let result = solver
4226 .try_eval_and_append(
4227 Expression::Binary {
4228 op: BinaryOperator::Multiply,
4229 left: *self.mat_exprs.get(&(k, l_nr)).unwrap(),
4230 right: *self.mat_exprs.get(&(r_nc, k)).unwrap(),
4231 },
4232 Default::default(),
4233 )
4234 .unwrap();
4235 self.flatten(result)
4236 }
4237
4238 fn flatten(&self, expr: Handle<Expression>) -> Vec<f32> {
4239 let Expression::Compose {
4240 ref components,
4241 ref ty,
4242 } = self.expressions[expr]
4243 else {
4244 unreachable!()
4245 };
4246
4247 match self.types[*ty].inner {
4248 TypeInner::Vector { .. } => components
4249 .iter()
4250 .map(|&comp| {
4251 let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else {
4252 unreachable!()
4253 };
4254 v
4255 })
4256 .collect(),
4257 TypeInner::Matrix { .. } => components
4258 .iter()
4259 .flat_map(|&comp| self.flatten(comp))
4260 .collect(),
4261 _ => unreachable!(),
4262 }
4263 }
4264
4265 fn int_to_vector_size(int: usize) -> VectorSize {
4266 match int {
4267 2 => VectorSize::Bi,
4268 3 => VectorSize::Tri,
4269 4 => VectorSize::Quad,
4270 _ => unreachable!(),
4271 }
4272 }
4273 }
4274
4275 #[test]
4276 fn cast() {
4277 let mut types = UniqueArena::new();
4278 let mut constants = Arena::new();
4279 let overrides = Arena::new();
4280 let mut global_expressions = Arena::new();
4281
4282 let scalar_ty = types.insert(
4283 Type {
4284 name: None,
4285 inner: TypeInner::Scalar(crate::Scalar::I32),
4286 },
4287 Default::default(),
4288 );
4289
4290 let h = constants.append(
4291 Constant {
4292 name: None,
4293 ty: scalar_ty,
4294 init: global_expressions
4295 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4296 },
4297 Default::default(),
4298 );
4299
4300 let expr = global_expressions.append(Expression::Constant(h), Default::default());
4301
4302 let root = Expression::As {
4303 expr,
4304 kind: ScalarKind::Bool,
4305 convert: Some(crate::BOOL_WIDTH),
4306 };
4307
4308 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4309 let mut solver = ConstantEvaluator {
4310 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4311 types: &mut types,
4312 constants: &constants,
4313 overrides: &overrides,
4314 expressions: &mut global_expressions,
4315 expression_kind_tracker,
4316 layouter: &mut crate::proc::Layouter::default(),
4317 };
4318
4319 let res = solver
4320 .try_eval_and_append(root, Default::default())
4321 .unwrap();
4322
4323 assert_eq!(
4324 global_expressions[res],
4325 Expression::Literal(Literal::Bool(true))
4326 );
4327 }
4328
4329 #[test]
4330 fn access() {
4331 let mut types = UniqueArena::new();
4332 let mut constants = Arena::new();
4333 let overrides = Arena::new();
4334 let mut global_expressions = Arena::new();
4335
4336 let matrix_ty = types.insert(
4337 Type {
4338 name: None,
4339 inner: TypeInner::Matrix {
4340 columns: VectorSize::Bi,
4341 rows: VectorSize::Tri,
4342 scalar: crate::Scalar::F32,
4343 },
4344 },
4345 Default::default(),
4346 );
4347
4348 let vec_ty = types.insert(
4349 Type {
4350 name: None,
4351 inner: TypeInner::Vector {
4352 size: VectorSize::Tri,
4353 scalar: crate::Scalar::F32,
4354 },
4355 },
4356 Default::default(),
4357 );
4358
4359 let mut vec1_components = Vec::with_capacity(3);
4360 let mut vec2_components = Vec::with_capacity(3);
4361
4362 for i in 0..3 {
4363 let h = global_expressions.append(
4364 Expression::Literal(Literal::F32(i as f32)),
4365 Default::default(),
4366 );
4367
4368 vec1_components.push(h)
4369 }
4370
4371 for i in 3..6 {
4372 let h = global_expressions.append(
4373 Expression::Literal(Literal::F32(i as f32)),
4374 Default::default(),
4375 );
4376
4377 vec2_components.push(h)
4378 }
4379
4380 let vec1 = constants.append(
4381 Constant {
4382 name: None,
4383 ty: vec_ty,
4384 init: global_expressions.append(
4385 Expression::Compose {
4386 ty: vec_ty,
4387 components: vec1_components,
4388 },
4389 Default::default(),
4390 ),
4391 },
4392 Default::default(),
4393 );
4394
4395 let vec2 = constants.append(
4396 Constant {
4397 name: None,
4398 ty: vec_ty,
4399 init: global_expressions.append(
4400 Expression::Compose {
4401 ty: vec_ty,
4402 components: vec2_components,
4403 },
4404 Default::default(),
4405 ),
4406 },
4407 Default::default(),
4408 );
4409
4410 let h = constants.append(
4411 Constant {
4412 name: None,
4413 ty: matrix_ty,
4414 init: global_expressions.append(
4415 Expression::Compose {
4416 ty: matrix_ty,
4417 components: vec![constants[vec1].init, constants[vec2].init],
4418 },
4419 Default::default(),
4420 ),
4421 },
4422 Default::default(),
4423 );
4424
4425 let base = global_expressions.append(Expression::Constant(h), Default::default());
4426
4427 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4428 let mut solver = ConstantEvaluator {
4429 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4430 types: &mut types,
4431 constants: &constants,
4432 overrides: &overrides,
4433 expressions: &mut global_expressions,
4434 expression_kind_tracker,
4435 layouter: &mut crate::proc::Layouter::default(),
4436 };
4437
4438 let root1 = Expression::AccessIndex { base, index: 1 };
4439
4440 let res1 = solver
4441 .try_eval_and_append(root1, Default::default())
4442 .unwrap();
4443
4444 let root2 = Expression::AccessIndex {
4445 base: res1,
4446 index: 2,
4447 };
4448
4449 let res2 = solver
4450 .try_eval_and_append(root2, Default::default())
4451 .unwrap();
4452
4453 match global_expressions[res1] {
4454 Expression::Compose {
4455 ref ty,
4456 ref components,
4457 } => {
4458 assert_eq!(*ty, vec_ty);
4459 let mut components_iter = components.iter().copied();
4460 assert_eq!(
4461 global_expressions[components_iter.next().unwrap()],
4462 Expression::Literal(Literal::F32(3.))
4463 );
4464 assert_eq!(
4465 global_expressions[components_iter.next().unwrap()],
4466 Expression::Literal(Literal::F32(4.))
4467 );
4468 assert_eq!(
4469 global_expressions[components_iter.next().unwrap()],
4470 Expression::Literal(Literal::F32(5.))
4471 );
4472 assert!(components_iter.next().is_none());
4473 }
4474 _ => panic!("Expected vector"),
4475 }
4476
4477 assert_eq!(
4478 global_expressions[res2],
4479 Expression::Literal(Literal::F32(5.))
4480 );
4481 }
4482
4483 #[test]
4484 fn compose_of_constants() {
4485 let mut types = UniqueArena::new();
4486 let mut constants = Arena::new();
4487 let overrides = Arena::new();
4488 let mut global_expressions = Arena::new();
4489
4490 let i32_ty = types.insert(
4491 Type {
4492 name: None,
4493 inner: TypeInner::Scalar(crate::Scalar::I32),
4494 },
4495 Default::default(),
4496 );
4497
4498 let vec2_i32_ty = types.insert(
4499 Type {
4500 name: None,
4501 inner: TypeInner::Vector {
4502 size: VectorSize::Bi,
4503 scalar: crate::Scalar::I32,
4504 },
4505 },
4506 Default::default(),
4507 );
4508
4509 let h = constants.append(
4510 Constant {
4511 name: None,
4512 ty: i32_ty,
4513 init: global_expressions
4514 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4515 },
4516 Default::default(),
4517 );
4518
4519 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4520
4521 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4522 let mut solver = ConstantEvaluator {
4523 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4524 types: &mut types,
4525 constants: &constants,
4526 overrides: &overrides,
4527 expressions: &mut global_expressions,
4528 expression_kind_tracker,
4529 layouter: &mut crate::proc::Layouter::default(),
4530 };
4531
4532 let solved_compose = solver
4533 .try_eval_and_append(
4534 Expression::Compose {
4535 ty: vec2_i32_ty,
4536 components: vec![h_expr, h_expr],
4537 },
4538 Default::default(),
4539 )
4540 .unwrap();
4541 let solved_negate = solver
4542 .try_eval_and_append(
4543 Expression::Unary {
4544 op: UnaryOperator::Negate,
4545 expr: solved_compose,
4546 },
4547 Default::default(),
4548 )
4549 .unwrap();
4550
4551 let pass = match global_expressions[solved_negate] {
4552 Expression::Compose { ty, ref components } => {
4553 ty == vec2_i32_ty
4554 && components.iter().all(|&component| {
4555 let component = &global_expressions[component];
4556 matches!(*component, Expression::Literal(Literal::I32(-4)))
4557 })
4558 }
4559 _ => false,
4560 };
4561 if !pass {
4562 panic!("unexpected evaluation result")
4563 }
4564 }
4565
4566 #[test]
4567 fn splat_of_constant() {
4568 let mut types = UniqueArena::new();
4569 let mut constants = Arena::new();
4570 let overrides = Arena::new();
4571 let mut global_expressions = Arena::new();
4572
4573 let i32_ty = types.insert(
4574 Type {
4575 name: None,
4576 inner: TypeInner::Scalar(crate::Scalar::I32),
4577 },
4578 Default::default(),
4579 );
4580
4581 let vec2_i32_ty = types.insert(
4582 Type {
4583 name: None,
4584 inner: TypeInner::Vector {
4585 size: VectorSize::Bi,
4586 scalar: crate::Scalar::I32,
4587 },
4588 },
4589 Default::default(),
4590 );
4591
4592 let h = constants.append(
4593 Constant {
4594 name: None,
4595 ty: i32_ty,
4596 init: global_expressions
4597 .append(Expression::Literal(Literal::I32(4)), Default::default()),
4598 },
4599 Default::default(),
4600 );
4601
4602 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4603
4604 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4605 let mut solver = ConstantEvaluator {
4606 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4607 types: &mut types,
4608 constants: &constants,
4609 overrides: &overrides,
4610 expressions: &mut global_expressions,
4611 expression_kind_tracker,
4612 layouter: &mut crate::proc::Layouter::default(),
4613 };
4614
4615 let solved_compose = solver
4616 .try_eval_and_append(
4617 Expression::Splat {
4618 size: VectorSize::Bi,
4619 value: h_expr,
4620 },
4621 Default::default(),
4622 )
4623 .unwrap();
4624 let solved_negate = solver
4625 .try_eval_and_append(
4626 Expression::Unary {
4627 op: UnaryOperator::Negate,
4628 expr: solved_compose,
4629 },
4630 Default::default(),
4631 )
4632 .unwrap();
4633
4634 let pass = match global_expressions[solved_negate] {
4635 Expression::Compose { ty, ref components } => {
4636 ty == vec2_i32_ty
4637 && components.iter().all(|&component| {
4638 let component = &global_expressions[component];
4639 matches!(*component, Expression::Literal(Literal::I32(-4)))
4640 })
4641 }
4642 _ => false,
4643 };
4644 if !pass {
4645 panic!("unexpected evaluation result")
4646 }
4647 }
4648
4649 #[test]
4650 fn splat_of_zero_value() {
4651 let mut types = UniqueArena::new();
4652 let constants = Arena::new();
4653 let overrides = Arena::new();
4654 let mut global_expressions = Arena::new();
4655
4656 let f32_ty = types.insert(
4657 Type {
4658 name: None,
4659 inner: TypeInner::Scalar(crate::Scalar::F32),
4660 },
4661 Default::default(),
4662 );
4663
4664 let vec2_f32_ty = types.insert(
4665 Type {
4666 name: None,
4667 inner: TypeInner::Vector {
4668 size: VectorSize::Bi,
4669 scalar: crate::Scalar::F32,
4670 },
4671 },
4672 Default::default(),
4673 );
4674
4675 let five =
4676 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4677 let five_splat = global_expressions.append(
4678 Expression::Splat {
4679 size: VectorSize::Bi,
4680 value: five,
4681 },
4682 Default::default(),
4683 );
4684 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4685 let zero_splat = global_expressions.append(
4686 Expression::Splat {
4687 size: VectorSize::Bi,
4688 value: zero,
4689 },
4690 Default::default(),
4691 );
4692
4693 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4694 let mut solver = ConstantEvaluator {
4695 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4696 types: &mut types,
4697 constants: &constants,
4698 overrides: &overrides,
4699 expressions: &mut global_expressions,
4700 expression_kind_tracker,
4701 layouter: &mut crate::proc::Layouter::default(),
4702 };
4703
4704 let solved_add = solver
4705 .try_eval_and_append(
4706 Expression::Binary {
4707 op: BinaryOperator::Add,
4708 left: zero_splat,
4709 right: five_splat,
4710 },
4711 Default::default(),
4712 )
4713 .unwrap();
4714
4715 let pass = match global_expressions[solved_add] {
4716 Expression::Compose { ty, ref components } => {
4717 ty == vec2_f32_ty
4718 && components.iter().all(|&component| {
4719 let component = &global_expressions[component];
4720 matches!(*component, Expression::Literal(Literal::F32(5.0)))
4721 })
4722 }
4723 _ => false,
4724 };
4725 if !pass {
4726 panic!("unexpected evaluation result")
4727 }
4728 }
4729}