naga/back/
pipeline_constants.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4};
5use core::mem;
6
7use hashbrown::HashSet;
8use thiserror::Error;
9
10use super::PipelineConstants;
11use crate::{
12    arena::HandleVec,
13    proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
14    valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
15    Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
16    Span, Statement, TypeInner, WithSpan,
17};
18
19#[derive(Error, Debug, Clone)]
20#[cfg_attr(test, derive(PartialEq))]
21pub enum PipelineConstantError {
22    #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
23    MissingValue(String),
24    #[error(
25        "Source f64 value needs to be finite ({}) for number destinations",
26        "NaNs and Inifinites are not allowed"
27    )]
28    SrcNeedsToBeFinite,
29    #[error("Source f64 value doesn't fit in destination")]
30    DstRangeTooSmall,
31    #[error(transparent)]
32    ConstantEvaluatorError(#[from] ConstantEvaluatorError),
33    #[error(transparent)]
34    ValidationError(#[from] WithSpan<ValidationError>),
35    #[error("workgroup_size override isn't strictly positive")]
36    NegativeWorkgroupSize,
37}
38
39/// Replace all overrides in `module` with constants.
40///
41/// If no changes are needed, this just returns `Cow::Borrowed`
42/// references to `module` and `module_info`. Otherwise, it clones
43/// `module`, edits its [`global_expressions`] arena to contain only
44/// fully-evaluated expressions, and returns `Cow::Owned` values
45/// holding the simplified module and its validation results.
46///
47/// In either case, the module returned has an empty `overrides`
48/// arena, and the `global_expressions` arena contains only
49/// fully-evaluated expressions.
50///
51/// [`global_expressions`]: Module::global_expressions
52pub fn process_overrides<'a>(
53    module: &'a Module,
54    module_info: &'a ModuleInfo,
55    pipeline_constants: &PipelineConstants,
56) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
57    if module.overrides.is_empty() {
58        return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
59    }
60
61    let mut module = module.clone();
62
63    // A map from override handles to the handles of the constants
64    // we've replaced them with.
65    let mut override_map = HandleVec::with_capacity(module.overrides.len());
66
67    // A map from `module`'s original global expression handles to
68    // handles in the new, simplified global expression arena.
69    let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());
70
71    // The set of constants whose initializer handles we've already
72    // updated to refer to the newly built global expression arena.
73    //
74    // All constants in `module` must have their `init` handles
75    // updated to point into the new, simplified global expression
76    // arena. Some of these we can most easily handle as a side effect
77    // during the simplification process, but we must handle the rest
78    // in a final fixup pass, guided by `adjusted_global_expressions`. We
79    // add their handles to this set, so that the final fixup step can
80    // leave them alone.
81    let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
82
83    let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
84    let mut layouter = crate::proc::Layouter::default();
85
86    // An iterator through the original overrides table, consumed in
87    // approximate tandem with the global expressions.
88    let mut overrides = mem::take(&mut module.overrides);
89    let mut override_iter = overrides.iter_mut_span();
90
91    // Do two things in tandem:
92    //
93    // - Rebuild the global expression arena from scratch, fully
94    //   evaluating all expressions, and replacing each `Override`
95    //   expression in `module.global_expressions` with a `Constant`
96    //   expression.
97    //
98    // - Build a new `Constant` in `module.constants` to take the
99    //   place of each `Override`.
100    //
101    // Build a map from old global expression handles to their
102    // fully-evaluated counterparts in `adjusted_global_expressions` as we
103    // go.
104    //
105    // Why in tandem? Overrides refer to expressions, and expressions
106    // refer to overrides, so we can't disentangle the two into
107    // separate phases. However, we can take advantage of the fact
108    // that the overrides and expressions must form a DAG, and work
109    // our way from the leaves to the roots, replacing and evaluating
110    // as we go.
111    //
112    // Although the two loops are nested, this is really two
113    // alternating phases: we adjust and evaluate constant expressions
114    // until we hit an `Override` expression, at which point we switch
115    // to building `Constant`s for `Overrides` until we've handled the
116    // one used by the expression. Then we switch back to processing
117    // expressions. Because we know they form a DAG, we know the
118    // `Override` expressions we encounter can only have initializers
119    // referring to global expressions we've already simplified.
120    for (old_h, expr, span) in module.global_expressions.drain() {
121        let mut expr = match expr {
122            Expression::Override(h) => {
123                let c_h = if let Some(new_h) = override_map.get(h) {
124                    *new_h
125                } else {
126                    let mut new_h = None;
127                    for entry in override_iter.by_ref() {
128                        let stop = entry.0 == h;
129                        new_h = Some(process_override(
130                            entry,
131                            pipeline_constants,
132                            &mut module,
133                            &mut override_map,
134                            &adjusted_global_expressions,
135                            &mut adjusted_constant_initializers,
136                            &mut global_expression_kind_tracker,
137                        )?);
138                        if stop {
139                            break;
140                        }
141                    }
142                    new_h.unwrap()
143                };
144                Expression::Constant(c_h)
145            }
146            Expression::Constant(c_h) => {
147                if adjusted_constant_initializers.insert(c_h) {
148                    let init = &mut module.constants[c_h].init;
149                    *init = adjusted_global_expressions[*init];
150                }
151                expr
152            }
153            expr => expr,
154        };
155        let mut evaluator = ConstantEvaluator::for_wgsl_module(
156            &mut module,
157            &mut global_expression_kind_tracker,
158            &mut layouter,
159            false,
160        );
161        adjust_expr(&adjusted_global_expressions, &mut expr);
162        let h = evaluator.try_eval_and_append(expr, span)?;
163        adjusted_global_expressions.insert(old_h, h);
164    }
165
166    // Finish processing any overrides we didn't visit in the loop above.
167    for entry in override_iter {
168        match *entry.1 {
169            Override { name: Some(_), .. } | Override { id: Some(_), .. } => {
170                process_override(
171                    entry,
172                    pipeline_constants,
173                    &mut module,
174                    &mut override_map,
175                    &adjusted_global_expressions,
176                    &mut adjusted_constant_initializers,
177                    &mut global_expression_kind_tracker,
178                )?;
179            }
180            Override {
181                init: Some(ref mut init),
182                ..
183            } => {
184                *init = adjusted_global_expressions[*init];
185            }
186            _ => {}
187        }
188    }
189
190    // Update the initialization expression handles of all `Constant`s
191    // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
192    // passant.
193    for (_, c) in module
194        .constants
195        .iter_mut()
196        .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
197    {
198        c.init = adjusted_global_expressions[c.init];
199    }
200
201    for (_, v) in module.global_variables.iter_mut() {
202        if let Some(ref mut init) = v.init {
203            *init = adjusted_global_expressions[*init];
204        }
205    }
206
207    let mut functions = mem::take(&mut module.functions);
208    for (_, function) in functions.iter_mut() {
209        process_function(&mut module, &override_map, &mut layouter, function)?;
210    }
211    module.functions = functions;
212
213    let mut entry_points = mem::take(&mut module.entry_points);
214    for ep in entry_points.iter_mut() {
215        process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?;
216        process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
217    }
218    module.entry_points = entry_points;
219    module.overrides = overrides;
220
221    // Now that we've rewritten all the expressions, we need to
222    // recompute their types and other metadata. For the time being,
223    // do a full re-validation.
224    let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
225    let module_info = validator.validate_resolved_overrides(&module)?;
226
227    Ok((Cow::Owned(module), Cow::Owned(module_info)))
228}
229
230fn process_workgroup_size_override(
231    module: &mut Module,
232    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
233    ep: &mut crate::EntryPoint,
234) -> Result<(), PipelineConstantError> {
235    match ep.workgroup_size_overrides {
236        None => {}
237        Some(overrides) => {
238            overrides.iter().enumerate().try_for_each(
239                |(i, overridden)| -> Result<(), PipelineConstantError> {
240                    match *overridden {
241                        None => Ok(()),
242                        Some(h) => {
243                            ep.workgroup_size[i] = module
244                                .to_ctx()
245                                .eval_expr_to_u32(adjusted_global_expressions[h])
246                                .map(|n| {
247                                    if n == 0 {
248                                        Err(PipelineConstantError::NegativeWorkgroupSize)
249                                    } else {
250                                        Ok(n)
251                                    }
252                                })
253                                .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
254                            Ok(())
255                        }
256                    }
257                },
258            )?;
259            ep.workgroup_size_overrides = None;
260        }
261    }
262    Ok(())
263}
264
265/// Add a [`Constant`] to `module` for the override `old_h`.
266///
267/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
268fn process_override(
269    (old_h, r#override, span): (Handle<Override>, &mut Override, &Span),
270    pipeline_constants: &PipelineConstants,
271    module: &mut Module,
272    override_map: &mut HandleVec<Override, Handle<Constant>>,
273    adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
274    adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
275    global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
276) -> Result<Handle<Constant>, PipelineConstantError> {
277    // Determine which key to use for `r#override` in `pipeline_constants`.
278    let key = if let Some(id) = r#override.id {
279        Cow::Owned(id.to_string())
280    } else if let Some(ref name) = r#override.name {
281        Cow::Borrowed(name)
282    } else {
283        unreachable!();
284    };
285
286    // Generate a global expression for `r#override`'s value, either
287    // from the provided `pipeline_constants` table or its initializer
288    // in the module.
289    let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
290        let literal = match module.types[r#override.ty].inner {
291            TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
292            _ => unreachable!(),
293        };
294        let expr = module
295            .global_expressions
296            .append(Expression::Literal(literal), Span::UNDEFINED);
297        global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
298        expr
299    } else if let Some(init) = r#override.init {
300        adjusted_global_expressions[init]
301    } else {
302        return Err(PipelineConstantError::MissingValue(key.to_string()));
303    };
304
305    // Generate a new `Constant` to represent the override's value.
306    let constant = Constant {
307        name: r#override.name.clone(),
308        ty: r#override.ty,
309        init,
310    };
311    let h = module.constants.append(constant, *span);
312    override_map.insert(old_h, h);
313    adjusted_constant_initializers.insert(h);
314    r#override.init = Some(init);
315    Ok(h)
316}
317
318/// Replace all override expressions in `function` with fully-evaluated constants.
319///
320/// Replace all `Expression::Override`s in `function`'s expression arena with
321/// the corresponding `Expression::Constant`s, as given in `override_map`.
322/// Replace any expressions whose values are now known with their fully
323/// evaluated form.
324///
325/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
326/// `Handle<Constant>` for the override's final value.
327fn process_function(
328    module: &mut Module,
329    override_map: &HandleVec<Override, Handle<Constant>>,
330    layouter: &mut crate::proc::Layouter,
331    function: &mut Function,
332) -> Result<(), ConstantEvaluatorError> {
333    // A map from original local expression handles to
334    // handles in the new, local expression arena.
335    let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());
336
337    let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
338
339    let mut expressions = mem::take(&mut function.expressions);
340
341    // Dummy `emitter` and `block` for the constant evaluator.
342    // We can ignore the concept of emitting expressions here since
343    // expressions have already been covered by a `Statement::Emit`
344    // in the frontend.
345    // The only thing we might have to do is remove some expressions
346    // that have been covered by a `Statement::Emit`. See the docs of
347    // `filter_emits_in_block` for the reasoning.
348    let mut emitter = Emitter::default();
349    let mut block = Block::new();
350
351    let mut evaluator = ConstantEvaluator::for_wgsl_function(
352        module,
353        &mut function.expressions,
354        &mut local_expression_kind_tracker,
355        layouter,
356        &mut emitter,
357        &mut block,
358        false,
359    );
360
361    for (old_h, mut expr, span) in expressions.drain() {
362        if let Expression::Override(h) = expr {
363            expr = Expression::Constant(override_map[h]);
364        }
365        adjust_expr(&adjusted_local_expressions, &mut expr);
366        let h = evaluator.try_eval_and_append(expr, span)?;
367        adjusted_local_expressions.insert(old_h, h);
368    }
369
370    adjust_block(&adjusted_local_expressions, &mut function.body);
371
372    filter_emits_in_block(&mut function.body, &function.expressions);
373
374    // Update local expression initializers.
375    for (_, local) in function.local_variables.iter_mut() {
376        if let &mut Some(ref mut init) = &mut local.init {
377            *init = adjusted_local_expressions[*init];
378        }
379    }
380
381    // We've changed the keys of `function.named_expression`, so we have to
382    // rebuild it from scratch.
383    let named_expressions = mem::take(&mut function.named_expressions);
384    for (expr_h, name) in named_expressions {
385        function
386            .named_expressions
387            .insert(adjusted_local_expressions[expr_h], name);
388    }
389
390    Ok(())
391}
392
393/// Replace every expression handle in `expr` with its counterpart
394/// given by `new_pos`.
395fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
396    let adjust = |expr: &mut Handle<Expression>| {
397        *expr = new_pos[*expr];
398    };
399    match *expr {
400        Expression::Compose {
401            ref mut components,
402            ty: _,
403        } => {
404            for c in components.iter_mut() {
405                adjust(c);
406            }
407        }
408        Expression::Access {
409            ref mut base,
410            ref mut index,
411        } => {
412            adjust(base);
413            adjust(index);
414        }
415        Expression::AccessIndex {
416            ref mut base,
417            index: _,
418        } => {
419            adjust(base);
420        }
421        Expression::Splat {
422            ref mut value,
423            size: _,
424        } => {
425            adjust(value);
426        }
427        Expression::Swizzle {
428            ref mut vector,
429            size: _,
430            pattern: _,
431        } => {
432            adjust(vector);
433        }
434        Expression::Load { ref mut pointer } => {
435            adjust(pointer);
436        }
437        Expression::ImageSample {
438            ref mut image,
439            ref mut sampler,
440            ref mut coordinate,
441            ref mut array_index,
442            ref mut offset,
443            ref mut level,
444            ref mut depth_ref,
445            gather: _,
446        } => {
447            adjust(image);
448            adjust(sampler);
449            adjust(coordinate);
450            if let Some(e) = array_index.as_mut() {
451                adjust(e);
452            }
453            if let Some(e) = offset.as_mut() {
454                adjust(e);
455            }
456            match *level {
457                crate::SampleLevel::Exact(ref mut expr)
458                | crate::SampleLevel::Bias(ref mut expr) => {
459                    adjust(expr);
460                }
461                crate::SampleLevel::Gradient {
462                    ref mut x,
463                    ref mut y,
464                } => {
465                    adjust(x);
466                    adjust(y);
467                }
468                _ => {}
469            }
470            if let Some(e) = depth_ref.as_mut() {
471                adjust(e);
472            }
473        }
474        Expression::ImageLoad {
475            ref mut image,
476            ref mut coordinate,
477            ref mut array_index,
478            ref mut sample,
479            ref mut level,
480        } => {
481            adjust(image);
482            adjust(coordinate);
483            if let Some(e) = array_index.as_mut() {
484                adjust(e);
485            }
486            if let Some(e) = sample.as_mut() {
487                adjust(e);
488            }
489            if let Some(e) = level.as_mut() {
490                adjust(e);
491            }
492        }
493        Expression::ImageQuery {
494            ref mut image,
495            ref mut query,
496        } => {
497            adjust(image);
498            match *query {
499                crate::ImageQuery::Size { ref mut level } => {
500                    if let Some(e) = level.as_mut() {
501                        adjust(e);
502                    }
503                }
504                crate::ImageQuery::NumLevels
505                | crate::ImageQuery::NumLayers
506                | crate::ImageQuery::NumSamples => {}
507            }
508        }
509        Expression::Unary {
510            ref mut expr,
511            op: _,
512        } => {
513            adjust(expr);
514        }
515        Expression::Binary {
516            ref mut left,
517            ref mut right,
518            op: _,
519        } => {
520            adjust(left);
521            adjust(right);
522        }
523        Expression::Select {
524            ref mut condition,
525            ref mut accept,
526            ref mut reject,
527        } => {
528            adjust(condition);
529            adjust(accept);
530            adjust(reject);
531        }
532        Expression::Derivative {
533            ref mut expr,
534            axis: _,
535            ctrl: _,
536        } => {
537            adjust(expr);
538        }
539        Expression::Relational {
540            ref mut argument,
541            fun: _,
542        } => {
543            adjust(argument);
544        }
545        Expression::Math {
546            ref mut arg,
547            ref mut arg1,
548            ref mut arg2,
549            ref mut arg3,
550            fun: _,
551        } => {
552            adjust(arg);
553            if let Some(e) = arg1.as_mut() {
554                adjust(e);
555            }
556            if let Some(e) = arg2.as_mut() {
557                adjust(e);
558            }
559            if let Some(e) = arg3.as_mut() {
560                adjust(e);
561            }
562        }
563        Expression::As {
564            ref mut expr,
565            kind: _,
566            convert: _,
567        } => {
568            adjust(expr);
569        }
570        Expression::ArrayLength(ref mut expr) => {
571            adjust(expr);
572        }
573        Expression::RayQueryGetIntersection {
574            ref mut query,
575            committed: _,
576        } => {
577            adjust(query);
578        }
579        Expression::Literal(_)
580        | Expression::FunctionArgument(_)
581        | Expression::GlobalVariable(_)
582        | Expression::LocalVariable(_)
583        | Expression::CallResult(_)
584        | Expression::RayQueryProceedResult
585        | Expression::Constant(_)
586        | Expression::Override(_)
587        | Expression::ZeroValue(_)
588        | Expression::AtomicResult {
589            ty: _,
590            comparison: _,
591        }
592        | Expression::WorkGroupUniformLoadResult { ty: _ }
593        | Expression::SubgroupBallotResult
594        | Expression::SubgroupOperationResult { .. } => {}
595        Expression::RayQueryVertexPositions {
596            ref mut query,
597            committed: _,
598        } => {
599            adjust(query);
600        }
601    }
602}
603
604/// Replace every expression handle in `block` with its counterpart
605/// given by `new_pos`.
606fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
607    for stmt in block.iter_mut() {
608        adjust_stmt(new_pos, stmt);
609    }
610}
611
612/// Replace every expression handle in `stmt` with its counterpart
613/// given by `new_pos`.
614fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
615    let adjust = |expr: &mut Handle<Expression>| {
616        *expr = new_pos[*expr];
617    };
618    match *stmt {
619        Statement::Emit(ref mut range) => {
620            if let Some((mut first, mut last)) = range.first_and_last() {
621                adjust(&mut first);
622                adjust(&mut last);
623                *range = Range::new_from_bounds(first, last);
624            }
625        }
626        Statement::Block(ref mut block) => {
627            adjust_block(new_pos, block);
628        }
629        Statement::If {
630            ref mut condition,
631            ref mut accept,
632            ref mut reject,
633        } => {
634            adjust(condition);
635            adjust_block(new_pos, accept);
636            adjust_block(new_pos, reject);
637        }
638        Statement::Switch {
639            ref mut selector,
640            ref mut cases,
641        } => {
642            adjust(selector);
643            for case in cases.iter_mut() {
644                adjust_block(new_pos, &mut case.body);
645            }
646        }
647        Statement::Loop {
648            ref mut body,
649            ref mut continuing,
650            ref mut break_if,
651        } => {
652            adjust_block(new_pos, body);
653            adjust_block(new_pos, continuing);
654            if let Some(e) = break_if.as_mut() {
655                adjust(e);
656            }
657        }
658        Statement::Return { ref mut value } => {
659            if let Some(e) = value.as_mut() {
660                adjust(e);
661            }
662        }
663        Statement::Store {
664            ref mut pointer,
665            ref mut value,
666        } => {
667            adjust(pointer);
668            adjust(value);
669        }
670        Statement::ImageStore {
671            ref mut image,
672            ref mut coordinate,
673            ref mut array_index,
674            ref mut value,
675        } => {
676            adjust(image);
677            adjust(coordinate);
678            if let Some(e) = array_index.as_mut() {
679                adjust(e);
680            }
681            adjust(value);
682        }
683        Statement::Atomic {
684            ref mut pointer,
685            ref mut value,
686            ref mut result,
687            ref mut fun,
688        } => {
689            adjust(pointer);
690            adjust(value);
691            if let Some(ref mut result) = *result {
692                adjust(result);
693            }
694            match *fun {
695                crate::AtomicFunction::Exchange {
696                    compare: Some(ref mut compare),
697                } => {
698                    adjust(compare);
699                }
700                crate::AtomicFunction::Add
701                | crate::AtomicFunction::Subtract
702                | crate::AtomicFunction::And
703                | crate::AtomicFunction::ExclusiveOr
704                | crate::AtomicFunction::InclusiveOr
705                | crate::AtomicFunction::Min
706                | crate::AtomicFunction::Max
707                | crate::AtomicFunction::Exchange { compare: None } => {}
708            }
709        }
710        Statement::ImageAtomic {
711            ref mut image,
712            ref mut coordinate,
713            ref mut array_index,
714            fun: _,
715            ref mut value,
716        } => {
717            adjust(image);
718            adjust(coordinate);
719            if let Some(ref mut array_index) = *array_index {
720                adjust(array_index);
721            }
722            adjust(value);
723        }
724        Statement::WorkGroupUniformLoad {
725            ref mut pointer,
726            ref mut result,
727        } => {
728            adjust(pointer);
729            adjust(result);
730        }
731        Statement::SubgroupBallot {
732            ref mut result,
733            ref mut predicate,
734        } => {
735            if let Some(ref mut predicate) = *predicate {
736                adjust(predicate);
737            }
738            adjust(result);
739        }
740        Statement::SubgroupCollectiveOperation {
741            ref mut argument,
742            ref mut result,
743            ..
744        } => {
745            adjust(argument);
746            adjust(result);
747        }
748        Statement::SubgroupGather {
749            ref mut mode,
750            ref mut argument,
751            ref mut result,
752        } => {
753            match *mode {
754                crate::GatherMode::BroadcastFirst => {}
755                crate::GatherMode::Broadcast(ref mut index)
756                | crate::GatherMode::Shuffle(ref mut index)
757                | crate::GatherMode::ShuffleDown(ref mut index)
758                | crate::GatherMode::ShuffleUp(ref mut index)
759                | crate::GatherMode::ShuffleXor(ref mut index) => {
760                    adjust(index);
761                }
762            }
763            adjust(argument);
764            adjust(result)
765        }
766        Statement::Call {
767            ref mut arguments,
768            ref mut result,
769            function: _,
770        } => {
771            for argument in arguments.iter_mut() {
772                adjust(argument);
773            }
774            if let Some(e) = result.as_mut() {
775                adjust(e);
776            }
777        }
778        Statement::RayQuery {
779            ref mut query,
780            ref mut fun,
781        } => {
782            adjust(query);
783            match *fun {
784                crate::RayQueryFunction::Initialize {
785                    ref mut acceleration_structure,
786                    ref mut descriptor,
787                } => {
788                    adjust(acceleration_structure);
789                    adjust(descriptor);
790                }
791                crate::RayQueryFunction::Proceed { ref mut result } => {
792                    adjust(result);
793                }
794                crate::RayQueryFunction::GenerateIntersection { ref mut hit_t } => {
795                    adjust(hit_t);
796                }
797                crate::RayQueryFunction::ConfirmIntersection => {}
798                crate::RayQueryFunction::Terminate => {}
799            }
800        }
801        Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
802    }
803}
804
805/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
806///
807/// According to validation, [`Emit`] statements must not cover any expressions
808/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
809/// by successful constant evaluation fall into that category, meaning that
810/// `process_function` will usually rewrite [`Override`] expressions and those
811/// that use their values into pre-emitted expressions, leaving any [`Emit`]
812/// statements that cover them invalid.
813///
814/// This function rewrites all [`Emit`] statements into zero or more new
815/// [`Emit`] statements covering only those expressions in the original range
816/// that are not pre-emitted.
817///
818/// [`Emit`]: Statement::Emit
819/// [`needs_pre_emit`]: Expression::needs_pre_emit
820/// [`Override`]: Expression::Override
821fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
822    let original = mem::replace(block, Block::with_capacity(block.len()));
823    for (stmt, span) in original.span_into_iter() {
824        match stmt {
825            Statement::Emit(range) => {
826                let mut current = None;
827                for expr_h in range {
828                    if expressions[expr_h].needs_pre_emit() {
829                        if let Some((first, last)) = current {
830                            block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
831                        }
832
833                        current = None;
834                    } else if let Some((_, ref mut last)) = current {
835                        *last = expr_h;
836                    } else {
837                        current = Some((expr_h, expr_h));
838                    }
839                }
840                if let Some((first, last)) = current {
841                    block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
842                }
843            }
844            Statement::Block(mut child) => {
845                filter_emits_in_block(&mut child, expressions);
846                block.push(Statement::Block(child), span);
847            }
848            Statement::If {
849                condition,
850                mut accept,
851                mut reject,
852            } => {
853                filter_emits_in_block(&mut accept, expressions);
854                filter_emits_in_block(&mut reject, expressions);
855                block.push(
856                    Statement::If {
857                        condition,
858                        accept,
859                        reject,
860                    },
861                    span,
862                );
863            }
864            Statement::Switch {
865                selector,
866                mut cases,
867            } => {
868                for case in &mut cases {
869                    filter_emits_in_block(&mut case.body, expressions);
870                }
871                block.push(Statement::Switch { selector, cases }, span);
872            }
873            Statement::Loop {
874                mut body,
875                mut continuing,
876                break_if,
877            } => {
878                filter_emits_in_block(&mut body, expressions);
879                filter_emits_in_block(&mut continuing, expressions);
880                block.push(
881                    Statement::Loop {
882                        body,
883                        continuing,
884                        break_if,
885                    },
886                    span,
887                );
888            }
889            stmt => block.push(stmt.clone(), span),
890        }
891    }
892}
893
894fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
895    // note that in rust 0.0 == -0.0
896    match scalar {
897        Scalar::BOOL => {
898            // https://webidl.spec.whatwg.org/#js-boolean
899            let value = value != 0.0 && !value.is_nan();
900            Ok(Literal::Bool(value))
901        }
902        Scalar::I32 => {
903            // https://webidl.spec.whatwg.org/#js-long
904            if !value.is_finite() {
905                return Err(PipelineConstantError::SrcNeedsToBeFinite);
906            }
907
908            let value = value.trunc();
909            if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
910                return Err(PipelineConstantError::DstRangeTooSmall);
911            }
912
913            let value = value as i32;
914            Ok(Literal::I32(value))
915        }
916        Scalar::U32 => {
917            // https://webidl.spec.whatwg.org/#js-unsigned-long
918            if !value.is_finite() {
919                return Err(PipelineConstantError::SrcNeedsToBeFinite);
920            }
921
922            let value = value.trunc();
923            if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
924                return Err(PipelineConstantError::DstRangeTooSmall);
925            }
926
927            let value = value as u32;
928            Ok(Literal::U32(value))
929        }
930        Scalar::F32 => {
931            // https://webidl.spec.whatwg.org/#js-float
932            if !value.is_finite() {
933                return Err(PipelineConstantError::SrcNeedsToBeFinite);
934            }
935
936            let value = value as f32;
937            if !value.is_finite() {
938                return Err(PipelineConstantError::DstRangeTooSmall);
939            }
940
941            Ok(Literal::F32(value))
942        }
943        Scalar::F64 => {
944            // https://webidl.spec.whatwg.org/#js-double
945            if !value.is_finite() {
946                return Err(PipelineConstantError::SrcNeedsToBeFinite);
947            }
948
949            Ok(Literal::F64(value))
950        }
951        _ => unreachable!(),
952    }
953}
954
955#[test]
956fn test_map_value_to_literal() {
957    let bool_test_cases = [
958        (0.0, false),
959        (-0.0, false),
960        (f64::NAN, false),
961        (1.0, true),
962        (f64::INFINITY, true),
963        (f64::NEG_INFINITY, true),
964    ];
965    for (value, out) in bool_test_cases {
966        let res = Ok(Literal::Bool(out));
967        assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
968    }
969
970    for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
971        for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
972            let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
973            assert_eq!(map_value_to_literal(value, scalar), res);
974        }
975    }
976
977    // i32
978    assert_eq!(
979        map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
980        Ok(Literal::I32(i32::MIN))
981    );
982    assert_eq!(
983        map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
984        Ok(Literal::I32(i32::MAX))
985    );
986    assert_eq!(
987        map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
988        Err(PipelineConstantError::DstRangeTooSmall)
989    );
990    assert_eq!(
991        map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
992        Err(PipelineConstantError::DstRangeTooSmall)
993    );
994
995    // u32
996    assert_eq!(
997        map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
998        Ok(Literal::U32(u32::MIN))
999    );
1000    assert_eq!(
1001        map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
1002        Ok(Literal::U32(u32::MAX))
1003    );
1004    assert_eq!(
1005        map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
1006        Err(PipelineConstantError::DstRangeTooSmall)
1007    );
1008    assert_eq!(
1009        map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
1010        Err(PipelineConstantError::DstRangeTooSmall)
1011    );
1012
1013    // f32
1014    assert_eq!(
1015        map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
1016        Ok(Literal::F32(f32::MIN))
1017    );
1018    assert_eq!(
1019        map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
1020        Ok(Literal::F32(f32::MAX))
1021    );
1022    assert_eq!(
1023        map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
1024        Ok(Literal::F32(f32::MIN))
1025    );
1026    assert_eq!(
1027        map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
1028        Ok(Literal::F32(f32::MAX))
1029    );
1030    assert_eq!(
1031        map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
1032        Err(PipelineConstantError::DstRangeTooSmall)
1033    );
1034    assert_eq!(
1035        map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
1036        Err(PipelineConstantError::DstRangeTooSmall)
1037    );
1038
1039    // f64
1040    assert_eq!(
1041        map_value_to_literal(f64::MIN, Scalar::F64),
1042        Ok(Literal::F64(f64::MIN))
1043    );
1044    assert_eq!(
1045        map_value_to_literal(f64::MAX, Scalar::F64),
1046        Ok(Literal::F64(f64::MAX))
1047    );
1048}