1use alloc::{
2 format,
3 string::{String, ToString},
4 vec,
5 vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14 arena::{Arena, Handle, HandleVec, UniqueArena},
15 ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16 ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22macro_rules! with_dollar_sign {
28 ($($body:tt)*) => {
29 macro_rules! __with_dollar_sign { $($body)* }
30 __with_dollar_sign!($);
31 }
32}
33
34macro_rules! gen_component_wise_extractor {
35 (
36 $ident:ident -> $target:ident,
37 literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38 scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39 ) => {
40 #[derive(Debug)]
42 #[cfg_attr(test, derive(PartialEq))]
43 enum $target<const N: usize> {
44 $(
45 #[doc = concat!(
46 "Maps to [`Literal::",
47 stringify!($literal),
48 "`]",
49 )]
50 $mapping([$ty; N]),
51 )+
52 }
53
54 impl From<$target<1>> for Expression {
55 fn from(value: $target<1>) -> Self {
56 match value {
57 $(
58 $target::$mapping([value]) => {
59 Expression::Literal(Literal::$literal(value))
60 }
61 )+
62 }
63 }
64 }
65
66 #[doc = concat!(
67 "Attempts to evaluate multiple `exprs` as a combined [`",
68 stringify!($target),
69 "`] to pass to `handler`. ",
70 )]
71 fn $ident<const N: usize, const M: usize, F>(
78 eval: &mut ConstantEvaluator<'_>,
79 span: Span,
80 exprs: [Handle<Expression>; N],
81 mut handler: F,
82 ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83 where
84 $target<M>: Into<Expression>,
85 F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86 {
87 assert!(N > 0);
88 let err = ConstantEvaluatorError::InvalidMathArg;
89 let mut exprs = exprs.into_iter();
90
91 macro_rules! sanitize {
92 ($expr:expr) => {
93 eval.eval_zero_value_and_splat($expr, span)
94 .map(|expr| &eval.expressions[expr])
95 };
96 }
97
98 let new_expr = match sanitize!(exprs.next().unwrap())? {
99 $(
100 &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101 .chain(exprs.map(|expr| {
102 sanitize!(expr).and_then(|expr| match expr {
103 &Expression::Literal(Literal::$literal(x)) => Ok(x),
104 _ => Err(err.clone()),
105 })
106 }))
107 .collect::<Result<ArrayVec<_, N>, _>>()
108 .map(|a| a.into_inner().unwrap())
109 .map($target::$mapping)
110 .and_then(|comps| Ok(handler(comps)?.into())),
111 )+
112 &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113 &TypeInner::Vector { size, scalar } => match scalar.kind {
114 $(ScalarKind::$scalar_kind)|* => {
115 let first_ty = ty;
116 let mut component_groups =
117 ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118 component_groups.push(crate::proc::flatten_compose(
119 first_ty,
120 components,
121 eval.expressions,
122 eval.types,
123 ).collect());
124 component_groups.extend(
125 exprs
126 .map(|expr| {
127 sanitize!(expr).and_then(|expr| match expr {
128 &Expression::Compose { ty, ref components }
129 if &eval.types[ty].inner
130 == &eval.types[first_ty].inner =>
131 {
132 Ok(crate::proc::flatten_compose(
133 ty,
134 components,
135 eval.expressions,
136 eval.types,
137 ).collect())
138 }
139 _ => Err(err.clone()),
140 })
141 })
142 .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143 )?,
144 );
145 let component_groups = component_groups.into_inner().unwrap();
146 let mut new_components =
147 ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148 for idx in 0..(size as u8).into() {
149 let group = component_groups
150 .iter()
151 .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152 .collect::<Result<ArrayVec<_, N>, _>>()?
153 .into_inner()
154 .unwrap();
155 new_components.push($ident(
156 eval,
157 span,
158 group,
159 handler.clone(),
160 )?);
161 }
162 Ok(Expression::Compose {
163 ty: first_ty,
164 components: new_components.into_iter().collect(),
165 })
166 }
167 _ => return Err(err),
168 },
169 _ => return Err(err),
170 },
171 _ => return Err(err),
172 }?;
173 eval.register_evaluated_expr(new_expr, span)
174 }
175
176 with_dollar_sign! {
177 ($d:tt) => {
178 #[allow(unused)]
179 #[doc = concat!(
180 "A convenience macro for using the same RHS for each [`",
181 stringify!($target),
182 "`] variant in a call to [`",
183 stringify!($ident),
184 "`].",
185 )]
186 macro_rules! $ident {
187 (
188 $eval:expr,
189 $span:expr,
190 [$d ($d expr:expr),+ $d (,)?],
191 |$d ($d arg:ident),+| $d tt:tt
192 ) => {
193 $ident($eval, $span, [$d ($d expr),+], |args| match args {
194 $(
195 $target::$mapping([$d ($d arg),+]) => {
196 let res = $d tt;
197 Result::map(res, $target::$mapping)
198 },
199 )+
200 })
201 };
202 }
203 };
204 }
205 };
206}
207
208gen_component_wise_extractor! {
209 component_wise_scalar -> Scalar,
210 literals: [
211 AbstractFloat => AbstractFloat: f64,
212 F32 => F32: f32,
213 F16 => F16: f16,
214 AbstractInt => AbstractInt: i64,
215 U32 => U32: u32,
216 I32 => I32: i32,
217 U64 => U64: u64,
218 I64 => I64: i64,
219 ],
220 scalar_kinds: [
221 Float,
222 AbstractFloat,
223 Sint,
224 Uint,
225 AbstractInt,
226 ],
227}
228
229gen_component_wise_extractor! {
230 component_wise_float -> Float,
231 literals: [
232 AbstractFloat => Abstract: f64,
233 F32 => F32: f32,
234 F16 => F16: f16,
235 ],
236 scalar_kinds: [
237 Float,
238 AbstractFloat,
239 ],
240}
241
242gen_component_wise_extractor! {
243 component_wise_concrete_int -> ConcreteInt,
244 literals: [
245 U32 => U32: u32,
246 I32 => I32: i32,
247 ],
248 scalar_kinds: [
249 Sint,
250 Uint,
251 ],
252}
253
254gen_component_wise_extractor! {
255 component_wise_signed -> Signed,
256 literals: [
257 AbstractFloat => AbstractFloat: f64,
258 AbstractInt => AbstractInt: i64,
259 F32 => F32: f32,
260 F16 => F16: f16,
261 I32 => I32: i32,
262 ],
263 scalar_kinds: [
264 Sint,
265 AbstractInt,
266 Float,
267 AbstractFloat,
268 ],
269}
270
271#[derive(Debug)]
272enum Behavior<'a> {
273 Wgsl(WgslRestrictions<'a>),
274 Glsl(GlslRestrictions<'a>),
275}
276
277impl Behavior<'_> {
278 const fn has_runtime_restrictions(&self) -> bool {
280 matches!(
281 self,
282 &Behavior::Wgsl(WgslRestrictions::Runtime(_))
283 | &Behavior::Glsl(GlslRestrictions::Runtime(_))
284 )
285 }
286}
287
288#[derive(Debug)]
306pub struct ConstantEvaluator<'a> {
307 behavior: Behavior<'a>,
309
310 types: &'a mut UniqueArena<Type>,
317
318 constants: &'a Arena<Constant>,
320
321 overrides: &'a Arena<Override>,
323
324 expressions: &'a mut Arena<Expression>,
326
327 expression_kind_tracker: &'a mut ExpressionKindTracker,
329
330 layouter: &'a mut crate::proc::Layouter,
331}
332
333#[derive(Debug)]
334enum WgslRestrictions<'a> {
335 Const(Option<FunctionLocalData<'a>>),
337 Override,
340 Runtime(FunctionLocalData<'a>),
344}
345
346#[derive(Debug)]
347enum GlslRestrictions<'a> {
348 Const,
350 Runtime(FunctionLocalData<'a>),
354}
355
356#[derive(Debug)]
357struct FunctionLocalData<'a> {
358 global_expressions: &'a Arena<Expression>,
360 emitter: &'a mut super::Emitter,
361 block: &'a mut crate::Block,
362}
363
364#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
365pub enum ExpressionKind {
366 Const,
367 Override,
368 Runtime,
369}
370
371#[derive(Debug)]
372pub struct ExpressionKindTracker {
373 inner: HandleVec<Expression, ExpressionKind>,
374}
375
376impl ExpressionKindTracker {
377 pub const fn new() -> Self {
378 Self {
379 inner: HandleVec::new(),
380 }
381 }
382
383 pub fn force_non_const(&mut self, value: Handle<Expression>) {
385 self.inner[value] = ExpressionKind::Runtime;
386 }
387
388 pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
389 self.inner.insert(value, expr_type);
390 }
391
392 pub fn is_const(&self, h: Handle<Expression>) -> bool {
393 matches!(self.type_of(h), ExpressionKind::Const)
394 }
395
396 pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
397 matches!(
398 self.type_of(h),
399 ExpressionKind::Const | ExpressionKind::Override
400 )
401 }
402
403 fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
404 self.inner[value]
405 }
406
407 pub fn from_arena(arena: &Arena<Expression>) -> Self {
408 let mut tracker = Self {
409 inner: HandleVec::with_capacity(arena.len()),
410 };
411 for (handle, expr) in arena.iter() {
412 tracker
413 .inner
414 .insert(handle, tracker.type_of_with_expr(expr));
415 }
416 tracker
417 }
418
419 fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
420 match *expr {
421 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
422 ExpressionKind::Const
423 }
424 Expression::Override(_) => ExpressionKind::Override,
425 Expression::Compose { ref components, .. } => {
426 let mut expr_type = ExpressionKind::Const;
427 for component in components {
428 expr_type = expr_type.max(self.type_of(*component))
429 }
430 expr_type
431 }
432 Expression::Splat { value, .. } => self.type_of(value),
433 Expression::AccessIndex { base, .. } => self.type_of(base),
434 Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
435 Expression::Swizzle { vector, .. } => self.type_of(vector),
436 Expression::Unary { expr, .. } => self.type_of(expr),
437 Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
438 Expression::Math {
439 arg,
440 arg1,
441 arg2,
442 arg3,
443 ..
444 } => self
445 .type_of(arg)
446 .max(
447 arg1.map(|arg| self.type_of(arg))
448 .unwrap_or(ExpressionKind::Const),
449 )
450 .max(
451 arg2.map(|arg| self.type_of(arg))
452 .unwrap_or(ExpressionKind::Const),
453 )
454 .max(
455 arg3.map(|arg| self.type_of(arg))
456 .unwrap_or(ExpressionKind::Const),
457 ),
458 Expression::As { expr, .. } => self.type_of(expr),
459 Expression::Select {
460 condition,
461 accept,
462 reject,
463 } => self
464 .type_of(condition)
465 .max(self.type_of(accept))
466 .max(self.type_of(reject)),
467 Expression::Relational { argument, .. } => self.type_of(argument),
468 Expression::ArrayLength(expr) => self.type_of(expr),
469 _ => ExpressionKind::Runtime,
470 }
471 }
472}
473
474#[derive(Clone, Debug, thiserror::Error)]
475#[cfg_attr(test, derive(PartialEq))]
476pub enum ConstantEvaluatorError {
477 #[error("Constants cannot access function arguments")]
478 FunctionArg,
479 #[error("Constants cannot access global variables")]
480 GlobalVariable,
481 #[error("Constants cannot access local variables")]
482 LocalVariable,
483 #[error("Cannot get the array length of a non array type")]
484 InvalidArrayLengthArg,
485 #[error("Constants cannot get the array length of a dynamically sized array")]
486 ArrayLengthDynamic,
487 #[error("Cannot call arrayLength on array sized by override-expression")]
488 ArrayLengthOverridden,
489 #[error("Constants cannot call functions")]
490 Call,
491 #[error("Constants don't support workGroupUniformLoad")]
492 WorkGroupUniformLoadResult,
493 #[error("Constants don't support atomic functions")]
494 Atomic,
495 #[error("Constants don't support derivative functions")]
496 Derivative,
497 #[error("Constants don't support load expressions")]
498 Load,
499 #[error("Constants don't support image expressions")]
500 ImageExpression,
501 #[error("Constants don't support ray query expressions")]
502 RayQueryExpression,
503 #[error("Constants don't support subgroup expressions")]
504 SubgroupExpression,
505 #[error("Cannot access the type")]
506 InvalidAccessBase,
507 #[error("Cannot access at the index")]
508 InvalidAccessIndex,
509 #[error("Cannot access with index of type")]
510 InvalidAccessIndexTy,
511 #[error("Constants don't support array length expressions")]
512 ArrayLength,
513 #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
514 InvalidCastArg { from: String, to: String },
515 #[error("Cannot apply the unary op to the argument")]
516 InvalidUnaryOpArg,
517 #[error("Cannot apply the binary op to the arguments")]
518 InvalidBinaryOpArgs,
519 #[error("Cannot apply math function to type")]
520 InvalidMathArg,
521 #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
522 InvalidMathArgCount(crate::MathFunction, usize, usize),
523 #[error("Cannot apply relational function to type")]
524 InvalidRelationalArg(RelationalFunction),
525 #[error("value of `low` is greater than `high` for clamp built-in function")]
526 InvalidClamp,
527 #[error("Constructor expects {expected} components, found {actual}")]
528 InvalidVectorComposeLength { expected: usize, actual: usize },
529 #[error("Constructor must only contain vector or scalar arguments")]
530 InvalidVectorComposeComponent,
531 #[error("Splat is defined only on scalar values")]
532 SplatScalarOnly,
533 #[error("Can only swizzle vector constants")]
534 SwizzleVectorOnly,
535 #[error("swizzle component not present in source expression")]
536 SwizzleOutOfBounds,
537 #[error("Type is not constructible")]
538 TypeNotConstructible,
539 #[error("Subexpression(s) are not constant")]
540 SubexpressionsAreNotConstant,
541 #[error("Not implemented as constant expression: {0}")]
542 NotImplemented(String),
543 #[error("{0} operation overflowed")]
544 Overflow(String),
545 #[error(
546 "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
547 )]
548 AutomaticConversionLossy {
549 value: String,
550 to_type: &'static str,
551 },
552 #[error("Division by zero")]
553 DivisionByZero,
554 #[error("Remainder by zero")]
555 RemainderByZero,
556 #[error("RHS of shift operation is greater than or equal to 32")]
557 ShiftedMoreThan32Bits,
558 #[error(transparent)]
559 Literal(#[from] crate::valid::LiteralError),
560 #[error("Can't use pipeline-overridable constants in const-expressions")]
561 Override,
562 #[error("Unexpected runtime-expression")]
563 RuntimeExpr,
564 #[error("Unexpected override-expression")]
565 OverrideExpr,
566}
567
568impl<'a> ConstantEvaluator<'a> {
569 pub fn for_wgsl_module(
574 module: &'a mut crate::Module,
575 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
576 layouter: &'a mut crate::proc::Layouter,
577 in_override_ctx: bool,
578 ) -> Self {
579 Self::for_module(
580 Behavior::Wgsl(if in_override_ctx {
581 WgslRestrictions::Override
582 } else {
583 WgslRestrictions::Const(None)
584 }),
585 module,
586 global_expression_kind_tracker,
587 layouter,
588 )
589 }
590
591 pub fn for_glsl_module(
596 module: &'a mut crate::Module,
597 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
598 layouter: &'a mut crate::proc::Layouter,
599 ) -> Self {
600 Self::for_module(
601 Behavior::Glsl(GlslRestrictions::Const),
602 module,
603 global_expression_kind_tracker,
604 layouter,
605 )
606 }
607
608 fn for_module(
609 behavior: Behavior<'a>,
610 module: &'a mut crate::Module,
611 global_expression_kind_tracker: &'a mut ExpressionKindTracker,
612 layouter: &'a mut crate::proc::Layouter,
613 ) -> Self {
614 Self {
615 behavior,
616 types: &mut module.types,
617 constants: &module.constants,
618 overrides: &module.overrides,
619 expressions: &mut module.global_expressions,
620 expression_kind_tracker: global_expression_kind_tracker,
621 layouter,
622 }
623 }
624
625 pub fn for_wgsl_function(
630 module: &'a mut crate::Module,
631 expressions: &'a mut Arena<Expression>,
632 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
633 layouter: &'a mut crate::proc::Layouter,
634 emitter: &'a mut super::Emitter,
635 block: &'a mut crate::Block,
636 is_const: bool,
637 ) -> Self {
638 let local_data = FunctionLocalData {
639 global_expressions: &module.global_expressions,
640 emitter,
641 block,
642 };
643 Self {
644 behavior: Behavior::Wgsl(if is_const {
645 WgslRestrictions::Const(Some(local_data))
646 } else {
647 WgslRestrictions::Runtime(local_data)
648 }),
649 types: &mut module.types,
650 constants: &module.constants,
651 overrides: &module.overrides,
652 expressions,
653 expression_kind_tracker: local_expression_kind_tracker,
654 layouter,
655 }
656 }
657
658 pub fn for_glsl_function(
663 module: &'a mut crate::Module,
664 expressions: &'a mut Arena<Expression>,
665 local_expression_kind_tracker: &'a mut ExpressionKindTracker,
666 layouter: &'a mut crate::proc::Layouter,
667 emitter: &'a mut super::Emitter,
668 block: &'a mut crate::Block,
669 ) -> Self {
670 Self {
671 behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
672 global_expressions: &module.global_expressions,
673 emitter,
674 block,
675 })),
676 types: &mut module.types,
677 constants: &module.constants,
678 overrides: &module.overrides,
679 expressions,
680 expression_kind_tracker: local_expression_kind_tracker,
681 layouter,
682 }
683 }
684
685 pub fn to_ctx(&self) -> crate::proc::GlobalCtx {
686 crate::proc::GlobalCtx {
687 types: self.types,
688 constants: self.constants,
689 overrides: self.overrides,
690 global_expressions: match self.function_local_data() {
691 Some(data) => data.global_expressions,
692 None => self.expressions,
693 },
694 }
695 }
696
697 fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
698 if !self.expression_kind_tracker.is_const(expr) {
699 log::debug!("check: SubexpressionsAreNotConstant");
700 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
701 }
702 Ok(())
703 }
704
705 fn check_and_get(
706 &mut self,
707 expr: Handle<Expression>,
708 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
709 match self.expressions[expr] {
710 Expression::Constant(c) => {
711 if let Some(function_local_data) = self.function_local_data() {
714 self.copy_from(
716 self.constants[c].init,
717 function_local_data.global_expressions,
718 )
719 } else {
720 Ok(self.constants[c].init)
722 }
723 }
724 _ => {
725 self.check(expr)?;
726 Ok(expr)
727 }
728 }
729 }
730
731 pub fn try_eval_and_append(
755 &mut self,
756 expr: Expression,
757 span: Span,
758 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
759 match self.expression_kind_tracker.type_of_with_expr(&expr) {
760 ExpressionKind::Const => {
761 let eval_result = self.try_eval_and_append_impl(&expr, span);
762 if self.behavior.has_runtime_restrictions()
767 && matches!(
768 eval_result,
769 Err(ConstantEvaluatorError::NotImplemented(_)
770 | ConstantEvaluatorError::InvalidBinaryOpArgs,)
771 )
772 {
773 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
774 } else {
775 eval_result
776 }
777 }
778 ExpressionKind::Override => match self.behavior {
779 Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
780 Ok(self.append_expr(expr, span, ExpressionKind::Override))
781 }
782 Behavior::Wgsl(WgslRestrictions::Const(_)) => {
783 Err(ConstantEvaluatorError::OverrideExpr)
784 }
785 Behavior::Glsl(_) => {
786 unreachable!()
787 }
788 },
789 ExpressionKind::Runtime => {
790 if self.behavior.has_runtime_restrictions() {
791 Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
792 } else {
793 Err(ConstantEvaluatorError::RuntimeExpr)
794 }
795 }
796 }
797 }
798
799 const fn is_global_arena(&self) -> bool {
801 matches!(
802 self.behavior,
803 Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
804 | Behavior::Glsl(GlslRestrictions::Const)
805 )
806 }
807
808 const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
809 match self.behavior {
810 Behavior::Wgsl(
811 WgslRestrictions::Runtime(ref function_local_data)
812 | WgslRestrictions::Const(Some(ref function_local_data)),
813 )
814 | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
815 Some(function_local_data)
816 }
817 _ => None,
818 }
819 }
820
821 fn try_eval_and_append_impl(
822 &mut self,
823 expr: &Expression,
824 span: Span,
825 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
826 log::trace!("try_eval_and_append: {:?}", expr);
827 match *expr {
828 Expression::Constant(c) if self.is_global_arena() => {
829 Ok(self.constants[c].init)
832 }
833 Expression::Override(_) => Err(ConstantEvaluatorError::Override),
834 Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
835 self.register_evaluated_expr(expr.clone(), span)
836 }
837 Expression::Compose { ty, ref components } => {
838 let components = components
839 .iter()
840 .map(|component| self.check_and_get(*component))
841 .collect::<Result<Vec<_>, _>>()?;
842 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
843 }
844 Expression::Splat { size, value } => {
845 let value = self.check_and_get(value)?;
846 self.register_evaluated_expr(Expression::Splat { size, value }, span)
847 }
848 Expression::AccessIndex { base, index } => {
849 let base = self.check_and_get(base)?;
850
851 self.access(base, index as usize, span)
852 }
853 Expression::Access { base, index } => {
854 let base = self.check_and_get(base)?;
855 let index = self.check_and_get(index)?;
856
857 self.access(base, self.constant_index(index)?, span)
858 }
859 Expression::Swizzle {
860 size,
861 vector,
862 pattern,
863 } => {
864 let vector = self.check_and_get(vector)?;
865
866 self.swizzle(size, span, vector, pattern)
867 }
868 Expression::Unary { expr, op } => {
869 let expr = self.check_and_get(expr)?;
870
871 self.unary_op(op, expr, span)
872 }
873 Expression::Binary { left, right, op } => {
874 let left = self.check_and_get(left)?;
875 let right = self.check_and_get(right)?;
876
877 self.binary_op(op, left, right, span)
878 }
879 Expression::Math {
880 fun,
881 arg,
882 arg1,
883 arg2,
884 arg3,
885 } => {
886 let arg = self.check_and_get(arg)?;
887 let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
888 let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
889 let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
890
891 self.math(arg, arg1, arg2, arg3, fun, span)
892 }
893 Expression::As {
894 convert,
895 expr,
896 kind,
897 } => {
898 let expr = self.check_and_get(expr)?;
899
900 match convert {
901 Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
902 None => Err(ConstantEvaluatorError::NotImplemented(
903 "bitcast built-in function".into(),
904 )),
905 }
906 }
907 Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
908 "select built-in function".into(),
909 )),
910 Expression::Relational { fun, argument } => {
911 let argument = self.check_and_get(argument)?;
912 self.relational(fun, argument, span)
913 }
914 Expression::ArrayLength(expr) => match self.behavior {
915 Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
916 Behavior::Glsl(_) => {
917 let expr = self.check_and_get(expr)?;
918 self.array_length(expr, span)
919 }
920 },
921 Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
922 Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
923 Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
924 Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
925 Expression::WorkGroupUniformLoadResult { .. } => {
926 Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
927 }
928 Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
929 Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
930 Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
931 Expression::ImageSample { .. }
932 | Expression::ImageLoad { .. }
933 | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
934 Expression::RayQueryProceedResult
935 | Expression::RayQueryGetIntersection { .. }
936 | Expression::RayQueryVertexPositions { .. } => {
937 Err(ConstantEvaluatorError::RayQueryExpression)
938 }
939 Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
940 Expression::SubgroupOperationResult { .. } => {
941 Err(ConstantEvaluatorError::SubgroupExpression)
942 }
943 }
944 }
945
946 fn splat(
959 &mut self,
960 value: Handle<Expression>,
961 size: crate::VectorSize,
962 span: Span,
963 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
964 match self.expressions[value] {
965 Expression::Literal(literal) => {
966 let scalar = literal.scalar();
967 let ty = self.types.insert(
968 Type {
969 name: None,
970 inner: TypeInner::Vector { size, scalar },
971 },
972 span,
973 );
974 let expr = Expression::Compose {
975 ty,
976 components: vec![value; size as usize],
977 };
978 self.register_evaluated_expr(expr, span)
979 }
980 Expression::ZeroValue(ty) => {
981 let inner = match self.types[ty].inner {
982 TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
983 _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
984 };
985 let res_ty = self.types.insert(Type { name: None, inner }, span);
986 let expr = Expression::ZeroValue(res_ty);
987 self.register_evaluated_expr(expr, span)
988 }
989 _ => Err(ConstantEvaluatorError::SplatScalarOnly),
990 }
991 }
992
993 fn swizzle(
994 &mut self,
995 size: crate::VectorSize,
996 span: Span,
997 src_constant: Handle<Expression>,
998 pattern: [crate::SwizzleComponent; 4],
999 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1000 let mut get_dst_ty = |ty| match self.types[ty].inner {
1001 TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1002 Type {
1003 name: None,
1004 inner: TypeInner::Vector { size, scalar },
1005 },
1006 span,
1007 )),
1008 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1009 };
1010
1011 match self.expressions[src_constant] {
1012 Expression::ZeroValue(ty) => {
1013 let dst_ty = get_dst_ty(ty)?;
1014 let expr = Expression::ZeroValue(dst_ty);
1015 self.register_evaluated_expr(expr, span)
1016 }
1017 Expression::Splat { value, .. } => {
1018 let expr = Expression::Splat { size, value };
1019 self.register_evaluated_expr(expr, span)
1020 }
1021 Expression::Compose { ty, ref components } => {
1022 let dst_ty = get_dst_ty(ty)?;
1023
1024 let mut flattened = [src_constant; 4]; let len =
1026 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1027 .zip(flattened.iter_mut())
1028 .map(|(component, elt)| *elt = component)
1029 .count();
1030 let flattened = &flattened[..len];
1031
1032 let swizzled_components = pattern[..size as usize]
1033 .iter()
1034 .map(|&sc| {
1035 let sc = sc as usize;
1036 if let Some(elt) = flattened.get(sc) {
1037 Ok(*elt)
1038 } else {
1039 Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1040 }
1041 })
1042 .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1043 let expr = Expression::Compose {
1044 ty: dst_ty,
1045 components: swizzled_components,
1046 };
1047 self.register_evaluated_expr(expr, span)
1048 }
1049 _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1050 }
1051 }
1052
1053 fn math(
1054 &mut self,
1055 arg: Handle<Expression>,
1056 arg1: Option<Handle<Expression>>,
1057 arg2: Option<Handle<Expression>>,
1058 arg3: Option<Handle<Expression>>,
1059 fun: crate::MathFunction,
1060 span: Span,
1061 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1062 let expected = fun.argument_count();
1063 let given = Some(arg)
1064 .into_iter()
1065 .chain(arg1)
1066 .chain(arg2)
1067 .chain(arg3)
1068 .count();
1069 if expected != given {
1070 return Err(ConstantEvaluatorError::InvalidMathArgCount(
1071 fun, expected, given,
1072 ));
1073 }
1074
1075 match fun {
1077 crate::MathFunction::Abs => {
1079 component_wise_scalar(self, span, [arg], |args| match args {
1080 Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1081 Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1082 Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1083 Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
1084 Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1085 Scalar::U32([e]) => Ok(Scalar::U32([e])), Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1087 Scalar::U64([e]) => Ok(Scalar::U64([e])),
1088 })
1089 }
1090 crate::MathFunction::Min => {
1091 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1092 Ok([e1.min(e2)])
1093 })
1094 }
1095 crate::MathFunction::Max => {
1096 component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1097 Ok([e1.max(e2)])
1098 })
1099 }
1100 crate::MathFunction::Clamp => {
1101 component_wise_scalar!(
1102 self,
1103 span,
1104 [arg, arg1.unwrap(), arg2.unwrap()],
1105 |e, low, high| {
1106 if low > high {
1107 Err(ConstantEvaluatorError::InvalidClamp)
1108 } else {
1109 Ok([e.clamp(low, high)])
1110 }
1111 }
1112 )
1113 }
1114 crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1115 Float::F16([e]) => Ok(Float::F16(
1116 [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1117 )),
1118 Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1119 Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1120 }),
1121
1122 crate::MathFunction::Cos => {
1124 component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1125 }
1126 crate::MathFunction::Cosh => {
1127 component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1128 }
1129 crate::MathFunction::Sin => {
1130 component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1131 }
1132 crate::MathFunction::Sinh => {
1133 component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1134 }
1135 crate::MathFunction::Tan => {
1136 component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1137 }
1138 crate::MathFunction::Tanh => {
1139 component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1140 }
1141 crate::MathFunction::Acos => {
1142 component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1143 }
1144 crate::MathFunction::Asin => {
1145 component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1146 }
1147 crate::MathFunction::Atan => {
1148 component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1149 }
1150 crate::MathFunction::Asinh => {
1151 component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1152 }
1153 crate::MathFunction::Acosh => {
1154 component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1155 }
1156 crate::MathFunction::Atanh => {
1157 component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1158 }
1159 crate::MathFunction::Radians => {
1160 component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1161 }
1162 crate::MathFunction::Degrees => {
1163 component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1164 }
1165
1166 crate::MathFunction::Ceil => {
1168 component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1169 }
1170 crate::MathFunction::Floor => {
1171 component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1172 }
1173 crate::MathFunction::Round => {
1174 component_wise_float(self, span, [arg], |e| match e {
1175 Float::Abstract([e]) => Ok(Float::Abstract([e.round_ties_even()])),
1176 Float::F32([e]) => Ok(Float::F32([e.round_ties_even()])),
1177 Float::F16([e]) => {
1178 fn round_ties_even(x: f64) -> f64 {
1186 let i = x as i64;
1187 let f = (x - i as f64).abs();
1188 if f == 0.5 {
1189 if i & 1 == 1 {
1190 (x.abs() + 0.5).copysign(x)
1192 } else {
1193 (x.abs() - 0.5).copysign(x)
1194 }
1195 } else {
1196 x.round()
1197 }
1198 }
1199
1200 Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1201 }
1202 })
1203 }
1204 crate::MathFunction::Fract => {
1205 component_wise_float!(self, span, [arg], |e| {
1206 Ok([e - e.floor()])
1209 })
1210 }
1211 crate::MathFunction::Trunc => {
1212 component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1213 }
1214
1215 crate::MathFunction::Exp => {
1217 component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1218 }
1219 crate::MathFunction::Exp2 => {
1220 component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1221 }
1222 crate::MathFunction::Log => {
1223 component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1224 }
1225 crate::MathFunction::Log2 => {
1226 component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1227 }
1228 crate::MathFunction::Pow => {
1229 component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1230 Ok([e1.powf(e2)])
1231 })
1232 }
1233
1234 crate::MathFunction::Sign => {
1236 component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1237 }
1238 crate::MathFunction::Fma => {
1239 component_wise_float!(
1240 self,
1241 span,
1242 [arg, arg1.unwrap(), arg2.unwrap()],
1243 |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1244 )
1245 }
1246 crate::MathFunction::Step => {
1247 component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1248 Float::Abstract([edge, x]) => {
1249 Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1250 }
1251 Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1252 Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1253 f16::one()
1254 } else {
1255 f16::zero()
1256 }])),
1257 })
1258 }
1259 crate::MathFunction::Sqrt => {
1260 component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1261 }
1262 crate::MathFunction::InverseSqrt => {
1263 component_wise_float(self, span, [arg], |e| match e {
1264 Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1265 Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1266 Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1267 })
1268 }
1269
1270 crate::MathFunction::CountTrailingZeros => {
1272 component_wise_concrete_int!(self, span, [arg], |e| {
1273 #[allow(clippy::useless_conversion)]
1274 Ok([e
1275 .trailing_zeros()
1276 .try_into()
1277 .expect("bit count overflowed 32 bits, somehow!?")])
1278 })
1279 }
1280 crate::MathFunction::CountLeadingZeros => {
1281 component_wise_concrete_int!(self, span, [arg], |e| {
1282 #[allow(clippy::useless_conversion)]
1283 Ok([e
1284 .leading_zeros()
1285 .try_into()
1286 .expect("bit count overflowed 32 bits, somehow!?")])
1287 })
1288 }
1289 crate::MathFunction::CountOneBits => {
1290 component_wise_concrete_int!(self, span, [arg], |e| {
1291 #[allow(clippy::useless_conversion)]
1292 Ok([e
1293 .count_ones()
1294 .try_into()
1295 .expect("bit count overflowed 32 bits, somehow!?")])
1296 })
1297 }
1298 crate::MathFunction::ReverseBits => {
1299 component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1300 }
1301 crate::MathFunction::FirstTrailingBit => {
1302 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1303 }
1304 crate::MathFunction::FirstLeadingBit => {
1305 component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1306 }
1307
1308 crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1310
1311 crate::MathFunction::Atan2
1313 | crate::MathFunction::Modf
1314 | crate::MathFunction::Frexp
1315 | crate::MathFunction::Ldexp
1316 | crate::MathFunction::Dot
1317 | crate::MathFunction::Outer
1318 | crate::MathFunction::Distance
1319 | crate::MathFunction::Length
1320 | crate::MathFunction::Normalize
1321 | crate::MathFunction::FaceForward
1322 | crate::MathFunction::Reflect
1323 | crate::MathFunction::Refract
1324 | crate::MathFunction::Mix
1325 | crate::MathFunction::SmoothStep
1326 | crate::MathFunction::Inverse
1327 | crate::MathFunction::Transpose
1328 | crate::MathFunction::Determinant
1329 | crate::MathFunction::QuantizeToF16
1330 | crate::MathFunction::ExtractBits
1331 | crate::MathFunction::InsertBits
1332 | crate::MathFunction::Pack4x8snorm
1333 | crate::MathFunction::Pack4x8unorm
1334 | crate::MathFunction::Pack2x16snorm
1335 | crate::MathFunction::Pack2x16unorm
1336 | crate::MathFunction::Pack2x16float
1337 | crate::MathFunction::Pack4xI8
1338 | crate::MathFunction::Pack4xU8
1339 | crate::MathFunction::Unpack4x8snorm
1340 | crate::MathFunction::Unpack4x8unorm
1341 | crate::MathFunction::Unpack2x16snorm
1342 | crate::MathFunction::Unpack2x16unorm
1343 | crate::MathFunction::Unpack2x16float
1344 | crate::MathFunction::Unpack4xI8
1345 | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1346 format!("{fun:?} built-in function"),
1347 )),
1348 }
1349 }
1350
1351 fn cross_product(
1353 &mut self,
1354 a: Handle<Expression>,
1355 b: Handle<Expression>,
1356 span: Span,
1357 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1358 use Literal as Li;
1359
1360 let (a, ty) = self.extract_vec::<3>(a)?;
1361 let (b, _) = self.extract_vec::<3>(b)?;
1362
1363 let product = match (a, b) {
1364 (
1365 [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1366 [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1367 ) => {
1368 let p = cross_product(
1373 [a0 as f64, a1 as f64, a2 as f64],
1374 [b0 as f64, b1 as f64, b2 as f64],
1375 );
1376 [
1377 Li::AbstractFloat(p[0]),
1378 Li::AbstractFloat(p[1]),
1379 Li::AbstractFloat(p[2]),
1380 ]
1381 }
1382 (
1383 [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1384 [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1385 ) => {
1386 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1387 [
1388 Li::AbstractFloat(p[0]),
1389 Li::AbstractFloat(p[1]),
1390 Li::AbstractFloat(p[2]),
1391 ]
1392 }
1393 ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1394 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1395 [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1396 }
1397 ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1398 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1399 [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1400 }
1401 ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1402 let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1403 [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1404 }
1405 _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1406 };
1407
1408 let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1409 let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1410 let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1411
1412 self.register_evaluated_expr(
1413 Expression::Compose {
1414 ty,
1415 components: vec![p0, p1, p2],
1416 },
1417 span,
1418 )
1419 }
1420
1421 fn extract_vec<const N: usize>(
1429 &mut self,
1430 expr: Handle<Expression>,
1431 ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1432 let span = self.expressions.get_span(expr);
1433 let expr = self.eval_zero_value_and_splat(expr, span)?;
1434 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1435 return Err(ConstantEvaluatorError::InvalidMathArg);
1436 };
1437
1438 let mut value = [Literal::Bool(false); N];
1439 for (component, elt) in
1440 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1441 .zip(value.iter_mut())
1442 {
1443 let Expression::Literal(literal) = self.expressions[component] else {
1444 return Err(ConstantEvaluatorError::InvalidMathArg);
1445 };
1446 *elt = literal;
1447 }
1448
1449 Ok((value, ty))
1450 }
1451
1452 fn array_length(
1453 &mut self,
1454 array: Handle<Expression>,
1455 span: Span,
1456 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1457 match self.expressions[array] {
1458 Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1459 match self.types[ty].inner {
1460 TypeInner::Array { size, .. } => match size {
1461 ArraySize::Constant(len) => {
1462 let expr = Expression::Literal(Literal::U32(len.get()));
1463 self.register_evaluated_expr(expr, span)
1464 }
1465 ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
1466 ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1467 },
1468 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1469 }
1470 }
1471 _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1472 }
1473 }
1474
1475 fn access(
1476 &mut self,
1477 base: Handle<Expression>,
1478 index: usize,
1479 span: Span,
1480 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1481 match self.expressions[base] {
1482 Expression::ZeroValue(ty) => {
1483 let ty_inner = &self.types[ty].inner;
1484 let components = ty_inner
1485 .components()
1486 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1487
1488 if index >= components as usize {
1489 Err(ConstantEvaluatorError::InvalidAccessBase)
1490 } else {
1491 let ty_res = ty_inner
1492 .component_type(index)
1493 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1494 let ty = match ty_res {
1495 crate::proc::TypeResolution::Handle(ty) => ty,
1496 crate::proc::TypeResolution::Value(inner) => {
1497 self.types.insert(Type { name: None, inner }, span)
1498 }
1499 };
1500 self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1501 }
1502 }
1503 Expression::Splat { size, value } => {
1504 if index >= size as usize {
1505 Err(ConstantEvaluatorError::InvalidAccessBase)
1506 } else {
1507 Ok(value)
1508 }
1509 }
1510 Expression::Compose { ty, ref components } => {
1511 let _ = self.types[ty]
1512 .inner
1513 .components()
1514 .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1515
1516 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1517 .nth(index)
1518 .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1519 }
1520 _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1521 }
1522 }
1523
1524 fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1525 match self.expressions[expr] {
1526 Expression::ZeroValue(ty)
1527 if matches!(
1528 self.types[ty].inner,
1529 TypeInner::Scalar(crate::Scalar {
1530 kind: ScalarKind::Uint,
1531 ..
1532 })
1533 ) =>
1534 {
1535 Ok(0)
1536 }
1537 Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1538 _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1539 }
1540 }
1541
1542 fn eval_zero_value_and_splat(
1549 &mut self,
1550 mut expr: Handle<Expression>,
1551 span: Span,
1552 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1553 if let Expression::Compose { ty, ref components } = self.expressions[expr] {
1556 let components = components
1557 .clone()
1558 .iter()
1559 .map(|component| self.eval_zero_value_and_splat(*component, span))
1560 .collect::<Result<_, _>>()?;
1561 expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
1562 }
1563
1564 if let Expression::Splat { size, value } = self.expressions[expr] {
1568 expr = self.splat(value, size, span)?;
1569 }
1570 if let Expression::ZeroValue(ty) = self.expressions[expr] {
1571 expr = self.eval_zero_value_impl(ty, span)?;
1572 }
1573 Ok(expr)
1574 }
1575
1576 fn eval_zero_value(
1582 &mut self,
1583 expr: Handle<Expression>,
1584 span: Span,
1585 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1586 match self.expressions[expr] {
1587 Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1588 _ => Ok(expr),
1589 }
1590 }
1591
1592 fn eval_zero_value_impl(
1598 &mut self,
1599 ty: Handle<Type>,
1600 span: Span,
1601 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1602 match self.types[ty].inner {
1603 TypeInner::Scalar(scalar) => {
1604 let expr = Expression::Literal(
1605 Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1606 );
1607 self.register_evaluated_expr(expr, span)
1608 }
1609 TypeInner::Vector { size, scalar } => {
1610 let scalar_ty = self.types.insert(
1611 Type {
1612 name: None,
1613 inner: TypeInner::Scalar(scalar),
1614 },
1615 span,
1616 );
1617 let el = self.eval_zero_value_impl(scalar_ty, span)?;
1618 let expr = Expression::Compose {
1619 ty,
1620 components: vec![el; size as usize],
1621 };
1622 self.register_evaluated_expr(expr, span)
1623 }
1624 TypeInner::Matrix {
1625 columns,
1626 rows,
1627 scalar,
1628 } => {
1629 let vec_ty = self.types.insert(
1630 Type {
1631 name: None,
1632 inner: TypeInner::Vector { size: rows, scalar },
1633 },
1634 span,
1635 );
1636 let el = self.eval_zero_value_impl(vec_ty, span)?;
1637 let expr = Expression::Compose {
1638 ty,
1639 components: vec![el; columns as usize],
1640 };
1641 self.register_evaluated_expr(expr, span)
1642 }
1643 TypeInner::Array {
1644 base,
1645 size: ArraySize::Constant(size),
1646 ..
1647 } => {
1648 let el = self.eval_zero_value_impl(base, span)?;
1649 let expr = Expression::Compose {
1650 ty,
1651 components: vec![el; size.get() as usize],
1652 };
1653 self.register_evaluated_expr(expr, span)
1654 }
1655 TypeInner::Struct { ref members, .. } => {
1656 let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1657 let mut components = Vec::with_capacity(members.len());
1658 for ty in types {
1659 components.push(self.eval_zero_value_impl(ty, span)?);
1660 }
1661 let expr = Expression::Compose { ty, components };
1662 self.register_evaluated_expr(expr, span)
1663 }
1664 _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1665 }
1666 }
1667
1668 pub fn cast(
1672 &mut self,
1673 expr: Handle<Expression>,
1674 target: crate::Scalar,
1675 span: Span,
1676 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1677 use crate::Scalar as Sc;
1678
1679 let expr = self.eval_zero_value(expr, span)?;
1680
1681 let make_error = || -> Result<_, ConstantEvaluatorError> {
1682 let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1683
1684 #[cfg(feature = "wgsl-in")]
1685 let to = target.to_wgsl_for_diagnostics();
1686
1687 #[cfg(not(feature = "wgsl-in"))]
1688 let to = format!("{target:?}");
1689
1690 Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1691 };
1692
1693 use crate::proc::type_methods::IntFloatLimits;
1694
1695 let expr = match self.expressions[expr] {
1696 Expression::Literal(literal) => {
1697 let literal = match target {
1698 Sc::I32 => Literal::I32(match literal {
1699 Literal::I32(v) => v,
1700 Literal::U32(v) => v as i32,
1701 Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
1702 Literal::F16(v) => f16::to_i32(&v).unwrap(), Literal::Bool(v) => v as i32,
1704 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1705 return make_error();
1706 }
1707 Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1708 Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1709 }),
1710 Sc::U32 => Literal::U32(match literal {
1711 Literal::I32(v) => v as u32,
1712 Literal::U32(v) => v,
1713 Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
1714 Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
1716 Literal::Bool(v) => v as u32,
1717 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1718 return make_error();
1719 }
1720 Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1721 Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1722 }),
1723 Sc::I64 => Literal::I64(match literal {
1724 Literal::I32(v) => v as i64,
1725 Literal::U32(v) => v as i64,
1726 Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1727 Literal::Bool(v) => v as i64,
1728 Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1729 Literal::I64(v) => v,
1730 Literal::U64(v) => v as i64,
1731 Literal::F16(v) => f16::to_i64(&v).unwrap(), Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1733 Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1734 }),
1735 Sc::U64 => Literal::U64(match literal {
1736 Literal::I32(v) => v as u64,
1737 Literal::U32(v) => v as u64,
1738 Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1739 Literal::Bool(v) => v as u64,
1740 Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1741 Literal::I64(v) => v as u64,
1742 Literal::U64(v) => v,
1743 Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
1745 Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1746 Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1747 }),
1748 Sc::F16 => Literal::F16(match literal {
1749 Literal::F16(v) => v,
1750 Literal::F32(v) => f16::from_f32(v),
1751 Literal::F64(v) => f16::from_f64(v),
1752 Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
1753 Literal::I64(v) => f16::from_i64(v).unwrap(),
1754 Literal::U64(v) => f16::from_u64(v).unwrap(),
1755 Literal::I32(v) => f16::from_i32(v).unwrap(),
1756 Literal::U32(v) => f16::from_u32(v).unwrap(),
1757 Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
1758 Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
1759 }),
1760 Sc::F32 => Literal::F32(match literal {
1761 Literal::I32(v) => v as f32,
1762 Literal::U32(v) => v as f32,
1763 Literal::F32(v) => v,
1764 Literal::Bool(v) => v as u32 as f32,
1765 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1766 return make_error();
1767 }
1768 Literal::F16(v) => f16::to_f32(v),
1769 Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1770 Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1771 }),
1772 Sc::F64 => Literal::F64(match literal {
1773 Literal::I32(v) => v as f64,
1774 Literal::U32(v) => v as f64,
1775 Literal::F16(v) => f16::to_f64(v),
1776 Literal::F32(v) => v as f64,
1777 Literal::F64(v) => v,
1778 Literal::Bool(v) => v as u32 as f64,
1779 Literal::I64(_) | Literal::U64(_) => return make_error(),
1780 Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1781 Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1782 }),
1783 Sc::BOOL => Literal::Bool(match literal {
1784 Literal::I32(v) => v != 0,
1785 Literal::U32(v) => v != 0,
1786 Literal::F32(v) => v != 0.0,
1787 Literal::F16(v) => v != f16::zero(),
1788 Literal::Bool(v) => v,
1789 Literal::AbstractInt(v) => v != 0,
1790 Literal::AbstractFloat(v) => v != 0.0,
1791 Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1792 return make_error();
1793 }
1794 }),
1795 Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1796 Literal::AbstractInt(v) => {
1797 v as f64
1802 }
1803 Literal::AbstractFloat(v) => v,
1804 _ => return make_error(),
1805 }),
1806 _ => {
1807 log::debug!("Constant evaluator refused to convert value to {target:?}");
1808 return make_error();
1809 }
1810 };
1811 Expression::Literal(literal)
1812 }
1813 Expression::Compose {
1814 ty,
1815 components: ref src_components,
1816 } => {
1817 let ty_inner = match self.types[ty].inner {
1818 TypeInner::Vector { size, .. } => TypeInner::Vector {
1819 size,
1820 scalar: target,
1821 },
1822 TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1823 columns,
1824 rows,
1825 scalar: target,
1826 },
1827 _ => return make_error(),
1828 };
1829
1830 let mut components = src_components.clone();
1831 for component in &mut components {
1832 *component = self.cast(*component, target, span)?;
1833 }
1834
1835 let ty = self.types.insert(
1836 Type {
1837 name: None,
1838 inner: ty_inner,
1839 },
1840 span,
1841 );
1842
1843 Expression::Compose { ty, components }
1844 }
1845 Expression::Splat { size, value } => {
1846 let value_span = self.expressions.get_span(value);
1847 let cast_value = self.cast(value, target, value_span)?;
1848 Expression::Splat {
1849 size,
1850 value: cast_value,
1851 }
1852 }
1853 _ => return make_error(),
1854 };
1855
1856 self.register_evaluated_expr(expr, span)
1857 }
1858
1859 pub fn cast_array(
1872 &mut self,
1873 expr: Handle<Expression>,
1874 target: crate::Scalar,
1875 span: Span,
1876 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1877 let expr = self.check_and_get(expr)?;
1878
1879 let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1880 return self.cast(expr, target, span);
1881 };
1882
1883 let TypeInner::Array {
1884 base: _,
1885 size,
1886 stride: _,
1887 } = self.types[ty].inner
1888 else {
1889 return self.cast(expr, target, span);
1890 };
1891
1892 let mut components = components.clone();
1893 for component in &mut components {
1894 *component = self.cast_array(*component, target, span)?;
1895 }
1896
1897 let first = components.first().unwrap();
1898 let new_base = match self.resolve_type(*first)? {
1899 crate::proc::TypeResolution::Handle(ty) => ty,
1900 crate::proc::TypeResolution::Value(inner) => {
1901 self.types.insert(Type { name: None, inner }, span)
1902 }
1903 };
1904 let mut layouter = core::mem::take(self.layouter);
1905 layouter.update(self.to_ctx()).unwrap();
1906 *self.layouter = layouter;
1907
1908 let new_base_stride = self.layouter[new_base].to_stride();
1909 let new_array_ty = self.types.insert(
1910 Type {
1911 name: None,
1912 inner: TypeInner::Array {
1913 base: new_base,
1914 size,
1915 stride: new_base_stride,
1916 },
1917 },
1918 span,
1919 );
1920
1921 let compose = Expression::Compose {
1922 ty: new_array_ty,
1923 components,
1924 };
1925 self.register_evaluated_expr(compose, span)
1926 }
1927
1928 fn unary_op(
1929 &mut self,
1930 op: UnaryOperator,
1931 expr: Handle<Expression>,
1932 span: Span,
1933 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1934 let expr = self.eval_zero_value_and_splat(expr, span)?;
1935
1936 let expr = match self.expressions[expr] {
1937 Expression::Literal(value) => Expression::Literal(match op {
1938 UnaryOperator::Negate => match value {
1939 Literal::I32(v) => Literal::I32(v.wrapping_neg()),
1940 Literal::I64(v) => Literal::I64(v.wrapping_neg()),
1941 Literal::F32(v) => Literal::F32(-v),
1942 Literal::F16(v) => Literal::F16(-v),
1943 Literal::F64(v) => Literal::F64(-v),
1944 Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
1945 Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
1946 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1947 },
1948 UnaryOperator::LogicalNot => match value {
1949 Literal::Bool(v) => Literal::Bool(!v),
1950 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1951 },
1952 UnaryOperator::BitwiseNot => match value {
1953 Literal::I32(v) => Literal::I32(!v),
1954 Literal::I64(v) => Literal::I64(!v),
1955 Literal::U32(v) => Literal::U32(!v),
1956 Literal::U64(v) => Literal::U64(!v),
1957 Literal::AbstractInt(v) => Literal::AbstractInt(!v),
1958 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1959 },
1960 }),
1961 Expression::Compose {
1962 ty,
1963 components: ref src_components,
1964 } => {
1965 match self.types[ty].inner {
1966 TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
1967 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1968 }
1969
1970 let mut components = src_components.clone();
1971 for component in &mut components {
1972 *component = self.unary_op(op, *component, span)?;
1973 }
1974
1975 Expression::Compose { ty, components }
1976 }
1977 _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1978 };
1979
1980 self.register_evaluated_expr(expr, span)
1981 }
1982
1983 fn binary_op(
1984 &mut self,
1985 op: BinaryOperator,
1986 left: Handle<Expression>,
1987 right: Handle<Expression>,
1988 span: Span,
1989 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1990 let left = self.eval_zero_value_and_splat(left, span)?;
1991 let right = self.eval_zero_value_and_splat(right, span)?;
1992
1993 let expr = match (&self.expressions[left], &self.expressions[right]) {
1994 (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
1995 let literal = match op {
1996 BinaryOperator::Equal => Literal::Bool(left_value == right_value),
1997 BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
1998 BinaryOperator::Less => Literal::Bool(left_value < right_value),
1999 BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2000 BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2001 BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2002
2003 _ => match (left_value, right_value) {
2004 (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2005 BinaryOperator::Add => a.wrapping_add(b),
2006 BinaryOperator::Subtract => a.wrapping_sub(b),
2007 BinaryOperator::Multiply => a.wrapping_mul(b),
2008 BinaryOperator::Divide => {
2009 if b == 0 {
2010 return Err(ConstantEvaluatorError::DivisionByZero);
2011 } else {
2012 a.wrapping_div(b)
2013 }
2014 }
2015 BinaryOperator::Modulo => {
2016 if b == 0 {
2017 return Err(ConstantEvaluatorError::RemainderByZero);
2018 } else {
2019 a.wrapping_rem(b)
2020 }
2021 }
2022 BinaryOperator::And => a & b,
2023 BinaryOperator::ExclusiveOr => a ^ b,
2024 BinaryOperator::InclusiveOr => a | b,
2025 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2026 }),
2027 (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2028 BinaryOperator::ShiftLeft => {
2029 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2030 return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2031 }
2032 a.checked_shl(b)
2033 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2034 }
2035 BinaryOperator::ShiftRight => a
2036 .checked_shr(b)
2037 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2038 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2039 }),
2040 (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2041 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2042 ConstantEvaluatorError::Overflow("addition".into())
2043 })?,
2044 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2045 ConstantEvaluatorError::Overflow("subtraction".into())
2046 })?,
2047 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2048 ConstantEvaluatorError::Overflow("multiplication".into())
2049 })?,
2050 BinaryOperator::Divide => a
2051 .checked_div(b)
2052 .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2053 BinaryOperator::Modulo => a
2054 .checked_rem(b)
2055 .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2056 BinaryOperator::And => a & b,
2057 BinaryOperator::ExclusiveOr => a ^ b,
2058 BinaryOperator::InclusiveOr => a | b,
2059 BinaryOperator::ShiftLeft => a
2060 .checked_mul(
2061 1u32.checked_shl(b)
2062 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2063 )
2064 .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2065 BinaryOperator::ShiftRight => a
2066 .checked_shr(b)
2067 .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2068 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2069 }),
2070 (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2071 BinaryOperator::Add => a + b,
2072 BinaryOperator::Subtract => a - b,
2073 BinaryOperator::Multiply => a * b,
2074 BinaryOperator::Divide => a / b,
2075 BinaryOperator::Modulo => a % b,
2076 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2077 }),
2078 (Literal::AbstractInt(a), Literal::U32(b)) => {
2079 Literal::AbstractInt(match op {
2080 BinaryOperator::ShiftLeft => {
2081 if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2082 return Err(ConstantEvaluatorError::Overflow(
2083 "<<".to_string(),
2084 ));
2085 }
2086 a.checked_shl(b).unwrap_or(0)
2087 }
2088 BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2089 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2090 })
2091 }
2092 (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2093 BinaryOperator::Add => a + b,
2094 BinaryOperator::Subtract => a - b,
2095 BinaryOperator::Multiply => a * b,
2096 BinaryOperator::Divide => a / b,
2097 BinaryOperator::Modulo => a % b,
2098 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2099 }),
2100 (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2101 Literal::AbstractInt(match op {
2102 BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2103 ConstantEvaluatorError::Overflow("addition".into())
2104 })?,
2105 BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2106 ConstantEvaluatorError::Overflow("subtraction".into())
2107 })?,
2108 BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2109 ConstantEvaluatorError::Overflow("multiplication".into())
2110 })?,
2111 BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2112 if b == 0 {
2113 ConstantEvaluatorError::DivisionByZero
2114 } else {
2115 ConstantEvaluatorError::Overflow("division".into())
2116 }
2117 })?,
2118 BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2119 if b == 0 {
2120 ConstantEvaluatorError::RemainderByZero
2121 } else {
2122 ConstantEvaluatorError::Overflow("remainder".into())
2123 }
2124 })?,
2125 BinaryOperator::And => a & b,
2126 BinaryOperator::ExclusiveOr => a ^ b,
2127 BinaryOperator::InclusiveOr => a | b,
2128 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2129 })
2130 }
2131 (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2132 Literal::AbstractFloat(match op {
2133 BinaryOperator::Add => a + b,
2134 BinaryOperator::Subtract => a - b,
2135 BinaryOperator::Multiply => a * b,
2136 BinaryOperator::Divide => a / b,
2137 BinaryOperator::Modulo => a % b,
2138 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2139 })
2140 }
2141 (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2142 BinaryOperator::LogicalAnd => a && b,
2143 BinaryOperator::LogicalOr => a || b,
2144 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2145 }),
2146 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2147 },
2148 };
2149 Expression::Literal(literal)
2150 }
2151 (
2152 &Expression::Compose {
2153 components: ref src_components,
2154 ty,
2155 },
2156 &Expression::Literal(_),
2157 ) => {
2158 let mut components = src_components.clone();
2159 for component in &mut components {
2160 *component = self.binary_op(op, *component, right, span)?;
2161 }
2162 Expression::Compose { ty, components }
2163 }
2164 (
2165 &Expression::Literal(_),
2166 &Expression::Compose {
2167 components: ref src_components,
2168 ty,
2169 },
2170 ) => {
2171 let mut components = src_components.clone();
2172 for component in &mut components {
2173 *component = self.binary_op(op, left, *component, span)?;
2174 }
2175 Expression::Compose { ty, components }
2176 }
2177 (
2178 &Expression::Compose {
2179 components: ref left_components,
2180 ty: left_ty,
2181 },
2182 &Expression::Compose {
2183 components: ref right_components,
2184 ty: right_ty,
2185 },
2186 ) => {
2187 let left_flattened = crate::proc::flatten_compose(
2191 left_ty,
2192 left_components,
2193 self.expressions,
2194 self.types,
2195 );
2196 let right_flattened = crate::proc::flatten_compose(
2197 right_ty,
2198 right_components,
2199 self.expressions,
2200 self.types,
2201 );
2202
2203 let mut flattened = Vec::with_capacity(left_components.len());
2206 flattened.extend(left_flattened.zip(right_flattened));
2207
2208 match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2209 (
2210 &TypeInner::Vector {
2211 size: left_size, ..
2212 },
2213 &TypeInner::Vector {
2214 size: right_size, ..
2215 },
2216 ) if left_size == right_size => {
2217 self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2218 }
2219 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2220 }
2221 }
2222 _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2223 };
2224
2225 self.register_evaluated_expr(expr, span)
2226 }
2227
2228 fn binary_op_vector(
2229 &mut self,
2230 op: BinaryOperator,
2231 size: crate::VectorSize,
2232 components: &[(Handle<Expression>, Handle<Expression>)],
2233 left_ty: Handle<Type>,
2234 span: Span,
2235 ) -> Result<Expression, ConstantEvaluatorError> {
2236 let ty = match op {
2237 BinaryOperator::Equal
2239 | BinaryOperator::NotEqual
2240 | BinaryOperator::Less
2241 | BinaryOperator::LessEqual
2242 | BinaryOperator::Greater
2243 | BinaryOperator::GreaterEqual => self.types.insert(
2244 Type {
2245 name: None,
2246 inner: TypeInner::Vector {
2247 size,
2248 scalar: crate::Scalar::BOOL,
2249 },
2250 },
2251 span,
2252 ),
2253
2254 BinaryOperator::Add
2257 | BinaryOperator::Subtract
2258 | BinaryOperator::Multiply
2259 | BinaryOperator::Divide
2260 | BinaryOperator::Modulo
2261 | BinaryOperator::And
2262 | BinaryOperator::ExclusiveOr
2263 | BinaryOperator::InclusiveOr
2264 | BinaryOperator::ShiftLeft
2265 | BinaryOperator::ShiftRight => left_ty,
2266
2267 BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2268 return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2270 }
2271 };
2272
2273 let components = components
2274 .iter()
2275 .map(|&(left, right)| self.binary_op(op, left, right, span))
2276 .collect::<Result<Vec<_>, _>>()?;
2277
2278 Ok(Expression::Compose { ty, components })
2279 }
2280
2281 fn relational(
2282 &mut self,
2283 fun: RelationalFunction,
2284 arg: Handle<Expression>,
2285 span: Span,
2286 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2287 let arg = self.eval_zero_value_and_splat(arg, span)?;
2288 match fun {
2289 RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2290 Expression::Literal(Literal::Bool(_)) => Ok(arg),
2291 Expression::Compose { ty, ref components }
2292 if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2293 {
2294 let components =
2295 crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2296 .map(|component| match self.expressions[component] {
2297 Expression::Literal(Literal::Bool(val)) => Ok(val),
2298 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2299 })
2300 .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2301 let result = match fun {
2302 RelationalFunction::All => components.iter().all(|c| *c),
2303 RelationalFunction::Any => components.iter().any(|c| *c),
2304 _ => unreachable!(),
2305 };
2306 self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2307 }
2308 _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2309 },
2310 _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2311 "{fun:?} built-in function"
2312 ))),
2313 }
2314 }
2315
2316 fn copy_from(
2324 &mut self,
2325 expr: Handle<Expression>,
2326 expressions: &Arena<Expression>,
2327 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2328 let span = expressions.get_span(expr);
2329 match expressions[expr] {
2330 ref expr @ (Expression::Literal(_)
2331 | Expression::Constant(_)
2332 | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2333 Expression::Compose { ty, ref components } => {
2334 let mut components = components.clone();
2335 for component in &mut components {
2336 *component = self.copy_from(*component, expressions)?;
2337 }
2338 self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2339 }
2340 Expression::Splat { size, value } => {
2341 let value = self.copy_from(value, expressions)?;
2342 self.register_evaluated_expr(Expression::Splat { size, value }, span)
2343 }
2344 _ => {
2345 log::debug!("copy_from: SubexpressionsAreNotConstant");
2346 Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2347 }
2348 }
2349 }
2350
2351 fn vector_compose_flattened_size(
2353 &self,
2354 components: &[Handle<Expression>],
2355 ) -> Result<usize, ConstantEvaluatorError> {
2356 components
2357 .iter()
2358 .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2359 let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2360 TypeInner::Scalar(_) => 1,
2361 TypeInner::Vector { size, .. } => size as usize,
2365 _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2366 };
2367 Ok(acc + size)
2368 })
2369 }
2370
2371 fn register_evaluated_expr(
2372 &mut self,
2373 expr: Expression,
2374 span: Span,
2375 ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2376 if let Expression::Literal(literal) = expr {
2381 crate::valid::check_literal_value(literal)?;
2382 }
2383
2384 if let Expression::Compose { ty, ref components } = expr {
2388 if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2389 let expected = size as usize;
2390 let actual = self.vector_compose_flattened_size(components)?;
2391 if expected != actual {
2392 return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2393 expected,
2394 actual,
2395 });
2396 }
2397 }
2398 }
2399
2400 Ok(self.append_expr(expr, span, ExpressionKind::Const))
2401 }
2402
2403 fn append_expr(
2404 &mut self,
2405 expr: Expression,
2406 span: Span,
2407 expr_type: ExpressionKind,
2408 ) -> Handle<Expression> {
2409 let h = match self.behavior {
2410 Behavior::Wgsl(
2411 WgslRestrictions::Runtime(ref mut function_local_data)
2412 | WgslRestrictions::Const(Some(ref mut function_local_data)),
2413 )
2414 | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2415 let is_running = function_local_data.emitter.is_running();
2416 let needs_pre_emit = expr.needs_pre_emit();
2417 if is_running && needs_pre_emit {
2418 function_local_data
2419 .block
2420 .extend(function_local_data.emitter.finish(self.expressions));
2421 let h = self.expressions.append(expr, span);
2422 function_local_data.emitter.start(self.expressions);
2423 h
2424 } else {
2425 self.expressions.append(expr, span)
2426 }
2427 }
2428 _ => self.expressions.append(expr, span),
2429 };
2430 self.expression_kind_tracker.insert(h, expr_type);
2431 h
2432 }
2433
2434 fn resolve_type(
2435 &self,
2436 expr: Handle<Expression>,
2437 ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2438 use crate::proc::TypeResolution as Tr;
2439 use crate::Expression as Ex;
2440 let resolution = match self.expressions[expr] {
2441 Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2442 Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2443 Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2444 Ex::Splat { size, value } => {
2445 let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2446 return Err(ConstantEvaluatorError::SplatScalarOnly);
2447 };
2448 Tr::Value(TypeInner::Vector { scalar, size })
2449 }
2450 _ => {
2451 log::debug!("resolve_type: SubexpressionsAreNotConstant");
2452 return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2453 }
2454 };
2455
2456 Ok(resolution)
2457 }
2458}
2459
2460fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2461 let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2465 match e {
2466 idx @ 0..=31 => idx,
2467 32 => u32::MAX,
2468 _ => unreachable!(),
2469 }
2470 };
2471 match concrete_int {
2472 ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2473 ConcreteInt::I32([e]) => {
2474 ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2475 }
2476 }
2477}
2478
2479#[test]
2480fn first_trailing_bit_smoke() {
2481 assert_eq!(
2482 first_trailing_bit(ConcreteInt::I32([0])),
2483 ConcreteInt::I32([-1])
2484 );
2485 assert_eq!(
2486 first_trailing_bit(ConcreteInt::I32([1])),
2487 ConcreteInt::I32([0])
2488 );
2489 assert_eq!(
2490 first_trailing_bit(ConcreteInt::I32([2])),
2491 ConcreteInt::I32([1])
2492 );
2493 assert_eq!(
2494 first_trailing_bit(ConcreteInt::I32([-1])),
2495 ConcreteInt::I32([0]),
2496 );
2497 assert_eq!(
2498 first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2499 ConcreteInt::I32([31]),
2500 );
2501 assert_eq!(
2502 first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2503 ConcreteInt::I32([0]),
2504 );
2505 for idx in 0..32 {
2506 assert_eq!(
2507 first_trailing_bit(ConcreteInt::I32([1 << idx])),
2508 ConcreteInt::I32([idx])
2509 )
2510 }
2511
2512 assert_eq!(
2513 first_trailing_bit(ConcreteInt::U32([0])),
2514 ConcreteInt::U32([u32::MAX])
2515 );
2516 assert_eq!(
2517 first_trailing_bit(ConcreteInt::U32([1])),
2518 ConcreteInt::U32([0])
2519 );
2520 assert_eq!(
2521 first_trailing_bit(ConcreteInt::U32([2])),
2522 ConcreteInt::U32([1])
2523 );
2524 assert_eq!(
2525 first_trailing_bit(ConcreteInt::U32([1 << 31])),
2526 ConcreteInt::U32([31]),
2527 );
2528 assert_eq!(
2529 first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2530 ConcreteInt::U32([0]),
2531 );
2532 for idx in 0..32 {
2533 assert_eq!(
2534 first_trailing_bit(ConcreteInt::U32([1 << idx])),
2535 ConcreteInt::U32([idx])
2536 )
2537 }
2538}
2539
2540fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2541 let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2545 match e {
2546 idx @ 0..=31 => 31 - idx,
2547 32 => u32::MAX,
2548 _ => unreachable!(),
2549 }
2550 };
2551 match concrete_int {
2552 ConcreteInt::I32([e]) => ConcreteInt::I32([{
2553 let rtl_bit_index = if e.is_negative() {
2554 e.leading_ones()
2555 } else {
2556 e.leading_zeros()
2557 };
2558 rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2559 }]),
2560 ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2561 }
2562}
2563
2564#[test]
2565fn first_leading_bit_smoke() {
2566 assert_eq!(
2567 first_leading_bit(ConcreteInt::I32([-1])),
2568 ConcreteInt::I32([-1])
2569 );
2570 assert_eq!(
2571 first_leading_bit(ConcreteInt::I32([0])),
2572 ConcreteInt::I32([-1])
2573 );
2574 assert_eq!(
2575 first_leading_bit(ConcreteInt::I32([1])),
2576 ConcreteInt::I32([0])
2577 );
2578 assert_eq!(
2579 first_leading_bit(ConcreteInt::I32([-2])),
2580 ConcreteInt::I32([0])
2581 );
2582 assert_eq!(
2583 first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2584 ConcreteInt::I32([12])
2585 );
2586 assert_eq!(
2587 first_leading_bit(ConcreteInt::I32([i32::MAX])),
2588 ConcreteInt::I32([30])
2589 );
2590 assert_eq!(
2591 first_leading_bit(ConcreteInt::I32([i32::MIN])),
2592 ConcreteInt::I32([30])
2593 );
2594 for idx in 0..(32 - 1) {
2596 assert_eq!(
2597 first_leading_bit(ConcreteInt::I32([1 << idx])),
2598 ConcreteInt::I32([idx])
2599 );
2600 }
2601 for idx in 1..(32 - 1) {
2602 assert_eq!(
2603 first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2604 ConcreteInt::I32([idx - 1])
2605 );
2606 }
2607
2608 assert_eq!(
2609 first_leading_bit(ConcreteInt::U32([0])),
2610 ConcreteInt::U32([u32::MAX])
2611 );
2612 assert_eq!(
2613 first_leading_bit(ConcreteInt::U32([1])),
2614 ConcreteInt::U32([0])
2615 );
2616 assert_eq!(
2617 first_leading_bit(ConcreteInt::U32([u32::MAX])),
2618 ConcreteInt::U32([31])
2619 );
2620 for idx in 0..32 {
2621 assert_eq!(
2622 first_leading_bit(ConcreteInt::U32([1 << idx])),
2623 ConcreteInt::U32([idx])
2624 )
2625 }
2626}
2627
2628trait TryFromAbstract<T>: Sized {
2630 fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2652}
2653
2654impl TryFromAbstract<i64> for i32 {
2655 fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2656 i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2657 value: format!("{value:?}"),
2658 to_type: "i32",
2659 })
2660 }
2661}
2662
2663impl TryFromAbstract<i64> for u32 {
2664 fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2665 u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2666 value: format!("{value:?}"),
2667 to_type: "u32",
2668 })
2669 }
2670}
2671
2672impl TryFromAbstract<i64> for u64 {
2673 fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2674 u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2675 value: format!("{value:?}"),
2676 to_type: "u64",
2677 })
2678 }
2679}
2680
2681impl TryFromAbstract<i64> for i64 {
2682 fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2683 Ok(value)
2684 }
2685}
2686
2687impl TryFromAbstract<i64> for f32 {
2688 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2689 let f = value as f32;
2690 Ok(f)
2694 }
2695}
2696
2697impl TryFromAbstract<f64> for f32 {
2698 fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2699 let f = value as f32;
2700 if f.is_infinite() {
2701 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2702 value: format!("{value:?}"),
2703 to_type: "f32",
2704 });
2705 }
2706 Ok(f)
2707 }
2708}
2709
2710impl TryFromAbstract<i64> for f64 {
2711 fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2712 let f = value as f64;
2713 Ok(f)
2717 }
2718}
2719
2720impl TryFromAbstract<f64> for f64 {
2721 fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2722 Ok(value)
2723 }
2724}
2725
2726impl TryFromAbstract<f64> for i32 {
2727 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2728 Ok(value as i32)
2741 }
2742}
2743
2744impl TryFromAbstract<f64> for u32 {
2745 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2746 Ok(value as u32)
2749 }
2750}
2751
2752impl TryFromAbstract<f64> for i64 {
2753 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2754 use crate::proc::type_methods::IntFloatLimits;
2757 Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
2758 }
2759}
2760
2761impl TryFromAbstract<f64> for u64 {
2762 fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2763 use crate::proc::type_methods::IntFloatLimits;
2766 Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
2767 }
2768}
2769
2770impl TryFromAbstract<f64> for f16 {
2771 fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
2772 let f = f16::from_f64(value);
2773 if f.is_infinite() {
2774 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2775 value: format!("{value:?}"),
2776 to_type: "f16",
2777 });
2778 }
2779 Ok(f)
2780 }
2781}
2782
2783impl TryFromAbstract<i64> for f16 {
2784 fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
2785 let f = f16::from_i64(value);
2786 if f.is_none() {
2787 return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2788 value: format!("{value:?}"),
2789 to_type: "f16",
2790 });
2791 }
2792 Ok(f.unwrap())
2793 }
2794}
2795
2796fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
2797where
2798 T: Copy,
2799 T: core::ops::Mul<T, Output = T>,
2800 T: core::ops::Sub<T, Output = T>,
2801{
2802 [
2803 a[1] * b[2] - a[2] * b[1],
2804 a[2] * b[0] - a[0] * b[2],
2805 a[0] * b[1] - a[1] * b[0],
2806 ]
2807}
2808
2809#[cfg(test)]
2810mod tests {
2811 use alloc::{vec, vec::Vec};
2812
2813 use crate::{
2814 Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
2815 UniqueArena, VectorSize,
2816 };
2817
2818 use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
2819
2820 #[test]
2821 fn unary_op() {
2822 let mut types = UniqueArena::new();
2823 let mut constants = Arena::new();
2824 let overrides = Arena::new();
2825 let mut global_expressions = Arena::new();
2826
2827 let scalar_ty = types.insert(
2828 Type {
2829 name: None,
2830 inner: TypeInner::Scalar(crate::Scalar::I32),
2831 },
2832 Default::default(),
2833 );
2834
2835 let vec_ty = types.insert(
2836 Type {
2837 name: None,
2838 inner: TypeInner::Vector {
2839 size: VectorSize::Bi,
2840 scalar: crate::Scalar::I32,
2841 },
2842 },
2843 Default::default(),
2844 );
2845
2846 let h = constants.append(
2847 Constant {
2848 name: None,
2849 ty: scalar_ty,
2850 init: global_expressions
2851 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2852 },
2853 Default::default(),
2854 );
2855
2856 let h1 = constants.append(
2857 Constant {
2858 name: None,
2859 ty: scalar_ty,
2860 init: global_expressions
2861 .append(Expression::Literal(Literal::I32(8)), Default::default()),
2862 },
2863 Default::default(),
2864 );
2865
2866 let vec_h = constants.append(
2867 Constant {
2868 name: None,
2869 ty: vec_ty,
2870 init: global_expressions.append(
2871 Expression::Compose {
2872 ty: vec_ty,
2873 components: vec![constants[h].init, constants[h1].init],
2874 },
2875 Default::default(),
2876 ),
2877 },
2878 Default::default(),
2879 );
2880
2881 let expr = global_expressions.append(Expression::Constant(h), Default::default());
2882 let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
2883
2884 let expr2 = Expression::Unary {
2885 op: UnaryOperator::Negate,
2886 expr,
2887 };
2888
2889 let expr3 = Expression::Unary {
2890 op: UnaryOperator::BitwiseNot,
2891 expr,
2892 };
2893
2894 let expr4 = Expression::Unary {
2895 op: UnaryOperator::BitwiseNot,
2896 expr: expr1,
2897 };
2898
2899 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2900 let mut solver = ConstantEvaluator {
2901 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2902 types: &mut types,
2903 constants: &constants,
2904 overrides: &overrides,
2905 expressions: &mut global_expressions,
2906 expression_kind_tracker,
2907 layouter: &mut crate::proc::Layouter::default(),
2908 };
2909
2910 let res1 = solver
2911 .try_eval_and_append(expr2, Default::default())
2912 .unwrap();
2913 let res2 = solver
2914 .try_eval_and_append(expr3, Default::default())
2915 .unwrap();
2916 let res3 = solver
2917 .try_eval_and_append(expr4, Default::default())
2918 .unwrap();
2919
2920 assert_eq!(
2921 global_expressions[res1],
2922 Expression::Literal(Literal::I32(-4))
2923 );
2924
2925 assert_eq!(
2926 global_expressions[res2],
2927 Expression::Literal(Literal::I32(!4))
2928 );
2929
2930 let res3_inner = &global_expressions[res3];
2931
2932 match *res3_inner {
2933 Expression::Compose {
2934 ref ty,
2935 ref components,
2936 } => {
2937 assert_eq!(*ty, vec_ty);
2938 let mut components_iter = components.iter().copied();
2939 assert_eq!(
2940 global_expressions[components_iter.next().unwrap()],
2941 Expression::Literal(Literal::I32(!4))
2942 );
2943 assert_eq!(
2944 global_expressions[components_iter.next().unwrap()],
2945 Expression::Literal(Literal::I32(!8))
2946 );
2947 assert!(components_iter.next().is_none());
2948 }
2949 _ => panic!("Expected vector"),
2950 }
2951 }
2952
2953 #[test]
2954 fn cast() {
2955 let mut types = UniqueArena::new();
2956 let mut constants = Arena::new();
2957 let overrides = Arena::new();
2958 let mut global_expressions = Arena::new();
2959
2960 let scalar_ty = types.insert(
2961 Type {
2962 name: None,
2963 inner: TypeInner::Scalar(crate::Scalar::I32),
2964 },
2965 Default::default(),
2966 );
2967
2968 let h = constants.append(
2969 Constant {
2970 name: None,
2971 ty: scalar_ty,
2972 init: global_expressions
2973 .append(Expression::Literal(Literal::I32(4)), Default::default()),
2974 },
2975 Default::default(),
2976 );
2977
2978 let expr = global_expressions.append(Expression::Constant(h), Default::default());
2979
2980 let root = Expression::As {
2981 expr,
2982 kind: ScalarKind::Bool,
2983 convert: Some(crate::BOOL_WIDTH),
2984 };
2985
2986 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2987 let mut solver = ConstantEvaluator {
2988 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2989 types: &mut types,
2990 constants: &constants,
2991 overrides: &overrides,
2992 expressions: &mut global_expressions,
2993 expression_kind_tracker,
2994 layouter: &mut crate::proc::Layouter::default(),
2995 };
2996
2997 let res = solver
2998 .try_eval_and_append(root, Default::default())
2999 .unwrap();
3000
3001 assert_eq!(
3002 global_expressions[res],
3003 Expression::Literal(Literal::Bool(true))
3004 );
3005 }
3006
3007 #[test]
3008 fn access() {
3009 let mut types = UniqueArena::new();
3010 let mut constants = Arena::new();
3011 let overrides = Arena::new();
3012 let mut global_expressions = Arena::new();
3013
3014 let matrix_ty = types.insert(
3015 Type {
3016 name: None,
3017 inner: TypeInner::Matrix {
3018 columns: VectorSize::Bi,
3019 rows: VectorSize::Tri,
3020 scalar: crate::Scalar::F32,
3021 },
3022 },
3023 Default::default(),
3024 );
3025
3026 let vec_ty = types.insert(
3027 Type {
3028 name: None,
3029 inner: TypeInner::Vector {
3030 size: VectorSize::Tri,
3031 scalar: crate::Scalar::F32,
3032 },
3033 },
3034 Default::default(),
3035 );
3036
3037 let mut vec1_components = Vec::with_capacity(3);
3038 let mut vec2_components = Vec::with_capacity(3);
3039
3040 for i in 0..3 {
3041 let h = global_expressions.append(
3042 Expression::Literal(Literal::F32(i as f32)),
3043 Default::default(),
3044 );
3045
3046 vec1_components.push(h)
3047 }
3048
3049 for i in 3..6 {
3050 let h = global_expressions.append(
3051 Expression::Literal(Literal::F32(i as f32)),
3052 Default::default(),
3053 );
3054
3055 vec2_components.push(h)
3056 }
3057
3058 let vec1 = constants.append(
3059 Constant {
3060 name: None,
3061 ty: vec_ty,
3062 init: global_expressions.append(
3063 Expression::Compose {
3064 ty: vec_ty,
3065 components: vec1_components,
3066 },
3067 Default::default(),
3068 ),
3069 },
3070 Default::default(),
3071 );
3072
3073 let vec2 = constants.append(
3074 Constant {
3075 name: None,
3076 ty: vec_ty,
3077 init: global_expressions.append(
3078 Expression::Compose {
3079 ty: vec_ty,
3080 components: vec2_components,
3081 },
3082 Default::default(),
3083 ),
3084 },
3085 Default::default(),
3086 );
3087
3088 let h = constants.append(
3089 Constant {
3090 name: None,
3091 ty: matrix_ty,
3092 init: global_expressions.append(
3093 Expression::Compose {
3094 ty: matrix_ty,
3095 components: vec![constants[vec1].init, constants[vec2].init],
3096 },
3097 Default::default(),
3098 ),
3099 },
3100 Default::default(),
3101 );
3102
3103 let base = global_expressions.append(Expression::Constant(h), Default::default());
3104
3105 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3106 let mut solver = ConstantEvaluator {
3107 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3108 types: &mut types,
3109 constants: &constants,
3110 overrides: &overrides,
3111 expressions: &mut global_expressions,
3112 expression_kind_tracker,
3113 layouter: &mut crate::proc::Layouter::default(),
3114 };
3115
3116 let root1 = Expression::AccessIndex { base, index: 1 };
3117
3118 let res1 = solver
3119 .try_eval_and_append(root1, Default::default())
3120 .unwrap();
3121
3122 let root2 = Expression::AccessIndex {
3123 base: res1,
3124 index: 2,
3125 };
3126
3127 let res2 = solver
3128 .try_eval_and_append(root2, Default::default())
3129 .unwrap();
3130
3131 match global_expressions[res1] {
3132 Expression::Compose {
3133 ref ty,
3134 ref components,
3135 } => {
3136 assert_eq!(*ty, vec_ty);
3137 let mut components_iter = components.iter().copied();
3138 assert_eq!(
3139 global_expressions[components_iter.next().unwrap()],
3140 Expression::Literal(Literal::F32(3.))
3141 );
3142 assert_eq!(
3143 global_expressions[components_iter.next().unwrap()],
3144 Expression::Literal(Literal::F32(4.))
3145 );
3146 assert_eq!(
3147 global_expressions[components_iter.next().unwrap()],
3148 Expression::Literal(Literal::F32(5.))
3149 );
3150 assert!(components_iter.next().is_none());
3151 }
3152 _ => panic!("Expected vector"),
3153 }
3154
3155 assert_eq!(
3156 global_expressions[res2],
3157 Expression::Literal(Literal::F32(5.))
3158 );
3159 }
3160
3161 #[test]
3162 fn compose_of_constants() {
3163 let mut types = UniqueArena::new();
3164 let mut constants = Arena::new();
3165 let overrides = Arena::new();
3166 let mut global_expressions = Arena::new();
3167
3168 let i32_ty = types.insert(
3169 Type {
3170 name: None,
3171 inner: TypeInner::Scalar(crate::Scalar::I32),
3172 },
3173 Default::default(),
3174 );
3175
3176 let vec2_i32_ty = types.insert(
3177 Type {
3178 name: None,
3179 inner: TypeInner::Vector {
3180 size: VectorSize::Bi,
3181 scalar: crate::Scalar::I32,
3182 },
3183 },
3184 Default::default(),
3185 );
3186
3187 let h = constants.append(
3188 Constant {
3189 name: None,
3190 ty: i32_ty,
3191 init: global_expressions
3192 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3193 },
3194 Default::default(),
3195 );
3196
3197 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3198
3199 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3200 let mut solver = ConstantEvaluator {
3201 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3202 types: &mut types,
3203 constants: &constants,
3204 overrides: &overrides,
3205 expressions: &mut global_expressions,
3206 expression_kind_tracker,
3207 layouter: &mut crate::proc::Layouter::default(),
3208 };
3209
3210 let solved_compose = solver
3211 .try_eval_and_append(
3212 Expression::Compose {
3213 ty: vec2_i32_ty,
3214 components: vec![h_expr, h_expr],
3215 },
3216 Default::default(),
3217 )
3218 .unwrap();
3219 let solved_negate = solver
3220 .try_eval_and_append(
3221 Expression::Unary {
3222 op: UnaryOperator::Negate,
3223 expr: solved_compose,
3224 },
3225 Default::default(),
3226 )
3227 .unwrap();
3228
3229 let pass = match global_expressions[solved_negate] {
3230 Expression::Compose { ty, ref components } => {
3231 ty == vec2_i32_ty
3232 && components.iter().all(|&component| {
3233 let component = &global_expressions[component];
3234 matches!(*component, Expression::Literal(Literal::I32(-4)))
3235 })
3236 }
3237 _ => false,
3238 };
3239 if !pass {
3240 panic!("unexpected evaluation result")
3241 }
3242 }
3243
3244 #[test]
3245 fn splat_of_constant() {
3246 let mut types = UniqueArena::new();
3247 let mut constants = Arena::new();
3248 let overrides = Arena::new();
3249 let mut global_expressions = Arena::new();
3250
3251 let i32_ty = types.insert(
3252 Type {
3253 name: None,
3254 inner: TypeInner::Scalar(crate::Scalar::I32),
3255 },
3256 Default::default(),
3257 );
3258
3259 let vec2_i32_ty = types.insert(
3260 Type {
3261 name: None,
3262 inner: TypeInner::Vector {
3263 size: VectorSize::Bi,
3264 scalar: crate::Scalar::I32,
3265 },
3266 },
3267 Default::default(),
3268 );
3269
3270 let h = constants.append(
3271 Constant {
3272 name: None,
3273 ty: i32_ty,
3274 init: global_expressions
3275 .append(Expression::Literal(Literal::I32(4)), Default::default()),
3276 },
3277 Default::default(),
3278 );
3279
3280 let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3281
3282 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3283 let mut solver = ConstantEvaluator {
3284 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3285 types: &mut types,
3286 constants: &constants,
3287 overrides: &overrides,
3288 expressions: &mut global_expressions,
3289 expression_kind_tracker,
3290 layouter: &mut crate::proc::Layouter::default(),
3291 };
3292
3293 let solved_compose = solver
3294 .try_eval_and_append(
3295 Expression::Splat {
3296 size: VectorSize::Bi,
3297 value: h_expr,
3298 },
3299 Default::default(),
3300 )
3301 .unwrap();
3302 let solved_negate = solver
3303 .try_eval_and_append(
3304 Expression::Unary {
3305 op: UnaryOperator::Negate,
3306 expr: solved_compose,
3307 },
3308 Default::default(),
3309 )
3310 .unwrap();
3311
3312 let pass = match global_expressions[solved_negate] {
3313 Expression::Compose { ty, ref components } => {
3314 ty == vec2_i32_ty
3315 && components.iter().all(|&component| {
3316 let component = &global_expressions[component];
3317 matches!(*component, Expression::Literal(Literal::I32(-4)))
3318 })
3319 }
3320 _ => false,
3321 };
3322 if !pass {
3323 panic!("unexpected evaluation result")
3324 }
3325 }
3326
3327 #[test]
3328 fn splat_of_zero_value() {
3329 let mut types = UniqueArena::new();
3330 let constants = Arena::new();
3331 let overrides = Arena::new();
3332 let mut global_expressions = Arena::new();
3333
3334 let f32_ty = types.insert(
3335 Type {
3336 name: None,
3337 inner: TypeInner::Scalar(crate::Scalar::F32),
3338 },
3339 Default::default(),
3340 );
3341
3342 let vec2_f32_ty = types.insert(
3343 Type {
3344 name: None,
3345 inner: TypeInner::Vector {
3346 size: VectorSize::Bi,
3347 scalar: crate::Scalar::F32,
3348 },
3349 },
3350 Default::default(),
3351 );
3352
3353 let five =
3354 global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
3355 let five_splat = global_expressions.append(
3356 Expression::Splat {
3357 size: VectorSize::Bi,
3358 value: five,
3359 },
3360 Default::default(),
3361 );
3362 let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
3363 let zero_splat = global_expressions.append(
3364 Expression::Splat {
3365 size: VectorSize::Bi,
3366 value: zero,
3367 },
3368 Default::default(),
3369 );
3370
3371 let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3372 let mut solver = ConstantEvaluator {
3373 behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3374 types: &mut types,
3375 constants: &constants,
3376 overrides: &overrides,
3377 expressions: &mut global_expressions,
3378 expression_kind_tracker,
3379 layouter: &mut crate::proc::Layouter::default(),
3380 };
3381
3382 let solved_add = solver
3383 .try_eval_and_append(
3384 Expression::Binary {
3385 op: crate::BinaryOperator::Add,
3386 left: zero_splat,
3387 right: five_splat,
3388 },
3389 Default::default(),
3390 )
3391 .unwrap();
3392
3393 let pass = match global_expressions[solved_add] {
3394 Expression::Compose { ty, ref components } => {
3395 ty == vec2_f32_ty
3396 && components.iter().all(|&component| {
3397 let component = &global_expressions[component];
3398 matches!(*component, Expression::Literal(Literal::F32(5.0)))
3399 })
3400 }
3401 _ => false,
3402 };
3403 if !pass {
3404 panic!("unexpected evaluation result")
3405 }
3406 }
3407}