naga/proc/
constant_evaluator.rs

1use alloc::{
2    format,
3    string::{String, ToString},
4    vec,
5    vec::Vec,
6};
7use core::iter;
8
9use arrayvec::ArrayVec;
10use half::f16;
11use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
12
13use crate::{
14    arena::{Arena, Handle, HandleVec, UniqueArena},
15    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
16    ScalarKind, Span, Type, TypeInner, UnaryOperator,
17};
18
19#[cfg(feature = "wgsl-in")]
20use crate::common::wgsl::TryToWgsl;
21
22/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
23/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
24///
25/// Technique stolen directly from
26/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
27macro_rules! with_dollar_sign {
28    ($($body:tt)*) => {
29        macro_rules! __with_dollar_sign { $($body)* }
30        __with_dollar_sign!($);
31    }
32}
33
34macro_rules! gen_component_wise_extractor {
35    (
36        $ident:ident -> $target:ident,
37        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
38        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
39    ) => {
40        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
41        #[derive(Debug)]
42        #[cfg_attr(test, derive(PartialEq))]
43        enum $target<const N: usize> {
44            $(
45                #[doc = concat!(
46                    "Maps to [`Literal::",
47                    stringify!($literal),
48                    "`]",
49                )]
50                $mapping([$ty; N]),
51            )+
52        }
53
54        impl From<$target<1>> for Expression {
55            fn from(value: $target<1>) -> Self {
56                match value {
57                    $(
58                        $target::$mapping([value]) => {
59                            Expression::Literal(Literal::$literal(value))
60                        }
61                    )+
62                }
63            }
64        }
65
66        #[doc = concat!(
67            "Attempts to evaluate multiple `exprs` as a combined [`",
68            stringify!($target),
69            "`] to pass to `handler`. ",
70        )]
71        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
72        /// component of each vector.
73        ///
74        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
75        /// same length, a new vector expression is registered, composed of each component emitted
76        /// by `handler`.
77        fn $ident<const N: usize, const M: usize, F>(
78            eval: &mut ConstantEvaluator<'_>,
79            span: Span,
80            exprs: [Handle<Expression>; N],
81            mut handler: F,
82        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
83        where
84            $target<M>: Into<Expression>,
85            F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
86        {
87            assert!(N > 0);
88            let err = ConstantEvaluatorError::InvalidMathArg;
89            let mut exprs = exprs.into_iter();
90
91            macro_rules! sanitize {
92                ($expr:expr) => {
93                    eval.eval_zero_value_and_splat($expr, span)
94                        .map(|expr| &eval.expressions[expr])
95                };
96            }
97
98            let new_expr = match sanitize!(exprs.next().unwrap())? {
99                $(
100                    &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
101                        .chain(exprs.map(|expr| {
102                            sanitize!(expr).and_then(|expr| match expr {
103                                &Expression::Literal(Literal::$literal(x)) => Ok(x),
104                                _ => Err(err.clone()),
105                            })
106                        }))
107                        .collect::<Result<ArrayVec<_, N>, _>>()
108                        .map(|a| a.into_inner().unwrap())
109                        .map($target::$mapping)
110                        .and_then(|comps| Ok(handler(comps)?.into())),
111                )+
112                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
113                    &TypeInner::Vector { size, scalar } => match scalar.kind {
114                        $(ScalarKind::$scalar_kind)|* => {
115                            let first_ty = ty;
116                            let mut component_groups =
117                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
118                            component_groups.push(crate::proc::flatten_compose(
119                                first_ty,
120                                components,
121                                eval.expressions,
122                                eval.types,
123                            ).collect());
124                            component_groups.extend(
125                                exprs
126                                    .map(|expr| {
127                                        sanitize!(expr).and_then(|expr| match expr {
128                                            &Expression::Compose { ty, ref components }
129                                                if &eval.types[ty].inner
130                                                    == &eval.types[first_ty].inner =>
131                                            {
132                                                Ok(crate::proc::flatten_compose(
133                                                    ty,
134                                                    components,
135                                                    eval.expressions,
136                                                    eval.types,
137                                                ).collect())
138                                            }
139                                            _ => Err(err.clone()),
140                                        })
141                                    })
142                                    .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
143                                    )?,
144                            );
145                            let component_groups = component_groups.into_inner().unwrap();
146                            let mut new_components =
147                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
148                            for idx in 0..(size as u8).into() {
149                                let group = component_groups
150                                    .iter()
151                                    .map(|cs| cs.get(idx).cloned().ok_or(err.clone()))
152                                    .collect::<Result<ArrayVec<_, N>, _>>()?
153                                    .into_inner()
154                                    .unwrap();
155                                new_components.push($ident(
156                                    eval,
157                                    span,
158                                    group,
159                                    handler.clone(),
160                                )?);
161                            }
162                            Ok(Expression::Compose {
163                                ty: first_ty,
164                                components: new_components.into_iter().collect(),
165                            })
166                        }
167                        _ => return Err(err),
168                    },
169                    _ => return Err(err),
170                },
171                _ => return Err(err),
172            }?;
173            eval.register_evaluated_expr(new_expr, span)
174        }
175
176        with_dollar_sign! {
177            ($d:tt) => {
178                #[allow(unused)]
179                #[doc = concat!(
180                    "A convenience macro for using the same RHS for each [`",
181                    stringify!($target),
182                    "`] variant in a call to [`",
183                    stringify!($ident),
184                    "`].",
185                )]
186                macro_rules! $ident {
187                    (
188                        $eval:expr,
189                        $span:expr,
190                        [$d ($d expr:expr),+ $d (,)?],
191                        |$d ($d arg:ident),+| $d tt:tt
192                    ) => {
193                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
194                            $(
195                                $target::$mapping([$d ($d arg),+]) => {
196                                    let res = $d tt;
197                                    Result::map(res, $target::$mapping)
198                                },
199                            )+
200                        })
201                    };
202                }
203            };
204        }
205    };
206}
207
208gen_component_wise_extractor! {
209    component_wise_scalar -> Scalar,
210    literals: [
211        AbstractFloat => AbstractFloat: f64,
212        F32 => F32: f32,
213        F16 => F16: f16,
214        AbstractInt => AbstractInt: i64,
215        U32 => U32: u32,
216        I32 => I32: i32,
217        U64 => U64: u64,
218        I64 => I64: i64,
219    ],
220    scalar_kinds: [
221        Float,
222        AbstractFloat,
223        Sint,
224        Uint,
225        AbstractInt,
226    ],
227}
228
229gen_component_wise_extractor! {
230    component_wise_float -> Float,
231    literals: [
232        AbstractFloat => Abstract: f64,
233        F32 => F32: f32,
234        F16 => F16: f16,
235    ],
236    scalar_kinds: [
237        Float,
238        AbstractFloat,
239    ],
240}
241
242gen_component_wise_extractor! {
243    component_wise_concrete_int -> ConcreteInt,
244    literals: [
245        U32 => U32: u32,
246        I32 => I32: i32,
247    ],
248    scalar_kinds: [
249        Sint,
250        Uint,
251    ],
252}
253
254gen_component_wise_extractor! {
255    component_wise_signed -> Signed,
256    literals: [
257        AbstractFloat => AbstractFloat: f64,
258        AbstractInt => AbstractInt: i64,
259        F32 => F32: f32,
260        F16 => F16: f16,
261        I32 => I32: i32,
262    ],
263    scalar_kinds: [
264        Sint,
265        AbstractInt,
266        Float,
267        AbstractFloat,
268    ],
269}
270
271#[derive(Debug)]
272enum Behavior<'a> {
273    Wgsl(WgslRestrictions<'a>),
274    Glsl(GlslRestrictions<'a>),
275}
276
277impl Behavior<'_> {
278    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
279    const fn has_runtime_restrictions(&self) -> bool {
280        matches!(
281            self,
282            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
283                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
284        )
285    }
286}
287
288/// A context for evaluating constant expressions.
289///
290/// A `ConstantEvaluator` points at an expression arena to which it can append
291/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
292/// of Naga [`Expression`] you like, and if its value can be computed at compile
293/// time, `try_eval_and_append` appends an expression representing the computed
294/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
295/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
296///
297/// A `ConstantEvaluator` also holds whatever information we need to carry out
298/// that evaluation: types, other constants, and so on.
299///
300/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
301/// [`Compose`]: Expression::Compose
302/// [`ZeroValue`]: Expression::ZeroValue
303/// [`Literal`]: Expression::Literal
304/// [`Swizzle`]: Expression::Swizzle
305#[derive(Debug)]
306pub struct ConstantEvaluator<'a> {
307    /// Which language's evaluation rules we should follow.
308    behavior: Behavior<'a>,
309
310    /// The module's type arena.
311    ///
312    /// Because expressions like [`Splat`] contain type handles, we need to be
313    /// able to add new types to produce those expressions.
314    ///
315    /// [`Splat`]: Expression::Splat
316    types: &'a mut UniqueArena<Type>,
317
318    /// The module's constant arena.
319    constants: &'a Arena<Constant>,
320
321    /// The module's override arena.
322    overrides: &'a Arena<Override>,
323
324    /// The arena to which we are contributing expressions.
325    expressions: &'a mut Arena<Expression>,
326
327    /// Tracks the constness of expressions residing in [`Self::expressions`]
328    expression_kind_tracker: &'a mut ExpressionKindTracker,
329
330    layouter: &'a mut crate::proc::Layouter,
331}
332
333#[derive(Debug)]
334enum WgslRestrictions<'a> {
335    /// - const-expressions will be evaluated and inserted in the arena
336    Const(Option<FunctionLocalData<'a>>),
337    /// - const-expressions will be evaluated and inserted in the arena
338    /// - override-expressions will be inserted in the arena
339    Override,
340    /// - const-expressions will be evaluated and inserted in the arena
341    /// - override-expressions will be inserted in the arena
342    /// - runtime-expressions will be inserted in the arena
343    Runtime(FunctionLocalData<'a>),
344}
345
346#[derive(Debug)]
347enum GlslRestrictions<'a> {
348    /// - const-expressions will be evaluated and inserted in the arena
349    Const,
350    /// - const-expressions will be evaluated and inserted in the arena
351    /// - override-expressions will be inserted in the arena
352    /// - runtime-expressions will be inserted in the arena
353    Runtime(FunctionLocalData<'a>),
354}
355
356#[derive(Debug)]
357struct FunctionLocalData<'a> {
358    /// Global constant expressions
359    global_expressions: &'a Arena<Expression>,
360    emitter: &'a mut super::Emitter,
361    block: &'a mut crate::Block,
362}
363
364#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
365pub enum ExpressionKind {
366    Const,
367    Override,
368    Runtime,
369}
370
371#[derive(Debug)]
372pub struct ExpressionKindTracker {
373    inner: HandleVec<Expression, ExpressionKind>,
374}
375
376impl ExpressionKindTracker {
377    pub const fn new() -> Self {
378        Self {
379            inner: HandleVec::new(),
380        }
381    }
382
383    /// Forces the the expression to not be const
384    pub fn force_non_const(&mut self, value: Handle<Expression>) {
385        self.inner[value] = ExpressionKind::Runtime;
386    }
387
388    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
389        self.inner.insert(value, expr_type);
390    }
391
392    pub fn is_const(&self, h: Handle<Expression>) -> bool {
393        matches!(self.type_of(h), ExpressionKind::Const)
394    }
395
396    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
397        matches!(
398            self.type_of(h),
399            ExpressionKind::Const | ExpressionKind::Override
400        )
401    }
402
403    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
404        self.inner[value]
405    }
406
407    pub fn from_arena(arena: &Arena<Expression>) -> Self {
408        let mut tracker = Self {
409            inner: HandleVec::with_capacity(arena.len()),
410        };
411        for (handle, expr) in arena.iter() {
412            tracker
413                .inner
414                .insert(handle, tracker.type_of_with_expr(expr));
415        }
416        tracker
417    }
418
419    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
420        match *expr {
421            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
422                ExpressionKind::Const
423            }
424            Expression::Override(_) => ExpressionKind::Override,
425            Expression::Compose { ref components, .. } => {
426                let mut expr_type = ExpressionKind::Const;
427                for component in components {
428                    expr_type = expr_type.max(self.type_of(*component))
429                }
430                expr_type
431            }
432            Expression::Splat { value, .. } => self.type_of(value),
433            Expression::AccessIndex { base, .. } => self.type_of(base),
434            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
435            Expression::Swizzle { vector, .. } => self.type_of(vector),
436            Expression::Unary { expr, .. } => self.type_of(expr),
437            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
438            Expression::Math {
439                arg,
440                arg1,
441                arg2,
442                arg3,
443                ..
444            } => self
445                .type_of(arg)
446                .max(
447                    arg1.map(|arg| self.type_of(arg))
448                        .unwrap_or(ExpressionKind::Const),
449                )
450                .max(
451                    arg2.map(|arg| self.type_of(arg))
452                        .unwrap_or(ExpressionKind::Const),
453                )
454                .max(
455                    arg3.map(|arg| self.type_of(arg))
456                        .unwrap_or(ExpressionKind::Const),
457                ),
458            Expression::As { expr, .. } => self.type_of(expr),
459            Expression::Select {
460                condition,
461                accept,
462                reject,
463            } => self
464                .type_of(condition)
465                .max(self.type_of(accept))
466                .max(self.type_of(reject)),
467            Expression::Relational { argument, .. } => self.type_of(argument),
468            Expression::ArrayLength(expr) => self.type_of(expr),
469            _ => ExpressionKind::Runtime,
470        }
471    }
472}
473
474#[derive(Clone, Debug, thiserror::Error)]
475#[cfg_attr(test, derive(PartialEq))]
476pub enum ConstantEvaluatorError {
477    #[error("Constants cannot access function arguments")]
478    FunctionArg,
479    #[error("Constants cannot access global variables")]
480    GlobalVariable,
481    #[error("Constants cannot access local variables")]
482    LocalVariable,
483    #[error("Cannot get the array length of a non array type")]
484    InvalidArrayLengthArg,
485    #[error("Constants cannot get the array length of a dynamically sized array")]
486    ArrayLengthDynamic,
487    #[error("Cannot call arrayLength on array sized by override-expression")]
488    ArrayLengthOverridden,
489    #[error("Constants cannot call functions")]
490    Call,
491    #[error("Constants don't support workGroupUniformLoad")]
492    WorkGroupUniformLoadResult,
493    #[error("Constants don't support atomic functions")]
494    Atomic,
495    #[error("Constants don't support derivative functions")]
496    Derivative,
497    #[error("Constants don't support load expressions")]
498    Load,
499    #[error("Constants don't support image expressions")]
500    ImageExpression,
501    #[error("Constants don't support ray query expressions")]
502    RayQueryExpression,
503    #[error("Constants don't support subgroup expressions")]
504    SubgroupExpression,
505    #[error("Cannot access the type")]
506    InvalidAccessBase,
507    #[error("Cannot access at the index")]
508    InvalidAccessIndex,
509    #[error("Cannot access with index of type")]
510    InvalidAccessIndexTy,
511    #[error("Constants don't support array length expressions")]
512    ArrayLength,
513    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
514    InvalidCastArg { from: String, to: String },
515    #[error("Cannot apply the unary op to the argument")]
516    InvalidUnaryOpArg,
517    #[error("Cannot apply the binary op to the arguments")]
518    InvalidBinaryOpArgs,
519    #[error("Cannot apply math function to type")]
520    InvalidMathArg,
521    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
522    InvalidMathArgCount(crate::MathFunction, usize, usize),
523    #[error("Cannot apply relational function to type")]
524    InvalidRelationalArg(RelationalFunction),
525    #[error("value of `low` is greater than `high` for clamp built-in function")]
526    InvalidClamp,
527    #[error("Constructor expects {expected} components, found {actual}")]
528    InvalidVectorComposeLength { expected: usize, actual: usize },
529    #[error("Constructor must only contain vector or scalar arguments")]
530    InvalidVectorComposeComponent,
531    #[error("Splat is defined only on scalar values")]
532    SplatScalarOnly,
533    #[error("Can only swizzle vector constants")]
534    SwizzleVectorOnly,
535    #[error("swizzle component not present in source expression")]
536    SwizzleOutOfBounds,
537    #[error("Type is not constructible")]
538    TypeNotConstructible,
539    #[error("Subexpression(s) are not constant")]
540    SubexpressionsAreNotConstant,
541    #[error("Not implemented as constant expression: {0}")]
542    NotImplemented(String),
543    #[error("{0} operation overflowed")]
544    Overflow(String),
545    #[error(
546        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
547    )]
548    AutomaticConversionLossy {
549        value: String,
550        to_type: &'static str,
551    },
552    #[error("Division by zero")]
553    DivisionByZero,
554    #[error("Remainder by zero")]
555    RemainderByZero,
556    #[error("RHS of shift operation is greater than or equal to 32")]
557    ShiftedMoreThan32Bits,
558    #[error(transparent)]
559    Literal(#[from] crate::valid::LiteralError),
560    #[error("Can't use pipeline-overridable constants in const-expressions")]
561    Override,
562    #[error("Unexpected runtime-expression")]
563    RuntimeExpr,
564    #[error("Unexpected override-expression")]
565    OverrideExpr,
566}
567
568impl<'a> ConstantEvaluator<'a> {
569    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
570    /// constant expression arena.
571    ///
572    /// Report errors according to WGSL's rules for constant evaluation.
573    pub fn for_wgsl_module(
574        module: &'a mut crate::Module,
575        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
576        layouter: &'a mut crate::proc::Layouter,
577        in_override_ctx: bool,
578    ) -> Self {
579        Self::for_module(
580            Behavior::Wgsl(if in_override_ctx {
581                WgslRestrictions::Override
582            } else {
583                WgslRestrictions::Const(None)
584            }),
585            module,
586            global_expression_kind_tracker,
587            layouter,
588        )
589    }
590
591    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
592    /// constant expression arena.
593    ///
594    /// Report errors according to GLSL's rules for constant evaluation.
595    pub fn for_glsl_module(
596        module: &'a mut crate::Module,
597        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
598        layouter: &'a mut crate::proc::Layouter,
599    ) -> Self {
600        Self::for_module(
601            Behavior::Glsl(GlslRestrictions::Const),
602            module,
603            global_expression_kind_tracker,
604            layouter,
605        )
606    }
607
608    fn for_module(
609        behavior: Behavior<'a>,
610        module: &'a mut crate::Module,
611        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
612        layouter: &'a mut crate::proc::Layouter,
613    ) -> Self {
614        Self {
615            behavior,
616            types: &mut module.types,
617            constants: &module.constants,
618            overrides: &module.overrides,
619            expressions: &mut module.global_expressions,
620            expression_kind_tracker: global_expression_kind_tracker,
621            layouter,
622        }
623    }
624
625    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
626    /// expression arena.
627    ///
628    /// Report errors according to WGSL's rules for constant evaluation.
629    pub fn for_wgsl_function(
630        module: &'a mut crate::Module,
631        expressions: &'a mut Arena<Expression>,
632        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
633        layouter: &'a mut crate::proc::Layouter,
634        emitter: &'a mut super::Emitter,
635        block: &'a mut crate::Block,
636        is_const: bool,
637    ) -> Self {
638        let local_data = FunctionLocalData {
639            global_expressions: &module.global_expressions,
640            emitter,
641            block,
642        };
643        Self {
644            behavior: Behavior::Wgsl(if is_const {
645                WgslRestrictions::Const(Some(local_data))
646            } else {
647                WgslRestrictions::Runtime(local_data)
648            }),
649            types: &mut module.types,
650            constants: &module.constants,
651            overrides: &module.overrides,
652            expressions,
653            expression_kind_tracker: local_expression_kind_tracker,
654            layouter,
655        }
656    }
657
658    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
659    /// expression arena.
660    ///
661    /// Report errors according to GLSL's rules for constant evaluation.
662    pub fn for_glsl_function(
663        module: &'a mut crate::Module,
664        expressions: &'a mut Arena<Expression>,
665        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
666        layouter: &'a mut crate::proc::Layouter,
667        emitter: &'a mut super::Emitter,
668        block: &'a mut crate::Block,
669    ) -> Self {
670        Self {
671            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
672                global_expressions: &module.global_expressions,
673                emitter,
674                block,
675            })),
676            types: &mut module.types,
677            constants: &module.constants,
678            overrides: &module.overrides,
679            expressions,
680            expression_kind_tracker: local_expression_kind_tracker,
681            layouter,
682        }
683    }
684
685    pub fn to_ctx(&self) -> crate::proc::GlobalCtx {
686        crate::proc::GlobalCtx {
687            types: self.types,
688            constants: self.constants,
689            overrides: self.overrides,
690            global_expressions: match self.function_local_data() {
691                Some(data) => data.global_expressions,
692                None => self.expressions,
693            },
694        }
695    }
696
697    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
698        if !self.expression_kind_tracker.is_const(expr) {
699            log::debug!("check: SubexpressionsAreNotConstant");
700            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
701        }
702        Ok(())
703    }
704
705    fn check_and_get(
706        &mut self,
707        expr: Handle<Expression>,
708    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
709        match self.expressions[expr] {
710            Expression::Constant(c) => {
711                // Are we working in a function's expression arena, or the
712                // module's constant expression arena?
713                if let Some(function_local_data) = self.function_local_data() {
714                    // Deep-copy the constant's value into our arena.
715                    self.copy_from(
716                        self.constants[c].init,
717                        function_local_data.global_expressions,
718                    )
719                } else {
720                    // "See through" the constant and use its initializer.
721                    Ok(self.constants[c].init)
722                }
723            }
724            _ => {
725                self.check(expr)?;
726                Ok(expr)
727            }
728        }
729    }
730
731    /// Try to evaluate `expr` at compile time.
732    ///
733    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
734    /// we can determine its value at compile time, we append an expression
735    /// representing its value - a tree of [`Literal`], [`Compose`],
736    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
737    /// `self` contributes to.
738    ///
739    /// If `expr`'s value cannot be determined at compile time, and `self` is
740    /// contributing to some function's expression arena, then append `expr` to
741    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
742    /// contributing to the module's constant expression arena; since `expr`'s
743    /// value is not a constant, return an error.
744    ///
745    /// We only consider `expr` itself, without recursing into its operands. Its
746    /// operands must all have been produced by prior calls to
747    /// `try_eval_and_append`, to ensure that they have already been reduced to
748    /// an evaluated form if possible.
749    ///
750    /// [`Literal`]: Expression::Literal
751    /// [`Compose`]: Expression::Compose
752    /// [`ZeroValue`]: Expression::ZeroValue
753    /// [`Swizzle`]: Expression::Swizzle
754    pub fn try_eval_and_append(
755        &mut self,
756        expr: Expression,
757        span: Span,
758    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
759        match self.expression_kind_tracker.type_of_with_expr(&expr) {
760            ExpressionKind::Const => {
761                let eval_result = self.try_eval_and_append_impl(&expr, span);
762                // We should be able to evaluate `Const` expressions at this
763                // point. If we failed to, then that probably means we just
764                // haven't implemented that part of constant evaluation. Work
765                // around this by simply emitting it as a run-time expression.
766                if self.behavior.has_runtime_restrictions()
767                    && matches!(
768                        eval_result,
769                        Err(ConstantEvaluatorError::NotImplemented(_)
770                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
771                    )
772                {
773                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
774                } else {
775                    eval_result
776                }
777            }
778            ExpressionKind::Override => match self.behavior {
779                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
780                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
781                }
782                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
783                    Err(ConstantEvaluatorError::OverrideExpr)
784                }
785                Behavior::Glsl(_) => {
786                    unreachable!()
787                }
788            },
789            ExpressionKind::Runtime => {
790                if self.behavior.has_runtime_restrictions() {
791                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
792                } else {
793                    Err(ConstantEvaluatorError::RuntimeExpr)
794                }
795            }
796        }
797    }
798
799    /// Is the [`Self::expressions`] arena the global module expression arena?
800    const fn is_global_arena(&self) -> bool {
801        matches!(
802            self.behavior,
803            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
804                | Behavior::Glsl(GlslRestrictions::Const)
805        )
806    }
807
808    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
809        match self.behavior {
810            Behavior::Wgsl(
811                WgslRestrictions::Runtime(ref function_local_data)
812                | WgslRestrictions::Const(Some(ref function_local_data)),
813            )
814            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
815                Some(function_local_data)
816            }
817            _ => None,
818        }
819    }
820
821    fn try_eval_and_append_impl(
822        &mut self,
823        expr: &Expression,
824        span: Span,
825    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
826        log::trace!("try_eval_and_append: {:?}", expr);
827        match *expr {
828            Expression::Constant(c) if self.is_global_arena() => {
829                // "See through" the constant and use its initializer.
830                // This is mainly done to avoid having constants pointing to other constants.
831                Ok(self.constants[c].init)
832            }
833            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
834            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
835                self.register_evaluated_expr(expr.clone(), span)
836            }
837            Expression::Compose { ty, ref components } => {
838                let components = components
839                    .iter()
840                    .map(|component| self.check_and_get(*component))
841                    .collect::<Result<Vec<_>, _>>()?;
842                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
843            }
844            Expression::Splat { size, value } => {
845                let value = self.check_and_get(value)?;
846                self.register_evaluated_expr(Expression::Splat { size, value }, span)
847            }
848            Expression::AccessIndex { base, index } => {
849                let base = self.check_and_get(base)?;
850
851                self.access(base, index as usize, span)
852            }
853            Expression::Access { base, index } => {
854                let base = self.check_and_get(base)?;
855                let index = self.check_and_get(index)?;
856
857                self.access(base, self.constant_index(index)?, span)
858            }
859            Expression::Swizzle {
860                size,
861                vector,
862                pattern,
863            } => {
864                let vector = self.check_and_get(vector)?;
865
866                self.swizzle(size, span, vector, pattern)
867            }
868            Expression::Unary { expr, op } => {
869                let expr = self.check_and_get(expr)?;
870
871                self.unary_op(op, expr, span)
872            }
873            Expression::Binary { left, right, op } => {
874                let left = self.check_and_get(left)?;
875                let right = self.check_and_get(right)?;
876
877                self.binary_op(op, left, right, span)
878            }
879            Expression::Math {
880                fun,
881                arg,
882                arg1,
883                arg2,
884                arg3,
885            } => {
886                let arg = self.check_and_get(arg)?;
887                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
888                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
889                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
890
891                self.math(arg, arg1, arg2, arg3, fun, span)
892            }
893            Expression::As {
894                convert,
895                expr,
896                kind,
897            } => {
898                let expr = self.check_and_get(expr)?;
899
900                match convert {
901                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
902                    None => Err(ConstantEvaluatorError::NotImplemented(
903                        "bitcast built-in function".into(),
904                    )),
905                }
906            }
907            Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
908                "select built-in function".into(),
909            )),
910            Expression::Relational { fun, argument } => {
911                let argument = self.check_and_get(argument)?;
912                self.relational(fun, argument, span)
913            }
914            Expression::ArrayLength(expr) => match self.behavior {
915                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
916                Behavior::Glsl(_) => {
917                    let expr = self.check_and_get(expr)?;
918                    self.array_length(expr, span)
919                }
920            },
921            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
922            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
923            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
924            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
925            Expression::WorkGroupUniformLoadResult { .. } => {
926                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
927            }
928            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
929            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
930            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
931            Expression::ImageSample { .. }
932            | Expression::ImageLoad { .. }
933            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
934            Expression::RayQueryProceedResult
935            | Expression::RayQueryGetIntersection { .. }
936            | Expression::RayQueryVertexPositions { .. } => {
937                Err(ConstantEvaluatorError::RayQueryExpression)
938            }
939            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
940            Expression::SubgroupOperationResult { .. } => {
941                Err(ConstantEvaluatorError::SubgroupExpression)
942            }
943        }
944    }
945
946    /// Splat `value` to `size`, without using [`Splat`] expressions.
947    ///
948    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
949    /// build a vector with the given `size` whose components are all
950    /// `value`.
951    ///
952    /// Use `span` as the span of the inserted expressions and
953    /// resulting types.
954    ///
955    /// [`Splat`]: Expression::Splat
956    /// [`Compose`]: Expression::Compose
957    /// [`ZeroValue`]: Expression::ZeroValue
958    fn splat(
959        &mut self,
960        value: Handle<Expression>,
961        size: crate::VectorSize,
962        span: Span,
963    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
964        match self.expressions[value] {
965            Expression::Literal(literal) => {
966                let scalar = literal.scalar();
967                let ty = self.types.insert(
968                    Type {
969                        name: None,
970                        inner: TypeInner::Vector { size, scalar },
971                    },
972                    span,
973                );
974                let expr = Expression::Compose {
975                    ty,
976                    components: vec![value; size as usize],
977                };
978                self.register_evaluated_expr(expr, span)
979            }
980            Expression::ZeroValue(ty) => {
981                let inner = match self.types[ty].inner {
982                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
983                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
984                };
985                let res_ty = self.types.insert(Type { name: None, inner }, span);
986                let expr = Expression::ZeroValue(res_ty);
987                self.register_evaluated_expr(expr, span)
988            }
989            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
990        }
991    }
992
993    fn swizzle(
994        &mut self,
995        size: crate::VectorSize,
996        span: Span,
997        src_constant: Handle<Expression>,
998        pattern: [crate::SwizzleComponent; 4],
999    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1000        let mut get_dst_ty = |ty| match self.types[ty].inner {
1001            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1002                Type {
1003                    name: None,
1004                    inner: TypeInner::Vector { size, scalar },
1005                },
1006                span,
1007            )),
1008            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1009        };
1010
1011        match self.expressions[src_constant] {
1012            Expression::ZeroValue(ty) => {
1013                let dst_ty = get_dst_ty(ty)?;
1014                let expr = Expression::ZeroValue(dst_ty);
1015                self.register_evaluated_expr(expr, span)
1016            }
1017            Expression::Splat { value, .. } => {
1018                let expr = Expression::Splat { size, value };
1019                self.register_evaluated_expr(expr, span)
1020            }
1021            Expression::Compose { ty, ref components } => {
1022                let dst_ty = get_dst_ty(ty)?;
1023
1024                let mut flattened = [src_constant; 4]; // dummy value
1025                let len =
1026                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1027                        .zip(flattened.iter_mut())
1028                        .map(|(component, elt)| *elt = component)
1029                        .count();
1030                let flattened = &flattened[..len];
1031
1032                let swizzled_components = pattern[..size as usize]
1033                    .iter()
1034                    .map(|&sc| {
1035                        let sc = sc as usize;
1036                        if let Some(elt) = flattened.get(sc) {
1037                            Ok(*elt)
1038                        } else {
1039                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1040                        }
1041                    })
1042                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1043                let expr = Expression::Compose {
1044                    ty: dst_ty,
1045                    components: swizzled_components,
1046                };
1047                self.register_evaluated_expr(expr, span)
1048            }
1049            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1050        }
1051    }
1052
1053    fn math(
1054        &mut self,
1055        arg: Handle<Expression>,
1056        arg1: Option<Handle<Expression>>,
1057        arg2: Option<Handle<Expression>>,
1058        arg3: Option<Handle<Expression>>,
1059        fun: crate::MathFunction,
1060        span: Span,
1061    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1062        let expected = fun.argument_count();
1063        let given = Some(arg)
1064            .into_iter()
1065            .chain(arg1)
1066            .chain(arg2)
1067            .chain(arg3)
1068            .count();
1069        if expected != given {
1070            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1071                fun, expected, given,
1072            ));
1073        }
1074
1075        // NOTE: We try to match the declaration order of `MathFunction` here.
1076        match fun {
1077            // comparison
1078            crate::MathFunction::Abs => {
1079                component_wise_scalar(self, span, [arg], |args| match args {
1080                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1081                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1082                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1083                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
1084                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1085                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1086                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1087                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1088                })
1089            }
1090            crate::MathFunction::Min => {
1091                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1092                    Ok([e1.min(e2)])
1093                })
1094            }
1095            crate::MathFunction::Max => {
1096                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1097                    Ok([e1.max(e2)])
1098                })
1099            }
1100            crate::MathFunction::Clamp => {
1101                component_wise_scalar!(
1102                    self,
1103                    span,
1104                    [arg, arg1.unwrap(), arg2.unwrap()],
1105                    |e, low, high| {
1106                        if low > high {
1107                            Err(ConstantEvaluatorError::InvalidClamp)
1108                        } else {
1109                            Ok([e.clamp(low, high)])
1110                        }
1111                    }
1112                )
1113            }
1114            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1115                Float::F16([e]) => Ok(Float::F16(
1116                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1117                )),
1118                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1119                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1120            }),
1121
1122            // trigonometry
1123            crate::MathFunction::Cos => {
1124                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1125            }
1126            crate::MathFunction::Cosh => {
1127                component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
1128            }
1129            crate::MathFunction::Sin => {
1130                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1131            }
1132            crate::MathFunction::Sinh => {
1133                component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
1134            }
1135            crate::MathFunction::Tan => {
1136                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1137            }
1138            crate::MathFunction::Tanh => {
1139                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1140            }
1141            crate::MathFunction::Acos => {
1142                component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
1143            }
1144            crate::MathFunction::Asin => {
1145                component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
1146            }
1147            crate::MathFunction::Atan => {
1148                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1149            }
1150            crate::MathFunction::Asinh => {
1151                component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
1152            }
1153            crate::MathFunction::Acosh => {
1154                component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1155            }
1156            crate::MathFunction::Atanh => {
1157                component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
1158            }
1159            crate::MathFunction::Radians => {
1160                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1161            }
1162            crate::MathFunction::Degrees => {
1163                component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
1164            }
1165
1166            // decomposition
1167            crate::MathFunction::Ceil => {
1168                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1169            }
1170            crate::MathFunction::Floor => {
1171                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1172            }
1173            crate::MathFunction::Round => {
1174                component_wise_float(self, span, [arg], |e| match e {
1175                    Float::Abstract([e]) => Ok(Float::Abstract([e.round_ties_even()])),
1176                    Float::F32([e]) => Ok(Float::F32([e.round_ties_even()])),
1177                    Float::F16([e]) => {
1178                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1179                        //
1180                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1181                        // which has licensing compatible with ours. See also
1182                        // <https://github.com/rust-lang/rust/issues/96710>.
1183                        //
1184                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1185                        fn round_ties_even(x: f64) -> f64 {
1186                            let i = x as i64;
1187                            let f = (x - i as f64).abs();
1188                            if f == 0.5 {
1189                                if i & 1 == 1 {
1190                                    // -1.5, 1.5, 3.5, ...
1191                                    (x.abs() + 0.5).copysign(x)
1192                                } else {
1193                                    (x.abs() - 0.5).copysign(x)
1194                                }
1195                            } else {
1196                                x.round()
1197                            }
1198                        }
1199
1200                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1201                    }
1202                })
1203            }
1204            crate::MathFunction::Fract => {
1205                component_wise_float!(self, span, [arg], |e| {
1206                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1207                    // here.
1208                    Ok([e - e.floor()])
1209                })
1210            }
1211            crate::MathFunction::Trunc => {
1212                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1213            }
1214
1215            // exponent
1216            crate::MathFunction::Exp => {
1217                component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
1218            }
1219            crate::MathFunction::Exp2 => {
1220                component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
1221            }
1222            crate::MathFunction::Log => {
1223                component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
1224            }
1225            crate::MathFunction::Log2 => {
1226                component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
1227            }
1228            crate::MathFunction::Pow => {
1229                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1230                    Ok([e1.powf(e2)])
1231                })
1232            }
1233
1234            // computational
1235            crate::MathFunction::Sign => {
1236                component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
1237            }
1238            crate::MathFunction::Fma => {
1239                component_wise_float!(
1240                    self,
1241                    span,
1242                    [arg, arg1.unwrap(), arg2.unwrap()],
1243                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1244                )
1245            }
1246            crate::MathFunction::Step => {
1247                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1248                    Float::Abstract([edge, x]) => {
1249                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1250                    }
1251                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1252                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1253                        f16::one()
1254                    } else {
1255                        f16::zero()
1256                    }])),
1257                })
1258            }
1259            crate::MathFunction::Sqrt => {
1260                component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
1261            }
1262            crate::MathFunction::InverseSqrt => {
1263                component_wise_float(self, span, [arg], |e| match e {
1264                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1265                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1266                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1267                })
1268            }
1269
1270            // bits
1271            crate::MathFunction::CountTrailingZeros => {
1272                component_wise_concrete_int!(self, span, [arg], |e| {
1273                    #[allow(clippy::useless_conversion)]
1274                    Ok([e
1275                        .trailing_zeros()
1276                        .try_into()
1277                        .expect("bit count overflowed 32 bits, somehow!?")])
1278                })
1279            }
1280            crate::MathFunction::CountLeadingZeros => {
1281                component_wise_concrete_int!(self, span, [arg], |e| {
1282                    #[allow(clippy::useless_conversion)]
1283                    Ok([e
1284                        .leading_zeros()
1285                        .try_into()
1286                        .expect("bit count overflowed 32 bits, somehow!?")])
1287                })
1288            }
1289            crate::MathFunction::CountOneBits => {
1290                component_wise_concrete_int!(self, span, [arg], |e| {
1291                    #[allow(clippy::useless_conversion)]
1292                    Ok([e
1293                        .count_ones()
1294                        .try_into()
1295                        .expect("bit count overflowed 32 bits, somehow!?")])
1296                })
1297            }
1298            crate::MathFunction::ReverseBits => {
1299                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1300            }
1301            crate::MathFunction::FirstTrailingBit => {
1302                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1303            }
1304            crate::MathFunction::FirstLeadingBit => {
1305                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1306            }
1307
1308            // vector
1309            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1310
1311            // unimplemented
1312            crate::MathFunction::Atan2
1313            | crate::MathFunction::Modf
1314            | crate::MathFunction::Frexp
1315            | crate::MathFunction::Ldexp
1316            | crate::MathFunction::Dot
1317            | crate::MathFunction::Outer
1318            | crate::MathFunction::Distance
1319            | crate::MathFunction::Length
1320            | crate::MathFunction::Normalize
1321            | crate::MathFunction::FaceForward
1322            | crate::MathFunction::Reflect
1323            | crate::MathFunction::Refract
1324            | crate::MathFunction::Mix
1325            | crate::MathFunction::SmoothStep
1326            | crate::MathFunction::Inverse
1327            | crate::MathFunction::Transpose
1328            | crate::MathFunction::Determinant
1329            | crate::MathFunction::QuantizeToF16
1330            | crate::MathFunction::ExtractBits
1331            | crate::MathFunction::InsertBits
1332            | crate::MathFunction::Pack4x8snorm
1333            | crate::MathFunction::Pack4x8unorm
1334            | crate::MathFunction::Pack2x16snorm
1335            | crate::MathFunction::Pack2x16unorm
1336            | crate::MathFunction::Pack2x16float
1337            | crate::MathFunction::Pack4xI8
1338            | crate::MathFunction::Pack4xU8
1339            | crate::MathFunction::Unpack4x8snorm
1340            | crate::MathFunction::Unpack4x8unorm
1341            | crate::MathFunction::Unpack2x16snorm
1342            | crate::MathFunction::Unpack2x16unorm
1343            | crate::MathFunction::Unpack2x16float
1344            | crate::MathFunction::Unpack4xI8
1345            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1346                format!("{fun:?} built-in function"),
1347            )),
1348        }
1349    }
1350
1351    /// Vector cross product.
1352    fn cross_product(
1353        &mut self,
1354        a: Handle<Expression>,
1355        b: Handle<Expression>,
1356        span: Span,
1357    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1358        use Literal as Li;
1359
1360        let (a, ty) = self.extract_vec::<3>(a)?;
1361        let (b, _) = self.extract_vec::<3>(b)?;
1362
1363        let product = match (a, b) {
1364            (
1365                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
1366                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
1367            ) => {
1368                // `cross` has no overload for AbstractInt, so AbstractInt
1369                // arguments are automatically converted to AbstractFloat. Since
1370                // `f64` has a much wider range than `i64`, there's no danger of
1371                // overflow here.
1372                let p = cross_product(
1373                    [a0 as f64, a1 as f64, a2 as f64],
1374                    [b0 as f64, b1 as f64, b2 as f64],
1375                );
1376                [
1377                    Li::AbstractFloat(p[0]),
1378                    Li::AbstractFloat(p[1]),
1379                    Li::AbstractFloat(p[2]),
1380                ]
1381            }
1382            (
1383                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
1384                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
1385            ) => {
1386                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1387                [
1388                    Li::AbstractFloat(p[0]),
1389                    Li::AbstractFloat(p[1]),
1390                    Li::AbstractFloat(p[2]),
1391                ]
1392            }
1393            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
1394                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1395                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
1396            }
1397            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
1398                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1399                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
1400            }
1401            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
1402                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
1403                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
1404            }
1405            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
1406        };
1407
1408        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
1409        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
1410        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
1411
1412        self.register_evaluated_expr(
1413            Expression::Compose {
1414                ty,
1415                components: vec![p0, p1, p2],
1416            },
1417            span,
1418        )
1419    }
1420
1421    /// Extract the values of a `vecN` from `expr`.
1422    ///
1423    /// Return the value of `expr`, whose type is `vecN<S>` for some
1424    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
1425    /// values.
1426    ///
1427    /// Also return the type handle from the `Compose` expression.
1428    fn extract_vec<const N: usize>(
1429        &mut self,
1430        expr: Handle<Expression>,
1431    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
1432        let span = self.expressions.get_span(expr);
1433        let expr = self.eval_zero_value_and_splat(expr, span)?;
1434        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1435            return Err(ConstantEvaluatorError::InvalidMathArg);
1436        };
1437
1438        let mut value = [Literal::Bool(false); N];
1439        for (component, elt) in
1440            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1441                .zip(value.iter_mut())
1442        {
1443            let Expression::Literal(literal) = self.expressions[component] else {
1444                return Err(ConstantEvaluatorError::InvalidMathArg);
1445            };
1446            *elt = literal;
1447        }
1448
1449        Ok((value, ty))
1450    }
1451
1452    fn array_length(
1453        &mut self,
1454        array: Handle<Expression>,
1455        span: Span,
1456    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1457        match self.expressions[array] {
1458            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
1459                match self.types[ty].inner {
1460                    TypeInner::Array { size, .. } => match size {
1461                        ArraySize::Constant(len) => {
1462                            let expr = Expression::Literal(Literal::U32(len.get()));
1463                            self.register_evaluated_expr(expr, span)
1464                        }
1465                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
1466                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
1467                    },
1468                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1469                }
1470            }
1471            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
1472        }
1473    }
1474
1475    fn access(
1476        &mut self,
1477        base: Handle<Expression>,
1478        index: usize,
1479        span: Span,
1480    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1481        match self.expressions[base] {
1482            Expression::ZeroValue(ty) => {
1483                let ty_inner = &self.types[ty].inner;
1484                let components = ty_inner
1485                    .components()
1486                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1487
1488                if index >= components as usize {
1489                    Err(ConstantEvaluatorError::InvalidAccessBase)
1490                } else {
1491                    let ty_res = ty_inner
1492                        .component_type(index)
1493                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
1494                    let ty = match ty_res {
1495                        crate::proc::TypeResolution::Handle(ty) => ty,
1496                        crate::proc::TypeResolution::Value(inner) => {
1497                            self.types.insert(Type { name: None, inner }, span)
1498                        }
1499                    };
1500                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
1501                }
1502            }
1503            Expression::Splat { size, value } => {
1504                if index >= size as usize {
1505                    Err(ConstantEvaluatorError::InvalidAccessBase)
1506                } else {
1507                    Ok(value)
1508                }
1509            }
1510            Expression::Compose { ty, ref components } => {
1511                let _ = self.types[ty]
1512                    .inner
1513                    .components()
1514                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
1515
1516                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1517                    .nth(index)
1518                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
1519            }
1520            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
1521        }
1522    }
1523
1524    fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
1525        match self.expressions[expr] {
1526            Expression::ZeroValue(ty)
1527                if matches!(
1528                    self.types[ty].inner,
1529                    TypeInner::Scalar(crate::Scalar {
1530                        kind: ScalarKind::Uint,
1531                        ..
1532                    })
1533                ) =>
1534            {
1535                Ok(0)
1536            }
1537            Expression::Literal(Literal::U32(index)) => Ok(index as usize),
1538            _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
1539        }
1540    }
1541
1542    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
1543    ///
1544    /// [`ZeroValue`]: Expression::ZeroValue
1545    /// [`Splat`]: Expression::Splat
1546    /// [`Literal`]: Expression::Literal
1547    /// [`Compose`]: Expression::Compose
1548    fn eval_zero_value_and_splat(
1549        &mut self,
1550        mut expr: Handle<Expression>,
1551        span: Span,
1552    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1553        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
1554        // each of its components.
1555        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
1556            let components = components
1557                .clone()
1558                .iter()
1559                .map(|component| self.eval_zero_value_and_splat(*component, span))
1560                .collect::<Result<_, _>>()?;
1561            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
1562        }
1563
1564        // The result of the splat() for a Splat of a scalar ZeroValue is a
1565        // vector ZeroValue, so we must call eval_zero_value_impl() after
1566        // splat() in order to ensure we have no ZeroValues remaining.
1567        if let Expression::Splat { size, value } = self.expressions[expr] {
1568            expr = self.splat(value, size, span)?;
1569        }
1570        if let Expression::ZeroValue(ty) = self.expressions[expr] {
1571            expr = self.eval_zero_value_impl(ty, span)?;
1572        }
1573        Ok(expr)
1574    }
1575
1576    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
1577    ///
1578    /// [`ZeroValue`]: Expression::ZeroValue
1579    /// [`Literal`]: Expression::Literal
1580    /// [`Compose`]: Expression::Compose
1581    fn eval_zero_value(
1582        &mut self,
1583        expr: Handle<Expression>,
1584        span: Span,
1585    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1586        match self.expressions[expr] {
1587            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
1588            _ => Ok(expr),
1589        }
1590    }
1591
1592    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
1593    ///
1594    /// [`ZeroValue`]: Expression::ZeroValue
1595    /// [`Literal`]: Expression::Literal
1596    /// [`Compose`]: Expression::Compose
1597    fn eval_zero_value_impl(
1598        &mut self,
1599        ty: Handle<Type>,
1600        span: Span,
1601    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1602        match self.types[ty].inner {
1603            TypeInner::Scalar(scalar) => {
1604                let expr = Expression::Literal(
1605                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
1606                );
1607                self.register_evaluated_expr(expr, span)
1608            }
1609            TypeInner::Vector { size, scalar } => {
1610                let scalar_ty = self.types.insert(
1611                    Type {
1612                        name: None,
1613                        inner: TypeInner::Scalar(scalar),
1614                    },
1615                    span,
1616                );
1617                let el = self.eval_zero_value_impl(scalar_ty, span)?;
1618                let expr = Expression::Compose {
1619                    ty,
1620                    components: vec![el; size as usize],
1621                };
1622                self.register_evaluated_expr(expr, span)
1623            }
1624            TypeInner::Matrix {
1625                columns,
1626                rows,
1627                scalar,
1628            } => {
1629                let vec_ty = self.types.insert(
1630                    Type {
1631                        name: None,
1632                        inner: TypeInner::Vector { size: rows, scalar },
1633                    },
1634                    span,
1635                );
1636                let el = self.eval_zero_value_impl(vec_ty, span)?;
1637                let expr = Expression::Compose {
1638                    ty,
1639                    components: vec![el; columns as usize],
1640                };
1641                self.register_evaluated_expr(expr, span)
1642            }
1643            TypeInner::Array {
1644                base,
1645                size: ArraySize::Constant(size),
1646                ..
1647            } => {
1648                let el = self.eval_zero_value_impl(base, span)?;
1649                let expr = Expression::Compose {
1650                    ty,
1651                    components: vec![el; size.get() as usize],
1652                };
1653                self.register_evaluated_expr(expr, span)
1654            }
1655            TypeInner::Struct { ref members, .. } => {
1656                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
1657                let mut components = Vec::with_capacity(members.len());
1658                for ty in types {
1659                    components.push(self.eval_zero_value_impl(ty, span)?);
1660                }
1661                let expr = Expression::Compose { ty, components };
1662                self.register_evaluated_expr(expr, span)
1663            }
1664            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
1665        }
1666    }
1667
1668    /// Convert the scalar components of `expr` to `target`.
1669    ///
1670    /// Treat `span` as the location of the resulting expression.
1671    pub fn cast(
1672        &mut self,
1673        expr: Handle<Expression>,
1674        target: crate::Scalar,
1675        span: Span,
1676    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1677        use crate::Scalar as Sc;
1678
1679        let expr = self.eval_zero_value(expr, span)?;
1680
1681        let make_error = || -> Result<_, ConstantEvaluatorError> {
1682            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
1683
1684            #[cfg(feature = "wgsl-in")]
1685            let to = target.to_wgsl_for_diagnostics();
1686
1687            #[cfg(not(feature = "wgsl-in"))]
1688            let to = format!("{target:?}");
1689
1690            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
1691        };
1692
1693        use crate::proc::type_methods::IntFloatLimits;
1694
1695        let expr = match self.expressions[expr] {
1696            Expression::Literal(literal) => {
1697                let literal = match target {
1698                    Sc::I32 => Literal::I32(match literal {
1699                        Literal::I32(v) => v,
1700                        Literal::U32(v) => v as i32,
1701                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
1702                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
1703                        Literal::Bool(v) => v as i32,
1704                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1705                            return make_error();
1706                        }
1707                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
1708                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
1709                    }),
1710                    Sc::U32 => Literal::U32(match literal {
1711                        Literal::I32(v) => v as u32,
1712                        Literal::U32(v) => v,
1713                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
1714                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
1715                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
1716                        Literal::Bool(v) => v as u32,
1717                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1718                            return make_error();
1719                        }
1720                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
1721                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
1722                    }),
1723                    Sc::I64 => Literal::I64(match literal {
1724                        Literal::I32(v) => v as i64,
1725                        Literal::U32(v) => v as i64,
1726                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1727                        Literal::Bool(v) => v as i64,
1728                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
1729                        Literal::I64(v) => v,
1730                        Literal::U64(v) => v as i64,
1731                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
1732                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
1733                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
1734                    }),
1735                    Sc::U64 => Literal::U64(match literal {
1736                        Literal::I32(v) => v as u64,
1737                        Literal::U32(v) => v as u64,
1738                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1739                        Literal::Bool(v) => v as u64,
1740                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
1741                        Literal::I64(v) => v as u64,
1742                        Literal::U64(v) => v,
1743                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
1744                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
1745                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
1746                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
1747                    }),
1748                    Sc::F16 => Literal::F16(match literal {
1749                        Literal::F16(v) => v,
1750                        Literal::F32(v) => f16::from_f32(v),
1751                        Literal::F64(v) => f16::from_f64(v),
1752                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
1753                        Literal::I64(v) => f16::from_i64(v).unwrap(),
1754                        Literal::U64(v) => f16::from_u64(v).unwrap(),
1755                        Literal::I32(v) => f16::from_i32(v).unwrap(),
1756                        Literal::U32(v) => f16::from_u32(v).unwrap(),
1757                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
1758                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
1759                    }),
1760                    Sc::F32 => Literal::F32(match literal {
1761                        Literal::I32(v) => v as f32,
1762                        Literal::U32(v) => v as f32,
1763                        Literal::F32(v) => v,
1764                        Literal::Bool(v) => v as u32 as f32,
1765                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1766                            return make_error();
1767                        }
1768                        Literal::F16(v) => f16::to_f32(v),
1769                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
1770                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
1771                    }),
1772                    Sc::F64 => Literal::F64(match literal {
1773                        Literal::I32(v) => v as f64,
1774                        Literal::U32(v) => v as f64,
1775                        Literal::F16(v) => f16::to_f64(v),
1776                        Literal::F32(v) => v as f64,
1777                        Literal::F64(v) => v,
1778                        Literal::Bool(v) => v as u32 as f64,
1779                        Literal::I64(_) | Literal::U64(_) => return make_error(),
1780                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
1781                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
1782                    }),
1783                    Sc::BOOL => Literal::Bool(match literal {
1784                        Literal::I32(v) => v != 0,
1785                        Literal::U32(v) => v != 0,
1786                        Literal::F32(v) => v != 0.0,
1787                        Literal::F16(v) => v != f16::zero(),
1788                        Literal::Bool(v) => v,
1789                        Literal::AbstractInt(v) => v != 0,
1790                        Literal::AbstractFloat(v) => v != 0.0,
1791                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
1792                            return make_error();
1793                        }
1794                    }),
1795                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
1796                        Literal::AbstractInt(v) => {
1797                            // Overflow is forbidden, but inexact conversions
1798                            // are fine. The range of f64 is far larger than
1799                            // that of i64, so we don't have to check anything
1800                            // here.
1801                            v as f64
1802                        }
1803                        Literal::AbstractFloat(v) => v,
1804                        _ => return make_error(),
1805                    }),
1806                    _ => {
1807                        log::debug!("Constant evaluator refused to convert value to {target:?}");
1808                        return make_error();
1809                    }
1810                };
1811                Expression::Literal(literal)
1812            }
1813            Expression::Compose {
1814                ty,
1815                components: ref src_components,
1816            } => {
1817                let ty_inner = match self.types[ty].inner {
1818                    TypeInner::Vector { size, .. } => TypeInner::Vector {
1819                        size,
1820                        scalar: target,
1821                    },
1822                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
1823                        columns,
1824                        rows,
1825                        scalar: target,
1826                    },
1827                    _ => return make_error(),
1828                };
1829
1830                let mut components = src_components.clone();
1831                for component in &mut components {
1832                    *component = self.cast(*component, target, span)?;
1833                }
1834
1835                let ty = self.types.insert(
1836                    Type {
1837                        name: None,
1838                        inner: ty_inner,
1839                    },
1840                    span,
1841                );
1842
1843                Expression::Compose { ty, components }
1844            }
1845            Expression::Splat { size, value } => {
1846                let value_span = self.expressions.get_span(value);
1847                let cast_value = self.cast(value, target, value_span)?;
1848                Expression::Splat {
1849                    size,
1850                    value: cast_value,
1851                }
1852            }
1853            _ => return make_error(),
1854        };
1855
1856        self.register_evaluated_expr(expr, span)
1857    }
1858
1859    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
1860    ///
1861    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
1862    /// matrix, or nested arrays of such.
1863    ///
1864    /// This is basically the same as the [`cast`] method, except that that
1865    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
1866    ///
1867    /// Treat `span` as the location of the resulting expression.
1868    ///
1869    /// [`cast`]: ConstantEvaluator::cast
1870    /// [`As`]: crate::Expression::As
1871    pub fn cast_array(
1872        &mut self,
1873        expr: Handle<Expression>,
1874        target: crate::Scalar,
1875        span: Span,
1876    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1877        let expr = self.check_and_get(expr)?;
1878
1879        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
1880            return self.cast(expr, target, span);
1881        };
1882
1883        let TypeInner::Array {
1884            base: _,
1885            size,
1886            stride: _,
1887        } = self.types[ty].inner
1888        else {
1889            return self.cast(expr, target, span);
1890        };
1891
1892        let mut components = components.clone();
1893        for component in &mut components {
1894            *component = self.cast_array(*component, target, span)?;
1895        }
1896
1897        let first = components.first().unwrap();
1898        let new_base = match self.resolve_type(*first)? {
1899            crate::proc::TypeResolution::Handle(ty) => ty,
1900            crate::proc::TypeResolution::Value(inner) => {
1901                self.types.insert(Type { name: None, inner }, span)
1902            }
1903        };
1904        let mut layouter = core::mem::take(self.layouter);
1905        layouter.update(self.to_ctx()).unwrap();
1906        *self.layouter = layouter;
1907
1908        let new_base_stride = self.layouter[new_base].to_stride();
1909        let new_array_ty = self.types.insert(
1910            Type {
1911                name: None,
1912                inner: TypeInner::Array {
1913                    base: new_base,
1914                    size,
1915                    stride: new_base_stride,
1916                },
1917            },
1918            span,
1919        );
1920
1921        let compose = Expression::Compose {
1922            ty: new_array_ty,
1923            components,
1924        };
1925        self.register_evaluated_expr(compose, span)
1926    }
1927
1928    fn unary_op(
1929        &mut self,
1930        op: UnaryOperator,
1931        expr: Handle<Expression>,
1932        span: Span,
1933    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1934        let expr = self.eval_zero_value_and_splat(expr, span)?;
1935
1936        let expr = match self.expressions[expr] {
1937            Expression::Literal(value) => Expression::Literal(match op {
1938                UnaryOperator::Negate => match value {
1939                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
1940                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
1941                    Literal::F32(v) => Literal::F32(-v),
1942                    Literal::F16(v) => Literal::F16(-v),
1943                    Literal::F64(v) => Literal::F64(-v),
1944                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
1945                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
1946                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1947                },
1948                UnaryOperator::LogicalNot => match value {
1949                    Literal::Bool(v) => Literal::Bool(!v),
1950                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1951                },
1952                UnaryOperator::BitwiseNot => match value {
1953                    Literal::I32(v) => Literal::I32(!v),
1954                    Literal::I64(v) => Literal::I64(!v),
1955                    Literal::U32(v) => Literal::U32(!v),
1956                    Literal::U64(v) => Literal::U64(!v),
1957                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
1958                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1959                },
1960            }),
1961            Expression::Compose {
1962                ty,
1963                components: ref src_components,
1964            } => {
1965                match self.types[ty].inner {
1966                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
1967                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1968                }
1969
1970                let mut components = src_components.clone();
1971                for component in &mut components {
1972                    *component = self.unary_op(op, *component, span)?;
1973                }
1974
1975                Expression::Compose { ty, components }
1976            }
1977            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
1978        };
1979
1980        self.register_evaluated_expr(expr, span)
1981    }
1982
1983    fn binary_op(
1984        &mut self,
1985        op: BinaryOperator,
1986        left: Handle<Expression>,
1987        right: Handle<Expression>,
1988        span: Span,
1989    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1990        let left = self.eval_zero_value_and_splat(left, span)?;
1991        let right = self.eval_zero_value_and_splat(right, span)?;
1992
1993        let expr = match (&self.expressions[left], &self.expressions[right]) {
1994            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
1995                let literal = match op {
1996                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
1997                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
1998                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
1999                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2000                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2001                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2002
2003                    _ => match (left_value, right_value) {
2004                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2005                            BinaryOperator::Add => a.wrapping_add(b),
2006                            BinaryOperator::Subtract => a.wrapping_sub(b),
2007                            BinaryOperator::Multiply => a.wrapping_mul(b),
2008                            BinaryOperator::Divide => {
2009                                if b == 0 {
2010                                    return Err(ConstantEvaluatorError::DivisionByZero);
2011                                } else {
2012                                    a.wrapping_div(b)
2013                                }
2014                            }
2015                            BinaryOperator::Modulo => {
2016                                if b == 0 {
2017                                    return Err(ConstantEvaluatorError::RemainderByZero);
2018                                } else {
2019                                    a.wrapping_rem(b)
2020                                }
2021                            }
2022                            BinaryOperator::And => a & b,
2023                            BinaryOperator::ExclusiveOr => a ^ b,
2024                            BinaryOperator::InclusiveOr => a | b,
2025                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2026                        }),
2027                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2028                            BinaryOperator::ShiftLeft => {
2029                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2030                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2031                                }
2032                                a.checked_shl(b)
2033                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2034                            }
2035                            BinaryOperator::ShiftRight => a
2036                                .checked_shr(b)
2037                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2038                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2039                        }),
2040                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2041                            BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2042                                ConstantEvaluatorError::Overflow("addition".into())
2043                            })?,
2044                            BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2045                                ConstantEvaluatorError::Overflow("subtraction".into())
2046                            })?,
2047                            BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2048                                ConstantEvaluatorError::Overflow("multiplication".into())
2049                            })?,
2050                            BinaryOperator::Divide => a
2051                                .checked_div(b)
2052                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2053                            BinaryOperator::Modulo => a
2054                                .checked_rem(b)
2055                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2056                            BinaryOperator::And => a & b,
2057                            BinaryOperator::ExclusiveOr => a ^ b,
2058                            BinaryOperator::InclusiveOr => a | b,
2059                            BinaryOperator::ShiftLeft => a
2060                                .checked_mul(
2061                                    1u32.checked_shl(b)
2062                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2063                                )
2064                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2065                            BinaryOperator::ShiftRight => a
2066                                .checked_shr(b)
2067                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2068                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2069                        }),
2070                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2071                            BinaryOperator::Add => a + b,
2072                            BinaryOperator::Subtract => a - b,
2073                            BinaryOperator::Multiply => a * b,
2074                            BinaryOperator::Divide => a / b,
2075                            BinaryOperator::Modulo => a % b,
2076                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2077                        }),
2078                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2079                            Literal::AbstractInt(match op {
2080                                BinaryOperator::ShiftLeft => {
2081                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2082                                        return Err(ConstantEvaluatorError::Overflow(
2083                                            "<<".to_string(),
2084                                        ));
2085                                    }
2086                                    a.checked_shl(b).unwrap_or(0)
2087                                }
2088                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2089                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2090                            })
2091                        }
2092                        (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op {
2093                            BinaryOperator::Add => a + b,
2094                            BinaryOperator::Subtract => a - b,
2095                            BinaryOperator::Multiply => a * b,
2096                            BinaryOperator::Divide => a / b,
2097                            BinaryOperator::Modulo => a % b,
2098                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2099                        }),
2100                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2101                            Literal::AbstractInt(match op {
2102                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2103                                    ConstantEvaluatorError::Overflow("addition".into())
2104                                })?,
2105                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2106                                    ConstantEvaluatorError::Overflow("subtraction".into())
2107                                })?,
2108                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2109                                    ConstantEvaluatorError::Overflow("multiplication".into())
2110                                })?,
2111                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2112                                    if b == 0 {
2113                                        ConstantEvaluatorError::DivisionByZero
2114                                    } else {
2115                                        ConstantEvaluatorError::Overflow("division".into())
2116                                    }
2117                                })?,
2118                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2119                                    if b == 0 {
2120                                        ConstantEvaluatorError::RemainderByZero
2121                                    } else {
2122                                        ConstantEvaluatorError::Overflow("remainder".into())
2123                                    }
2124                                })?,
2125                                BinaryOperator::And => a & b,
2126                                BinaryOperator::ExclusiveOr => a ^ b,
2127                                BinaryOperator::InclusiveOr => a | b,
2128                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2129                            })
2130                        }
2131                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2132                            Literal::AbstractFloat(match op {
2133                                BinaryOperator::Add => a + b,
2134                                BinaryOperator::Subtract => a - b,
2135                                BinaryOperator::Multiply => a * b,
2136                                BinaryOperator::Divide => a / b,
2137                                BinaryOperator::Modulo => a % b,
2138                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2139                            })
2140                        }
2141                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2142                            BinaryOperator::LogicalAnd => a && b,
2143                            BinaryOperator::LogicalOr => a || b,
2144                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2145                        }),
2146                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2147                    },
2148                };
2149                Expression::Literal(literal)
2150            }
2151            (
2152                &Expression::Compose {
2153                    components: ref src_components,
2154                    ty,
2155                },
2156                &Expression::Literal(_),
2157            ) => {
2158                let mut components = src_components.clone();
2159                for component in &mut components {
2160                    *component = self.binary_op(op, *component, right, span)?;
2161                }
2162                Expression::Compose { ty, components }
2163            }
2164            (
2165                &Expression::Literal(_),
2166                &Expression::Compose {
2167                    components: ref src_components,
2168                    ty,
2169                },
2170            ) => {
2171                let mut components = src_components.clone();
2172                for component in &mut components {
2173                    *component = self.binary_op(op, left, *component, span)?;
2174                }
2175                Expression::Compose { ty, components }
2176            }
2177            (
2178                &Expression::Compose {
2179                    components: ref left_components,
2180                    ty: left_ty,
2181                },
2182                &Expression::Compose {
2183                    components: ref right_components,
2184                    ty: right_ty,
2185                },
2186            ) => {
2187                // We have to make a copy of the component lists, because the
2188                // call to `binary_op_vector` needs `&mut self`, but `self` owns
2189                // the component lists.
2190                let left_flattened = crate::proc::flatten_compose(
2191                    left_ty,
2192                    left_components,
2193                    self.expressions,
2194                    self.types,
2195                );
2196                let right_flattened = crate::proc::flatten_compose(
2197                    right_ty,
2198                    right_components,
2199                    self.expressions,
2200                    self.types,
2201                );
2202
2203                // `flatten_compose` doesn't return an `ExactSizeIterator`, so
2204                // make a reasonable guess of the capacity we'll need.
2205                let mut flattened = Vec::with_capacity(left_components.len());
2206                flattened.extend(left_flattened.zip(right_flattened));
2207
2208                match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2209                    (
2210                        &TypeInner::Vector {
2211                            size: left_size, ..
2212                        },
2213                        &TypeInner::Vector {
2214                            size: right_size, ..
2215                        },
2216                    ) if left_size == right_size => {
2217                        self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
2218                    }
2219                    _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2220                }
2221            }
2222            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2223        };
2224
2225        self.register_evaluated_expr(expr, span)
2226    }
2227
2228    fn binary_op_vector(
2229        &mut self,
2230        op: BinaryOperator,
2231        size: crate::VectorSize,
2232        components: &[(Handle<Expression>, Handle<Expression>)],
2233        left_ty: Handle<Type>,
2234        span: Span,
2235    ) -> Result<Expression, ConstantEvaluatorError> {
2236        let ty = match op {
2237            // Relational operators produce vectors of booleans.
2238            BinaryOperator::Equal
2239            | BinaryOperator::NotEqual
2240            | BinaryOperator::Less
2241            | BinaryOperator::LessEqual
2242            | BinaryOperator::Greater
2243            | BinaryOperator::GreaterEqual => self.types.insert(
2244                Type {
2245                    name: None,
2246                    inner: TypeInner::Vector {
2247                        size,
2248                        scalar: crate::Scalar::BOOL,
2249                    },
2250                },
2251                span,
2252            ),
2253
2254            // Other operators produce the same type as their left
2255            // operand.
2256            BinaryOperator::Add
2257            | BinaryOperator::Subtract
2258            | BinaryOperator::Multiply
2259            | BinaryOperator::Divide
2260            | BinaryOperator::Modulo
2261            | BinaryOperator::And
2262            | BinaryOperator::ExclusiveOr
2263            | BinaryOperator::InclusiveOr
2264            | BinaryOperator::ShiftLeft
2265            | BinaryOperator::ShiftRight => left_ty,
2266
2267            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
2268                // Not supported on vectors
2269                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2270            }
2271        };
2272
2273        let components = components
2274            .iter()
2275            .map(|&(left, right)| self.binary_op(op, left, right, span))
2276            .collect::<Result<Vec<_>, _>>()?;
2277
2278        Ok(Expression::Compose { ty, components })
2279    }
2280
2281    fn relational(
2282        &mut self,
2283        fun: RelationalFunction,
2284        arg: Handle<Expression>,
2285        span: Span,
2286    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2287        let arg = self.eval_zero_value_and_splat(arg, span)?;
2288        match fun {
2289            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
2290                Expression::Literal(Literal::Bool(_)) => Ok(arg),
2291                Expression::Compose { ty, ref components }
2292                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
2293                {
2294                    let components =
2295                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2296                            .map(|component| match self.expressions[component] {
2297                                Expression::Literal(Literal::Bool(val)) => Ok(val),
2298                                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2299                            })
2300                            .collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
2301                    let result = match fun {
2302                        RelationalFunction::All => components.iter().all(|c| *c),
2303                        RelationalFunction::Any => components.iter().any(|c| *c),
2304                        _ => unreachable!(),
2305                    };
2306                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
2307                }
2308                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
2309            },
2310            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
2311                "{fun:?} built-in function"
2312            ))),
2313        }
2314    }
2315
2316    /// Deep copy `expr` from `expressions` into `self.expressions`.
2317    ///
2318    /// Return the root of the new copy.
2319    ///
2320    /// This is used when we're evaluating expressions in a function's
2321    /// expression arena that refer to a constant: we need to copy the
2322    /// constant's value into the function's arena so we can operate on it.
2323    fn copy_from(
2324        &mut self,
2325        expr: Handle<Expression>,
2326        expressions: &Arena<Expression>,
2327    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2328        let span = expressions.get_span(expr);
2329        match expressions[expr] {
2330            ref expr @ (Expression::Literal(_)
2331            | Expression::Constant(_)
2332            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
2333            Expression::Compose { ty, ref components } => {
2334                let mut components = components.clone();
2335                for component in &mut components {
2336                    *component = self.copy_from(*component, expressions)?;
2337                }
2338                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
2339            }
2340            Expression::Splat { size, value } => {
2341                let value = self.copy_from(value, expressions)?;
2342                self.register_evaluated_expr(Expression::Splat { size, value }, span)
2343            }
2344            _ => {
2345                log::debug!("copy_from: SubexpressionsAreNotConstant");
2346                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
2347            }
2348        }
2349    }
2350
2351    /// Returns the total number of components, after flattening, of a vector compose expression.
2352    fn vector_compose_flattened_size(
2353        &self,
2354        components: &[Handle<Expression>],
2355    ) -> Result<usize, ConstantEvaluatorError> {
2356        components
2357            .iter()
2358            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
2359                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
2360                    TypeInner::Scalar(_) => 1,
2361                    // We trust that the vector size of `component` is correct,
2362                    // as it will have already been validated when `component`
2363                    // was registered.
2364                    TypeInner::Vector { size, .. } => size as usize,
2365                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
2366                };
2367                Ok(acc + size)
2368            })
2369    }
2370
2371    fn register_evaluated_expr(
2372        &mut self,
2373        expr: Expression,
2374        span: Span,
2375    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2376        // It suffices to only check_literal_value() for `Literal` expressions,
2377        // since we only register one expression at a time, `Compose`
2378        // expressions can only refer to other expressions, and `ZeroValue`
2379        // expressions are always okay.
2380        if let Expression::Literal(literal) = expr {
2381            crate::valid::check_literal_value(literal)?;
2382        }
2383
2384        // Ensure vector composes contain the correct number of components. We
2385        // do so here when each compose is registered to avoid having to deal
2386        // with the mess each time the compose is used in another expression.
2387        if let Expression::Compose { ty, ref components } = expr {
2388            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
2389                let expected = size as usize;
2390                let actual = self.vector_compose_flattened_size(components)?;
2391                if expected != actual {
2392                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
2393                        expected,
2394                        actual,
2395                    });
2396                }
2397            }
2398        }
2399
2400        Ok(self.append_expr(expr, span, ExpressionKind::Const))
2401    }
2402
2403    fn append_expr(
2404        &mut self,
2405        expr: Expression,
2406        span: Span,
2407        expr_type: ExpressionKind,
2408    ) -> Handle<Expression> {
2409        let h = match self.behavior {
2410            Behavior::Wgsl(
2411                WgslRestrictions::Runtime(ref mut function_local_data)
2412                | WgslRestrictions::Const(Some(ref mut function_local_data)),
2413            )
2414            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
2415                let is_running = function_local_data.emitter.is_running();
2416                let needs_pre_emit = expr.needs_pre_emit();
2417                if is_running && needs_pre_emit {
2418                    function_local_data
2419                        .block
2420                        .extend(function_local_data.emitter.finish(self.expressions));
2421                    let h = self.expressions.append(expr, span);
2422                    function_local_data.emitter.start(self.expressions);
2423                    h
2424                } else {
2425                    self.expressions.append(expr, span)
2426                }
2427            }
2428            _ => self.expressions.append(expr, span),
2429        };
2430        self.expression_kind_tracker.insert(h, expr_type);
2431        h
2432    }
2433
2434    fn resolve_type(
2435        &self,
2436        expr: Handle<Expression>,
2437    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
2438        use crate::proc::TypeResolution as Tr;
2439        use crate::Expression as Ex;
2440        let resolution = match self.expressions[expr] {
2441            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
2442            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
2443            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
2444            Ex::Splat { size, value } => {
2445                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
2446                    return Err(ConstantEvaluatorError::SplatScalarOnly);
2447                };
2448                Tr::Value(TypeInner::Vector { scalar, size })
2449            }
2450            _ => {
2451                log::debug!("resolve_type: SubexpressionsAreNotConstant");
2452                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
2453            }
2454        };
2455
2456        Ok(resolution)
2457    }
2458}
2459
2460fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2461    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
2462    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
2463    // return a right-to-left bit index of 0.
2464    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
2465        match e {
2466            idx @ 0..=31 => idx,
2467            32 => u32::MAX,
2468            _ => unreachable!(),
2469        }
2470    };
2471    match concrete_int {
2472        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
2473        ConcreteInt::I32([e]) => {
2474            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
2475        }
2476    }
2477}
2478
2479#[test]
2480fn first_trailing_bit_smoke() {
2481    assert_eq!(
2482        first_trailing_bit(ConcreteInt::I32([0])),
2483        ConcreteInt::I32([-1])
2484    );
2485    assert_eq!(
2486        first_trailing_bit(ConcreteInt::I32([1])),
2487        ConcreteInt::I32([0])
2488    );
2489    assert_eq!(
2490        first_trailing_bit(ConcreteInt::I32([2])),
2491        ConcreteInt::I32([1])
2492    );
2493    assert_eq!(
2494        first_trailing_bit(ConcreteInt::I32([-1])),
2495        ConcreteInt::I32([0]),
2496    );
2497    assert_eq!(
2498        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
2499        ConcreteInt::I32([31]),
2500    );
2501    assert_eq!(
2502        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
2503        ConcreteInt::I32([0]),
2504    );
2505    for idx in 0..32 {
2506        assert_eq!(
2507            first_trailing_bit(ConcreteInt::I32([1 << idx])),
2508            ConcreteInt::I32([idx])
2509        )
2510    }
2511
2512    assert_eq!(
2513        first_trailing_bit(ConcreteInt::U32([0])),
2514        ConcreteInt::U32([u32::MAX])
2515    );
2516    assert_eq!(
2517        first_trailing_bit(ConcreteInt::U32([1])),
2518        ConcreteInt::U32([0])
2519    );
2520    assert_eq!(
2521        first_trailing_bit(ConcreteInt::U32([2])),
2522        ConcreteInt::U32([1])
2523    );
2524    assert_eq!(
2525        first_trailing_bit(ConcreteInt::U32([1 << 31])),
2526        ConcreteInt::U32([31]),
2527    );
2528    assert_eq!(
2529        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
2530        ConcreteInt::U32([0]),
2531    );
2532    for idx in 0..32 {
2533        assert_eq!(
2534            first_trailing_bit(ConcreteInt::U32([1 << idx])),
2535            ConcreteInt::U32([idx])
2536        )
2537    }
2538}
2539
2540fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
2541    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
2542    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
2543    // index of 0.
2544    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
2545        match e {
2546            idx @ 0..=31 => 31 - idx,
2547            32 => u32::MAX,
2548            _ => unreachable!(),
2549        }
2550    };
2551    match concrete_int {
2552        ConcreteInt::I32([e]) => ConcreteInt::I32([{
2553            let rtl_bit_index = if e.is_negative() {
2554                e.leading_ones()
2555            } else {
2556                e.leading_zeros()
2557            };
2558            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
2559        }]),
2560        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
2561    }
2562}
2563
2564#[test]
2565fn first_leading_bit_smoke() {
2566    assert_eq!(
2567        first_leading_bit(ConcreteInt::I32([-1])),
2568        ConcreteInt::I32([-1])
2569    );
2570    assert_eq!(
2571        first_leading_bit(ConcreteInt::I32([0])),
2572        ConcreteInt::I32([-1])
2573    );
2574    assert_eq!(
2575        first_leading_bit(ConcreteInt::I32([1])),
2576        ConcreteInt::I32([0])
2577    );
2578    assert_eq!(
2579        first_leading_bit(ConcreteInt::I32([-2])),
2580        ConcreteInt::I32([0])
2581    );
2582    assert_eq!(
2583        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
2584        ConcreteInt::I32([12])
2585    );
2586    assert_eq!(
2587        first_leading_bit(ConcreteInt::I32([i32::MAX])),
2588        ConcreteInt::I32([30])
2589    );
2590    assert_eq!(
2591        first_leading_bit(ConcreteInt::I32([i32::MIN])),
2592        ConcreteInt::I32([30])
2593    );
2594    // NOTE: Ignore the sign bit, which is a separate (above) case.
2595    for idx in 0..(32 - 1) {
2596        assert_eq!(
2597            first_leading_bit(ConcreteInt::I32([1 << idx])),
2598            ConcreteInt::I32([idx])
2599        );
2600    }
2601    for idx in 1..(32 - 1) {
2602        assert_eq!(
2603            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
2604            ConcreteInt::I32([idx - 1])
2605        );
2606    }
2607
2608    assert_eq!(
2609        first_leading_bit(ConcreteInt::U32([0])),
2610        ConcreteInt::U32([u32::MAX])
2611    );
2612    assert_eq!(
2613        first_leading_bit(ConcreteInt::U32([1])),
2614        ConcreteInt::U32([0])
2615    );
2616    assert_eq!(
2617        first_leading_bit(ConcreteInt::U32([u32::MAX])),
2618        ConcreteInt::U32([31])
2619    );
2620    for idx in 0..32 {
2621        assert_eq!(
2622            first_leading_bit(ConcreteInt::U32([1 << idx])),
2623            ConcreteInt::U32([idx])
2624        )
2625    }
2626}
2627
2628/// Trait for conversions of abstract values to concrete types.
2629trait TryFromAbstract<T>: Sized {
2630    /// Convert an abstract literal `value` to `Self`.
2631    ///
2632    /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support
2633    /// WGSL, we follow WGSL's conversion rules here:
2634    ///
2635    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
2636    ///   from `AbstractInt` to an integer type are either lossless or an
2637    ///   error.
2638    ///
2639    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
2640    ///   to floating point in constant expressions and override
2641    ///   expressions are errors if the value is out of range for the
2642    ///   destination type, but rounding is okay.
2643    ///
2644    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
2645    ///   other floating point type, following the scalar floating point to
2646    ///   integral conversion algorithm (§15.7.6). There is no automatic
2647    ///   conversion from AbstractFloat to integer types.
2648    ///
2649    /// [`AbstractInt`]: crate::Literal::AbstractInt
2650    /// [`Float`]: crate::Literal::Float
2651    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
2652}
2653
2654impl TryFromAbstract<i64> for i32 {
2655    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
2656        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2657            value: format!("{value:?}"),
2658            to_type: "i32",
2659        })
2660    }
2661}
2662
2663impl TryFromAbstract<i64> for u32 {
2664    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
2665        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2666            value: format!("{value:?}"),
2667            to_type: "u32",
2668        })
2669    }
2670}
2671
2672impl TryFromAbstract<i64> for u64 {
2673    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
2674        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
2675            value: format!("{value:?}"),
2676            to_type: "u64",
2677        })
2678    }
2679}
2680
2681impl TryFromAbstract<i64> for i64 {
2682    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
2683        Ok(value)
2684    }
2685}
2686
2687impl TryFromAbstract<i64> for f32 {
2688    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2689        let f = value as f32;
2690        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
2691        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
2692        // overflow here.
2693        Ok(f)
2694    }
2695}
2696
2697impl TryFromAbstract<f64> for f32 {
2698    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
2699        let f = value as f32;
2700        if f.is_infinite() {
2701            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2702                value: format!("{value:?}"),
2703                to_type: "f32",
2704            });
2705        }
2706        Ok(f)
2707    }
2708}
2709
2710impl TryFromAbstract<i64> for f64 {
2711    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
2712        let f = value as f64;
2713        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
2714        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
2715        // overflow here.
2716        Ok(f)
2717    }
2718}
2719
2720impl TryFromAbstract<f64> for f64 {
2721    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
2722        Ok(value)
2723    }
2724}
2725
2726impl TryFromAbstract<f64> for i32 {
2727    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2728        // https://www.w3.org/TR/WGSL/#floating-point-conversion
2729        // To convert a floating point scalar value X to an integer scalar type T:
2730        // * If X is a NaN, the result is an indeterminate value in T.
2731        // * If X is exactly representable in the target type T, then the
2732        //   result is that value.
2733        // * Otherwise, the result is the value in T closest to truncate(X) and
2734        //   also exactly representable in the original floating point type.
2735        //
2736        // A rust cast satisfies these requirements apart from "the result
2737        // is... exactly representable in the original floating point type".
2738        // However, i32::MIN and i32::MAX are exactly representable by f64, so
2739        // we're all good.
2740        Ok(value as i32)
2741    }
2742}
2743
2744impl TryFromAbstract<f64> for u32 {
2745    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2746        // As above, u32::MIN and u32::MAX are exactly representable by f64,
2747        // so a simple rust cast is sufficient.
2748        Ok(value as u32)
2749    }
2750}
2751
2752impl TryFromAbstract<f64> for i64 {
2753    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2754        // As above, except we clamp to the minimum and maximum values
2755        // representable by both f64 and i64.
2756        use crate::proc::type_methods::IntFloatLimits;
2757        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
2758    }
2759}
2760
2761impl TryFromAbstract<f64> for u64 {
2762    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
2763        // As above, this time clamping to the minimum and maximum values
2764        // representable by both f64 and u64.
2765        use crate::proc::type_methods::IntFloatLimits;
2766        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
2767    }
2768}
2769
2770impl TryFromAbstract<f64> for f16 {
2771    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
2772        let f = f16::from_f64(value);
2773        if f.is_infinite() {
2774            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2775                value: format!("{value:?}"),
2776                to_type: "f16",
2777            });
2778        }
2779        Ok(f)
2780    }
2781}
2782
2783impl TryFromAbstract<i64> for f16 {
2784    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
2785        let f = f16::from_i64(value);
2786        if f.is_none() {
2787            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
2788                value: format!("{value:?}"),
2789                to_type: "f16",
2790            });
2791        }
2792        Ok(f.unwrap())
2793    }
2794}
2795
2796fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
2797where
2798    T: Copy,
2799    T: core::ops::Mul<T, Output = T>,
2800    T: core::ops::Sub<T, Output = T>,
2801{
2802    [
2803        a[1] * b[2] - a[2] * b[1],
2804        a[2] * b[0] - a[0] * b[2],
2805        a[0] * b[1] - a[1] * b[0],
2806    ]
2807}
2808
2809#[cfg(test)]
2810mod tests {
2811    use alloc::{vec, vec::Vec};
2812
2813    use crate::{
2814        Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
2815        UniqueArena, VectorSize,
2816    };
2817
2818    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
2819
2820    #[test]
2821    fn unary_op() {
2822        let mut types = UniqueArena::new();
2823        let mut constants = Arena::new();
2824        let overrides = Arena::new();
2825        let mut global_expressions = Arena::new();
2826
2827        let scalar_ty = types.insert(
2828            Type {
2829                name: None,
2830                inner: TypeInner::Scalar(crate::Scalar::I32),
2831            },
2832            Default::default(),
2833        );
2834
2835        let vec_ty = types.insert(
2836            Type {
2837                name: None,
2838                inner: TypeInner::Vector {
2839                    size: VectorSize::Bi,
2840                    scalar: crate::Scalar::I32,
2841                },
2842            },
2843            Default::default(),
2844        );
2845
2846        let h = constants.append(
2847            Constant {
2848                name: None,
2849                ty: scalar_ty,
2850                init: global_expressions
2851                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2852            },
2853            Default::default(),
2854        );
2855
2856        let h1 = constants.append(
2857            Constant {
2858                name: None,
2859                ty: scalar_ty,
2860                init: global_expressions
2861                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
2862            },
2863            Default::default(),
2864        );
2865
2866        let vec_h = constants.append(
2867            Constant {
2868                name: None,
2869                ty: vec_ty,
2870                init: global_expressions.append(
2871                    Expression::Compose {
2872                        ty: vec_ty,
2873                        components: vec![constants[h].init, constants[h1].init],
2874                    },
2875                    Default::default(),
2876                ),
2877            },
2878            Default::default(),
2879        );
2880
2881        let expr = global_expressions.append(Expression::Constant(h), Default::default());
2882        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
2883
2884        let expr2 = Expression::Unary {
2885            op: UnaryOperator::Negate,
2886            expr,
2887        };
2888
2889        let expr3 = Expression::Unary {
2890            op: UnaryOperator::BitwiseNot,
2891            expr,
2892        };
2893
2894        let expr4 = Expression::Unary {
2895            op: UnaryOperator::BitwiseNot,
2896            expr: expr1,
2897        };
2898
2899        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2900        let mut solver = ConstantEvaluator {
2901            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2902            types: &mut types,
2903            constants: &constants,
2904            overrides: &overrides,
2905            expressions: &mut global_expressions,
2906            expression_kind_tracker,
2907            layouter: &mut crate::proc::Layouter::default(),
2908        };
2909
2910        let res1 = solver
2911            .try_eval_and_append(expr2, Default::default())
2912            .unwrap();
2913        let res2 = solver
2914            .try_eval_and_append(expr3, Default::default())
2915            .unwrap();
2916        let res3 = solver
2917            .try_eval_and_append(expr4, Default::default())
2918            .unwrap();
2919
2920        assert_eq!(
2921            global_expressions[res1],
2922            Expression::Literal(Literal::I32(-4))
2923        );
2924
2925        assert_eq!(
2926            global_expressions[res2],
2927            Expression::Literal(Literal::I32(!4))
2928        );
2929
2930        let res3_inner = &global_expressions[res3];
2931
2932        match *res3_inner {
2933            Expression::Compose {
2934                ref ty,
2935                ref components,
2936            } => {
2937                assert_eq!(*ty, vec_ty);
2938                let mut components_iter = components.iter().copied();
2939                assert_eq!(
2940                    global_expressions[components_iter.next().unwrap()],
2941                    Expression::Literal(Literal::I32(!4))
2942                );
2943                assert_eq!(
2944                    global_expressions[components_iter.next().unwrap()],
2945                    Expression::Literal(Literal::I32(!8))
2946                );
2947                assert!(components_iter.next().is_none());
2948            }
2949            _ => panic!("Expected vector"),
2950        }
2951    }
2952
2953    #[test]
2954    fn cast() {
2955        let mut types = UniqueArena::new();
2956        let mut constants = Arena::new();
2957        let overrides = Arena::new();
2958        let mut global_expressions = Arena::new();
2959
2960        let scalar_ty = types.insert(
2961            Type {
2962                name: None,
2963                inner: TypeInner::Scalar(crate::Scalar::I32),
2964            },
2965            Default::default(),
2966        );
2967
2968        let h = constants.append(
2969            Constant {
2970                name: None,
2971                ty: scalar_ty,
2972                init: global_expressions
2973                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
2974            },
2975            Default::default(),
2976        );
2977
2978        let expr = global_expressions.append(Expression::Constant(h), Default::default());
2979
2980        let root = Expression::As {
2981            expr,
2982            kind: ScalarKind::Bool,
2983            convert: Some(crate::BOOL_WIDTH),
2984        };
2985
2986        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
2987        let mut solver = ConstantEvaluator {
2988            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
2989            types: &mut types,
2990            constants: &constants,
2991            overrides: &overrides,
2992            expressions: &mut global_expressions,
2993            expression_kind_tracker,
2994            layouter: &mut crate::proc::Layouter::default(),
2995        };
2996
2997        let res = solver
2998            .try_eval_and_append(root, Default::default())
2999            .unwrap();
3000
3001        assert_eq!(
3002            global_expressions[res],
3003            Expression::Literal(Literal::Bool(true))
3004        );
3005    }
3006
3007    #[test]
3008    fn access() {
3009        let mut types = UniqueArena::new();
3010        let mut constants = Arena::new();
3011        let overrides = Arena::new();
3012        let mut global_expressions = Arena::new();
3013
3014        let matrix_ty = types.insert(
3015            Type {
3016                name: None,
3017                inner: TypeInner::Matrix {
3018                    columns: VectorSize::Bi,
3019                    rows: VectorSize::Tri,
3020                    scalar: crate::Scalar::F32,
3021                },
3022            },
3023            Default::default(),
3024        );
3025
3026        let vec_ty = types.insert(
3027            Type {
3028                name: None,
3029                inner: TypeInner::Vector {
3030                    size: VectorSize::Tri,
3031                    scalar: crate::Scalar::F32,
3032                },
3033            },
3034            Default::default(),
3035        );
3036
3037        let mut vec1_components = Vec::with_capacity(3);
3038        let mut vec2_components = Vec::with_capacity(3);
3039
3040        for i in 0..3 {
3041            let h = global_expressions.append(
3042                Expression::Literal(Literal::F32(i as f32)),
3043                Default::default(),
3044            );
3045
3046            vec1_components.push(h)
3047        }
3048
3049        for i in 3..6 {
3050            let h = global_expressions.append(
3051                Expression::Literal(Literal::F32(i as f32)),
3052                Default::default(),
3053            );
3054
3055            vec2_components.push(h)
3056        }
3057
3058        let vec1 = constants.append(
3059            Constant {
3060                name: None,
3061                ty: vec_ty,
3062                init: global_expressions.append(
3063                    Expression::Compose {
3064                        ty: vec_ty,
3065                        components: vec1_components,
3066                    },
3067                    Default::default(),
3068                ),
3069            },
3070            Default::default(),
3071        );
3072
3073        let vec2 = constants.append(
3074            Constant {
3075                name: None,
3076                ty: vec_ty,
3077                init: global_expressions.append(
3078                    Expression::Compose {
3079                        ty: vec_ty,
3080                        components: vec2_components,
3081                    },
3082                    Default::default(),
3083                ),
3084            },
3085            Default::default(),
3086        );
3087
3088        let h = constants.append(
3089            Constant {
3090                name: None,
3091                ty: matrix_ty,
3092                init: global_expressions.append(
3093                    Expression::Compose {
3094                        ty: matrix_ty,
3095                        components: vec![constants[vec1].init, constants[vec2].init],
3096                    },
3097                    Default::default(),
3098                ),
3099            },
3100            Default::default(),
3101        );
3102
3103        let base = global_expressions.append(Expression::Constant(h), Default::default());
3104
3105        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3106        let mut solver = ConstantEvaluator {
3107            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3108            types: &mut types,
3109            constants: &constants,
3110            overrides: &overrides,
3111            expressions: &mut global_expressions,
3112            expression_kind_tracker,
3113            layouter: &mut crate::proc::Layouter::default(),
3114        };
3115
3116        let root1 = Expression::AccessIndex { base, index: 1 };
3117
3118        let res1 = solver
3119            .try_eval_and_append(root1, Default::default())
3120            .unwrap();
3121
3122        let root2 = Expression::AccessIndex {
3123            base: res1,
3124            index: 2,
3125        };
3126
3127        let res2 = solver
3128            .try_eval_and_append(root2, Default::default())
3129            .unwrap();
3130
3131        match global_expressions[res1] {
3132            Expression::Compose {
3133                ref ty,
3134                ref components,
3135            } => {
3136                assert_eq!(*ty, vec_ty);
3137                let mut components_iter = components.iter().copied();
3138                assert_eq!(
3139                    global_expressions[components_iter.next().unwrap()],
3140                    Expression::Literal(Literal::F32(3.))
3141                );
3142                assert_eq!(
3143                    global_expressions[components_iter.next().unwrap()],
3144                    Expression::Literal(Literal::F32(4.))
3145                );
3146                assert_eq!(
3147                    global_expressions[components_iter.next().unwrap()],
3148                    Expression::Literal(Literal::F32(5.))
3149                );
3150                assert!(components_iter.next().is_none());
3151            }
3152            _ => panic!("Expected vector"),
3153        }
3154
3155        assert_eq!(
3156            global_expressions[res2],
3157            Expression::Literal(Literal::F32(5.))
3158        );
3159    }
3160
3161    #[test]
3162    fn compose_of_constants() {
3163        let mut types = UniqueArena::new();
3164        let mut constants = Arena::new();
3165        let overrides = Arena::new();
3166        let mut global_expressions = Arena::new();
3167
3168        let i32_ty = types.insert(
3169            Type {
3170                name: None,
3171                inner: TypeInner::Scalar(crate::Scalar::I32),
3172            },
3173            Default::default(),
3174        );
3175
3176        let vec2_i32_ty = types.insert(
3177            Type {
3178                name: None,
3179                inner: TypeInner::Vector {
3180                    size: VectorSize::Bi,
3181                    scalar: crate::Scalar::I32,
3182                },
3183            },
3184            Default::default(),
3185        );
3186
3187        let h = constants.append(
3188            Constant {
3189                name: None,
3190                ty: i32_ty,
3191                init: global_expressions
3192                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3193            },
3194            Default::default(),
3195        );
3196
3197        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3198
3199        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3200        let mut solver = ConstantEvaluator {
3201            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3202            types: &mut types,
3203            constants: &constants,
3204            overrides: &overrides,
3205            expressions: &mut global_expressions,
3206            expression_kind_tracker,
3207            layouter: &mut crate::proc::Layouter::default(),
3208        };
3209
3210        let solved_compose = solver
3211            .try_eval_and_append(
3212                Expression::Compose {
3213                    ty: vec2_i32_ty,
3214                    components: vec![h_expr, h_expr],
3215                },
3216                Default::default(),
3217            )
3218            .unwrap();
3219        let solved_negate = solver
3220            .try_eval_and_append(
3221                Expression::Unary {
3222                    op: UnaryOperator::Negate,
3223                    expr: solved_compose,
3224                },
3225                Default::default(),
3226            )
3227            .unwrap();
3228
3229        let pass = match global_expressions[solved_negate] {
3230            Expression::Compose { ty, ref components } => {
3231                ty == vec2_i32_ty
3232                    && components.iter().all(|&component| {
3233                        let component = &global_expressions[component];
3234                        matches!(*component, Expression::Literal(Literal::I32(-4)))
3235                    })
3236            }
3237            _ => false,
3238        };
3239        if !pass {
3240            panic!("unexpected evaluation result")
3241        }
3242    }
3243
3244    #[test]
3245    fn splat_of_constant() {
3246        let mut types = UniqueArena::new();
3247        let mut constants = Arena::new();
3248        let overrides = Arena::new();
3249        let mut global_expressions = Arena::new();
3250
3251        let i32_ty = types.insert(
3252            Type {
3253                name: None,
3254                inner: TypeInner::Scalar(crate::Scalar::I32),
3255            },
3256            Default::default(),
3257        );
3258
3259        let vec2_i32_ty = types.insert(
3260            Type {
3261                name: None,
3262                inner: TypeInner::Vector {
3263                    size: VectorSize::Bi,
3264                    scalar: crate::Scalar::I32,
3265                },
3266            },
3267            Default::default(),
3268        );
3269
3270        let h = constants.append(
3271            Constant {
3272                name: None,
3273                ty: i32_ty,
3274                init: global_expressions
3275                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3276            },
3277            Default::default(),
3278        );
3279
3280        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
3281
3282        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3283        let mut solver = ConstantEvaluator {
3284            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3285            types: &mut types,
3286            constants: &constants,
3287            overrides: &overrides,
3288            expressions: &mut global_expressions,
3289            expression_kind_tracker,
3290            layouter: &mut crate::proc::Layouter::default(),
3291        };
3292
3293        let solved_compose = solver
3294            .try_eval_and_append(
3295                Expression::Splat {
3296                    size: VectorSize::Bi,
3297                    value: h_expr,
3298                },
3299                Default::default(),
3300            )
3301            .unwrap();
3302        let solved_negate = solver
3303            .try_eval_and_append(
3304                Expression::Unary {
3305                    op: UnaryOperator::Negate,
3306                    expr: solved_compose,
3307                },
3308                Default::default(),
3309            )
3310            .unwrap();
3311
3312        let pass = match global_expressions[solved_negate] {
3313            Expression::Compose { ty, ref components } => {
3314                ty == vec2_i32_ty
3315                    && components.iter().all(|&component| {
3316                        let component = &global_expressions[component];
3317                        matches!(*component, Expression::Literal(Literal::I32(-4)))
3318                    })
3319            }
3320            _ => false,
3321        };
3322        if !pass {
3323            panic!("unexpected evaluation result")
3324        }
3325    }
3326
3327    #[test]
3328    fn splat_of_zero_value() {
3329        let mut types = UniqueArena::new();
3330        let constants = Arena::new();
3331        let overrides = Arena::new();
3332        let mut global_expressions = Arena::new();
3333
3334        let f32_ty = types.insert(
3335            Type {
3336                name: None,
3337                inner: TypeInner::Scalar(crate::Scalar::F32),
3338            },
3339            Default::default(),
3340        );
3341
3342        let vec2_f32_ty = types.insert(
3343            Type {
3344                name: None,
3345                inner: TypeInner::Vector {
3346                    size: VectorSize::Bi,
3347                    scalar: crate::Scalar::F32,
3348                },
3349            },
3350            Default::default(),
3351        );
3352
3353        let five =
3354            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
3355        let five_splat = global_expressions.append(
3356            Expression::Splat {
3357                size: VectorSize::Bi,
3358                value: five,
3359            },
3360            Default::default(),
3361        );
3362        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
3363        let zero_splat = global_expressions.append(
3364            Expression::Splat {
3365                size: VectorSize::Bi,
3366                value: zero,
3367            },
3368            Default::default(),
3369        );
3370
3371        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3372        let mut solver = ConstantEvaluator {
3373            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3374            types: &mut types,
3375            constants: &constants,
3376            overrides: &overrides,
3377            expressions: &mut global_expressions,
3378            expression_kind_tracker,
3379            layouter: &mut crate::proc::Layouter::default(),
3380        };
3381
3382        let solved_add = solver
3383            .try_eval_and_append(
3384                Expression::Binary {
3385                    op: crate::BinaryOperator::Add,
3386                    left: zero_splat,
3387                    right: five_splat,
3388                },
3389                Default::default(),
3390            )
3391            .unwrap();
3392
3393        let pass = match global_expressions[solved_add] {
3394            Expression::Compose { ty, ref components } => {
3395                ty == vec2_f32_ty
3396                    && components.iter().all(|&component| {
3397                        let component = &global_expressions[component];
3398                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
3399                    })
3400            }
3401            _ => false,
3402        };
3403        if !pass {
3404            panic!("unexpected evaluation result")
3405        }
3406    }
3407}