Skip to main content

naga/proc/
constant_evaluator.rs

1// Code in this file intentionally uses `for` loops and `.push()` rather than
2// `ArrayVec::from_iter`, because the latter is monomorphized by all three of
3// the item type, the capacity, and the iterator type, which can easily bloat
4// the compiled executable (by ~260 KiB, when it was removed).
5
6use alloc::{
7    format,
8    string::{String, ToString},
9    vec,
10    vec::Vec,
11};
12use core::iter;
13
14use arrayvec::ArrayVec;
15use half::f16;
16use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero};
17
18use crate::{
19    arena::{Arena, Handle, HandleVec, UniqueArena},
20    ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
21    ScalarKind, Span, Type, TypeInner, UnaryOperator,
22};
23
24#[cfg(feature = "wgsl-in")]
25use crate::common::wgsl::TryToWgsl;
26
27/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
28/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
29///
30/// Technique stolen directly from
31/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
32macro_rules! with_dollar_sign {
33    ($($body:tt)*) => {
34        macro_rules! __with_dollar_sign { $($body)* }
35        __with_dollar_sign!($);
36    }
37}
38
39macro_rules! gen_component_wise_extractor {
40    (
41        $ident:ident -> $target:ident,
42        literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
43        scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
44    ) => {
45        /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
46        #[derive(Debug)]
47        #[cfg_attr(test, derive(PartialEq))]
48        enum $target<const N: usize> {
49            $(
50                #[doc = concat!(
51                    "Maps to [`Literal::",
52                    stringify!($literal),
53                    "`]",
54                )]
55                $mapping([$ty; N]),
56            )+
57        }
58
59        impl From<$target<1>> for Expression {
60            fn from(value: $target<1>) -> Self {
61                match value {
62                    $(
63                        $target::$mapping([value]) => {
64                            Expression::Literal(Literal::$literal(value))
65                        }
66                    )+
67                }
68            }
69        }
70
71        #[doc = concat!(
72            "Attempts to evaluate multiple `exprs` as a combined [`",
73            stringify!($target),
74            "`] to pass to `handler`. ",
75        )]
76        /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
77        /// component of each vector.
78        ///
79        /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
80        /// same length, a new vector expression is registered, composed of each component emitted
81        /// by `handler`.
82        fn $ident<const N: usize, const M: usize>(
83            eval: &mut ConstantEvaluator<'_>,
84            span: Span,
85            exprs: [Handle<Expression>; N],
86            handler: fn($target<N>) -> Result<$target<M>, ConstantEvaluatorError>,
87        ) -> Result<Handle<Expression>, ConstantEvaluatorError>
88        where
89            $target<M>: Into<Expression>,
90        {
91            assert!(N > 0);
92            let err = ConstantEvaluatorError::InvalidMathArg;
93            let mut exprs = exprs.into_iter();
94
95            macro_rules! sanitize {
96                ($expr:expr) => {
97                    eval.eval_zero_value_and_splat($expr, span)
98                        .map(|expr| &eval.expressions[expr])
99                };
100            }
101
102            let new_expr: Result<Expression, ConstantEvaluatorError> = match sanitize!(exprs.next().unwrap())? {
103                $(
104                    &Expression::Literal(Literal::$literal(x)) => {
105                        let mut arr = ArrayVec::<_, N>::new();
106                        arr.push(x);
107                        for expr in exprs {
108                            match sanitize!(expr)? {
109                                &Expression::Literal(Literal::$literal(val)) => arr.push(val),
110                                _ => return Err(err),
111                            }
112                        }
113                        let comps = $target::$mapping(arr.into_inner().unwrap());
114                        Ok(handler(comps)?.into())
115                    },
116                )+
117                &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
118                    &TypeInner::Vector { size, scalar } => match scalar.kind {
119                        $(ScalarKind::$scalar_kind)|* => {
120                            let first_ty = ty;
121                            let mut component_groups =
122                                ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
123                            {
124                                let mut inner = ArrayVec::new();
125                                for item in crate::proc::flatten_compose(
126                                    first_ty,
127                                    components,
128                                    eval.expressions,
129                                    eval.types,
130                                ) {
131                                    inner.push(item);
132                                }
133                                component_groups.push(inner);
134                            }
135                            for expr in exprs {
136                                match sanitize!(expr)? {
137                                    &Expression::Compose { ty, ref components }
138                                        if &eval.types[ty].inner
139                                            == &eval.types[first_ty].inner =>
140                                    {
141                                        let mut inner = ArrayVec::new();
142                                        for item in crate::proc::flatten_compose(
143                                            ty,
144                                            components,
145                                            eval.expressions,
146                                            eval.types,
147                                        ) {
148                                            inner.push(item);
149                                        }
150                                        component_groups.push(inner);
151                                    }
152                                    _ => return Err(err),
153                                }
154                            }
155                            let component_groups = component_groups.into_inner().unwrap();
156                            let mut new_components =
157                                ArrayVec::<_, { crate::VectorSize::MAX }>::new();
158                            for idx in 0..(size as u8).into() {
159                                let mut group_arr = ArrayVec::<_, N>::new();
160                                for cs in component_groups.iter() {
161                                    group_arr.push(
162                                        cs.get(idx).cloned().ok_or_else(|| err.clone())?,
163                                    );
164                                }
165                                let group = group_arr.into_inner().unwrap();
166                                new_components.push($ident(
167                                    eval,
168                                    span,
169                                    group,
170                                    handler,
171                                )?);
172                            }
173                            Ok(Expression::Compose {
174                                ty: first_ty,
175                                components: new_components.into_iter().collect(),
176                            })
177                        }
178                        _ => return Err(err),
179                    },
180                    _ => return Err(err),
181                },
182                _ => return Err(err),
183            };
184            eval.register_evaluated_expr(new_expr?, span)
185        }
186
187        with_dollar_sign! {
188            ($d:tt) => {
189                #[allow(unused)]
190                #[doc = concat!(
191                    "A convenience macro for using the same RHS for each [`",
192                    stringify!($target),
193                    "`] variant in a call to [`",
194                    stringify!($ident),
195                    "`].",
196                )]
197                macro_rules! $ident {
198                    (
199                        $eval:expr,
200                        $span:expr,
201                        [$d ($d expr:expr),+ $d (,)?],
202                        |$d ($d arg:ident),+| $d tt:tt
203                    ) => {
204                        $ident($eval, $span, [$d ($d expr),+], |args| match args {
205                            $(
206                                $target::$mapping([$d ($d arg),+]) => {
207                                    let res = $d tt;
208                                    Result::map(res, $target::$mapping)
209                                },
210                            )+
211                        })
212                    };
213                }
214            };
215        }
216    };
217}
218
219gen_component_wise_extractor! {
220    component_wise_scalar -> Scalar,
221    literals: [
222        AbstractFloat => AbstractFloat: f64,
223        F32 => F32: f32,
224        F16 => F16: f16,
225        AbstractInt => AbstractInt: i64,
226        U32 => U32: u32,
227        I32 => I32: i32,
228        U64 => U64: u64,
229        I64 => I64: i64,
230    ],
231    scalar_kinds: [
232        Float,
233        AbstractFloat,
234        Sint,
235        Uint,
236        AbstractInt,
237    ],
238}
239
240gen_component_wise_extractor! {
241    component_wise_float -> Float,
242    literals: [
243        AbstractFloat => Abstract: f64,
244        F32 => F32: f32,
245        F16 => F16: f16,
246    ],
247    scalar_kinds: [
248        Float,
249        AbstractFloat,
250    ],
251}
252
253gen_component_wise_extractor! {
254    component_wise_concrete_int -> ConcreteInt,
255    literals: [
256        U32 => U32: u32,
257        I32 => I32: i32,
258    ],
259    scalar_kinds: [
260        Sint,
261        Uint,
262    ],
263}
264
265gen_component_wise_extractor! {
266    component_wise_signed -> Signed,
267    literals: [
268        AbstractFloat => AbstractFloat: f64,
269        AbstractInt => AbstractInt: i64,
270        F32 => F32: f32,
271        F16 => F16: f16,
272        I32 => I32: i32,
273    ],
274    scalar_kinds: [
275        Sint,
276        AbstractInt,
277        Float,
278        AbstractFloat,
279    ],
280}
281
282/// Vectors with a concrete element type.
283#[derive(Debug)]
284enum LiteralVector {
285    F64(ArrayVec<f64, { crate::VectorSize::MAX }>),
286    F32(ArrayVec<f32, { crate::VectorSize::MAX }>),
287    F16(ArrayVec<f16, { crate::VectorSize::MAX }>),
288    U32(ArrayVec<u32, { crate::VectorSize::MAX }>),
289    I32(ArrayVec<i32, { crate::VectorSize::MAX }>),
290    U64(ArrayVec<u64, { crate::VectorSize::MAX }>),
291    I64(ArrayVec<i64, { crate::VectorSize::MAX }>),
292    Bool(ArrayVec<bool, { crate::VectorSize::MAX }>),
293    AbstractInt(ArrayVec<i64, { crate::VectorSize::MAX }>),
294    AbstractFloat(ArrayVec<f64, { crate::VectorSize::MAX }>),
295}
296
297impl LiteralVector {
298    #[allow(clippy::missing_const_for_fn, reason = "MSRV")]
299    fn len(&self) -> usize {
300        match *self {
301            LiteralVector::F64(ref v) => v.len(),
302            LiteralVector::F32(ref v) => v.len(),
303            LiteralVector::F16(ref v) => v.len(),
304            LiteralVector::U32(ref v) => v.len(),
305            LiteralVector::I32(ref v) => v.len(),
306            LiteralVector::U64(ref v) => v.len(),
307            LiteralVector::I64(ref v) => v.len(),
308            LiteralVector::Bool(ref v) => v.len(),
309            LiteralVector::AbstractInt(ref v) => v.len(),
310            LiteralVector::AbstractFloat(ref v) => v.len(),
311        }
312    }
313
314    /// Creates [`LiteralVector`] of size 1 from single [`Literal`]
315    fn from_literal(literal: Literal) -> Self {
316        fn arrayvec_of<T, const N: usize>(val: T) -> ArrayVec<T, N> {
317            let mut v = ArrayVec::new();
318            v.push(val);
319            v
320        }
321        match literal {
322            Literal::F64(e) => Self::F64(arrayvec_of(e)),
323            Literal::F32(e) => Self::F32(arrayvec_of(e)),
324            Literal::U32(e) => Self::U32(arrayvec_of(e)),
325            Literal::I32(e) => Self::I32(arrayvec_of(e)),
326            Literal::U64(e) => Self::U64(arrayvec_of(e)),
327            Literal::I64(e) => Self::I64(arrayvec_of(e)),
328            Literal::Bool(e) => Self::Bool(arrayvec_of(e)),
329            Literal::AbstractInt(e) => Self::AbstractInt(arrayvec_of(e)),
330            Literal::AbstractFloat(e) => Self::AbstractFloat(arrayvec_of(e)),
331            Literal::F16(e) => Self::F16(arrayvec_of(e)),
332        }
333    }
334
335    /// Creates [`LiteralVector`] from [`ArrayVec`] of [`Literal`]s.
336    /// Returns error if components types do not match.
337    /// # Panics
338    /// Panics if vector is empty
339    fn from_literal_vec(
340        components: ArrayVec<Literal, { crate::VectorSize::MAX }>,
341    ) -> Result<Self, ConstantEvaluatorError> {
342        assert!(!components.is_empty());
343        // TODO: should a vector of i32 be constructible from abstract int?
344        macro_rules! compose_literals {
345            ($components:expr, $variant:ident, $self_variant:ident) => {{
346                let mut out = ArrayVec::new();
347                for l in &$components {
348                    match l {
349                        &Literal::$variant(v) => out.push(v),
350                        _ => return Err(ConstantEvaluatorError::InvalidMathArg),
351                    }
352                }
353                Self::$self_variant(out)
354            }};
355        }
356        Ok(match components[0] {
357            Literal::I32(_) => compose_literals!(components, I32, I32),
358            Literal::U32(_) => compose_literals!(components, U32, U32),
359            Literal::I64(_) => compose_literals!(components, I64, I64),
360            Literal::U64(_) => compose_literals!(components, U64, U64),
361            Literal::F32(_) => compose_literals!(components, F32, F32),
362            Literal::F64(_) => compose_literals!(components, F64, F64),
363            Literal::Bool(_) => compose_literals!(components, Bool, Bool),
364            Literal::AbstractInt(_) => compose_literals!(components, AbstractInt, AbstractInt),
365            Literal::AbstractFloat(_) => {
366                compose_literals!(components, AbstractFloat, AbstractFloat)
367            }
368            Literal::F16(_) => compose_literals!(components, F16, F16),
369        })
370    }
371
372    #[allow(dead_code)]
373    /// Returns [`ArrayVec`] of [`Literal`]s
374    fn to_literal_vec(&self) -> ArrayVec<Literal, { crate::VectorSize::MAX }> {
375        macro_rules! decompose_literals {
376            ($v:expr, $variant:ident) => {{
377                let mut out = ArrayVec::new();
378                for e in $v {
379                    out.push(Literal::$variant(*e));
380                }
381                out
382            }};
383        }
384        match *self {
385            LiteralVector::F64(ref v) => decompose_literals!(v, F64),
386            LiteralVector::F32(ref v) => decompose_literals!(v, F32),
387            LiteralVector::F16(ref v) => decompose_literals!(v, F16),
388            LiteralVector::U32(ref v) => decompose_literals!(v, U32),
389            LiteralVector::I32(ref v) => decompose_literals!(v, I32),
390            LiteralVector::U64(ref v) => decompose_literals!(v, U64),
391            LiteralVector::I64(ref v) => decompose_literals!(v, I64),
392            LiteralVector::Bool(ref v) => decompose_literals!(v, Bool),
393            LiteralVector::AbstractInt(ref v) => decompose_literals!(v, AbstractInt),
394            LiteralVector::AbstractFloat(ref v) => decompose_literals!(v, AbstractFloat),
395        }
396    }
397
398    #[allow(dead_code)]
399    /// Puts self into eval's expressions arena and returns handle to it
400    fn register_as_evaluated_expr(
401        &self,
402        eval: &mut ConstantEvaluator<'_>,
403        span: Span,
404    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
405        let lit_vec = self.to_literal_vec();
406        assert!(!lit_vec.is_empty());
407        let expr = if lit_vec.len() == 1 {
408            Expression::Literal(lit_vec[0])
409        } else {
410            Expression::Compose {
411                ty: eval.types.insert(
412                    Type {
413                        name: None,
414                        inner: TypeInner::Vector {
415                            size: match lit_vec.len() {
416                                2 => crate::VectorSize::Bi,
417                                3 => crate::VectorSize::Tri,
418                                4 => crate::VectorSize::Quad,
419                                _ => unreachable!(),
420                            },
421                            scalar: lit_vec[0].scalar(),
422                        },
423                    },
424                    Span::UNDEFINED,
425                ),
426                components: lit_vec
427                    .iter()
428                    .map(|&l| eval.register_evaluated_expr(Expression::Literal(l), span))
429                    .collect::<Result<_, _>>()?,
430            }
431        };
432        eval.register_evaluated_expr(expr, span)
433    }
434}
435
436/// A macro for matching on [`LiteralVector`] variants.
437///
438/// `Float` variant expands to `F16`, `F32`, `F64` and `AbstractFloat`.
439/// `Integer` variant expands to `I32`, `I64`, `U32`, `U64` and `AbstractInt`.
440///
441/// For output both [`Literal`] (fold) and [`LiteralVector`] (map) are supported.
442///
443/// Example usage:
444///
445/// ```rust,ignore
446/// match_literal_vector!(match v => Literal {
447///     F16 => |v| {v.sum()},
448///     Integer => |v| {v.sum()},
449///     U32 => |v| -> I32 {v.sum()}, // optionally override return type
450/// })
451/// ```
452///
453/// ```rust,ignore
454/// match_literal_vector!(match (e1, e2) => LiteralVector {
455///     F16 => |e1, e2| {e1+e2},
456///     Integer => |e1, e2| {e1+e2},
457///     U32 => |e1, e2| -> I32 {e1+e2}, // optionally override return type
458/// })
459/// ```
460macro_rules! match_literal_vector {
461    (match $lit_vec:expr => $out:ident {
462        $(
463            $ty:ident => |$($var:ident),+| $(-> $ret:ident)? { $body:expr }
464        ),+
465        $(,)?
466    }) => {
467        match_literal_vector!(@inner_start $lit_vec; $out; [$($ty),+]; [$({ $($var),+ ; $($ret)? ; $body }),+])
468    };
469
470    (@inner_start
471        $lit_vec:expr;
472        $out:ident;
473        [$($ty:ident),+];
474        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
475    ) => {
476        match_literal_vector!(@inner
477            $lit_vec;
478            $out;
479            [$($ty),+];
480            [] <> [$({ $($var),+ ; $($ret)? ; $body }),+]
481        )
482    };
483
484    (@inner
485        $lit_vec:expr;
486        $out:ident;
487        [$ty:ident $(, $ty1:ident)*];
488        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
489        [$({ $($var:ident),+ ; $($ret:ident)? ; $body:expr }),+]
490    ) => {
491        match_literal_vector!(@inner
492            $ty;
493            $lit_vec;
494            $out;
495            [$($ty1),*];
496            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
497            [$({ $($var),+ ; $($ret)? ; $body }),+]
498        )
499    };
500    (@inner
501        Integer;
502        $lit_vec:expr;
503        $out:ident;
504        [$($ty:ident),*];
505        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
506        [
507            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
508            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
509        ]
510    ) => {
511        match_literal_vector!(@inner
512            $lit_vec;
513            $out;
514            [U32, I32, U64, I64, AbstractInt $(, $ty)*];
515            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
516            [
517                { $($var),+ ; $($ret)? ; $body }, // U32
518                { $($var),+ ; $($ret)? ; $body }, // I32
519                { $($var),+ ; $($ret)? ; $body }, // U64
520                { $($var),+ ; $($ret)? ; $body }, // I64
521                { $($var),+ ; $($ret)? ; $body }  // AbstractInt
522                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
523            ]
524        )
525    };
526    (@inner
527        Float;
528        $lit_vec:expr;
529        $out:ident;
530        [$($ty:ident),*];
531        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
532        [
533            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
534            $(,{ $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
535        ]
536    ) => {
537        match_literal_vector!(@inner
538            $lit_vec;
539            $out;
540            [F16, F32, F64, AbstractFloat $(, $ty)*];
541            [$({$_ty ; $($_var),+ ; $($_ret)? ; $_body}),*] <>
542            [
543                { $($var),+ ; $($ret)? ; $body }, // F16
544                { $($var),+ ; $($ret)? ; $body }, // F32
545                { $($var),+ ; $($ret)? ; $body }, // F64
546                { $($var),+ ; $($ret)? ; $body }  // AbstractFloat
547                $(,{ $($var1),+ ; $($ret1)? ; $body1 })*
548            ]
549        )
550    };
551    (@inner
552        $ty:ident;
553        $lit_vec:expr;
554        $out:ident;
555        [$ty1:ident $(,$ty2:ident)*];
556        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <> [
557            { $($var:ident),+ ; $($ret:ident)? ; $body:expr }
558            $(, { $($var1:ident),+ ; $($ret1:ident)? ; $body1:expr })*
559        ]
560    ) => {
561        match_literal_vector!(@inner
562            $ty1;
563            $lit_vec;
564            $out;
565            [$($ty2),*];
566            [
567                $({$_ty ; $($_var),+ ; $($_ret)? ; $_body},)*
568                { $ty; $($var),+ ; $($ret)? ; $body }
569            ] <>
570            [$({ $($var1),+ ; $($ret1)? ; $body1 }),*]
571
572        )
573    };
574    (@inner
575        $ty:ident;
576        $lit_vec:expr;
577        $out:ident;
578        [];
579        [$({$_ty:ident ; $($_var:ident),+ ; $($_ret:ident)? ; $_body:expr}),*] <>
580        [{ $($var:ident),+ ; $($ret:ident)? ; $body:expr }]
581    ) => {
582        match_literal_vector!(@inner_finish
583            $lit_vec;
584            $out;
585            [
586                $({ $_ty ; $($_var),+ ; $($_ret)? ; $_body },)*
587                { $ty; $($var),+ ; $($ret)? ; $body }
588            ]
589        )
590    };
591    (@inner_finish
592        $lit_vec:expr;
593        $out:ident;
594        [$({$ty:ident ; $($var:ident),+ ; $($ret:ident)? ; $body:expr}),+]
595    ) => {
596        match $lit_vec {
597            $(
598                #[allow(unused_parens)]
599                ($(LiteralVector::$ty(ref $var)),+) => { Ok(match_literal_vector!(@expand_ret $out; $ty $(; $ret)? ; $body)) }
600            )+
601            _ => Err(ConstantEvaluatorError::InvalidMathArg),
602        }
603    };
604    (@expand_ret $out:ident; $ty:ident; $body:expr) => {
605        $out::$ty($body)
606    };
607    (@expand_ret $out:ident; $_ty:ident; $ret:ident; $body:expr) => {
608        $out::$ret($body)
609    };
610}
611
612#[derive(Debug)]
613enum Behavior<'a> {
614    Wgsl(WgslRestrictions<'a>),
615    Glsl(GlslRestrictions<'a>),
616}
617
618impl Behavior<'_> {
619    /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
620    const fn has_runtime_restrictions(&self) -> bool {
621        matches!(
622            self,
623            &Behavior::Wgsl(WgslRestrictions::Runtime(_))
624                | &Behavior::Glsl(GlslRestrictions::Runtime(_))
625        )
626    }
627}
628
629/// A context for evaluating constant expressions.
630///
631/// A `ConstantEvaluator` points at an expression arena to which it can append
632/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
633/// of Naga [`Expression`] you like, and if its value can be computed at compile
634/// time, `try_eval_and_append` appends an expression representing the computed
635/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
636/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
637///
638/// A `ConstantEvaluator` also holds whatever information we need to carry out
639/// that evaluation: types, other constants, and so on.
640///
641/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
642/// [`Compose`]: Expression::Compose
643/// [`ZeroValue`]: Expression::ZeroValue
644/// [`Literal`]: Expression::Literal
645/// [`Swizzle`]: Expression::Swizzle
646#[derive(Debug)]
647pub struct ConstantEvaluator<'a> {
648    /// Which language's evaluation rules we should follow.
649    behavior: Behavior<'a>,
650
651    /// The module's type arena.
652    ///
653    /// Because expressions like [`Splat`] contain type handles, we need to be
654    /// able to add new types to produce those expressions.
655    ///
656    /// [`Splat`]: Expression::Splat
657    types: &'a mut UniqueArena<Type>,
658
659    /// The module's constant arena.
660    constants: &'a Arena<Constant>,
661
662    /// The module's override arena.
663    overrides: &'a Arena<Override>,
664
665    /// The arena to which we are contributing expressions.
666    expressions: &'a mut Arena<Expression>,
667
668    /// Tracks the constness of expressions residing in [`Self::expressions`]
669    expression_kind_tracker: &'a mut ExpressionKindTracker,
670
671    layouter: &'a mut crate::proc::Layouter,
672}
673
674#[derive(Debug)]
675enum WgslRestrictions<'a> {
676    /// - const-expressions will be evaluated and inserted in the arena
677    Const(Option<FunctionLocalData<'a>>),
678    /// - const-expressions will be evaluated and inserted in the arena
679    /// - override-expressions will be inserted in the arena
680    Override,
681    /// - const-expressions will be evaluated and inserted in the arena
682    /// - override-expressions will be inserted in the arena
683    /// - runtime-expressions will be inserted in the arena
684    Runtime(FunctionLocalData<'a>),
685}
686
687#[derive(Debug)]
688enum GlslRestrictions<'a> {
689    /// - const-expressions will be evaluated and inserted in the arena
690    Const,
691    /// - const-expressions will be evaluated and inserted in the arena
692    /// - override-expressions will be inserted in the arena
693    /// - runtime-expressions will be inserted in the arena
694    Runtime(FunctionLocalData<'a>),
695}
696
697#[derive(Debug)]
698struct FunctionLocalData<'a> {
699    /// Global constant expressions
700    global_expressions: &'a Arena<Expression>,
701    emitter: &'a mut super::Emitter,
702    block: &'a mut crate::Block,
703}
704
705#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
706pub enum ExpressionKind {
707    Const,
708    Override,
709    Runtime,
710}
711
712#[derive(Debug)]
713pub struct ExpressionKindTracker {
714    inner: HandleVec<Expression, ExpressionKind>,
715}
716
717impl ExpressionKindTracker {
718    pub const fn new() -> Self {
719        Self {
720            inner: HandleVec::new(),
721        }
722    }
723
724    /// Forces the the expression to not be const
725    pub fn force_non_const(&mut self, value: Handle<Expression>) {
726        self.inner[value] = ExpressionKind::Runtime;
727    }
728
729    pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
730        self.inner.insert(value, expr_type);
731    }
732
733    pub fn is_const(&self, h: Handle<Expression>) -> bool {
734        matches!(self.type_of(h), ExpressionKind::Const)
735    }
736
737    pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
738        matches!(
739            self.type_of(h),
740            ExpressionKind::Const | ExpressionKind::Override
741        )
742    }
743
744    fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
745        self.inner[value]
746    }
747
748    pub fn from_arena(arena: &Arena<Expression>) -> Self {
749        let mut tracker = Self {
750            inner: HandleVec::with_capacity(arena.len()),
751        };
752        for (handle, expr) in arena.iter() {
753            tracker
754                .inner
755                .insert(handle, tracker.type_of_with_expr(expr));
756        }
757        tracker
758    }
759
760    fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
761        match *expr {
762            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
763                ExpressionKind::Const
764            }
765            Expression::Override(_) => ExpressionKind::Override,
766            Expression::Compose { ref components, .. } => {
767                let mut expr_type = ExpressionKind::Const;
768                for component in components {
769                    expr_type = expr_type.max(self.type_of(*component))
770                }
771                expr_type
772            }
773            Expression::Splat { value, .. } => self.type_of(value),
774            Expression::AccessIndex { base, .. } => self.type_of(base),
775            Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
776            Expression::Swizzle { vector, .. } => self.type_of(vector),
777            Expression::Unary { expr, .. } => self.type_of(expr),
778            Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
779            Expression::Math {
780                arg,
781                arg1,
782                arg2,
783                arg3,
784                ..
785            } => self
786                .type_of(arg)
787                .max(
788                    arg1.map(|arg| self.type_of(arg))
789                        .unwrap_or(ExpressionKind::Const),
790                )
791                .max(
792                    arg2.map(|arg| self.type_of(arg))
793                        .unwrap_or(ExpressionKind::Const),
794                )
795                .max(
796                    arg3.map(|arg| self.type_of(arg))
797                        .unwrap_or(ExpressionKind::Const),
798                ),
799            Expression::As { expr, .. } => self.type_of(expr),
800            Expression::Select {
801                condition,
802                accept,
803                reject,
804            } => self
805                .type_of(condition)
806                .max(self.type_of(accept))
807                .max(self.type_of(reject)),
808            Expression::Relational { argument, .. } => self.type_of(argument),
809            Expression::ArrayLength(expr) => self.type_of(expr),
810            _ => ExpressionKind::Runtime,
811        }
812    }
813}
814
815#[derive(Clone, Debug, thiserror::Error)]
816#[cfg_attr(test, derive(PartialEq))]
817pub enum ConstantEvaluatorError {
818    #[error("Constants cannot access function arguments")]
819    FunctionArg,
820    #[error("Constants cannot access global variables")]
821    GlobalVariable,
822    #[error("Constants cannot access local variables")]
823    LocalVariable,
824    #[error("Cannot get the array length of a non array type")]
825    InvalidArrayLengthArg,
826    #[error("Constants cannot get the array length of a dynamically sized array")]
827    ArrayLengthDynamic,
828    #[error("Cannot call arrayLength on array sized by override-expression")]
829    ArrayLengthOverridden,
830    #[error("Constants cannot call functions")]
831    Call,
832    #[error("Constants don't support workGroupUniformLoad")]
833    WorkGroupUniformLoadResult,
834    #[error("Constants don't support atomic functions")]
835    Atomic,
836    #[error("Constants don't support derivative functions")]
837    Derivative,
838    #[error("Constants don't support load expressions")]
839    Load,
840    #[error("Constants don't support image expressions")]
841    ImageExpression,
842    #[error("Constants don't support ray query expressions")]
843    RayQueryExpression,
844    #[error("Constants don't support subgroup expressions")]
845    SubgroupExpression,
846    #[error("Cannot access the type")]
847    InvalidAccessBase,
848    #[error("Cannot access at the index")]
849    InvalidAccessIndex,
850    #[error("Cannot access with index of type")]
851    InvalidAccessIndexTy,
852    #[error("Constants don't support array length expressions")]
853    ArrayLength,
854    #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
855    InvalidCastArg { from: String, to: String },
856    #[error("Cannot apply the unary op to the argument")]
857    InvalidUnaryOpArg,
858    #[error("Cannot apply the binary op to the arguments")]
859    InvalidBinaryOpArgs,
860    #[error("Cannot apply math function to type")]
861    InvalidMathArg,
862    #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
863    InvalidMathArgCount(crate::MathFunction, usize, usize),
864    #[error("{0} built-in function argument is out of valid range")]
865    InvalidMathArgValue(String),
866    #[error("Cannot apply relational function to type")]
867    InvalidRelationalArg(RelationalFunction),
868    #[error("value of `low` is greater than `high` for clamp built-in function")]
869    InvalidClamp,
870    #[error("Constructor expects {expected} components, found {actual}")]
871    InvalidVectorComposeLength { expected: usize, actual: usize },
872    #[error("Constructor must only contain vector or scalar arguments")]
873    InvalidVectorComposeComponent,
874    #[error("Splat is defined only on scalar values")]
875    SplatScalarOnly,
876    #[error("Can only swizzle vector constants")]
877    SwizzleVectorOnly,
878    #[error("swizzle component not present in source expression")]
879    SwizzleOutOfBounds,
880    #[error("Type is not constructible")]
881    TypeNotConstructible,
882    #[error("Subexpression(s) are not constant")]
883    SubexpressionsAreNotConstant,
884    #[error("Not implemented as constant expression: {0}")]
885    NotImplemented(String),
886    #[error("{0} operation overflowed")]
887    Overflow(String),
888    #[error(
889        "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
890    )]
891    AutomaticConversionLossy {
892        value: String,
893        to_type: &'static str,
894    },
895    #[error("Division by zero")]
896    DivisionByZero,
897    #[error("Remainder by zero")]
898    RemainderByZero,
899    #[error("RHS of shift operation is greater than or equal to 32")]
900    ShiftedMoreThan32Bits,
901    #[error(transparent)]
902    Literal(#[from] crate::valid::LiteralError),
903    #[error("Can't use pipeline-overridable constants in const-expressions")]
904    Override,
905    #[error("Unexpected runtime-expression")]
906    RuntimeExpr,
907    #[error("Unexpected override-expression")]
908    OverrideExpr,
909    #[error("Expected boolean expression for condition argument of `select`, got something else")]
910    SelectScalarConditionNotABool,
911    #[error(
912        "Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
913        reject,
914        accept
915    )]
916    SelectVecRejectAcceptSizeMismatch {
917        reject: crate::VectorSize,
918        accept: crate::VectorSize,
919    },
920    #[error("Expected boolean vector for condition arg., got something else")]
921    SelectConditionNotAVecBool,
922    #[error(
923        "Expected same number of vector components between condition, accept, and reject args., got something else",
924    )]
925    SelectConditionVecSizeMismatch,
926    #[error(
927        "Expected reject and accept args. to be scalars of vectors of the same type, got something else",
928    )]
929    SelectAcceptRejectTypeMismatch,
930    #[error("Cooperative operations can't be constant")]
931    CooperativeOperation,
932}
933
934impl<'a> ConstantEvaluator<'a> {
935    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
936    /// constant expression arena.
937    ///
938    /// Report errors according to WGSL's rules for constant evaluation.
939    pub const fn for_wgsl_module(
940        module: &'a mut crate::Module,
941        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
942        layouter: &'a mut crate::proc::Layouter,
943        in_override_ctx: bool,
944    ) -> Self {
945        Self::for_module(
946            Behavior::Wgsl(if in_override_ctx {
947                WgslRestrictions::Override
948            } else {
949                WgslRestrictions::Const(None)
950            }),
951            module,
952            global_expression_kind_tracker,
953            layouter,
954        )
955    }
956
957    /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
958    /// constant expression arena.
959    ///
960    /// Report errors according to GLSL's rules for constant evaluation.
961    pub const fn for_glsl_module(
962        module: &'a mut crate::Module,
963        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
964        layouter: &'a mut crate::proc::Layouter,
965    ) -> Self {
966        Self::for_module(
967            Behavior::Glsl(GlslRestrictions::Const),
968            module,
969            global_expression_kind_tracker,
970            layouter,
971        )
972    }
973
974    const fn for_module(
975        behavior: Behavior<'a>,
976        module: &'a mut crate::Module,
977        global_expression_kind_tracker: &'a mut ExpressionKindTracker,
978        layouter: &'a mut crate::proc::Layouter,
979    ) -> Self {
980        Self {
981            behavior,
982            types: &mut module.types,
983            constants: &module.constants,
984            overrides: &module.overrides,
985            expressions: &mut module.global_expressions,
986            expression_kind_tracker: global_expression_kind_tracker,
987            layouter,
988        }
989    }
990
991    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
992    /// expression arena.
993    ///
994    /// Report errors according to WGSL's rules for constant evaluation.
995    pub const fn for_wgsl_function(
996        module: &'a mut crate::Module,
997        expressions: &'a mut Arena<Expression>,
998        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
999        layouter: &'a mut crate::proc::Layouter,
1000        emitter: &'a mut super::Emitter,
1001        block: &'a mut crate::Block,
1002        is_const: bool,
1003    ) -> Self {
1004        let local_data = FunctionLocalData {
1005            global_expressions: &module.global_expressions,
1006            emitter,
1007            block,
1008        };
1009        Self {
1010            behavior: Behavior::Wgsl(if is_const {
1011                WgslRestrictions::Const(Some(local_data))
1012            } else {
1013                WgslRestrictions::Runtime(local_data)
1014            }),
1015            types: &mut module.types,
1016            constants: &module.constants,
1017            overrides: &module.overrides,
1018            expressions,
1019            expression_kind_tracker: local_expression_kind_tracker,
1020            layouter,
1021        }
1022    }
1023
1024    /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
1025    /// expression arena.
1026    ///
1027    /// Report errors according to GLSL's rules for constant evaluation.
1028    pub const fn for_glsl_function(
1029        module: &'a mut crate::Module,
1030        expressions: &'a mut Arena<Expression>,
1031        local_expression_kind_tracker: &'a mut ExpressionKindTracker,
1032        layouter: &'a mut crate::proc::Layouter,
1033        emitter: &'a mut super::Emitter,
1034        block: &'a mut crate::Block,
1035    ) -> Self {
1036        Self {
1037            behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
1038                global_expressions: &module.global_expressions,
1039                emitter,
1040                block,
1041            })),
1042            types: &mut module.types,
1043            constants: &module.constants,
1044            overrides: &module.overrides,
1045            expressions,
1046            expression_kind_tracker: local_expression_kind_tracker,
1047            layouter,
1048        }
1049    }
1050
1051    pub const fn to_ctx(&self) -> crate::proc::GlobalCtx<'_> {
1052        crate::proc::GlobalCtx {
1053            types: self.types,
1054            constants: self.constants,
1055            overrides: self.overrides,
1056            global_expressions: match self.function_local_data() {
1057                Some(data) => data.global_expressions,
1058                None => self.expressions,
1059            },
1060        }
1061    }
1062
1063    fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
1064        if !self.expression_kind_tracker.is_const(expr) {
1065            log::debug!("check: SubexpressionsAreNotConstant");
1066            return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
1067        }
1068        Ok(())
1069    }
1070
1071    fn check_and_get(
1072        &mut self,
1073        expr: Handle<Expression>,
1074    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1075        match self.expressions[expr] {
1076            Expression::Constant(c) => {
1077                // Are we working in a function's expression arena, or the
1078                // module's constant expression arena?
1079                if let Some(function_local_data) = self.function_local_data() {
1080                    // Deep-copy the constant's value into our arena.
1081                    self.copy_from(
1082                        self.constants[c].init,
1083                        function_local_data.global_expressions,
1084                    )
1085                } else {
1086                    // "See through" the constant and use its initializer.
1087                    Ok(self.constants[c].init)
1088                }
1089            }
1090            _ => {
1091                self.check(expr)?;
1092                Ok(expr)
1093            }
1094        }
1095    }
1096
1097    /// Try to evaluate `expr` at compile time.
1098    ///
1099    /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
1100    /// we can determine its value at compile time, we append an expression
1101    /// representing its value - a tree of [`Literal`], [`Compose`],
1102    /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
1103    /// `self` contributes to.
1104    ///
1105    /// If `expr`'s value cannot be determined at compile time, and `self` is
1106    /// contributing to some function's expression arena, then append `expr` to
1107    /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
1108    /// contributing to the module's constant expression arena; since `expr`'s
1109    /// value is not a constant, return an error.
1110    ///
1111    /// We only consider `expr` itself, without recursing into its operands. Its
1112    /// operands must all have been produced by prior calls to
1113    /// `try_eval_and_append`, to ensure that they have already been reduced to
1114    /// an evaluated form if possible.
1115    ///
1116    /// [`Literal`]: Expression::Literal
1117    /// [`Compose`]: Expression::Compose
1118    /// [`ZeroValue`]: Expression::ZeroValue
1119    /// [`Swizzle`]: Expression::Swizzle
1120    pub fn try_eval_and_append(
1121        &mut self,
1122        expr: Expression,
1123        span: Span,
1124    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1125        match self.expression_kind_tracker.type_of_with_expr(&expr) {
1126            ExpressionKind::Const => {
1127                let eval_result = self.try_eval_and_append_impl(&expr, span);
1128                // We should be able to evaluate `Const` expressions at this
1129                // point. If we failed to, then that probably means we just
1130                // haven't implemented that part of constant evaluation. Work
1131                // around this by simply emitting it as a run-time expression.
1132                if self.behavior.has_runtime_restrictions()
1133                    && matches!(
1134                        eval_result,
1135                        Err(ConstantEvaluatorError::NotImplemented(_)
1136                            | ConstantEvaluatorError::InvalidBinaryOpArgs,)
1137                    )
1138                {
1139                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1140                } else {
1141                    eval_result
1142                }
1143            }
1144            ExpressionKind::Override => match self.behavior {
1145                Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
1146                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1147                }
1148                Behavior::Wgsl(WgslRestrictions::Const(_)) => {
1149                    Err(ConstantEvaluatorError::OverrideExpr)
1150                }
1151
1152                // GLSL specialization constants (constant_id) become Override expressions
1153                Behavior::Glsl(GlslRestrictions::Runtime(_)) => {
1154                    Ok(self.append_expr(expr, span, ExpressionKind::Override))
1155                }
1156                Behavior::Glsl(GlslRestrictions::Const) => {
1157                    Err(ConstantEvaluatorError::OverrideExpr)
1158                }
1159            },
1160            ExpressionKind::Runtime => {
1161                if self.behavior.has_runtime_restrictions() {
1162                    Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
1163                } else {
1164                    Err(ConstantEvaluatorError::RuntimeExpr)
1165                }
1166            }
1167        }
1168    }
1169
1170    /// Is the [`Self::expressions`] arena the global module expression arena?
1171    const fn is_global_arena(&self) -> bool {
1172        matches!(
1173            self.behavior,
1174            Behavior::Wgsl(WgslRestrictions::Const(None) | WgslRestrictions::Override)
1175                | Behavior::Glsl(GlslRestrictions::Const)
1176        )
1177    }
1178
1179    const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
1180        match self.behavior {
1181            Behavior::Wgsl(
1182                WgslRestrictions::Runtime(ref function_local_data)
1183                | WgslRestrictions::Const(Some(ref function_local_data)),
1184            )
1185            | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
1186                Some(function_local_data)
1187            }
1188            _ => None,
1189        }
1190    }
1191
1192    fn try_eval_and_append_impl(
1193        &mut self,
1194        expr: &Expression,
1195        span: Span,
1196    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1197        log::trace!("try_eval_and_append: {expr:?}");
1198        match *expr {
1199            Expression::Constant(c) if self.is_global_arena() => {
1200                // "See through" the constant and use its initializer.
1201                // This is mainly done to avoid having constants pointing to other constants.
1202                Ok(self.constants[c].init)
1203            }
1204            Expression::Override(_) => Err(ConstantEvaluatorError::Override),
1205            Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
1206                self.register_evaluated_expr(expr.clone(), span)
1207            }
1208            Expression::Compose { ty, ref components } => {
1209                let components = components
1210                    .iter()
1211                    .map(|component| self.check_and_get(*component))
1212                    .collect::<Result<Vec<_>, _>>()?;
1213                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
1214            }
1215            Expression::Splat { size, value } => {
1216                let value = self.check_and_get(value)?;
1217                self.register_evaluated_expr(Expression::Splat { size, value }, span)
1218            }
1219            Expression::AccessIndex { base, index } => {
1220                let base = self.check_and_get(base)?;
1221
1222                self.access(base, index as usize, span)
1223            }
1224            Expression::Access { base, index } => {
1225                let base = self.check_and_get(base)?;
1226                let index = self.check_and_get(index)?;
1227
1228                let index_val: u32 = self
1229                    .to_ctx()
1230                    .get_const_val_from(index, self.expressions)
1231                    .map_err(|_| ConstantEvaluatorError::InvalidAccessIndexTy)?;
1232                self.access(base, index_val as usize, span)
1233            }
1234            Expression::Swizzle {
1235                size,
1236                vector,
1237                pattern,
1238            } => {
1239                let vector = self.check_and_get(vector)?;
1240
1241                self.swizzle(size, span, vector, pattern)
1242            }
1243            Expression::Unary { expr, op } => {
1244                let expr = self.check_and_get(expr)?;
1245
1246                self.unary_op(op, expr, span)
1247            }
1248            Expression::Binary { left, right, op } => {
1249                let left = self.check_and_get(left)?;
1250                let right = self.check_and_get(right)?;
1251
1252                self.binary_op(op, left, right, span)
1253            }
1254            Expression::Math {
1255                fun,
1256                arg,
1257                arg1,
1258                arg2,
1259                arg3,
1260            } => {
1261                let arg = self.check_and_get(arg)?;
1262                let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
1263                let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
1264                let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
1265
1266                self.math(arg, arg1, arg2, arg3, fun, span)
1267            }
1268            Expression::As {
1269                convert,
1270                expr,
1271                kind,
1272            } => {
1273                let expr = self.check_and_get(expr)?;
1274
1275                match convert {
1276                    Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
1277                    None => Err(ConstantEvaluatorError::NotImplemented(
1278                        "bitcast built-in function".into(),
1279                    )),
1280                }
1281            }
1282            Expression::Select {
1283                reject,
1284                accept,
1285                condition,
1286            } => {
1287                let mut arg = |expr| self.check_and_get(expr);
1288
1289                let reject = arg(reject)?;
1290                let accept = arg(accept)?;
1291                let condition = arg(condition)?;
1292
1293                self.select(reject, accept, condition, span)
1294            }
1295            Expression::Relational { fun, argument } => {
1296                let argument = self.check_and_get(argument)?;
1297                self.relational(fun, argument, span)
1298            }
1299            Expression::ArrayLength(expr) => match self.behavior {
1300                Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
1301                Behavior::Glsl(_) => {
1302                    let expr = self.check_and_get(expr)?;
1303                    self.array_length(expr, span)
1304                }
1305            },
1306            Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
1307            Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
1308            Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
1309            Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
1310            Expression::WorkGroupUniformLoadResult { .. } => {
1311                Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
1312            }
1313            Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
1314            Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
1315            Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
1316            Expression::ImageSample { .. }
1317            | Expression::ImageLoad { .. }
1318            | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
1319            Expression::RayQueryProceedResult
1320            | Expression::RayQueryGetIntersection { .. }
1321            | Expression::RayQueryVertexPositions { .. } => {
1322                Err(ConstantEvaluatorError::RayQueryExpression)
1323            }
1324            Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupExpression),
1325            Expression::SubgroupOperationResult { .. } => {
1326                Err(ConstantEvaluatorError::SubgroupExpression)
1327            }
1328            Expression::CooperativeLoad { .. } | Expression::CooperativeMultiplyAdd { .. } => {
1329                Err(ConstantEvaluatorError::CooperativeOperation)
1330            }
1331        }
1332    }
1333
1334    /// Splat `value` to `size`, without using [`Splat`] expressions.
1335    ///
1336    /// This constructs [`Compose`] or [`ZeroValue`] expressions to
1337    /// build a vector with the given `size` whose components are all
1338    /// `value`.
1339    ///
1340    /// Use `span` as the span of the inserted expressions and
1341    /// resulting types.
1342    ///
1343    /// [`Splat`]: Expression::Splat
1344    /// [`Compose`]: Expression::Compose
1345    /// [`ZeroValue`]: Expression::ZeroValue
1346    fn splat(
1347        &mut self,
1348        value: Handle<Expression>,
1349        size: crate::VectorSize,
1350        span: Span,
1351    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1352        match self.expressions[value] {
1353            Expression::Literal(literal) => {
1354                let scalar = literal.scalar();
1355                let ty = self.types.insert(
1356                    Type {
1357                        name: None,
1358                        inner: TypeInner::Vector { size, scalar },
1359                    },
1360                    span,
1361                );
1362                let expr = Expression::Compose {
1363                    ty,
1364                    components: vec![value; size as usize],
1365                };
1366                self.register_evaluated_expr(expr, span)
1367            }
1368            Expression::ZeroValue(ty) => {
1369                let inner = match self.types[ty].inner {
1370                    TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
1371                    _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
1372                };
1373                let res_ty = self.types.insert(Type { name: None, inner }, span);
1374                let expr = Expression::ZeroValue(res_ty);
1375                self.register_evaluated_expr(expr, span)
1376            }
1377            _ => Err(ConstantEvaluatorError::SplatScalarOnly),
1378        }
1379    }
1380
1381    fn swizzle(
1382        &mut self,
1383        size: crate::VectorSize,
1384        span: Span,
1385        src_constant: Handle<Expression>,
1386        pattern: [crate::SwizzleComponent; 4],
1387    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1388        let mut get_dst_ty = |ty| match self.types[ty].inner {
1389            TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
1390                Type {
1391                    name: None,
1392                    inner: TypeInner::Vector { size, scalar },
1393                },
1394                span,
1395            )),
1396            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1397        };
1398
1399        match self.expressions[src_constant] {
1400            Expression::ZeroValue(ty) => {
1401                let dst_ty = get_dst_ty(ty)?;
1402                let expr = Expression::ZeroValue(dst_ty);
1403                self.register_evaluated_expr(expr, span)
1404            }
1405            Expression::Splat { value, .. } => {
1406                let expr = Expression::Splat { size, value };
1407                self.register_evaluated_expr(expr, span)
1408            }
1409            Expression::Compose { ty, ref components } => {
1410                let dst_ty = get_dst_ty(ty)?;
1411
1412                let mut flattened = [src_constant; 4]; // dummy value
1413                let len =
1414                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
1415                        .zip(flattened.iter_mut())
1416                        .map(|(component, elt)| *elt = component)
1417                        .count();
1418                let flattened = &flattened[..len];
1419
1420                let swizzled_components = pattern[..size as usize]
1421                    .iter()
1422                    .map(|&sc| {
1423                        let sc = sc as usize;
1424                        if let Some(elt) = flattened.get(sc) {
1425                            Ok(*elt)
1426                        } else {
1427                            Err(ConstantEvaluatorError::SwizzleOutOfBounds)
1428                        }
1429                    })
1430                    .collect::<Result<Vec<Handle<Expression>>, _>>()?;
1431                let expr = Expression::Compose {
1432                    ty: dst_ty,
1433                    components: swizzled_components,
1434                };
1435                self.register_evaluated_expr(expr, span)
1436            }
1437            _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
1438        }
1439    }
1440
1441    fn math(
1442        &mut self,
1443        arg: Handle<Expression>,
1444        arg1: Option<Handle<Expression>>,
1445        arg2: Option<Handle<Expression>>,
1446        arg3: Option<Handle<Expression>>,
1447        fun: crate::MathFunction,
1448        span: Span,
1449    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1450        let expected = fun.argument_count();
1451        let given = Some(arg)
1452            .into_iter()
1453            .chain(arg1)
1454            .chain(arg2)
1455            .chain(arg3)
1456            .count();
1457        if expected != given {
1458            return Err(ConstantEvaluatorError::InvalidMathArgCount(
1459                fun, expected, given,
1460            ));
1461        }
1462
1463        // NOTE: We try to match the declaration order of `MathFunction` here.
1464        match fun {
1465            // comparison
1466            crate::MathFunction::Abs => {
1467                component_wise_scalar(self, span, [arg], |args| match args {
1468                    Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
1469                    Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
1470                    Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])),
1471                    Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.wrapping_abs()])),
1472                    Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
1473                    Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
1474                    Scalar::I64([e]) => Ok(Scalar::I64([e.wrapping_abs()])),
1475                    Scalar::U64([e]) => Ok(Scalar::U64([e])),
1476                })
1477            }
1478            crate::MathFunction::Min => {
1479                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1480                    Ok([e1.min(e2)])
1481                })
1482            }
1483            crate::MathFunction::Max => {
1484                component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1485                    Ok([e1.max(e2)])
1486                })
1487            }
1488            crate::MathFunction::Clamp => {
1489                component_wise_scalar!(
1490                    self,
1491                    span,
1492                    [arg, arg1.unwrap(), arg2.unwrap()],
1493                    |e, low, high| {
1494                        if low > high {
1495                            Err(ConstantEvaluatorError::InvalidClamp)
1496                        } else {
1497                            Ok([e.clamp(low, high)])
1498                        }
1499                    }
1500                )
1501            }
1502            crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e {
1503                Float::F16([e]) => Ok(Float::F16(
1504                    [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))],
1505                )),
1506                Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])),
1507                Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])),
1508            }),
1509
1510            // trigonometry
1511            crate::MathFunction::Cos => {
1512                component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
1513            }
1514            crate::MathFunction::Cosh => {
1515                component_wise_float!(self, span, [arg], |e| {
1516                    let result = e.cosh();
1517                    if result.is_finite() {
1518                        Ok([result])
1519                    } else {
1520                        Err(ConstantEvaluatorError::Overflow("cosh".into()))
1521                    }
1522                })
1523            }
1524            crate::MathFunction::Sin => {
1525                component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
1526            }
1527            crate::MathFunction::Sinh => {
1528                component_wise_float!(self, span, [arg], |e| {
1529                    let result = e.sinh();
1530                    if result.is_finite() {
1531                        Ok([result])
1532                    } else {
1533                        Err(ConstantEvaluatorError::Overflow("sinh".into()))
1534                    }
1535                })
1536            }
1537            crate::MathFunction::Tan => {
1538                component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
1539            }
1540            crate::MathFunction::Tanh => {
1541                component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
1542            }
1543            crate::MathFunction::Acos => {
1544                component_wise_float!(self, span, [arg], |e| {
1545                    if e.abs() <= One::one() {
1546                        Ok([e.acos()])
1547                    } else {
1548                        Err(ConstantEvaluatorError::InvalidMathArgValue("acos".into()))
1549                    }
1550                })
1551            }
1552            crate::MathFunction::Asin => {
1553                component_wise_float!(self, span, [arg], |e| {
1554                    if e.abs() <= One::one() {
1555                        Ok([e.asin()])
1556                    } else {
1557                        Err(ConstantEvaluatorError::InvalidMathArgValue("asin".into()))
1558                    }
1559                })
1560            }
1561            crate::MathFunction::Atan => {
1562                component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
1563            }
1564            crate::MathFunction::Atan2 => {
1565                component_wise_float!(self, span, [arg, arg1.unwrap()], |y, x| {
1566                    Ok([y.atan2(x)])
1567                })
1568            }
1569            crate::MathFunction::Asinh => component_wise_float(self, span, [arg], |e| match e {
1570                Float::Abstract([e]) => Ok(Float::Abstract([libm::asinh(e)])),
1571                Float::F32([e]) => Ok(Float::F32([(e as f64).asinh() as f32])),
1572                Float::F16([e]) => Ok(Float::F16([e.asinh()])),
1573            }),
1574            crate::MathFunction::Acosh => {
1575                component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
1576            }
1577            crate::MathFunction::Atanh => {
1578                component_wise_float!(self, span, [arg], |e| {
1579                    if e.abs() < One::one() {
1580                        Ok([e.atanh()])
1581                    } else {
1582                        Err(ConstantEvaluatorError::InvalidMathArgValue("atanh".into()))
1583                    }
1584                })
1585            }
1586            crate::MathFunction::Radians => {
1587                component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
1588            }
1589            crate::MathFunction::Degrees => {
1590                component_wise_float!(self, span, [arg], |e| {
1591                    let result = e.to_degrees();
1592                    if result.is_finite() {
1593                        Ok([result])
1594                    } else {
1595                        Err(ConstantEvaluatorError::Overflow("degrees".into()))
1596                    }
1597                })
1598            }
1599
1600            // decomposition
1601            crate::MathFunction::Ceil => {
1602                component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
1603            }
1604            crate::MathFunction::Floor => {
1605                component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
1606            }
1607            crate::MathFunction::Round => {
1608                component_wise_float(self, span, [arg], |e| match e {
1609                    Float::Abstract([e]) => Ok(Float::Abstract([libm::rint(e)])),
1610                    Float::F32([e]) => Ok(Float::F32([libm::rintf(e)])),
1611                    Float::F16([e]) => {
1612                        // TODO: `round_ties_even` is not available on `half::f16` yet.
1613                        //
1614                        // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
1615                        // which has licensing compatible with ours. See also
1616                        // <https://github.com/rust-lang/rust/issues/96710>.
1617                        //
1618                        // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
1619                        fn round_ties_even(x: f64) -> f64 {
1620                            let i = x as i64;
1621                            let f = (x - i as f64).abs();
1622                            if f == 0.5 {
1623                                if i & 1 == 1 {
1624                                    // -1.5, 1.5, 3.5, ...
1625                                    (x.abs() + 0.5).copysign(x)
1626                                } else {
1627                                    (x.abs() - 0.5).copysign(x)
1628                                }
1629                            } else {
1630                                x.round()
1631                            }
1632                        }
1633
1634                        Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))]))
1635                    }
1636                })
1637            }
1638            crate::MathFunction::Fract => {
1639                component_wise_float!(self, span, [arg], |e| {
1640                    // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
1641                    // here.
1642                    Ok([e - e.floor()])
1643                })
1644            }
1645            crate::MathFunction::Trunc => {
1646                component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
1647            }
1648
1649            // exponent
1650            crate::MathFunction::Exp => {
1651                component_wise_float!(self, span, [arg], |e| {
1652                    let result = e.exp();
1653                    if result.is_finite() {
1654                        Ok([result])
1655                    } else {
1656                        Err(ConstantEvaluatorError::Overflow("exp".into()))
1657                    }
1658                })
1659            }
1660            crate::MathFunction::Exp2 => {
1661                component_wise_float!(self, span, [arg], |e| {
1662                    let result = e.exp2();
1663                    if result.is_finite() {
1664                        Ok([result])
1665                    } else {
1666                        Err(ConstantEvaluatorError::Overflow("exp2".into()))
1667                    }
1668                })
1669            }
1670            crate::MathFunction::Log => {
1671                component_wise_float!(self, span, [arg], |e| {
1672                    if e > Zero::zero() {
1673                        Ok([e.ln()])
1674                    } else {
1675                        Err(ConstantEvaluatorError::InvalidMathArgValue("log".into()))
1676                    }
1677                })
1678            }
1679            crate::MathFunction::Log2 => {
1680                component_wise_float!(self, span, [arg], |e| {
1681                    if e > Zero::zero() {
1682                        Ok([e.log2()])
1683                    } else {
1684                        Err(ConstantEvaluatorError::InvalidMathArgValue("log2".into()))
1685                    }
1686                })
1687            }
1688            crate::MathFunction::Pow => {
1689                component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
1690                    Ok([e1.powf(e2)])
1691                })
1692            }
1693
1694            // computational
1695            crate::MathFunction::Sign => {
1696                component_wise_signed!(self, span, [arg], |e| {
1697                    Ok([if e.is_zero() {
1698                        Zero::zero()
1699                    } else {
1700                        e.signum()
1701                    }])
1702                })
1703            }
1704            crate::MathFunction::Fma => {
1705                component_wise_float!(
1706                    self,
1707                    span,
1708                    [arg, arg1.unwrap(), arg2.unwrap()],
1709                    |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
1710                )
1711            }
1712            crate::MathFunction::Step => {
1713                component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x {
1714                    Float::Abstract([edge, x]) => {
1715                        Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }]))
1716                    }
1717                    Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])),
1718                    Float::F16([edge, x]) => Ok(Float::F16([if edge <= x {
1719                        f16::one()
1720                    } else {
1721                        f16::zero()
1722                    }])),
1723                })
1724            }
1725            crate::MathFunction::Sqrt => {
1726                component_wise_float!(self, span, [arg], |e| {
1727                    if e >= Zero::zero() {
1728                        Ok([e.sqrt()])
1729                    } else {
1730                        Err(ConstantEvaluatorError::InvalidMathArgValue("sqrt".into()))
1731                    }
1732                })
1733            }
1734            crate::MathFunction::InverseSqrt => {
1735                component_wise_float(self, span, [arg], |e| match e {
1736                    Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])),
1737                    Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])),
1738                    Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])),
1739                })
1740            }
1741
1742            // bits
1743            crate::MathFunction::CountTrailingZeros => {
1744                component_wise_concrete_int!(self, span, [arg], |e| {
1745                    #[allow(clippy::useless_conversion)]
1746                    Ok([e
1747                        .trailing_zeros()
1748                        .try_into()
1749                        .expect("bit count overflowed 32 bits, somehow!?")])
1750                })
1751            }
1752            crate::MathFunction::CountLeadingZeros => {
1753                component_wise_concrete_int!(self, span, [arg], |e| {
1754                    #[allow(clippy::useless_conversion)]
1755                    Ok([e
1756                        .leading_zeros()
1757                        .try_into()
1758                        .expect("bit count overflowed 32 bits, somehow!?")])
1759                })
1760            }
1761            crate::MathFunction::CountOneBits => {
1762                component_wise_concrete_int!(self, span, [arg], |e| {
1763                    #[allow(clippy::useless_conversion)]
1764                    Ok([e
1765                        .count_ones()
1766                        .try_into()
1767                        .expect("bit count overflowed 32 bits, somehow!?")])
1768                })
1769            }
1770            crate::MathFunction::ReverseBits => {
1771                component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
1772            }
1773            crate::MathFunction::FirstTrailingBit => {
1774                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_trailing_bit(ci)))
1775            }
1776            crate::MathFunction::FirstLeadingBit => {
1777                component_wise_concrete_int(self, span, [arg], |ci| Ok(first_leading_bit(ci)))
1778            }
1779
1780            // vector
1781            crate::MathFunction::Dot4I8Packed => {
1782                self.packed_dot_product(arg, arg1.unwrap(), span, true)
1783            }
1784            crate::MathFunction::Dot4U8Packed => {
1785                self.packed_dot_product(arg, arg1.unwrap(), span, false)
1786            }
1787            crate::MathFunction::Cross => self.cross_product(arg, arg1.unwrap(), span),
1788            crate::MathFunction::Dot => {
1789                // https://www.w3.org/TR/WGSL/#dot-builtin
1790                let e1 = self.extract_vec(arg, false)?;
1791                let e2 = self.extract_vec(arg1.unwrap(), false)?;
1792                if e1.len() != e2.len() {
1793                    return Err(ConstantEvaluatorError::InvalidMathArg);
1794                }
1795
1796                fn float_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1797                where
1798                    P: num_traits::Float,
1799                {
1800                    let result = a
1801                        .iter()
1802                        .zip(b.iter())
1803                        .map(|(&aa, &bb)| aa * bb)
1804                        .fold(P::zero(), |acc, x| acc + x);
1805                    if result.is_finite() {
1806                        Ok(result)
1807                    } else {
1808                        Err(ConstantEvaluatorError::Overflow("in dot built-in".into()))
1809                    }
1810                }
1811
1812                fn int_dot_checked<P>(a: &[P], b: &[P]) -> Result<P, ConstantEvaluatorError>
1813                where
1814                    P: num_traits::PrimInt + num_traits::CheckedAdd + num_traits::CheckedMul,
1815                {
1816                    a.iter()
1817                        .zip(b.iter())
1818                        .map(|(&aa, bb)| aa.checked_mul(bb))
1819                        .try_fold(P::zero(), |acc, x| {
1820                            if let Some(x) = x {
1821                                acc.checked_add(&x)
1822                            } else {
1823                                None
1824                            }
1825                        })
1826                        .ok_or(ConstantEvaluatorError::Overflow(
1827                            "in dot built-in".to_string(),
1828                        ))
1829                }
1830
1831                fn int_dot_wrapping<P>(a: &[P], b: &[P]) -> P
1832                where
1833                    P: num_traits::PrimInt + num_traits::WrappingAdd + num_traits::WrappingMul,
1834                {
1835                    a.iter()
1836                        .zip(b.iter())
1837                        .map(|(&aa, bb)| aa.wrapping_mul(bb))
1838                        .fold(P::zero(), |acc, x| acc.wrapping_add(&x))
1839                }
1840
1841                let result = match_literal_vector!(match (e1, e2) => Literal {
1842                    Float => |e1, e2| { float_dot_checked(e1, e2)? },
1843                    AbstractInt => |e1, e2 | { int_dot_checked(e1, e2)? },
1844                    I32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1845                    U32 => |e1, e2| { int_dot_wrapping(e1, e2) },
1846                })?;
1847                self.register_evaluated_expr(Expression::Literal(result), span)
1848            }
1849            crate::MathFunction::Length => {
1850                // https://www.w3.org/TR/WGSL/#length-builtin
1851                let e1 = self.extract_vec(arg, true)?;
1852
1853                fn float_length<F>(e: &[F]) -> F
1854                where
1855                    F: core::ops::Mul<F>,
1856                    F: num_traits::Float + iter::Sum,
1857                {
1858                    if e.len() == 1 {
1859                        // Avoids possible overflow in squaring
1860                        e[0].abs()
1861                    } else {
1862                        e.iter().map(|&ei| ei * ei).sum::<F>().sqrt()
1863                    }
1864                }
1865
1866                let result = match_literal_vector!(match e1 => Literal {
1867                    Float => |e1| { float_length(e1) },
1868                })?;
1869                self.register_evaluated_expr(Expression::Literal(result), span)
1870            }
1871            crate::MathFunction::Distance => {
1872                // https://www.w3.org/TR/WGSL/#distance-builtin
1873                let e1 = self.extract_vec(arg, true)?;
1874                let e2 = self.extract_vec(arg1.unwrap(), true)?;
1875                if e1.len() != e2.len() {
1876                    return Err(ConstantEvaluatorError::InvalidMathArg);
1877                }
1878
1879                fn float_distance<F>(a: &[F], b: &[F]) -> F
1880                where
1881                    F: core::ops::Mul<F>,
1882                    F: num_traits::Float + iter::Sum + core::ops::Sub,
1883                {
1884                    if a.len() == 1 {
1885                        // Avoids possible overflow in squaring
1886                        (a[0] - b[0]).abs()
1887                    } else {
1888                        a.iter()
1889                            .zip(b.iter())
1890                            .map(|(&aa, &bb)| aa - bb)
1891                            .map(|ei| ei * ei)
1892                            .sum::<F>()
1893                            .sqrt()
1894                    }
1895                }
1896                let result = match_literal_vector!(match (e1, e2) => Literal {
1897                    Float => |e1, e2| { float_distance(e1, e2) },
1898                })?;
1899                self.register_evaluated_expr(Expression::Literal(result), span)
1900            }
1901            crate::MathFunction::Normalize => {
1902                // https://www.w3.org/TR/WGSL/#normalize-builtin
1903                let e1 = self.extract_vec(arg, true)?;
1904
1905                fn float_normalize<F>(e: &[F]) -> ArrayVec<F, { crate::VectorSize::MAX }>
1906                where
1907                    F: core::ops::Mul<F>,
1908                    F: num_traits::Float + iter::Sum,
1909                {
1910                    let len = e.iter().map(|&ei| ei * ei).sum::<F>().sqrt();
1911                    let mut out = ArrayVec::new();
1912                    for &ei in e {
1913                        out.push(ei / len);
1914                    }
1915                    out
1916                }
1917
1918                let result = match_literal_vector!(match e1 => LiteralVector {
1919                    Float => |e1| { float_normalize(e1) },
1920                })?;
1921                result.register_as_evaluated_expr(self, span)
1922            }
1923
1924            // unimplemented
1925            crate::MathFunction::Modf
1926            | crate::MathFunction::Frexp
1927            | crate::MathFunction::Ldexp
1928            | crate::MathFunction::Outer
1929            | crate::MathFunction::FaceForward
1930            | crate::MathFunction::Reflect
1931            | crate::MathFunction::Refract
1932            | crate::MathFunction::Mix
1933            | crate::MathFunction::SmoothStep
1934            | crate::MathFunction::Inverse
1935            | crate::MathFunction::Transpose
1936            | crate::MathFunction::Determinant
1937            | crate::MathFunction::QuantizeToF16
1938            | crate::MathFunction::ExtractBits
1939            | crate::MathFunction::InsertBits
1940            | crate::MathFunction::Pack4x8snorm
1941            | crate::MathFunction::Pack4x8unorm
1942            | crate::MathFunction::Pack2x16snorm
1943            | crate::MathFunction::Pack2x16unorm
1944            | crate::MathFunction::Pack2x16float
1945            | crate::MathFunction::Pack4xI8
1946            | crate::MathFunction::Pack4xU8
1947            | crate::MathFunction::Pack4xI8Clamp
1948            | crate::MathFunction::Pack4xU8Clamp
1949            | crate::MathFunction::Unpack4x8snorm
1950            | crate::MathFunction::Unpack4x8unorm
1951            | crate::MathFunction::Unpack2x16snorm
1952            | crate::MathFunction::Unpack2x16unorm
1953            | crate::MathFunction::Unpack2x16float
1954            | crate::MathFunction::Unpack4xI8
1955            | crate::MathFunction::Unpack4xU8 => Err(ConstantEvaluatorError::NotImplemented(
1956                format!("{fun:?} built-in function"),
1957            )),
1958        }
1959    }
1960
1961    /// Dot product of two packed vectors (`dot4I8Packed` and `dot4U8Packed`)
1962    fn packed_dot_product(
1963        &mut self,
1964        a: Handle<Expression>,
1965        b: Handle<Expression>,
1966        span: Span,
1967        signed: bool,
1968    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
1969        let Expression::Literal(Literal::U32(a)) = self.expressions[a] else {
1970            return Err(ConstantEvaluatorError::InvalidMathArg);
1971        };
1972        let Expression::Literal(Literal::U32(b)) = self.expressions[b] else {
1973            return Err(ConstantEvaluatorError::InvalidMathArg);
1974        };
1975
1976        let result = if signed {
1977            Literal::I32(
1978                (a & 0xFF) as i8 as i32 * (b & 0xFF) as i8 as i32
1979                    + ((a >> 8) & 0xFF) as i8 as i32 * ((b >> 8) & 0xFF) as i8 as i32
1980                    + ((a >> 16) & 0xFF) as i8 as i32 * ((b >> 16) & 0xFF) as i8 as i32
1981                    + ((a >> 24) & 0xFF) as i8 as i32 * ((b >> 24) & 0xFF) as i8 as i32,
1982            )
1983        } else {
1984            Literal::U32(
1985                (a & 0xFF) * (b & 0xFF)
1986                    + ((a >> 8) & 0xFF) * ((b >> 8) & 0xFF)
1987                    + ((a >> 16) & 0xFF) * ((b >> 16) & 0xFF)
1988                    + ((a >> 24) & 0xFF) * ((b >> 24) & 0xFF),
1989            )
1990        };
1991
1992        self.register_evaluated_expr(Expression::Literal(result), span)
1993    }
1994
1995    /// Vector cross product.
1996    fn cross_product(
1997        &mut self,
1998        a: Handle<Expression>,
1999        b: Handle<Expression>,
2000        span: Span,
2001    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2002        use Literal as Li;
2003
2004        let (a, ty) = self.extract_vec_with_size::<3>(a)?;
2005        let (b, _) = self.extract_vec_with_size::<3>(b)?;
2006
2007        let product = match (a, b) {
2008            (
2009                [Li::AbstractInt(a0), Li::AbstractInt(a1), Li::AbstractInt(a2)],
2010                [Li::AbstractInt(b0), Li::AbstractInt(b1), Li::AbstractInt(b2)],
2011            ) => {
2012                // `cross` has no overload for AbstractInt, so AbstractInt
2013                // arguments are automatically converted to AbstractFloat. Since
2014                // `f64` has a much wider range than `i64`, there's no danger of
2015                // overflow here.
2016                let p = cross_product(
2017                    [a0 as f64, a1 as f64, a2 as f64],
2018                    [b0 as f64, b1 as f64, b2 as f64],
2019                );
2020                [
2021                    Li::AbstractFloat(p[0]),
2022                    Li::AbstractFloat(p[1]),
2023                    Li::AbstractFloat(p[2]),
2024                ]
2025            }
2026            (
2027                [Li::AbstractFloat(a0), Li::AbstractFloat(a1), Li::AbstractFloat(a2)],
2028                [Li::AbstractFloat(b0), Li::AbstractFloat(b1), Li::AbstractFloat(b2)],
2029            ) => {
2030                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2031                [
2032                    Li::AbstractFloat(p[0]),
2033                    Li::AbstractFloat(p[1]),
2034                    Li::AbstractFloat(p[2]),
2035                ]
2036            }
2037            ([Li::F16(a0), Li::F16(a1), Li::F16(a2)], [Li::F16(b0), Li::F16(b1), Li::F16(b2)]) => {
2038                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2039                [Li::F16(p[0]), Li::F16(p[1]), Li::F16(p[2])]
2040            }
2041            ([Li::F32(a0), Li::F32(a1), Li::F32(a2)], [Li::F32(b0), Li::F32(b1), Li::F32(b2)]) => {
2042                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2043                [Li::F32(p[0]), Li::F32(p[1]), Li::F32(p[2])]
2044            }
2045            ([Li::F64(a0), Li::F64(a1), Li::F64(a2)], [Li::F64(b0), Li::F64(b1), Li::F64(b2)]) => {
2046                let p = cross_product([a0, a1, a2], [b0, b1, b2]);
2047                [Li::F64(p[0]), Li::F64(p[1]), Li::F64(p[2])]
2048            }
2049            _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2050        };
2051
2052        let p0 = self.register_evaluated_expr(Expression::Literal(product[0]), span)?;
2053        let p1 = self.register_evaluated_expr(Expression::Literal(product[1]), span)?;
2054        let p2 = self.register_evaluated_expr(Expression::Literal(product[2]), span)?;
2055
2056        self.register_evaluated_expr(
2057            Expression::Compose {
2058                ty,
2059                components: vec![p0, p1, p2],
2060            },
2061            span,
2062        )
2063    }
2064
2065    /// Extract the values of a `vecN` from `expr`.
2066    ///
2067    /// Return the value of `expr`, whose type is `vecN<S>` for some
2068    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2069    /// values.
2070    ///
2071    /// Also return the type handle from the `Compose` expression.
2072    fn extract_vec_with_size<const N: usize>(
2073        &mut self,
2074        expr: Handle<Expression>,
2075    ) -> Result<([Literal; N], Handle<Type>), ConstantEvaluatorError> {
2076        let span = self.expressions.get_span(expr);
2077        let expr = self.eval_zero_value_and_splat(expr, span)?;
2078        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2079            return Err(ConstantEvaluatorError::InvalidMathArg);
2080        };
2081
2082        let mut value = [Literal::Bool(false); N];
2083        for (component, elt) in
2084            crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2085                .zip(value.iter_mut())
2086        {
2087            let Expression::Literal(literal) = self.expressions[component] else {
2088                return Err(ConstantEvaluatorError::InvalidMathArg);
2089            };
2090            *elt = literal;
2091        }
2092
2093        Ok((value, ty))
2094    }
2095
2096    /// Extract the values of a `vecN` from `expr`.
2097    ///
2098    /// Return the value of `expr`, whose type is `vecN<S>` for some
2099    /// vector size `N` and scalar `S`, as an array of `N` [`Literal`]
2100    /// values.
2101    ///
2102    /// Also return the type handle from the `Compose` expression.
2103    fn extract_vec(
2104        &mut self,
2105        expr: Handle<Expression>,
2106        allow_single: bool,
2107    ) -> Result<LiteralVector, ConstantEvaluatorError> {
2108        let span = self.expressions.get_span(expr);
2109        let expr = self.eval_zero_value_and_splat(expr, span)?;
2110
2111        match self.expressions[expr] {
2112            Expression::Literal(literal) if allow_single => {
2113                Ok(LiteralVector::from_literal(literal))
2114            }
2115            Expression::Compose { ty, ref components } => {
2116                let mut components_out = ArrayVec::<Literal, { crate::VectorSize::MAX }>::new();
2117                for expr in
2118                    crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2119                {
2120                    match self.expressions[expr] {
2121                        Expression::Literal(l) => components_out.push(l),
2122                        _ => return Err(ConstantEvaluatorError::InvalidMathArg),
2123                    }
2124                }
2125                LiteralVector::from_literal_vec(components_out)
2126            }
2127            _ => Err(ConstantEvaluatorError::InvalidMathArg),
2128        }
2129    }
2130
2131    fn array_length(
2132        &mut self,
2133        array: Handle<Expression>,
2134        span: Span,
2135    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2136        match self.expressions[array] {
2137            Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
2138                match self.types[ty].inner {
2139                    TypeInner::Array { size, .. } => match size {
2140                        ArraySize::Constant(len) => {
2141                            let expr = Expression::Literal(Literal::U32(len.get()));
2142                            self.register_evaluated_expr(expr, span)
2143                        }
2144                        ArraySize::Pending(_) => Err(ConstantEvaluatorError::ArrayLengthOverridden),
2145                        ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
2146                    },
2147                    _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2148                }
2149            }
2150            _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
2151        }
2152    }
2153
2154    fn access(
2155        &mut self,
2156        base: Handle<Expression>,
2157        index: usize,
2158        span: Span,
2159    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2160        match self.expressions[base] {
2161            Expression::ZeroValue(ty) => {
2162                let ty_inner = &self.types[ty].inner;
2163                let components = ty_inner
2164                    .components()
2165                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2166
2167                if index >= components as usize {
2168                    Err(ConstantEvaluatorError::InvalidAccessBase)
2169                } else {
2170                    let ty_res = ty_inner
2171                        .component_type(index)
2172                        .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
2173                    let ty = match ty_res {
2174                        crate::proc::TypeResolution::Handle(ty) => ty,
2175                        crate::proc::TypeResolution::Value(inner) => {
2176                            self.types.insert(Type { name: None, inner }, span)
2177                        }
2178                    };
2179                    self.register_evaluated_expr(Expression::ZeroValue(ty), span)
2180                }
2181            }
2182            Expression::Splat { size, value } => {
2183                if index >= size as usize {
2184                    Err(ConstantEvaluatorError::InvalidAccessBase)
2185                } else {
2186                    Ok(value)
2187                }
2188            }
2189            Expression::Compose { ty, ref components } => {
2190                let _ = self.types[ty]
2191                    .inner
2192                    .components()
2193                    .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
2194
2195                crate::proc::flatten_compose(ty, components, self.expressions, self.types)
2196                    .nth(index)
2197                    .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
2198            }
2199            _ => Err(ConstantEvaluatorError::InvalidAccessBase),
2200        }
2201    }
2202
2203    /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
2204    ///
2205    /// [`ZeroValue`]: Expression::ZeroValue
2206    /// [`Splat`]: Expression::Splat
2207    /// [`Literal`]: Expression::Literal
2208    /// [`Compose`]: Expression::Compose
2209    fn eval_zero_value_and_splat(
2210        &mut self,
2211        mut expr: Handle<Expression>,
2212        span: Span,
2213    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2214        // If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
2215        // each of its components.
2216        if let Expression::Compose { ty, ref components } = self.expressions[expr] {
2217            let components = components
2218                .clone()
2219                .iter()
2220                .map(|component| self.eval_zero_value_and_splat(*component, span))
2221                .collect::<Result<_, _>>()?;
2222            expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
2223        }
2224
2225        // The result of the splat() for a Splat of a scalar ZeroValue is a
2226        // vector ZeroValue, so we must call eval_zero_value_impl() after
2227        // splat() in order to ensure we have no ZeroValues remaining.
2228        if let Expression::Splat { size, value } = self.expressions[expr] {
2229            expr = self.splat(value, size, span)?;
2230        }
2231        if let Expression::ZeroValue(ty) = self.expressions[expr] {
2232            expr = self.eval_zero_value_impl(ty, span)?;
2233        }
2234        Ok(expr)
2235    }
2236
2237    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2238    ///
2239    /// [`ZeroValue`]: Expression::ZeroValue
2240    /// [`Literal`]: Expression::Literal
2241    /// [`Compose`]: Expression::Compose
2242    fn eval_zero_value(
2243        &mut self,
2244        expr: Handle<Expression>,
2245        span: Span,
2246    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2247        match self.expressions[expr] {
2248            Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
2249            _ => Ok(expr),
2250        }
2251    }
2252
2253    /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
2254    ///
2255    /// [`ZeroValue`]: Expression::ZeroValue
2256    /// [`Literal`]: Expression::Literal
2257    /// [`Compose`]: Expression::Compose
2258    fn eval_zero_value_impl(
2259        &mut self,
2260        ty: Handle<Type>,
2261        span: Span,
2262    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2263        match self.types[ty].inner {
2264            TypeInner::Scalar(scalar) => {
2265                let expr = Expression::Literal(
2266                    Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
2267                );
2268                self.register_evaluated_expr(expr, span)
2269            }
2270            TypeInner::Vector { size, scalar } => {
2271                let scalar_ty = self.types.insert(
2272                    Type {
2273                        name: None,
2274                        inner: TypeInner::Scalar(scalar),
2275                    },
2276                    span,
2277                );
2278                let el = self.eval_zero_value_impl(scalar_ty, span)?;
2279                let expr = Expression::Compose {
2280                    ty,
2281                    components: vec![el; size as usize],
2282                };
2283                self.register_evaluated_expr(expr, span)
2284            }
2285            TypeInner::Matrix {
2286                columns,
2287                rows,
2288                scalar,
2289            } => {
2290                let vec_ty = self.types.insert(
2291                    Type {
2292                        name: None,
2293                        inner: TypeInner::Vector { size: rows, scalar },
2294                    },
2295                    span,
2296                );
2297                let el = self.eval_zero_value_impl(vec_ty, span)?;
2298                let expr = Expression::Compose {
2299                    ty,
2300                    components: vec![el; columns as usize],
2301                };
2302                self.register_evaluated_expr(expr, span)
2303            }
2304            TypeInner::Array {
2305                base,
2306                size: ArraySize::Constant(size),
2307                ..
2308            } => {
2309                let el = self.eval_zero_value_impl(base, span)?;
2310                let expr = Expression::Compose {
2311                    ty,
2312                    components: vec![el; size.get() as usize],
2313                };
2314                self.register_evaluated_expr(expr, span)
2315            }
2316            TypeInner::Struct { ref members, .. } => {
2317                let types: Vec<_> = members.iter().map(|m| m.ty).collect();
2318                let mut components = Vec::with_capacity(members.len());
2319                for ty in types {
2320                    components.push(self.eval_zero_value_impl(ty, span)?);
2321                }
2322                let expr = Expression::Compose { ty, components };
2323                self.register_evaluated_expr(expr, span)
2324            }
2325            _ => Err(ConstantEvaluatorError::TypeNotConstructible),
2326        }
2327    }
2328
2329    /// Convert the scalar components of `expr` to `target`.
2330    ///
2331    /// Treat `span` as the location of the resulting expression.
2332    pub fn cast(
2333        &mut self,
2334        expr: Handle<Expression>,
2335        target: crate::Scalar,
2336        span: Span,
2337    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2338        use crate::Scalar as Sc;
2339
2340        let expr = self.eval_zero_value(expr, span)?;
2341
2342        let make_error = || -> Result<_, ConstantEvaluatorError> {
2343            let from = format!("{:?} {:?}", expr, self.expressions[expr]);
2344
2345            #[cfg(feature = "wgsl-in")]
2346            let to = target.to_wgsl_for_diagnostics();
2347
2348            #[cfg(not(feature = "wgsl-in"))]
2349            let to = format!("{target:?}");
2350
2351            Err(ConstantEvaluatorError::InvalidCastArg { from, to })
2352        };
2353
2354        use crate::proc::type_methods::IntFloatLimits;
2355
2356        let expr = match self.expressions[expr] {
2357            Expression::Literal(literal) => {
2358                let literal = match target {
2359                    Sc::I32 => Literal::I32(match literal {
2360                        Literal::I32(v) => v,
2361                        Literal::U32(v) => v as i32,
2362                        Literal::F32(v) => v.clamp(i32::min_float(), i32::max_float()) as i32,
2363                        Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf
2364                        Literal::Bool(v) => v as i32,
2365                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2366                            return make_error();
2367                        }
2368                        Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
2369                        Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
2370                    }),
2371                    Sc::U32 => Literal::U32(match literal {
2372                        Literal::I32(v) => v as u32,
2373                        Literal::U32(v) => v,
2374                        Literal::F32(v) => v.clamp(u32::min_float(), u32::max_float()) as u32,
2375                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2376                        Literal::F16(v) => f16::to_u32(&v.max(f16::ZERO)).unwrap(),
2377                        Literal::Bool(v) => v as u32,
2378                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2379                            return make_error();
2380                        }
2381                        Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
2382                        Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
2383                    }),
2384                    Sc::I64 => Literal::I64(match literal {
2385                        Literal::I32(v) => v as i64,
2386                        Literal::U32(v) => v as i64,
2387                        Literal::F32(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2388                        Literal::Bool(v) => v as i64,
2389                        Literal::F64(v) => v.clamp(i64::min_float(), i64::max_float()) as i64,
2390                        Literal::I64(v) => v,
2391                        Literal::U64(v) => v as i64,
2392                        Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf
2393                        Literal::AbstractInt(v) => i64::try_from_abstract(v)?,
2394                        Literal::AbstractFloat(v) => i64::try_from_abstract(v)?,
2395                    }),
2396                    Sc::U64 => Literal::U64(match literal {
2397                        Literal::I32(v) => v as u64,
2398                        Literal::U32(v) => v as u64,
2399                        Literal::F32(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2400                        Literal::Bool(v) => v as u64,
2401                        Literal::F64(v) => v.clamp(u64::min_float(), u64::max_float()) as u64,
2402                        Literal::I64(v) => v as u64,
2403                        Literal::U64(v) => v,
2404                        // max(0) avoids None due to negative, therefore only None on NaN or Inf
2405                        Literal::F16(v) => f16::to_u64(&v.max(f16::ZERO)).unwrap(),
2406                        Literal::AbstractInt(v) => u64::try_from_abstract(v)?,
2407                        Literal::AbstractFloat(v) => u64::try_from_abstract(v)?,
2408                    }),
2409                    Sc::F16 => Literal::F16(match literal {
2410                        Literal::F16(v) => v,
2411                        Literal::F32(v) => f16::from_f32(v),
2412                        Literal::F64(v) => f16::from_f64(v),
2413                        Literal::Bool(v) => f16::from_u32(v as u32).unwrap(),
2414                        Literal::I64(v) => f16::from_i64(v).unwrap(),
2415                        Literal::U64(v) => f16::from_u64(v).unwrap(),
2416                        Literal::I32(v) => f16::from_i32(v).unwrap(),
2417                        Literal::U32(v) => f16::from_u32(v).unwrap(),
2418                        Literal::AbstractFloat(v) => f16::try_from_abstract(v)?,
2419                        Literal::AbstractInt(v) => f16::try_from_abstract(v)?,
2420                    }),
2421                    Sc::F32 => Literal::F32(match literal {
2422                        Literal::I32(v) => v as f32,
2423                        Literal::U32(v) => v as f32,
2424                        Literal::F32(v) => v,
2425                        Literal::Bool(v) => v as u32 as f32,
2426                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2427                            return make_error();
2428                        }
2429                        Literal::F16(v) => f16::to_f32(v),
2430                        Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
2431                        Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
2432                    }),
2433                    Sc::F64 => Literal::F64(match literal {
2434                        Literal::I32(v) => v as f64,
2435                        Literal::U32(v) => v as f64,
2436                        Literal::F16(v) => f16::to_f64(v),
2437                        Literal::F32(v) => v as f64,
2438                        Literal::F64(v) => v,
2439                        Literal::Bool(v) => v as u32 as f64,
2440                        Literal::I64(_) | Literal::U64(_) => return make_error(),
2441                        Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
2442                        Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
2443                    }),
2444                    Sc::BOOL => Literal::Bool(match literal {
2445                        Literal::I32(v) => v != 0,
2446                        Literal::U32(v) => v != 0,
2447                        Literal::F32(v) => v != 0.0,
2448                        Literal::F16(v) => v != f16::zero(),
2449                        Literal::Bool(v) => v,
2450                        Literal::AbstractInt(v) => v != 0,
2451                        Literal::AbstractFloat(v) => v != 0.0,
2452                        Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => {
2453                            return make_error();
2454                        }
2455                    }),
2456                    Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
2457                        Literal::AbstractInt(v) => {
2458                            // Overflow is forbidden, but inexact conversions
2459                            // are fine. The range of f64 is far larger than
2460                            // that of i64, so we don't have to check anything
2461                            // here.
2462                            v as f64
2463                        }
2464                        Literal::AbstractFloat(v) => v,
2465                        _ => return make_error(),
2466                    }),
2467                    Sc::ABSTRACT_INT => Literal::AbstractInt(match literal {
2468                        Literal::AbstractInt(v) => v,
2469                        _ => return make_error(),
2470                    }),
2471                    _ => {
2472                        log::debug!("Constant evaluator refused to convert value to {target:?}");
2473                        return make_error();
2474                    }
2475                };
2476                Expression::Literal(literal)
2477            }
2478            Expression::Compose {
2479                ty,
2480                components: ref src_components,
2481            } => {
2482                let ty_inner = match self.types[ty].inner {
2483                    TypeInner::Vector { size, .. } => TypeInner::Vector {
2484                        size,
2485                        scalar: target,
2486                    },
2487                    TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
2488                        columns,
2489                        rows,
2490                        scalar: target,
2491                    },
2492                    _ => return make_error(),
2493                };
2494
2495                let mut components = src_components.clone();
2496                for component in &mut components {
2497                    *component = self.cast(*component, target, span)?;
2498                }
2499
2500                let ty = self.types.insert(
2501                    Type {
2502                        name: None,
2503                        inner: ty_inner,
2504                    },
2505                    span,
2506                );
2507
2508                Expression::Compose { ty, components }
2509            }
2510            Expression::Splat { size, value } => {
2511                let value_span = self.expressions.get_span(value);
2512                let cast_value = self.cast(value, target, value_span)?;
2513                Expression::Splat {
2514                    size,
2515                    value: cast_value,
2516                }
2517            }
2518            _ => return make_error(),
2519        };
2520
2521        self.register_evaluated_expr(expr, span)
2522    }
2523
2524    /// Convert the scalar leaves of  `expr` to `target`, handling arrays.
2525    ///
2526    /// `expr` must be a `Compose` expression whose type is a scalar, vector,
2527    /// matrix, or nested arrays of such.
2528    ///
2529    /// This is basically the same as the [`cast`] method, except that that
2530    /// should only handle Naga [`As`] expressions, which cannot convert arrays.
2531    ///
2532    /// Treat `span` as the location of the resulting expression.
2533    ///
2534    /// [`cast`]: ConstantEvaluator::cast
2535    /// [`As`]: crate::Expression::As
2536    pub fn cast_array(
2537        &mut self,
2538        expr: Handle<Expression>,
2539        target: crate::Scalar,
2540        span: Span,
2541    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2542        let expr = self.check_and_get(expr)?;
2543
2544        let Expression::Compose { ty, ref components } = self.expressions[expr] else {
2545            return self.cast(expr, target, span);
2546        };
2547
2548        let TypeInner::Array {
2549            base: _,
2550            size,
2551            stride: _,
2552        } = self.types[ty].inner
2553        else {
2554            return self.cast(expr, target, span);
2555        };
2556
2557        let mut components = components.clone();
2558        for component in &mut components {
2559            *component = self.cast_array(*component, target, span)?;
2560        }
2561
2562        let first = components.first().unwrap();
2563        let new_base = match self.resolve_type(*first)? {
2564            crate::proc::TypeResolution::Handle(ty) => ty,
2565            crate::proc::TypeResolution::Value(inner) => {
2566                self.types.insert(Type { name: None, inner }, span)
2567            }
2568        };
2569        let mut layouter = core::mem::take(self.layouter);
2570        layouter.update(self.to_ctx()).unwrap();
2571        *self.layouter = layouter;
2572
2573        let new_base_stride = self.layouter[new_base].to_stride();
2574        let new_array_ty = self.types.insert(
2575            Type {
2576                name: None,
2577                inner: TypeInner::Array {
2578                    base: new_base,
2579                    size,
2580                    stride: new_base_stride,
2581                },
2582            },
2583            span,
2584        );
2585
2586        let compose = Expression::Compose {
2587            ty: new_array_ty,
2588            components,
2589        };
2590        self.register_evaluated_expr(compose, span)
2591    }
2592
2593    fn unary_op(
2594        &mut self,
2595        op: UnaryOperator,
2596        expr: Handle<Expression>,
2597        span: Span,
2598    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2599        let expr = self.eval_zero_value_and_splat(expr, span)?;
2600
2601        let expr = match self.expressions[expr] {
2602            Expression::Literal(value) => Expression::Literal(match op {
2603                UnaryOperator::Negate => match value {
2604                    Literal::I32(v) => Literal::I32(v.wrapping_neg()),
2605                    Literal::I64(v) => Literal::I64(v.wrapping_neg()),
2606                    Literal::F32(v) => Literal::F32(-v),
2607                    Literal::F16(v) => Literal::F16(-v),
2608                    Literal::F64(v) => Literal::F64(-v),
2609                    Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
2610                    Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
2611                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2612                },
2613                UnaryOperator::LogicalNot => match value {
2614                    Literal::Bool(v) => Literal::Bool(!v),
2615                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2616                },
2617                UnaryOperator::BitwiseNot => match value {
2618                    Literal::I32(v) => Literal::I32(!v),
2619                    Literal::I64(v) => Literal::I64(!v),
2620                    Literal::U32(v) => Literal::U32(!v),
2621                    Literal::U64(v) => Literal::U64(!v),
2622                    Literal::AbstractInt(v) => Literal::AbstractInt(!v),
2623                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2624                },
2625            }),
2626            Expression::Compose {
2627                ty,
2628                components: ref src_components,
2629            } => {
2630                match self.types[ty].inner {
2631                    TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
2632                    _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2633                }
2634
2635                let mut components = src_components.clone();
2636                for component in &mut components {
2637                    *component = self.unary_op(op, *component, span)?;
2638                }
2639
2640                Expression::Compose { ty, components }
2641            }
2642            _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
2643        };
2644
2645        self.register_evaluated_expr(expr, span)
2646    }
2647
2648    fn binary_op(
2649        &mut self,
2650        op: BinaryOperator,
2651        left: Handle<Expression>,
2652        right: Handle<Expression>,
2653        span: Span,
2654    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
2655        let left = self.eval_zero_value_and_splat(left, span)?;
2656        let right = self.eval_zero_value_and_splat(right, span)?;
2657
2658        // Note: in most cases constant evaluation checks for overflow, but for
2659        // i32/u32, it uses wrapping arithmetic. See
2660        // <https://gpuweb.github.io/gpuweb/wgsl/#integer-types>.
2661
2662        let expr = match (&self.expressions[left], &self.expressions[right]) {
2663            (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
2664                if !matches!(op, BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight)
2665                    && core::mem::discriminant(&left_value) != core::mem::discriminant(&right_value)
2666                {
2667                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2668                }
2669
2670                let literal = match op {
2671                    BinaryOperator::Equal => Literal::Bool(left_value == right_value),
2672                    BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
2673                    BinaryOperator::Less => Literal::Bool(left_value < right_value),
2674                    BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
2675                    BinaryOperator::Greater => Literal::Bool(left_value > right_value),
2676                    BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
2677
2678                    _ => match (left_value, right_value) {
2679                        (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
2680                            BinaryOperator::Add => a.wrapping_add(b),
2681                            BinaryOperator::Subtract => a.wrapping_sub(b),
2682                            BinaryOperator::Multiply => a.wrapping_mul(b),
2683                            BinaryOperator::Divide => {
2684                                if b == 0 {
2685                                    return Err(ConstantEvaluatorError::DivisionByZero);
2686                                } else {
2687                                    a.wrapping_div(b)
2688                                }
2689                            }
2690                            BinaryOperator::Modulo => {
2691                                if b == 0 {
2692                                    return Err(ConstantEvaluatorError::RemainderByZero);
2693                                } else {
2694                                    a.wrapping_rem(b)
2695                                }
2696                            }
2697                            BinaryOperator::And => a & b,
2698                            BinaryOperator::ExclusiveOr => a ^ b,
2699                            BinaryOperator::InclusiveOr => a | b,
2700                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2701                        }),
2702                        (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
2703                            BinaryOperator::ShiftLeft => {
2704                                if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2705                                    return Err(ConstantEvaluatorError::Overflow("<<".to_string()));
2706                                }
2707                                a.checked_shl(b)
2708                                    .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?
2709                            }
2710                            BinaryOperator::ShiftRight => a
2711                                .checked_shr(b)
2712                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2713                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2714                        }),
2715                        (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
2716                            BinaryOperator::Add => a.wrapping_add(b),
2717                            BinaryOperator::Subtract => a.wrapping_sub(b),
2718                            BinaryOperator::Multiply => a.wrapping_mul(b),
2719                            BinaryOperator::Divide => a
2720                                .checked_div(b)
2721                                .ok_or(ConstantEvaluatorError::DivisionByZero)?,
2722                            BinaryOperator::Modulo => a
2723                                .checked_rem(b)
2724                                .ok_or(ConstantEvaluatorError::RemainderByZero)?,
2725                            BinaryOperator::And => a & b,
2726                            BinaryOperator::ExclusiveOr => a ^ b,
2727                            BinaryOperator::InclusiveOr => a | b,
2728                            BinaryOperator::ShiftLeft => a
2729                                .checked_mul(
2730                                    1u32.checked_shl(b)
2731                                        .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2732                                )
2733                                .ok_or(ConstantEvaluatorError::Overflow("<<".to_string()))?,
2734                            BinaryOperator::ShiftRight => a
2735                                .checked_shr(b)
2736                                .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
2737                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2738                        }),
2739                        (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
2740                            BinaryOperator::Add => a + b,
2741                            BinaryOperator::Subtract => a - b,
2742                            BinaryOperator::Multiply => a * b,
2743                            BinaryOperator::Divide => a / b,
2744                            BinaryOperator::Modulo => a % b,
2745                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2746                        }),
2747                        (Literal::AbstractInt(a), Literal::U32(b)) => {
2748                            Literal::AbstractInt(match op {
2749                                BinaryOperator::ShiftLeft => {
2750                                    if (if a.is_negative() { !a } else { a }).leading_zeros() <= b {
2751                                        return Err(ConstantEvaluatorError::Overflow(
2752                                            "<<".to_string(),
2753                                        ));
2754                                    }
2755                                    a.checked_shl(b).unwrap_or(0)
2756                                }
2757                                BinaryOperator::ShiftRight => a.checked_shr(b).unwrap_or(0),
2758                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2759                            })
2760                        }
2761                        (Literal::F16(a), Literal::F16(b)) => {
2762                            let result = match op {
2763                                BinaryOperator::Add => a + b,
2764                                BinaryOperator::Subtract => a - b,
2765                                BinaryOperator::Multiply => a * b,
2766                                BinaryOperator::Divide => a / b,
2767                                BinaryOperator::Modulo => a % b,
2768                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2769                            };
2770                            if !result.is_finite() {
2771                                return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2772                            }
2773                            Literal::F16(result)
2774                        }
2775                        (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
2776                            Literal::AbstractInt(match op {
2777                                BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
2778                                    ConstantEvaluatorError::Overflow("addition".into())
2779                                })?,
2780                                BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
2781                                    ConstantEvaluatorError::Overflow("subtraction".into())
2782                                })?,
2783                                BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
2784                                    ConstantEvaluatorError::Overflow("multiplication".into())
2785                                })?,
2786                                BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
2787                                    if b == 0 {
2788                                        ConstantEvaluatorError::DivisionByZero
2789                                    } else {
2790                                        ConstantEvaluatorError::Overflow("division".into())
2791                                    }
2792                                })?,
2793                                BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
2794                                    if b == 0 {
2795                                        ConstantEvaluatorError::RemainderByZero
2796                                    } else {
2797                                        ConstantEvaluatorError::Overflow("remainder".into())
2798                                    }
2799                                })?,
2800                                BinaryOperator::And => a & b,
2801                                BinaryOperator::ExclusiveOr => a ^ b,
2802                                BinaryOperator::InclusiveOr => a | b,
2803                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2804                            })
2805                        }
2806                        (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
2807                            let result = match op {
2808                                BinaryOperator::Add => a + b,
2809                                BinaryOperator::Subtract => a - b,
2810                                BinaryOperator::Multiply => a * b,
2811                                BinaryOperator::Divide => a / b,
2812                                BinaryOperator::Modulo => a % b,
2813                                _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2814                            };
2815                            if !result.is_finite() {
2816                                return Err(ConstantEvaluatorError::Overflow(format!("{op:?}")));
2817                            }
2818                            Literal::AbstractFloat(result)
2819                        }
2820                        (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
2821                            BinaryOperator::LogicalAnd => a && b,
2822                            BinaryOperator::LogicalOr => a || b,
2823                            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2824                        }),
2825                        _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2826                    },
2827                };
2828                Expression::Literal(literal)
2829            }
2830            (
2831                &Expression::Compose {
2832                    components: ref src_components,
2833                    ty,
2834                },
2835                &Expression::Literal(_),
2836            ) => {
2837                if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
2838                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2839                }
2840                let mut components = src_components.clone();
2841                for component in &mut components {
2842                    *component = self.binary_op(op, *component, right, span)?;
2843                }
2844                Expression::Compose { ty, components }
2845            }
2846            (
2847                &Expression::Literal(_),
2848                &Expression::Compose {
2849                    components: ref src_components,
2850                    ty,
2851                },
2852            ) => {
2853                if !is_allowed_compose_literal_op(&self.types[ty].inner, op) {
2854                    return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
2855                }
2856                let mut components = src_components.clone();
2857                for component in &mut components {
2858                    *component = self.binary_op(op, left, *component, span)?;
2859                }
2860                Expression::Compose { ty, components }
2861            }
2862            (
2863                &Expression::Compose {
2864                    components: ref left_components,
2865                    ty: left_ty,
2866                },
2867                &Expression::Compose {
2868                    components: ref right_components,
2869                    ty: right_ty,
2870                },
2871            ) => {
2872                // We have to make a copy of the component lists, because the
2873                // call to `binary_op_vector` needs `&mut self`, but `self` owns
2874                // the component lists.
2875                let left_flattened = crate::proc::flatten_compose(
2876                    left_ty,
2877                    left_components,
2878                    self.expressions,
2879                    self.types,
2880                )
2881                .collect::<Vec<_>>();
2882                let right_flattened = crate::proc::flatten_compose(
2883                    right_ty,
2884                    right_components,
2885                    self.expressions,
2886                    self.types,
2887                )
2888                .collect::<Vec<_>>();
2889
2890                self.binary_op_compose(
2891                    op,
2892                    &left_flattened,
2893                    &right_flattened,
2894                    left_ty,
2895                    right_ty,
2896                    span,
2897                )?
2898            }
2899            _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
2900        };
2901
2902        return self.register_evaluated_expr(expr, span);
2903
2904        fn is_allowed_compose_literal_op(compose_ty: &TypeInner, op: BinaryOperator) -> bool {
2905            let is_numeric_vec = matches!(
2906                compose_ty, TypeInner::Vector { scalar, .. }
2907                if scalar.kind != ScalarKind::Bool
2908            );
2909            let is_allowed_vec_scalar_op = matches!(
2910                op,
2911                BinaryOperator::Add
2912                    | BinaryOperator::Subtract
2913                    | BinaryOperator::Multiply
2914                    | BinaryOperator::Divide
2915                    | BinaryOperator::Modulo
2916            );
2917            let is_mat = matches!(compose_ty, TypeInner::Matrix { .. });
2918            let is_allowed_mat_scalar_op = matches!(op, BinaryOperator::Multiply);
2919            is_numeric_vec && is_allowed_vec_scalar_op || is_mat && is_allowed_mat_scalar_op
2920        }
2921    }
2922
2923    fn binary_op_compose(
2924        &mut self,
2925        op: BinaryOperator,
2926        left_components: &[Handle<Expression>],
2927        right_components: &[Handle<Expression>],
2928        left_ty: Handle<Type>,
2929        right_ty: Handle<Type>,
2930        span: Span,
2931    ) -> Result<Expression, ConstantEvaluatorError> {
2932        match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
2933            // Binary operation on vector-vector
2934            (
2935                &TypeInner::Vector {
2936                    size: left_size, ..
2937                },
2938                &TypeInner::Vector {
2939                    size: right_size, ..
2940                },
2941            ) if left_size == right_size => self.binary_op_vector(
2942                op,
2943                left_size,
2944                left_components,
2945                right_components,
2946                left_ty,
2947                span,
2948            ),
2949            // Binary operation on vector-matrix
2950            (
2951                &TypeInner::Vector { size, .. },
2952                &TypeInner::Matrix {
2953                    columns,
2954                    rows,
2955                    scalar,
2956                },
2957            ) if op == BinaryOperator::Multiply && size == rows => self.multiply_vector_matrix(
2958                left_components,
2959                right_components,
2960                columns,
2961                scalar,
2962                span,
2963            ),
2964            // Binary operation on matrix-vector
2965            (
2966                &TypeInner::Matrix {
2967                    columns,
2968                    rows,
2969                    scalar,
2970                },
2971                &TypeInner::Vector { size, .. },
2972            ) if op == BinaryOperator::Multiply && size == columns => {
2973                self.multiply_matrix_vector(left_components, right_components, rows, scalar, span)
2974            }
2975            // Binary operation on matrix-matrix
2976            (
2977                &TypeInner::Matrix {
2978                    columns: left_columns,
2979                    rows: left_rows,
2980                    scalar,
2981                },
2982                &TypeInner::Matrix {
2983                    columns: right_columns,
2984                    rows: right_rows,
2985                    ..
2986                },
2987            ) => match op {
2988                BinaryOperator::Add | BinaryOperator::Subtract
2989                    if left_columns == right_columns && left_rows == right_rows =>
2990                {
2991                    let components = left_components
2992                        .iter()
2993                        .zip(right_components)
2994                        .map(|(&left, &right)| self.binary_op(op, left, right, span))
2995                        .collect::<Result<Vec<_>, _>>()?;
2996                    Ok(Expression::Compose {
2997                        ty: left_ty,
2998                        components,
2999                    })
3000                }
3001                BinaryOperator::Multiply if left_columns == right_rows => self
3002                    .multiply_matrix_matrix(
3003                        left_components,
3004                        right_components,
3005                        left_rows,
3006                        right_columns,
3007                        scalar,
3008                        span,
3009                    ),
3010                _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3011            },
3012            _ => Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
3013        }
3014    }
3015
3016    fn binary_op_vector(
3017        &mut self,
3018        op: BinaryOperator,
3019        size: crate::VectorSize,
3020        left_components: &[Handle<Expression>],
3021        right_components: &[Handle<Expression>],
3022        left_ty: Handle<Type>,
3023        span: Span,
3024    ) -> Result<Expression, ConstantEvaluatorError> {
3025        let ty = match op {
3026            // Relational operators produce vectors of booleans.
3027            BinaryOperator::Equal
3028            | BinaryOperator::NotEqual
3029            | BinaryOperator::Less
3030            | BinaryOperator::LessEqual
3031            | BinaryOperator::Greater
3032            | BinaryOperator::GreaterEqual => self.types.insert(
3033                Type {
3034                    name: None,
3035                    inner: TypeInner::Vector {
3036                        size,
3037                        scalar: crate::Scalar::BOOL,
3038                    },
3039                },
3040                span,
3041            ),
3042
3043            // Other operators produce the same type as their left
3044            // operand.
3045            BinaryOperator::Add
3046            | BinaryOperator::Subtract
3047            | BinaryOperator::Multiply
3048            | BinaryOperator::Divide
3049            | BinaryOperator::Modulo
3050            | BinaryOperator::And
3051            | BinaryOperator::ExclusiveOr
3052            | BinaryOperator::InclusiveOr
3053            | BinaryOperator::ShiftLeft
3054            | BinaryOperator::ShiftRight => left_ty,
3055
3056            BinaryOperator::LogicalAnd | BinaryOperator::LogicalOr => {
3057                // Not supported on vectors
3058                return Err(ConstantEvaluatorError::InvalidBinaryOpArgs);
3059            }
3060        };
3061
3062        let components = left_components
3063            .iter()
3064            .zip(right_components)
3065            .map(|(&left, &right)| self.binary_op(op, left, right, span))
3066            .collect::<Result<Vec<_>, _>>()?;
3067
3068        Ok(Expression::Compose { ty, components })
3069    }
3070
3071    fn multiply_vector_matrix(
3072        &mut self,
3073        vec_components: &[Handle<Expression>],
3074        mat_components: &[Handle<Expression>],
3075        mat_columns: crate::VectorSize,
3076        scalar: crate::Scalar,
3077        span: Span,
3078    ) -> Result<Expression, ConstantEvaluatorError> {
3079        let ty = self.types.insert(
3080            Type {
3081                name: None,
3082                inner: TypeInner::Vector {
3083                    size: mat_columns,
3084                    scalar,
3085                },
3086            },
3087            span,
3088        );
3089        let components = mat_components
3090            .iter()
3091            .map(|&column| {
3092                let Expression::Compose { ref components, .. } = self.expressions[column] else {
3093                    unreachable!()
3094                };
3095                self.dot_exprs(
3096                    vec_components.iter().cloned(),
3097                    components.clone().into_iter(),
3098                    span,
3099                )
3100            })
3101            .collect::<Result<Vec<_>, _>>()?;
3102        Ok(Expression::Compose { ty, components })
3103    }
3104
3105    fn multiply_matrix_vector(
3106        &mut self,
3107        mat_components: &[Handle<Expression>],
3108        vec_components: &[Handle<Expression>],
3109        mat_rows: crate::VectorSize,
3110        scalar: crate::Scalar,
3111        span: Span,
3112    ) -> Result<Expression, ConstantEvaluatorError> {
3113        let ty = self.types.insert(
3114            Type {
3115                name: None,
3116                inner: TypeInner::Vector {
3117                    size: mat_rows,
3118                    scalar,
3119                },
3120            },
3121            span,
3122        );
3123
3124        let flatten = self.flatten_matrix(mat_components);
3125        let nr = mat_rows as usize;
3126        let components = (0..nr)
3127            .map(|r| {
3128                let row = flatten.iter().skip(r).step_by(nr).cloned();
3129                self.dot_exprs(row, vec_components.iter().cloned(), span)
3130            })
3131            .collect::<Result<Vec<_>, _>>()?;
3132        Ok(Expression::Compose { ty, components })
3133    }
3134
3135    fn multiply_matrix_matrix(
3136        &mut self,
3137        left_components: &[Handle<Expression>],
3138        right_components: &[Handle<Expression>],
3139        left_rows: crate::VectorSize,
3140        right_columns: crate::VectorSize,
3141        scalar: crate::Scalar,
3142        span: Span,
3143    ) -> Result<Expression, ConstantEvaluatorError> {
3144        let left_nc = left_components.len();
3145        let left_nr = left_rows as usize;
3146        let right_nc = right_columns as usize;
3147        let right_nr = left_nc;
3148
3149        let mut result = Vec::with_capacity(right_nc);
3150        let result_ty = self.types.insert(
3151            Type {
3152                name: None,
3153                inner: TypeInner::Matrix {
3154                    columns: right_columns,
3155                    rows: left_rows,
3156                    scalar,
3157                },
3158            },
3159            span,
3160        );
3161        let result_column_ty = self.types.insert(
3162            Type {
3163                name: None,
3164                inner: TypeInner::Vector {
3165                    size: left_rows,
3166                    scalar,
3167                },
3168            },
3169            span,
3170        );
3171
3172        let left_flattened = self.flatten_matrix(left_components);
3173        let right_flattened = self.flatten_matrix(right_components);
3174        for c in 0..right_nc {
3175            let result_column = (0..left_nr)
3176                .map(|r| {
3177                    let row = left_flattened.iter().skip(r).step_by(left_nr);
3178                    let column = right_flattened.iter().skip(c * right_nr).take(right_nr);
3179                    self.dot_exprs(row.cloned(), column.cloned(), span)
3180                })
3181                .collect::<Result<Vec<_>, _>>()?;
3182            let expr = Expression::Compose {
3183                ty: result_column_ty,
3184                components: result_column,
3185            };
3186            let handle = self.register_evaluated_expr(expr, span)?;
3187            result.push(handle);
3188        }
3189        Ok(Expression::Compose {
3190            ty: result_ty,
3191            components: result,
3192        })
3193    }
3194
3195    fn flatten_matrix(&self, columns: &[Handle<Expression>]) -> ArrayVec<Handle<Expression>, 16> {
3196        let mut flattened = ArrayVec::<_, 16>::new();
3197        for &column in columns {
3198            let Expression::Compose { ref components, .. } = self.expressions[column] else {
3199                unreachable!()
3200            };
3201            flattened.extend(components.iter().cloned());
3202        }
3203        flattened
3204    }
3205
3206    fn dot_exprs(
3207        &mut self,
3208        left: impl Iterator<Item = Handle<Expression>>,
3209        right: impl Iterator<Item = Handle<Expression>>,
3210        span: Span,
3211    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3212        let mut acc = None;
3213        for (l, r) in left.zip(right) {
3214            let result = self.binary_op(BinaryOperator::Multiply, l, r, span)?;
3215            match acc.as_mut() {
3216                Some(acc) => *acc = self.binary_op(BinaryOperator::Add, *acc, result, span)?,
3217                None => acc = Some(result),
3218            }
3219        }
3220        Ok(acc.unwrap())
3221    }
3222
3223    fn relational(
3224        &mut self,
3225        fun: RelationalFunction,
3226        arg: Handle<Expression>,
3227        span: Span,
3228    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3229        let arg = self.eval_zero_value_and_splat(arg, span)?;
3230        match fun {
3231            RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
3232                Expression::Literal(Literal::Bool(_)) => Ok(arg),
3233                Expression::Compose { ty, ref components }
3234                    if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
3235                {
3236                    let mut bool_components = ArrayVec::<bool, { crate::VectorSize::MAX }>::new();
3237                    for component in
3238                        crate::proc::flatten_compose(ty, components, self.expressions, self.types)
3239                    {
3240                        match self.expressions[component] {
3241                            Expression::Literal(Literal::Bool(val)) => {
3242                                bool_components.push(val);
3243                            }
3244                            _ => {
3245                                return Err(ConstantEvaluatorError::InvalidRelationalArg(fun));
3246                            }
3247                        }
3248                    }
3249                    let components = bool_components;
3250                    let result = match fun {
3251                        RelationalFunction::All => components.iter().all(|c| *c),
3252                        RelationalFunction::Any => components.iter().any(|c| *c),
3253                        _ => unreachable!(),
3254                    };
3255                    self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
3256                }
3257                _ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
3258            },
3259            _ => Err(ConstantEvaluatorError::NotImplemented(format!(
3260                "{fun:?} built-in function"
3261            ))),
3262        }
3263    }
3264
3265    /// Deep copy `expr` from `expressions` into `self.expressions`.
3266    ///
3267    /// Return the root of the new copy.
3268    ///
3269    /// This is used when we're evaluating expressions in a function's
3270    /// expression arena that refer to a constant: we need to copy the
3271    /// constant's value into the function's arena so we can operate on it.
3272    fn copy_from(
3273        &mut self,
3274        expr: Handle<Expression>,
3275        expressions: &Arena<Expression>,
3276    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3277        let span = expressions.get_span(expr);
3278        match expressions[expr] {
3279            ref expr @ (Expression::Literal(_)
3280            | Expression::Constant(_)
3281            | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
3282            Expression::Compose { ty, ref components } => {
3283                let mut components = components.clone();
3284                for component in &mut components {
3285                    *component = self.copy_from(*component, expressions)?;
3286                }
3287                self.register_evaluated_expr(Expression::Compose { ty, components }, span)
3288            }
3289            Expression::Splat { size, value } => {
3290                let value = self.copy_from(value, expressions)?;
3291                self.register_evaluated_expr(Expression::Splat { size, value }, span)
3292            }
3293            _ => {
3294                log::debug!("copy_from: SubexpressionsAreNotConstant");
3295                Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
3296            }
3297        }
3298    }
3299
3300    /// Returns the total number of components, after flattening, of a vector compose expression.
3301    fn vector_compose_flattened_size(
3302        &self,
3303        components: &[Handle<Expression>],
3304    ) -> Result<usize, ConstantEvaluatorError> {
3305        components
3306            .iter()
3307            .try_fold(0, |acc, c| -> Result<_, ConstantEvaluatorError> {
3308                let size = match *self.resolve_type(*c)?.inner_with(self.types) {
3309                    TypeInner::Scalar(_) => 1,
3310                    // We trust that the vector size of `component` is correct,
3311                    // as it will have already been validated when `component`
3312                    // was registered.
3313                    TypeInner::Vector { size, .. } => size as usize,
3314                    _ => return Err(ConstantEvaluatorError::InvalidVectorComposeComponent),
3315                };
3316                Ok(acc + size)
3317            })
3318    }
3319
3320    fn register_evaluated_expr(
3321        &mut self,
3322        expr: Expression,
3323        span: Span,
3324    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3325        // It suffices to only check_literal_value() for `Literal` expressions,
3326        // since we only register one expression at a time, `Compose`
3327        // expressions can only refer to other expressions, and `ZeroValue`
3328        // expressions are always okay.
3329        if let Expression::Literal(literal) = expr {
3330            crate::valid::check_literal_value(literal)?;
3331        }
3332
3333        // Ensure vector composes contain the correct number of components. We
3334        // do so here when each compose is registered to avoid having to deal
3335        // with the mess each time the compose is used in another expression.
3336        if let Expression::Compose { ty, ref components } = expr {
3337            if let TypeInner::Vector { size, scalar: _ } = self.types[ty].inner {
3338                let expected = size as usize;
3339                let actual = self.vector_compose_flattened_size(components)?;
3340                if expected != actual {
3341                    return Err(ConstantEvaluatorError::InvalidVectorComposeLength {
3342                        expected,
3343                        actual,
3344                    });
3345                }
3346            }
3347        }
3348
3349        Ok(self.append_expr(expr, span, ExpressionKind::Const))
3350    }
3351
3352    fn append_expr(
3353        &mut self,
3354        expr: Expression,
3355        span: Span,
3356        expr_type: ExpressionKind,
3357    ) -> Handle<Expression> {
3358        let h = match self.behavior {
3359            Behavior::Wgsl(
3360                WgslRestrictions::Runtime(ref mut function_local_data)
3361                | WgslRestrictions::Const(Some(ref mut function_local_data)),
3362            )
3363            | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
3364                let is_running = function_local_data.emitter.is_running();
3365                let needs_pre_emit = expr.needs_pre_emit();
3366                if is_running && needs_pre_emit {
3367                    function_local_data
3368                        .block
3369                        .extend(function_local_data.emitter.finish(self.expressions));
3370                    let h = self.expressions.append(expr, span);
3371                    function_local_data.emitter.start(self.expressions);
3372                    h
3373                } else {
3374                    self.expressions.append(expr, span)
3375                }
3376            }
3377            _ => self.expressions.append(expr, span),
3378        };
3379        self.expression_kind_tracker.insert(h, expr_type);
3380        h
3381    }
3382
3383    /// Resolve the type of `expr` if it is a constant expression.
3384    ///
3385    /// If `expr` was evaluated to a constant, returns its type.
3386    /// Otherwise, returns an error.
3387    fn resolve_type(
3388        &self,
3389        expr: Handle<Expression>,
3390    ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
3391        use crate::proc::TypeResolution as Tr;
3392        use crate::Expression as Ex;
3393        let resolution = match self.expressions[expr] {
3394            Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
3395            Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
3396            Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
3397            Ex::Splat { size, value } => {
3398                let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
3399                    return Err(ConstantEvaluatorError::SplatScalarOnly);
3400                };
3401                Tr::Value(TypeInner::Vector { scalar, size })
3402            }
3403            _ => {
3404                log::debug!("resolve_type: SubexpressionsAreNotConstant");
3405                return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
3406            }
3407        };
3408
3409        Ok(resolution)
3410    }
3411
3412    fn select(
3413        &mut self,
3414        reject: Handle<Expression>,
3415        accept: Handle<Expression>,
3416        condition: Handle<Expression>,
3417        span: Span,
3418    ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
3419        let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
3420
3421        let reject = arg(reject)?;
3422        let accept = arg(accept)?;
3423        let condition = arg(condition)?;
3424
3425        let select_single_component =
3426            |this: &mut Self, reject_scalar, reject, accept, condition| {
3427                let accept = this.cast(accept, reject_scalar, span)?;
3428                if condition {
3429                    Ok(accept)
3430                } else {
3431                    Ok(reject)
3432                }
3433            };
3434
3435        match (&self.expressions[reject], &self.expressions[accept]) {
3436            (&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
3437                let reject_scalar = reject_lit.scalar();
3438                let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
3439                else {
3440                    return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
3441                };
3442                select_single_component(self, reject_scalar, reject, accept, condition)
3443            }
3444            (
3445                &Expression::Compose {
3446                    ty: reject_ty,
3447                    components: ref reject_components,
3448                },
3449                &Expression::Compose {
3450                    ty: accept_ty,
3451                    components: ref accept_components,
3452                },
3453            ) => {
3454                let ty_deets = |ty| {
3455                    let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
3456                    (size.unwrap(), scalar)
3457                };
3458
3459                let expected_vec_size = {
3460                    let [(reject_vec_size, _), (accept_vec_size, _)] =
3461                        [reject_ty, accept_ty].map(ty_deets);
3462
3463                    if reject_vec_size != accept_vec_size {
3464                        return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
3465                            reject: reject_vec_size,
3466                            accept: accept_vec_size,
3467                        });
3468                    }
3469                    reject_vec_size
3470                };
3471
3472                let condition_components = match self.expressions[condition] {
3473                    Expression::Literal(Literal::Bool(condition)) => {
3474                        vec![condition; (expected_vec_size as u8).into()]
3475                    }
3476                    Expression::Compose {
3477                        ty: condition_ty,
3478                        components: ref condition_components,
3479                    } => {
3480                        let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
3481                        if condition_scalar.kind != ScalarKind::Bool {
3482                            return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
3483                        }
3484                        if condition_vec_size != expected_vec_size {
3485                            return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
3486                        }
3487                        condition_components
3488                            .iter()
3489                            .copied()
3490                            .map(|component| match &self.expressions[component] {
3491                                &Expression::Literal(Literal::Bool(condition)) => condition,
3492                                _ => unreachable!(),
3493                            })
3494                            .collect()
3495                    }
3496
3497                    _ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
3498                };
3499
3500                let evaluated = Expression::Compose {
3501                    ty: reject_ty,
3502                    components: reject_components
3503                        .clone()
3504                        .into_iter()
3505                        .zip(accept_components.clone().into_iter())
3506                        .zip(condition_components.into_iter())
3507                        .map(|((reject, accept), condition)| {
3508                            let reject_scalar = match &self.expressions[reject] {
3509                                &Expression::Literal(lit) => lit.scalar(),
3510                                _ => unreachable!(),
3511                            };
3512                            select_single_component(self, reject_scalar, reject, accept, condition)
3513                        })
3514                        .collect::<Result<_, _>>()?,
3515                };
3516                self.register_evaluated_expr(evaluated, span)
3517            }
3518            _ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
3519        }
3520    }
3521}
3522
3523fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3524    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, a value
3525    // of 1 means the least significant bit is set. Therefore, an input of `0x[80 00…]` would
3526    // return a right-to-left bit index of 0.
3527    let trailing_zeros_to_bit_idx = |e: u32| -> u32 {
3528        match e {
3529            idx @ 0..=31 => idx,
3530            32 => u32::MAX,
3531            _ => unreachable!(),
3532        }
3533    };
3534    match concrete_int {
3535        ConcreteInt::U32([e]) => ConcreteInt::U32([trailing_zeros_to_bit_idx(e.trailing_zeros())]),
3536        ConcreteInt::I32([e]) => {
3537            ConcreteInt::I32([trailing_zeros_to_bit_idx(e.trailing_zeros()) as i32])
3538        }
3539    }
3540}
3541
3542#[test]
3543fn first_trailing_bit_smoke() {
3544    assert_eq!(
3545        first_trailing_bit(ConcreteInt::I32([0])),
3546        ConcreteInt::I32([-1])
3547    );
3548    assert_eq!(
3549        first_trailing_bit(ConcreteInt::I32([1])),
3550        ConcreteInt::I32([0])
3551    );
3552    assert_eq!(
3553        first_trailing_bit(ConcreteInt::I32([2])),
3554        ConcreteInt::I32([1])
3555    );
3556    assert_eq!(
3557        first_trailing_bit(ConcreteInt::I32([-1])),
3558        ConcreteInt::I32([0]),
3559    );
3560    assert_eq!(
3561        first_trailing_bit(ConcreteInt::I32([i32::MIN])),
3562        ConcreteInt::I32([31]),
3563    );
3564    assert_eq!(
3565        first_trailing_bit(ConcreteInt::I32([i32::MAX])),
3566        ConcreteInt::I32([0]),
3567    );
3568    for idx in 0..32 {
3569        assert_eq!(
3570            first_trailing_bit(ConcreteInt::I32([1 << idx])),
3571            ConcreteInt::I32([idx])
3572        )
3573    }
3574
3575    assert_eq!(
3576        first_trailing_bit(ConcreteInt::U32([0])),
3577        ConcreteInt::U32([u32::MAX])
3578    );
3579    assert_eq!(
3580        first_trailing_bit(ConcreteInt::U32([1])),
3581        ConcreteInt::U32([0])
3582    );
3583    assert_eq!(
3584        first_trailing_bit(ConcreteInt::U32([2])),
3585        ConcreteInt::U32([1])
3586    );
3587    assert_eq!(
3588        first_trailing_bit(ConcreteInt::U32([1 << 31])),
3589        ConcreteInt::U32([31]),
3590    );
3591    assert_eq!(
3592        first_trailing_bit(ConcreteInt::U32([u32::MAX])),
3593        ConcreteInt::U32([0]),
3594    );
3595    for idx in 0..32 {
3596        assert_eq!(
3597            first_trailing_bit(ConcreteInt::U32([1 << idx])),
3598            ConcreteInt::U32([idx])
3599        )
3600    }
3601}
3602
3603fn first_leading_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {
3604    // NOTE: Bit indices for this built-in start at 0 at the "right" (or LSB). For example, 1 means
3605    // the least significant bit is set. Therefore, an input of 1 would return a right-to-left bit
3606    // index of 0.
3607    let rtl_to_ltr_bit_idx = |e: u32| -> u32 {
3608        match e {
3609            idx @ 0..=31 => 31 - idx,
3610            32 => u32::MAX,
3611            _ => unreachable!(),
3612        }
3613    };
3614    match concrete_int {
3615        ConcreteInt::I32([e]) => ConcreteInt::I32([{
3616            let rtl_bit_index = if e.is_negative() {
3617                e.leading_ones()
3618            } else {
3619                e.leading_zeros()
3620            };
3621            rtl_to_ltr_bit_idx(rtl_bit_index) as i32
3622        }]),
3623        ConcreteInt::U32([e]) => ConcreteInt::U32([rtl_to_ltr_bit_idx(e.leading_zeros())]),
3624    }
3625}
3626
3627#[test]
3628fn first_leading_bit_smoke() {
3629    assert_eq!(
3630        first_leading_bit(ConcreteInt::I32([-1])),
3631        ConcreteInt::I32([-1])
3632    );
3633    assert_eq!(
3634        first_leading_bit(ConcreteInt::I32([0])),
3635        ConcreteInt::I32([-1])
3636    );
3637    assert_eq!(
3638        first_leading_bit(ConcreteInt::I32([1])),
3639        ConcreteInt::I32([0])
3640    );
3641    assert_eq!(
3642        first_leading_bit(ConcreteInt::I32([-2])),
3643        ConcreteInt::I32([0])
3644    );
3645    assert_eq!(
3646        first_leading_bit(ConcreteInt::I32([1234 + 4567])),
3647        ConcreteInt::I32([12])
3648    );
3649    assert_eq!(
3650        first_leading_bit(ConcreteInt::I32([i32::MAX])),
3651        ConcreteInt::I32([30])
3652    );
3653    assert_eq!(
3654        first_leading_bit(ConcreteInt::I32([i32::MIN])),
3655        ConcreteInt::I32([30])
3656    );
3657    // NOTE: Ignore the sign bit, which is a separate (above) case.
3658    for idx in 0..(32 - 1) {
3659        assert_eq!(
3660            first_leading_bit(ConcreteInt::I32([1 << idx])),
3661            ConcreteInt::I32([idx])
3662        );
3663    }
3664    for idx in 1..(32 - 1) {
3665        assert_eq!(
3666            first_leading_bit(ConcreteInt::I32([-(1 << idx)])),
3667            ConcreteInt::I32([idx - 1])
3668        );
3669    }
3670
3671    assert_eq!(
3672        first_leading_bit(ConcreteInt::U32([0])),
3673        ConcreteInt::U32([u32::MAX])
3674    );
3675    assert_eq!(
3676        first_leading_bit(ConcreteInt::U32([1])),
3677        ConcreteInt::U32([0])
3678    );
3679    assert_eq!(
3680        first_leading_bit(ConcreteInt::U32([u32::MAX])),
3681        ConcreteInt::U32([31])
3682    );
3683    for idx in 0..32 {
3684        assert_eq!(
3685            first_leading_bit(ConcreteInt::U32([1 << idx])),
3686            ConcreteInt::U32([idx])
3687        )
3688    }
3689}
3690
3691/// Trait for conversions of abstract values to concrete types.
3692trait TryFromAbstract<T>: Sized {
3693    /// Convert an abstract literal `value` to `Self`.
3694    ///
3695    /// Since Naga's [`AbstractInt`] and [`AbstractFloat`] exist to support
3696    /// WGSL, we follow WGSL's conversion rules here:
3697    ///
3698    /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
3699    ///   from [`AbstractInt`] to an integer type are either lossless or an
3700    ///   error.
3701    ///
3702    /// - WGSL §15.7.6 Floating Point Conversion says that conversions
3703    ///   to floating point in constant expressions and override
3704    ///   expressions are errors if the value is out of range for the
3705    ///   destination type, but rounding is okay.
3706    ///
3707    /// - WGSL §17.1.2 i32()/u32() constructors treat AbstractFloat as any
3708    ///   other floating point type, following the scalar floating point to
3709    ///   integral conversion algorithm (§15.7.6). There is no automatic
3710    ///   conversion from AbstractFloat to integer types.
3711    ///
3712    /// [`AbstractInt`]: crate::Literal::AbstractInt
3713    /// [`AbstractFloat`]: crate::Literal::AbstractFloat
3714    fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
3715}
3716
3717impl TryFromAbstract<i64> for i32 {
3718    fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
3719        i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3720            value: format!("{value:?}"),
3721            to_type: "i32",
3722        })
3723    }
3724}
3725
3726impl TryFromAbstract<i64> for u32 {
3727    fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
3728        u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3729            value: format!("{value:?}"),
3730            to_type: "u32",
3731        })
3732    }
3733}
3734
3735impl TryFromAbstract<i64> for u64 {
3736    fn try_from_abstract(value: i64) -> Result<u64, ConstantEvaluatorError> {
3737        u64::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
3738            value: format!("{value:?}"),
3739            to_type: "u64",
3740        })
3741    }
3742}
3743
3744impl TryFromAbstract<i64> for i64 {
3745    fn try_from_abstract(value: i64) -> Result<i64, ConstantEvaluatorError> {
3746        Ok(value)
3747    }
3748}
3749
3750impl TryFromAbstract<i64> for f32 {
3751    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3752        let f = value as f32;
3753        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3754        // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
3755        // overflow here.
3756        Ok(f)
3757    }
3758}
3759
3760impl TryFromAbstract<f64> for f32 {
3761    fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
3762        let f = value as f32;
3763        if f.is_infinite() {
3764            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3765                value: format!("{value:?}"),
3766                to_type: "f32",
3767            });
3768        }
3769        Ok(f)
3770    }
3771}
3772
3773impl TryFromAbstract<i64> for f64 {
3774    fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
3775        let f = value as f64;
3776        // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
3777        // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
3778        // overflow here.
3779        Ok(f)
3780    }
3781}
3782
3783impl TryFromAbstract<f64> for f64 {
3784    fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
3785        Ok(value)
3786    }
3787}
3788
3789impl TryFromAbstract<f64> for i32 {
3790    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3791        // https://www.w3.org/TR/WGSL/#floating-point-conversion
3792        // To convert a floating point scalar value X to an integer scalar type T:
3793        // * If X is a NaN, the result is an indeterminate value in T.
3794        // * If X is exactly representable in the target type T, then the
3795        //   result is that value.
3796        // * Otherwise, the result is the value in T closest to truncate(X) and
3797        //   also exactly representable in the original floating point type.
3798        //
3799        // A rust cast satisfies these requirements apart from "the result
3800        // is... exactly representable in the original floating point type".
3801        // However, i32::MIN and i32::MAX are exactly representable by f64, so
3802        // we're all good.
3803        Ok(value as i32)
3804    }
3805}
3806
3807impl TryFromAbstract<f64> for u32 {
3808    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3809        // As above, u32::MIN and u32::MAX are exactly representable by f64,
3810        // so a simple rust cast is sufficient.
3811        Ok(value as u32)
3812    }
3813}
3814
3815impl TryFromAbstract<f64> for i64 {
3816    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3817        // As above, except we clamp to the minimum and maximum values
3818        // representable by both f64 and i64.
3819        use crate::proc::type_methods::IntFloatLimits;
3820        Ok(value.clamp(i64::min_float(), i64::max_float()) as i64)
3821    }
3822}
3823
3824impl TryFromAbstract<f64> for u64 {
3825    fn try_from_abstract(value: f64) -> Result<Self, ConstantEvaluatorError> {
3826        // As above, this time clamping to the minimum and maximum values
3827        // representable by both f64 and u64.
3828        use crate::proc::type_methods::IntFloatLimits;
3829        Ok(value.clamp(u64::min_float(), u64::max_float()) as u64)
3830    }
3831}
3832
3833impl TryFromAbstract<f64> for f16 {
3834    fn try_from_abstract(value: f64) -> Result<f16, ConstantEvaluatorError> {
3835        let f = f16::from_f64(value);
3836        if f.is_infinite() {
3837            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3838                value: format!("{value:?}"),
3839                to_type: "f16",
3840            });
3841        }
3842        Ok(f)
3843    }
3844}
3845
3846impl TryFromAbstract<i64> for f16 {
3847    fn try_from_abstract(value: i64) -> Result<f16, ConstantEvaluatorError> {
3848        let f = f16::from_i64(value);
3849        if f.is_none() {
3850            return Err(ConstantEvaluatorError::AutomaticConversionLossy {
3851                value: format!("{value:?}"),
3852                to_type: "f16",
3853            });
3854        }
3855        Ok(f.unwrap())
3856    }
3857}
3858
3859fn cross_product<T>(a: [T; 3], b: [T; 3]) -> [T; 3]
3860where
3861    T: Copy,
3862    T: core::ops::Mul<T, Output = T>,
3863    T: core::ops::Sub<T, Output = T>,
3864{
3865    [
3866        a[1] * b[2] - a[2] * b[1],
3867        a[2] * b[0] - a[0] * b[2],
3868        a[0] * b[1] - a[1] * b[0],
3869    ]
3870}
3871
3872#[cfg(test)]
3873mod tests {
3874    use alloc::{vec, vec::Vec};
3875
3876    use crate::{
3877        Arena, BinaryOperator, Constant, Expression, FastHashMap, Handle, Literal, ScalarKind,
3878        Type, TypeInner, UnaryOperator, UniqueArena, VectorSize,
3879    };
3880
3881    use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
3882
3883    #[test]
3884    fn unary_op() {
3885        let mut types = UniqueArena::new();
3886        let mut constants = Arena::new();
3887        let overrides = Arena::new();
3888        let mut global_expressions = Arena::new();
3889
3890        let scalar_ty = types.insert(
3891            Type {
3892                name: None,
3893                inner: TypeInner::Scalar(crate::Scalar::I32),
3894            },
3895            Default::default(),
3896        );
3897
3898        let vec_ty = types.insert(
3899            Type {
3900                name: None,
3901                inner: TypeInner::Vector {
3902                    size: VectorSize::Bi,
3903                    scalar: crate::Scalar::I32,
3904                },
3905            },
3906            Default::default(),
3907        );
3908
3909        let h = constants.append(
3910            Constant {
3911                name: None,
3912                ty: scalar_ty,
3913                init: global_expressions
3914                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
3915            },
3916            Default::default(),
3917        );
3918
3919        let h1 = constants.append(
3920            Constant {
3921                name: None,
3922                ty: scalar_ty,
3923                init: global_expressions
3924                    .append(Expression::Literal(Literal::I32(8)), Default::default()),
3925            },
3926            Default::default(),
3927        );
3928
3929        let vec_h = constants.append(
3930            Constant {
3931                name: None,
3932                ty: vec_ty,
3933                init: global_expressions.append(
3934                    Expression::Compose {
3935                        ty: vec_ty,
3936                        components: vec![constants[h].init, constants[h1].init],
3937                    },
3938                    Default::default(),
3939                ),
3940            },
3941            Default::default(),
3942        );
3943
3944        let expr = global_expressions.append(Expression::Constant(h), Default::default());
3945        let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
3946
3947        let expr2 = Expression::Unary {
3948            op: UnaryOperator::Negate,
3949            expr,
3950        };
3951
3952        let expr3 = Expression::Unary {
3953            op: UnaryOperator::BitwiseNot,
3954            expr,
3955        };
3956
3957        let expr4 = Expression::Unary {
3958            op: UnaryOperator::BitwiseNot,
3959            expr: expr1,
3960        };
3961
3962        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
3963        let mut solver = ConstantEvaluator {
3964            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
3965            types: &mut types,
3966            constants: &constants,
3967            overrides: &overrides,
3968            expressions: &mut global_expressions,
3969            expression_kind_tracker,
3970            layouter: &mut crate::proc::Layouter::default(),
3971        };
3972
3973        let res1 = solver
3974            .try_eval_and_append(expr2, Default::default())
3975            .unwrap();
3976        let res2 = solver
3977            .try_eval_and_append(expr3, Default::default())
3978            .unwrap();
3979        let res3 = solver
3980            .try_eval_and_append(expr4, Default::default())
3981            .unwrap();
3982
3983        assert_eq!(
3984            global_expressions[res1],
3985            Expression::Literal(Literal::I32(-4))
3986        );
3987
3988        assert_eq!(
3989            global_expressions[res2],
3990            Expression::Literal(Literal::I32(!4))
3991        );
3992
3993        let res3_inner = &global_expressions[res3];
3994
3995        match *res3_inner {
3996            Expression::Compose {
3997                ref ty,
3998                ref components,
3999            } => {
4000                assert_eq!(*ty, vec_ty);
4001                let mut components_iter = components.iter().copied();
4002                assert_eq!(
4003                    global_expressions[components_iter.next().unwrap()],
4004                    Expression::Literal(Literal::I32(!4))
4005                );
4006                assert_eq!(
4007                    global_expressions[components_iter.next().unwrap()],
4008                    Expression::Literal(Literal::I32(!8))
4009                );
4010                assert!(components_iter.next().is_none());
4011            }
4012            _ => panic!("Expected vector"),
4013        }
4014    }
4015
4016    #[test]
4017    fn matrix_op() {
4018        let mut helper = MatrixTestHelper::new();
4019
4020        for nc in 2..=4 {
4021            for nr in 2..=4 {
4022                // Validates multiplication on vector-matrix.
4023                // vecR(0, 1, .., r) * matCxR(0, 1, .., nc * nr)
4024                let evaluated = helper.eval_vector_multiply_matrix(nc, nr);
4025                let expected = (0..nc)
4026                    .map(|c| (0..nr).map(|r| (r * (c * nr + r)) as f32).sum())
4027                    .collect::<Vec<f32>>();
4028                assert_eq!(evaluated, expected);
4029
4030                // Validates multiplication on matrix-vector.
4031                // matCxR(0, 1, .., nc * nr) * vecC(0, 1, .., nc)
4032                let evaluated = helper.eval_matrix_multiply_vector(nc, nr);
4033                let expected = (0..nr)
4034                    .map(|r| (0..nc).map(|c| (c * (c * nr + r)) as f32).sum())
4035                    .collect::<Vec<f32>>();
4036                assert_eq!(evaluated, expected);
4037
4038                for k in 2..=4 {
4039                    // Validates multiplication on matrix-matrix.
4040                    // matKxR(0, 1, .., k * nr) * matCxK(0, 1, .., nc * k)
4041                    let evaluated = helper.eval_matrix_multiply_matrix(nr, nc, k);
4042                    let expected = (0..nc)
4043                        .flat_map(|c| {
4044                            (0..nr).map(move |r| {
4045                                (0..k).map(|v| ((v * nr + r) * (c * k + v)) as f32).sum()
4046                            })
4047                        })
4048                        .collect::<Vec<f32>>();
4049                    assert_eq!(evaluated, expected);
4050                }
4051            }
4052        }
4053    }
4054
4055    /// Test fixture providing pre-built f32 vector and matrix constant
4056    /// expressions with sequential element values, used to evaluate and verify
4057    /// matrix operations.
4058    struct MatrixTestHelper {
4059        types: UniqueArena<Type>,
4060        expressions: Arena<Expression>,
4061        /// Vector expressions from [0, 1] to [0, 1, 2, 3].
4062        vec_exprs: FastHashMap<usize, Handle<Expression>>,
4063        /// Matrix expressions from [0, .., 3] to [0, .., 15].
4064        mat_exprs: FastHashMap<(usize, usize), Handle<Expression>>,
4065    }
4066
4067    impl MatrixTestHelper {
4068        fn new() -> Self {
4069            let mut types = UniqueArena::new();
4070            let mut expressions = Arena::new();
4071            let span = crate::Span::default();
4072
4073            let (mut vec_tys, mut mat_tys) = (FastHashMap::default(), FastHashMap::default());
4074            for c in 2..=4 {
4075                let vec_ty = types.insert(
4076                    Type {
4077                        name: None,
4078                        inner: TypeInner::Vector {
4079                            size: Self::int_to_vector_size(c),
4080                            scalar: crate::Scalar::F32,
4081                        },
4082                    },
4083                    span,
4084                );
4085                vec_tys.insert(c, vec_ty);
4086                for r in 2..=4 {
4087                    let mat_ty = types.insert(
4088                        Type {
4089                            name: None,
4090                            inner: TypeInner::Matrix {
4091                                columns: Self::int_to_vector_size(c),
4092                                rows: Self::int_to_vector_size(r),
4093                                scalar: crate::Scalar::F32,
4094                            },
4095                        },
4096                        span,
4097                    );
4098                    mat_tys.insert((c, r), mat_ty);
4099                }
4100            }
4101
4102            let mut lit_exprs = FastHashMap::default();
4103            for i in 0..16 {
4104                let expr = expressions.append(Expression::Literal(Literal::F32(i as f32)), span);
4105                lit_exprs.insert(i, expr);
4106            }
4107
4108            let mut vec_exprs = FastHashMap::default();
4109            for c in 2..=4 {
4110                let expr = expressions.append(
4111                    Expression::Compose {
4112                        ty: *vec_tys.get(&c).unwrap(),
4113                        components: (0..c)
4114                            .map(|i| *lit_exprs.get(&i).unwrap())
4115                            .collect::<Vec<_>>(),
4116                    },
4117                    span,
4118                );
4119                vec_exprs.insert(c, expr);
4120            }
4121
4122            let mut mat_exprs = FastHashMap::default();
4123            for c in 2..=4 {
4124                for r in 2..=4 {
4125                    let mut columns = Vec::with_capacity(c);
4126                    for cc in 0..c {
4127                        let start = cc * r;
4128                        let expr = expressions.append(
4129                            Expression::Compose {
4130                                ty: *vec_tys.get(&r).unwrap(),
4131                                components: (start..start + r)
4132                                    .map(|i| *lit_exprs.get(&i).unwrap())
4133                                    .collect::<Vec<_>>(),
4134                            },
4135                            span,
4136                        );
4137                        columns.push(expr);
4138                    }
4139
4140                    let expr = expressions.append(
4141                        Expression::Compose {
4142                            ty: *mat_tys.get(&(c, r)).unwrap(),
4143                            components: columns,
4144                        },
4145                        span,
4146                    );
4147                    mat_exprs.insert((c, r), expr);
4148                }
4149            }
4150
4151            Self {
4152                types,
4153                expressions,
4154                vec_exprs,
4155                mat_exprs,
4156            }
4157        }
4158
4159        /// Evaluates vec[0..nr] * mat[0..nc*nr] and returns the result as f32s.
4160        fn eval_vector_multiply_matrix(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4161            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4162            let mut solver = ConstantEvaluator {
4163                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4164                types: &mut self.types,
4165                constants: &Arena::new(),
4166                overrides: &Arena::new(),
4167                expressions: &mut self.expressions,
4168                expression_kind_tracker,
4169                layouter: &mut crate::proc::Layouter::default(),
4170            };
4171
4172            let result = solver
4173                .try_eval_and_append(
4174                    Expression::Binary {
4175                        op: BinaryOperator::Multiply,
4176                        left: *self.vec_exprs.get(&nr).unwrap(),
4177                        right: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4178                    },
4179                    Default::default(),
4180                )
4181                .unwrap();
4182            self.flatten(result)
4183        }
4184
4185        /// Evaluates mat[0..nc*nr] * vec[0..nc] and returns the result as f32s.
4186        fn eval_matrix_multiply_vector(&mut self, nc: usize, nr: usize) -> Vec<f32> {
4187            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4188            let mut solver = ConstantEvaluator {
4189                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4190                types: &mut self.types,
4191                constants: &Arena::new(),
4192                overrides: &Arena::new(),
4193                expressions: &mut self.expressions,
4194                expression_kind_tracker,
4195                layouter: &mut crate::proc::Layouter::default(),
4196            };
4197
4198            let result = solver
4199                .try_eval_and_append(
4200                    Expression::Binary {
4201                        op: BinaryOperator::Multiply,
4202                        left: *self.mat_exprs.get(&(nc, nr)).unwrap(),
4203                        right: *self.vec_exprs.get(&nc).unwrap(),
4204                    },
4205                    Default::default(),
4206                )
4207                .unwrap();
4208            self.flatten(result)
4209        }
4210
4211        /// Evaluates mat[0..k*l_nr] * mat[0..r_nc*k] and returns the result as
4212        /// f32s.
4213        fn eval_matrix_multiply_matrix(&mut self, l_nr: usize, r_nc: usize, k: usize) -> Vec<f32> {
4214            let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&self.expressions);
4215            let mut solver = ConstantEvaluator {
4216                behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4217                types: &mut self.types,
4218                constants: &Arena::new(),
4219                overrides: &Arena::new(),
4220                expressions: &mut self.expressions,
4221                expression_kind_tracker,
4222                layouter: &mut crate::proc::Layouter::default(),
4223            };
4224
4225            let result = solver
4226                .try_eval_and_append(
4227                    Expression::Binary {
4228                        op: BinaryOperator::Multiply,
4229                        left: *self.mat_exprs.get(&(k, l_nr)).unwrap(),
4230                        right: *self.mat_exprs.get(&(r_nc, k)).unwrap(),
4231                    },
4232                    Default::default(),
4233                )
4234                .unwrap();
4235            self.flatten(result)
4236        }
4237
4238        fn flatten(&self, expr: Handle<Expression>) -> Vec<f32> {
4239            let Expression::Compose {
4240                ref components,
4241                ref ty,
4242            } = self.expressions[expr]
4243            else {
4244                unreachable!()
4245            };
4246
4247            match self.types[*ty].inner {
4248                TypeInner::Vector { .. } => components
4249                    .iter()
4250                    .map(|&comp| {
4251                        let Expression::Literal(Literal::F32(v)) = self.expressions[comp] else {
4252                            unreachable!()
4253                        };
4254                        v
4255                    })
4256                    .collect(),
4257                TypeInner::Matrix { .. } => components
4258                    .iter()
4259                    .flat_map(|&comp| self.flatten(comp))
4260                    .collect(),
4261                _ => unreachable!(),
4262            }
4263        }
4264
4265        fn int_to_vector_size(int: usize) -> VectorSize {
4266            match int {
4267                2 => VectorSize::Bi,
4268                3 => VectorSize::Tri,
4269                4 => VectorSize::Quad,
4270                _ => unreachable!(),
4271            }
4272        }
4273    }
4274
4275    #[test]
4276    fn cast() {
4277        let mut types = UniqueArena::new();
4278        let mut constants = Arena::new();
4279        let overrides = Arena::new();
4280        let mut global_expressions = Arena::new();
4281
4282        let scalar_ty = types.insert(
4283            Type {
4284                name: None,
4285                inner: TypeInner::Scalar(crate::Scalar::I32),
4286            },
4287            Default::default(),
4288        );
4289
4290        let h = constants.append(
4291            Constant {
4292                name: None,
4293                ty: scalar_ty,
4294                init: global_expressions
4295                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4296            },
4297            Default::default(),
4298        );
4299
4300        let expr = global_expressions.append(Expression::Constant(h), Default::default());
4301
4302        let root = Expression::As {
4303            expr,
4304            kind: ScalarKind::Bool,
4305            convert: Some(crate::BOOL_WIDTH),
4306        };
4307
4308        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4309        let mut solver = ConstantEvaluator {
4310            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4311            types: &mut types,
4312            constants: &constants,
4313            overrides: &overrides,
4314            expressions: &mut global_expressions,
4315            expression_kind_tracker,
4316            layouter: &mut crate::proc::Layouter::default(),
4317        };
4318
4319        let res = solver
4320            .try_eval_and_append(root, Default::default())
4321            .unwrap();
4322
4323        assert_eq!(
4324            global_expressions[res],
4325            Expression::Literal(Literal::Bool(true))
4326        );
4327    }
4328
4329    #[test]
4330    fn access() {
4331        let mut types = UniqueArena::new();
4332        let mut constants = Arena::new();
4333        let overrides = Arena::new();
4334        let mut global_expressions = Arena::new();
4335
4336        let matrix_ty = types.insert(
4337            Type {
4338                name: None,
4339                inner: TypeInner::Matrix {
4340                    columns: VectorSize::Bi,
4341                    rows: VectorSize::Tri,
4342                    scalar: crate::Scalar::F32,
4343                },
4344            },
4345            Default::default(),
4346        );
4347
4348        let vec_ty = types.insert(
4349            Type {
4350                name: None,
4351                inner: TypeInner::Vector {
4352                    size: VectorSize::Tri,
4353                    scalar: crate::Scalar::F32,
4354                },
4355            },
4356            Default::default(),
4357        );
4358
4359        let mut vec1_components = Vec::with_capacity(3);
4360        let mut vec2_components = Vec::with_capacity(3);
4361
4362        for i in 0..3 {
4363            let h = global_expressions.append(
4364                Expression::Literal(Literal::F32(i as f32)),
4365                Default::default(),
4366            );
4367
4368            vec1_components.push(h)
4369        }
4370
4371        for i in 3..6 {
4372            let h = global_expressions.append(
4373                Expression::Literal(Literal::F32(i as f32)),
4374                Default::default(),
4375            );
4376
4377            vec2_components.push(h)
4378        }
4379
4380        let vec1 = constants.append(
4381            Constant {
4382                name: None,
4383                ty: vec_ty,
4384                init: global_expressions.append(
4385                    Expression::Compose {
4386                        ty: vec_ty,
4387                        components: vec1_components,
4388                    },
4389                    Default::default(),
4390                ),
4391            },
4392            Default::default(),
4393        );
4394
4395        let vec2 = constants.append(
4396            Constant {
4397                name: None,
4398                ty: vec_ty,
4399                init: global_expressions.append(
4400                    Expression::Compose {
4401                        ty: vec_ty,
4402                        components: vec2_components,
4403                    },
4404                    Default::default(),
4405                ),
4406            },
4407            Default::default(),
4408        );
4409
4410        let h = constants.append(
4411            Constant {
4412                name: None,
4413                ty: matrix_ty,
4414                init: global_expressions.append(
4415                    Expression::Compose {
4416                        ty: matrix_ty,
4417                        components: vec![constants[vec1].init, constants[vec2].init],
4418                    },
4419                    Default::default(),
4420                ),
4421            },
4422            Default::default(),
4423        );
4424
4425        let base = global_expressions.append(Expression::Constant(h), Default::default());
4426
4427        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4428        let mut solver = ConstantEvaluator {
4429            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4430            types: &mut types,
4431            constants: &constants,
4432            overrides: &overrides,
4433            expressions: &mut global_expressions,
4434            expression_kind_tracker,
4435            layouter: &mut crate::proc::Layouter::default(),
4436        };
4437
4438        let root1 = Expression::AccessIndex { base, index: 1 };
4439
4440        let res1 = solver
4441            .try_eval_and_append(root1, Default::default())
4442            .unwrap();
4443
4444        let root2 = Expression::AccessIndex {
4445            base: res1,
4446            index: 2,
4447        };
4448
4449        let res2 = solver
4450            .try_eval_and_append(root2, Default::default())
4451            .unwrap();
4452
4453        match global_expressions[res1] {
4454            Expression::Compose {
4455                ref ty,
4456                ref components,
4457            } => {
4458                assert_eq!(*ty, vec_ty);
4459                let mut components_iter = components.iter().copied();
4460                assert_eq!(
4461                    global_expressions[components_iter.next().unwrap()],
4462                    Expression::Literal(Literal::F32(3.))
4463                );
4464                assert_eq!(
4465                    global_expressions[components_iter.next().unwrap()],
4466                    Expression::Literal(Literal::F32(4.))
4467                );
4468                assert_eq!(
4469                    global_expressions[components_iter.next().unwrap()],
4470                    Expression::Literal(Literal::F32(5.))
4471                );
4472                assert!(components_iter.next().is_none());
4473            }
4474            _ => panic!("Expected vector"),
4475        }
4476
4477        assert_eq!(
4478            global_expressions[res2],
4479            Expression::Literal(Literal::F32(5.))
4480        );
4481    }
4482
4483    #[test]
4484    fn compose_of_constants() {
4485        let mut types = UniqueArena::new();
4486        let mut constants = Arena::new();
4487        let overrides = Arena::new();
4488        let mut global_expressions = Arena::new();
4489
4490        let i32_ty = types.insert(
4491            Type {
4492                name: None,
4493                inner: TypeInner::Scalar(crate::Scalar::I32),
4494            },
4495            Default::default(),
4496        );
4497
4498        let vec2_i32_ty = types.insert(
4499            Type {
4500                name: None,
4501                inner: TypeInner::Vector {
4502                    size: VectorSize::Bi,
4503                    scalar: crate::Scalar::I32,
4504                },
4505            },
4506            Default::default(),
4507        );
4508
4509        let h = constants.append(
4510            Constant {
4511                name: None,
4512                ty: i32_ty,
4513                init: global_expressions
4514                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4515            },
4516            Default::default(),
4517        );
4518
4519        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4520
4521        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4522        let mut solver = ConstantEvaluator {
4523            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4524            types: &mut types,
4525            constants: &constants,
4526            overrides: &overrides,
4527            expressions: &mut global_expressions,
4528            expression_kind_tracker,
4529            layouter: &mut crate::proc::Layouter::default(),
4530        };
4531
4532        let solved_compose = solver
4533            .try_eval_and_append(
4534                Expression::Compose {
4535                    ty: vec2_i32_ty,
4536                    components: vec![h_expr, h_expr],
4537                },
4538                Default::default(),
4539            )
4540            .unwrap();
4541        let solved_negate = solver
4542            .try_eval_and_append(
4543                Expression::Unary {
4544                    op: UnaryOperator::Negate,
4545                    expr: solved_compose,
4546                },
4547                Default::default(),
4548            )
4549            .unwrap();
4550
4551        let pass = match global_expressions[solved_negate] {
4552            Expression::Compose { ty, ref components } => {
4553                ty == vec2_i32_ty
4554                    && components.iter().all(|&component| {
4555                        let component = &global_expressions[component];
4556                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4557                    })
4558            }
4559            _ => false,
4560        };
4561        if !pass {
4562            panic!("unexpected evaluation result")
4563        }
4564    }
4565
4566    #[test]
4567    fn splat_of_constant() {
4568        let mut types = UniqueArena::new();
4569        let mut constants = Arena::new();
4570        let overrides = Arena::new();
4571        let mut global_expressions = Arena::new();
4572
4573        let i32_ty = types.insert(
4574            Type {
4575                name: None,
4576                inner: TypeInner::Scalar(crate::Scalar::I32),
4577            },
4578            Default::default(),
4579        );
4580
4581        let vec2_i32_ty = types.insert(
4582            Type {
4583                name: None,
4584                inner: TypeInner::Vector {
4585                    size: VectorSize::Bi,
4586                    scalar: crate::Scalar::I32,
4587                },
4588            },
4589            Default::default(),
4590        );
4591
4592        let h = constants.append(
4593            Constant {
4594                name: None,
4595                ty: i32_ty,
4596                init: global_expressions
4597                    .append(Expression::Literal(Literal::I32(4)), Default::default()),
4598            },
4599            Default::default(),
4600        );
4601
4602        let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
4603
4604        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4605        let mut solver = ConstantEvaluator {
4606            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4607            types: &mut types,
4608            constants: &constants,
4609            overrides: &overrides,
4610            expressions: &mut global_expressions,
4611            expression_kind_tracker,
4612            layouter: &mut crate::proc::Layouter::default(),
4613        };
4614
4615        let solved_compose = solver
4616            .try_eval_and_append(
4617                Expression::Splat {
4618                    size: VectorSize::Bi,
4619                    value: h_expr,
4620                },
4621                Default::default(),
4622            )
4623            .unwrap();
4624        let solved_negate = solver
4625            .try_eval_and_append(
4626                Expression::Unary {
4627                    op: UnaryOperator::Negate,
4628                    expr: solved_compose,
4629                },
4630                Default::default(),
4631            )
4632            .unwrap();
4633
4634        let pass = match global_expressions[solved_negate] {
4635            Expression::Compose { ty, ref components } => {
4636                ty == vec2_i32_ty
4637                    && components.iter().all(|&component| {
4638                        let component = &global_expressions[component];
4639                        matches!(*component, Expression::Literal(Literal::I32(-4)))
4640                    })
4641            }
4642            _ => false,
4643        };
4644        if !pass {
4645            panic!("unexpected evaluation result")
4646        }
4647    }
4648
4649    #[test]
4650    fn splat_of_zero_value() {
4651        let mut types = UniqueArena::new();
4652        let constants = Arena::new();
4653        let overrides = Arena::new();
4654        let mut global_expressions = Arena::new();
4655
4656        let f32_ty = types.insert(
4657            Type {
4658                name: None,
4659                inner: TypeInner::Scalar(crate::Scalar::F32),
4660            },
4661            Default::default(),
4662        );
4663
4664        let vec2_f32_ty = types.insert(
4665            Type {
4666                name: None,
4667                inner: TypeInner::Vector {
4668                    size: VectorSize::Bi,
4669                    scalar: crate::Scalar::F32,
4670                },
4671            },
4672            Default::default(),
4673        );
4674
4675        let five =
4676            global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
4677        let five_splat = global_expressions.append(
4678            Expression::Splat {
4679                size: VectorSize::Bi,
4680                value: five,
4681            },
4682            Default::default(),
4683        );
4684        let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
4685        let zero_splat = global_expressions.append(
4686            Expression::Splat {
4687                size: VectorSize::Bi,
4688                value: zero,
4689            },
4690            Default::default(),
4691        );
4692
4693        let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
4694        let mut solver = ConstantEvaluator {
4695            behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
4696            types: &mut types,
4697            constants: &constants,
4698            overrides: &overrides,
4699            expressions: &mut global_expressions,
4700            expression_kind_tracker,
4701            layouter: &mut crate::proc::Layouter::default(),
4702        };
4703
4704        let solved_add = solver
4705            .try_eval_and_append(
4706                Expression::Binary {
4707                    op: BinaryOperator::Add,
4708                    left: zero_splat,
4709                    right: five_splat,
4710                },
4711                Default::default(),
4712            )
4713            .unwrap();
4714
4715        let pass = match global_expressions[solved_add] {
4716            Expression::Compose { ty, ref components } => {
4717                ty == vec2_f32_ty
4718                    && components.iter().all(|&component| {
4719                        let component = &global_expressions[component];
4720                        matches!(*component, Expression::Literal(Literal::F32(5.0)))
4721                    })
4722            }
4723            _ => false,
4724        };
4725        if !pass {
4726            panic!("unexpected evaluation result")
4727        }
4728    }
4729}