naga/valid/
analyzer.rs

1//! Module analyzer.
2//!
3//! Figures out the following properties:
4//! - control flow uniformity
5//! - texture/sampler pairs
6//! - expression reference counts
7
8use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15    arena::{Arena, Handle},
16    proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24    /// Kinds of expressions that require uniform control flow.
25    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28    pub struct UniformityRequirements: u8 {
29        const WORK_GROUP_BARRIER = 0x1;
30        const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31        const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32    }
33}
34
35/// Uniform control flow characteristics.
36#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41    /// A child expression with non-uniform result.
42    ///
43    /// This means, when the relevant invocations are scheduled on a compute unit,
44    /// they have to use vector registers to store an individual value
45    /// per invocation.
46    ///
47    /// Whenever the control flow is conditioned on such value,
48    /// the hardware needs to keep track of the mask of invocations,
49    /// and process all branches of the control flow.
50    ///
51    /// Any operations that depend on non-uniform results also produce non-uniform.
52    pub non_uniform_result: NonUniformResult,
53    /// If this expression requires uniform control flow, store the reason here.
54    pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58    const fn new() -> Self {
59        Uniformity {
60            non_uniform_result: None,
61            requirements: UniformityRequirements::empty(),
62        }
63    }
64}
65
66bitflags::bitflags! {
67    #[derive(Clone, Copy, Debug, PartialEq)]
68    struct ExitFlags: u8 {
69        /// Control flow may return from the function, which makes all the
70        /// subsequent statements within the current function (only!)
71        /// to be executed in a non-uniform control flow.
72        const MAY_RETURN = 0x1;
73        /// Control flow may be killed. Anything after [`Statement::Kill`] is
74        /// considered inside non-uniform context.
75        ///
76        /// [`Statement::Kill`]: crate::Statement::Kill
77        const MAY_KILL = 0x2;
78    }
79}
80
81/// Uniformity characteristics of a function.
82#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84    result: Uniformity,
85    exit: ExitFlags,
86}
87
88impl ops::BitOr for FunctionUniformity {
89    type Output = Self;
90    fn bitor(self, other: Self) -> Self {
91        FunctionUniformity {
92            result: Uniformity {
93                non_uniform_result: self
94                    .result
95                    .non_uniform_result
96                    .or(other.result.non_uniform_result),
97                requirements: self.result.requirements | other.result.requirements,
98            },
99            exit: self.exit | other.exit,
100        }
101    }
102}
103
104impl FunctionUniformity {
105    const fn new() -> Self {
106        FunctionUniformity {
107            result: Uniformity::new(),
108            exit: ExitFlags::empty(),
109        }
110    }
111
112    /// Returns a disruptor based on the stored exit flags, if any.
113    const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
114        if self.exit.contains(ExitFlags::MAY_RETURN) {
115            Some(UniformityDisruptor::Return)
116        } else if self.exit.contains(ExitFlags::MAY_KILL) {
117            Some(UniformityDisruptor::Discard)
118        } else {
119            None
120        }
121    }
122}
123
124bitflags::bitflags! {
125    /// Indicates how a global variable is used.
126    #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127    #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128    #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129    pub struct GlobalUse: u8 {
130        /// Data will be read from the variable.
131        const READ = 0x1;
132        /// Data will be written to the variable.
133        const WRITE = 0x2;
134        /// The information about the data is queried.
135        const QUERY = 0x4;
136        /// Atomic operations will be performed on the variable.
137        const ATOMIC = 0x8;
138    }
139}
140
141#[derive(Clone, Debug, Eq, Hash, PartialEq)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct SamplingKey {
145    pub image: Handle<crate::GlobalVariable>,
146    pub sampler: Handle<crate::GlobalVariable>,
147}
148
149#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
151#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
152/// Information about an expression in a function body.
153pub struct ExpressionInfo {
154    /// Whether this expression is uniform, and why.
155    ///
156    /// If this expression's value is not uniform, this is the handle
157    /// of the expression from which this one's non-uniformity
158    /// originates. Otherwise, this is `None`.
159    pub uniformity: Uniformity,
160
161    /// The number of statements and other expressions using this
162    /// expression's value.
163    pub ref_count: usize,
164
165    /// The global variable into which this expression produces a pointer.
166    ///
167    /// This is `None` unless this expression is either a
168    /// [`GlobalVariable`], or an [`Access`] or [`AccessIndex`] that
169    /// ultimately refers to some part of a global.
170    ///
171    /// [`Load`] expressions applied to pointer-typed arguments could
172    /// refer to globals, but we leave this as `None` for them.
173    ///
174    /// [`GlobalVariable`]: crate::Expression::GlobalVariable
175    /// [`Access`]: crate::Expression::Access
176    /// [`AccessIndex`]: crate::Expression::AccessIndex
177    /// [`Load`]: crate::Expression::Load
178    assignable_global: Option<Handle<crate::GlobalVariable>>,
179
180    /// The type of this expression.
181    pub ty: TypeResolution,
182}
183
184impl ExpressionInfo {
185    const fn new() -> Self {
186        ExpressionInfo {
187            uniformity: Uniformity::new(),
188            ref_count: 0,
189            assignable_global: None,
190            // this doesn't matter at this point, will be overwritten
191            ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
192                kind: crate::ScalarKind::Bool,
193                width: 0,
194            })),
195        }
196    }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
201#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
202enum GlobalOrArgument {
203    Global(Handle<crate::GlobalVariable>),
204    Argument(u32),
205}
206
207impl GlobalOrArgument {
208    fn from_expression(
209        expression_arena: &Arena<crate::Expression>,
210        expression: Handle<crate::Expression>,
211    ) -> Result<GlobalOrArgument, ExpressionError> {
212        Ok(match expression_arena[expression] {
213            crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
214            crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
215            crate::Expression::Access { base, .. }
216            | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
217                crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
218                _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
219            },
220            _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
221        })
222    }
223}
224
225#[derive(Debug, Clone, PartialEq, Eq, Hash)]
226#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
227#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
228struct Sampling {
229    image: GlobalOrArgument,
230    sampler: GlobalOrArgument,
231}
232
233#[derive(Debug, Clone)]
234#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
235#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
236pub struct FunctionInfo {
237    /// Validation flags.
238    #[allow(dead_code)]
239    flags: ValidationFlags,
240    /// Set of shader stages where calling this function is valid.
241    pub available_stages: ShaderStages,
242    /// Uniformity characteristics.
243    pub uniformity: Uniformity,
244    /// Function may kill the invocation.
245    pub may_kill: bool,
246
247    /// All pairs of (texture, sampler) globals that may be used together in
248    /// sampling operations by this function and its callees. This includes
249    /// pairings that arise when this function passes textures and samplers as
250    /// arguments to its callees.
251    ///
252    /// This table does not include uses of textures and samplers passed as
253    /// arguments to this function itself, since we do not know which globals
254    /// those will be. However, this table *is* exhaustive when computed for an
255    /// entry point function: entry points never receive textures or samplers as
256    /// arguments, so all an entry point's sampling can be reported in terms of
257    /// globals.
258    ///
259    /// The GLSL back end uses this table to construct reflection info that
260    /// clients need to construct texture-combined sampler values.
261    pub sampling_set: crate::FastHashSet<SamplingKey>,
262
263    /// How this function and its callees use this module's globals.
264    ///
265    /// This is indexed by `Handle<GlobalVariable>` indices. However,
266    /// `FunctionInfo` implements `core::ops::Index<Handle<GlobalVariable>>`,
267    /// so you can simply index this struct with a global handle to retrieve
268    /// its usage information.
269    global_uses: Box<[GlobalUse]>,
270
271    /// Information about each expression in this function's body.
272    ///
273    /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
274    /// implements `core::ops::Index<Handle<Expression>>`, so you can simply
275    /// index this struct with an expression handle to retrieve its
276    /// `ExpressionInfo`.
277    expressions: Box<[ExpressionInfo]>,
278
279    /// All (texture, sampler) pairs that may be used together in sampling
280    /// operations by this function and its callees, whether they are accessed
281    /// as globals or passed as arguments.
282    ///
283    /// Participants are represented by [`GlobalVariable`] handles whenever
284    /// possible, and otherwise by indices of this function's arguments.
285    ///
286    /// When analyzing a function call, we combine this data about the callee
287    /// with the actual arguments being passed to produce the callers' own
288    /// `sampling_set` and `sampling` tables.
289    ///
290    /// [`GlobalVariable`]: crate::GlobalVariable
291    sampling: crate::FastHashSet<Sampling>,
292
293    /// Indicates that the function is using dual source blending.
294    pub dual_source_blending: bool,
295
296    /// The leaf of all module-wide diagnostic filter rules tree parsed from directives in this
297    /// module.
298    ///
299    /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in
300    /// validation.
301    diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
302}
303
304impl FunctionInfo {
305    pub const fn global_variable_count(&self) -> usize {
306        self.global_uses.len()
307    }
308    pub const fn expression_count(&self) -> usize {
309        self.expressions.len()
310    }
311    pub fn dominates_global_use(&self, other: &Self) -> bool {
312        for (self_global_uses, other_global_uses) in
313            self.global_uses.iter().zip(other.global_uses.iter())
314        {
315            if !self_global_uses.contains(*other_global_uses) {
316                return false;
317            }
318        }
319        true
320    }
321}
322
323impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
324    type Output = GlobalUse;
325    fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
326        &self.global_uses[handle.index()]
327    }
328}
329
330impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
331    type Output = ExpressionInfo;
332    fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
333        &self.expressions[handle.index()]
334    }
335}
336
337/// Disruptor of the uniform control flow.
338#[derive(Clone, Copy, Debug, thiserror::Error)]
339#[cfg_attr(test, derive(PartialEq))]
340pub enum UniformityDisruptor {
341    #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
342    Expression(Handle<crate::Expression>),
343    #[error("There is a Return earlier in the control flow of the function")]
344    Return,
345    #[error("There is a Discard earlier in the entry point across all called functions")]
346    Discard,
347}
348
349impl FunctionInfo {
350    /// Record a use of `expr` of the sort given by `global_use`.
351    ///
352    /// Bump `expr`'s reference count, and return its uniformity.
353    ///
354    /// If `expr` is a pointer to a global variable, or some part of
355    /// a global variable, add `global_use` to that global's set of
356    /// uses.
357    #[must_use]
358    fn add_ref_impl(
359        &mut self,
360        expr: Handle<crate::Expression>,
361        global_use: GlobalUse,
362    ) -> NonUniformResult {
363        let info = &mut self.expressions[expr.index()];
364        info.ref_count += 1;
365        // mark the used global as read
366        if let Some(global) = info.assignable_global {
367            self.global_uses[global.index()] |= global_use;
368        }
369        info.uniformity.non_uniform_result
370    }
371
372    /// Record a use of `expr` for its value.
373    ///
374    /// This is used for almost all expression references. Anything
375    /// that writes to the value `expr` points to, or otherwise wants
376    /// contribute flags other than `GlobalUse::READ`, should use
377    /// `add_ref_impl` directly.
378    #[must_use]
379    fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
380        self.add_ref_impl(expr, GlobalUse::READ)
381    }
382
383    /// Record a use of `expr`, and indicate which global variable it
384    /// refers to, if any.
385    ///
386    /// Bump `expr`'s reference count, and return its uniformity.
387    ///
388    /// If `expr` is a pointer to a global variable, or some part
389    /// thereof, store that global in `*assignable_global`. Leave the
390    /// global's uses unchanged.
391    ///
392    /// This is used to determine the [`assignable_global`] for
393    /// [`Access`] and [`AccessIndex`] expressions that ultimately
394    /// refer to a global variable. Those expressions don't contribute
395    /// any usage to the global themselves; that depends on how other
396    /// expressions use them.
397    ///
398    /// [`assignable_global`]: ExpressionInfo::assignable_global
399    /// [`Access`]: crate::Expression::Access
400    /// [`AccessIndex`]: crate::Expression::AccessIndex
401    #[must_use]
402    fn add_assignable_ref(
403        &mut self,
404        expr: Handle<crate::Expression>,
405        assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
406    ) -> NonUniformResult {
407        let info = &mut self.expressions[expr.index()];
408        info.ref_count += 1;
409        // propagate the assignable global up the chain, till it either hits
410        // a value-type expression, or the assignment statement.
411        if let Some(global) = info.assignable_global {
412            if let Some(_old) = assignable_global.replace(global) {
413                unreachable!()
414            }
415        }
416        info.uniformity.non_uniform_result
417    }
418
419    /// Inherit information from a called function.
420    fn process_call(
421        &mut self,
422        callee: &Self,
423        arguments: &[Handle<crate::Expression>],
424        expression_arena: &Arena<crate::Expression>,
425    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
426        self.sampling_set
427            .extend(callee.sampling_set.iter().cloned());
428        for sampling in callee.sampling.iter() {
429            // If the callee was passed the texture or sampler as an argument,
430            // we may now be able to determine which globals those referred to.
431            let image_storage = match sampling.image {
432                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
433                GlobalOrArgument::Argument(i) => {
434                    let Some(handle) = arguments.get(i as usize).cloned() else {
435                        // Argument count mismatch, will be reported later by validate_call
436                        break;
437                    };
438                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
439                        |source| {
440                            FunctionError::Expression { handle, source }
441                                .with_span_handle(handle, expression_arena)
442                        },
443                    )?
444                }
445            };
446
447            let sampler_storage = match sampling.sampler {
448                GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
449                GlobalOrArgument::Argument(i) => {
450                    let Some(handle) = arguments.get(i as usize).cloned() else {
451                        // Argument count mismatch, will be reported later by validate_call
452                        break;
453                    };
454                    GlobalOrArgument::from_expression(expression_arena, handle).map_err(
455                        |source| {
456                            FunctionError::Expression { handle, source }
457                                .with_span_handle(handle, expression_arena)
458                        },
459                    )?
460                }
461            };
462
463            // If we've managed to pin both the image and sampler down to
464            // specific globals, record that in our `sampling_set`. Otherwise,
465            // record as much as we do know in our own `sampling` table, for our
466            // callers to sort out.
467            match (image_storage, sampler_storage) {
468                (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
469                    self.sampling_set.insert(SamplingKey { image, sampler });
470                }
471                (image, sampler) => {
472                    self.sampling.insert(Sampling { image, sampler });
473                }
474            }
475        }
476
477        // Inherit global use from our callees.
478        for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
479            *mine |= *other;
480        }
481
482        Ok(FunctionUniformity {
483            result: callee.uniformity.clone(),
484            exit: if callee.may_kill {
485                ExitFlags::MAY_KILL
486            } else {
487                ExitFlags::empty()
488            },
489        })
490    }
491
492    /// Compute the [`ExpressionInfo`] for `handle`.
493    ///
494    /// Replace the dummy entry in [`self.expressions`] for `handle`
495    /// with a real `ExpressionInfo` value describing that expression.
496    ///
497    /// This function is called as part of a forward sweep through the
498    /// arena, so we can assume that all earlier expressions in the
499    /// arena already have valid info. Since expressions only depend
500    /// on earlier expressions, this includes all our subexpressions.
501    ///
502    /// Adjust the reference counts on all expressions we use.
503    ///
504    /// Also populate the [`sampling_set`], [`sampling`] and
505    /// [`global_uses`] fields of `self`.
506    ///
507    /// [`self.expressions`]: FunctionInfo::expressions
508    /// [`sampling_set`]: FunctionInfo::sampling_set
509    /// [`sampling`]: FunctionInfo::sampling
510    /// [`global_uses`]: FunctionInfo::global_uses
511    #[allow(clippy::or_fun_call)]
512    fn process_expression(
513        &mut self,
514        handle: Handle<crate::Expression>,
515        expression_arena: &Arena<crate::Expression>,
516        other_functions: &[FunctionInfo],
517        resolve_context: &ResolveContext,
518        capabilities: super::Capabilities,
519    ) -> Result<(), ExpressionError> {
520        use crate::{Expression as E, SampleLevel as Sl};
521
522        let expression = &expression_arena[handle];
523        let mut assignable_global = None;
524        let uniformity = match *expression {
525            E::Access { base, index } => {
526                let base_ty = self[base].ty.inner_with(resolve_context.types);
527
528                // build up the caps needed if this is indexed non-uniformly
529                let mut needed_caps = super::Capabilities::empty();
530                let is_binding_array = match *base_ty {
531                    crate::TypeInner::BindingArray {
532                        base: array_element_ty_handle,
533                        ..
534                    } => {
535                        // these are nasty aliases, but these idents are too long and break rustfmt
536                        let sto = super::Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
537                        let uni = super::Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
538                        let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
539                        let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
540
541                        // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
542                        let array_element_ty =
543                            &resolve_context.types[array_element_ty_handle].inner;
544
545                        needed_caps |= match *array_element_ty {
546                            // If we're an image, use the appropriate limit.
547                            crate::TypeInner::Image { class, .. } => match class {
548                                crate::ImageClass::Storage { .. } => sto,
549                                _ => st_sb,
550                            },
551                            crate::TypeInner::Sampler { .. } => sampler,
552                            // If we're anything but an image, assume we're a buffer and use the address space.
553                            _ => {
554                                if let E::GlobalVariable(global_handle) = expression_arena[base] {
555                                    let global = &resolve_context.global_vars[global_handle];
556                                    match global.space {
557                                        crate::AddressSpace::Uniform => uni,
558                                        crate::AddressSpace::Storage { .. } => st_sb,
559                                        _ => unreachable!(),
560                                    }
561                                } else {
562                                    unreachable!()
563                                }
564                            }
565                        };
566
567                        true
568                    }
569                    _ => false,
570                };
571
572                if self[index].uniformity.non_uniform_result.is_some()
573                    && !capabilities.contains(needed_caps)
574                    && is_binding_array
575                {
576                    return Err(ExpressionError::MissingCapabilities(needed_caps));
577                }
578
579                Uniformity {
580                    non_uniform_result: self
581                        .add_assignable_ref(base, &mut assignable_global)
582                        .or(self.add_ref(index)),
583                    requirements: UniformityRequirements::empty(),
584                }
585            }
586            E::AccessIndex { base, .. } => Uniformity {
587                non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
588                requirements: UniformityRequirements::empty(),
589            },
590            // always uniform
591            E::Splat { size: _, value } => Uniformity {
592                non_uniform_result: self.add_ref(value),
593                requirements: UniformityRequirements::empty(),
594            },
595            E::Swizzle { vector, .. } => Uniformity {
596                non_uniform_result: self.add_ref(vector),
597                requirements: UniformityRequirements::empty(),
598            },
599            E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
600            E::Compose { ref components, .. } => {
601                let non_uniform_result = components
602                    .iter()
603                    .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
604                Uniformity {
605                    non_uniform_result,
606                    requirements: UniformityRequirements::empty(),
607                }
608            }
609            // depends on the builtin
610            E::FunctionArgument(index) => {
611                let arg = &resolve_context.arguments[index as usize];
612                let uniform = match arg.binding {
613                    Some(crate::Binding::BuiltIn(
614                        // per-work-group built-ins are uniform
615                        crate::BuiltIn::WorkGroupId
616                        | crate::BuiltIn::WorkGroupSize
617                        | crate::BuiltIn::NumWorkGroups,
618                    )) => true,
619                    _ => false,
620                };
621                Uniformity {
622                    non_uniform_result: if uniform { None } else { Some(handle) },
623                    requirements: UniformityRequirements::empty(),
624                }
625            }
626            // depends on the address space
627            E::GlobalVariable(gh) => {
628                use crate::AddressSpace as As;
629                assignable_global = Some(gh);
630                let var = &resolve_context.global_vars[gh];
631                let uniform = match var.space {
632                    // local data is non-uniform
633                    As::Function | As::Private => false,
634                    // workgroup memory is exclusively accessed by the group
635                    As::WorkGroup => true,
636                    // uniform data
637                    As::Uniform | As::PushConstant => true,
638                    // storage data is only uniform when read-only
639                    As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
640                    As::Handle => false,
641                };
642                Uniformity {
643                    non_uniform_result: if uniform { None } else { Some(handle) },
644                    requirements: UniformityRequirements::empty(),
645                }
646            }
647            E::LocalVariable(_) => Uniformity {
648                non_uniform_result: Some(handle),
649                requirements: UniformityRequirements::empty(),
650            },
651            E::Load { pointer } => Uniformity {
652                non_uniform_result: self.add_ref(pointer),
653                requirements: UniformityRequirements::empty(),
654            },
655            E::ImageSample {
656                image,
657                sampler,
658                gather: _,
659                coordinate,
660                array_index,
661                offset,
662                level,
663                depth_ref,
664            } => {
665                let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
666                let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
667
668                match (image_storage, sampler_storage) {
669                    (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
670                        self.sampling_set.insert(SamplingKey { image, sampler });
671                    }
672                    _ => {
673                        self.sampling.insert(Sampling {
674                            image: image_storage,
675                            sampler: sampler_storage,
676                        });
677                    }
678                }
679
680                // "nur" == "Non-Uniform Result"
681                let array_nur = array_index.and_then(|h| self.add_ref(h));
682                let level_nur = match level {
683                    Sl::Auto | Sl::Zero => None,
684                    Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
685                    Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
686                };
687                let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
688                let offset_nur = offset.and_then(|h| self.add_ref(h));
689                Uniformity {
690                    non_uniform_result: self
691                        .add_ref(image)
692                        .or(self.add_ref(sampler))
693                        .or(self.add_ref(coordinate))
694                        .or(array_nur)
695                        .or(level_nur)
696                        .or(dref_nur)
697                        .or(offset_nur),
698                    requirements: if level.implicit_derivatives() {
699                        UniformityRequirements::IMPLICIT_LEVEL
700                    } else {
701                        UniformityRequirements::empty()
702                    },
703                }
704            }
705            E::ImageLoad {
706                image,
707                coordinate,
708                array_index,
709                sample,
710                level,
711            } => {
712                let array_nur = array_index.and_then(|h| self.add_ref(h));
713                let sample_nur = sample.and_then(|h| self.add_ref(h));
714                let level_nur = level.and_then(|h| self.add_ref(h));
715                Uniformity {
716                    non_uniform_result: self
717                        .add_ref(image)
718                        .or(self.add_ref(coordinate))
719                        .or(array_nur)
720                        .or(sample_nur)
721                        .or(level_nur),
722                    requirements: UniformityRequirements::empty(),
723                }
724            }
725            E::ImageQuery { image, query } => {
726                let query_nur = match query {
727                    crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
728                    _ => None,
729                };
730                Uniformity {
731                    non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
732                    requirements: UniformityRequirements::empty(),
733                }
734            }
735            E::Unary { expr, .. } => Uniformity {
736                non_uniform_result: self.add_ref(expr),
737                requirements: UniformityRequirements::empty(),
738            },
739            E::Binary { left, right, .. } => Uniformity {
740                non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
741                requirements: UniformityRequirements::empty(),
742            },
743            E::Select {
744                condition,
745                accept,
746                reject,
747            } => Uniformity {
748                non_uniform_result: self
749                    .add_ref(condition)
750                    .or(self.add_ref(accept))
751                    .or(self.add_ref(reject)),
752                requirements: UniformityRequirements::empty(),
753            },
754            // explicit derivatives require uniform
755            E::Derivative { expr, .. } => Uniformity {
756                //Note: taking a derivative of a uniform doesn't make it non-uniform
757                non_uniform_result: self.add_ref(expr),
758                requirements: UniformityRequirements::DERIVATIVE,
759            },
760            E::Relational { argument, .. } => Uniformity {
761                non_uniform_result: self.add_ref(argument),
762                requirements: UniformityRequirements::empty(),
763            },
764            E::Math {
765                fun: _,
766                arg,
767                arg1,
768                arg2,
769                arg3,
770            } => {
771                let arg1_nur = arg1.and_then(|h| self.add_ref(h));
772                let arg2_nur = arg2.and_then(|h| self.add_ref(h));
773                let arg3_nur = arg3.and_then(|h| self.add_ref(h));
774                Uniformity {
775                    non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
776                    requirements: UniformityRequirements::empty(),
777                }
778            }
779            E::As { expr, .. } => Uniformity {
780                non_uniform_result: self.add_ref(expr),
781                requirements: UniformityRequirements::empty(),
782            },
783            E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
784            E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
785                non_uniform_result: Some(handle),
786                requirements: UniformityRequirements::empty(),
787            },
788            E::WorkGroupUniformLoadResult { .. } => Uniformity {
789                // The result of WorkGroupUniformLoad is always uniform by definition
790                non_uniform_result: None,
791                // The call is what cares about uniformity, not the expression
792                // This expression is never emitted, so this requirement should never be used anyway?
793                requirements: UniformityRequirements::empty(),
794            },
795            E::ArrayLength(expr) => Uniformity {
796                non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
797                requirements: UniformityRequirements::empty(),
798            },
799            E::RayQueryGetIntersection {
800                query,
801                committed: _,
802            } => Uniformity {
803                non_uniform_result: self.add_ref(query),
804                requirements: UniformityRequirements::empty(),
805            },
806            E::SubgroupBallotResult => Uniformity {
807                non_uniform_result: Some(handle),
808                requirements: UniformityRequirements::empty(),
809            },
810            E::SubgroupOperationResult { .. } => Uniformity {
811                non_uniform_result: Some(handle),
812                requirements: UniformityRequirements::empty(),
813            },
814            E::RayQueryVertexPositions {
815                query,
816                committed: _,
817            } => Uniformity {
818                non_uniform_result: self.add_ref(query),
819                requirements: UniformityRequirements::empty(),
820            },
821        };
822
823        let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
824        self.expressions[handle.index()] = ExpressionInfo {
825            uniformity,
826            ref_count: 0,
827            assignable_global,
828            ty,
829        };
830        Ok(())
831    }
832
833    /// Analyzes the uniformity requirements of a block (as a sequence of statements).
834    /// Returns the uniformity characteristics at the *function* level, i.e.
835    /// whether or not the function requires to be called in uniform control flow,
836    /// and whether the produced result is not disrupting the control flow.
837    ///
838    /// The parent control flow is uniform if `disruptor.is_none()`.
839    ///
840    /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
841    /// require uniformity, but the current flow is non-uniform.
842    #[allow(clippy::or_fun_call)]
843    fn process_block(
844        &mut self,
845        statements: &crate::Block,
846        other_functions: &[FunctionInfo],
847        mut disruptor: Option<UniformityDisruptor>,
848        expression_arena: &Arena<crate::Expression>,
849        diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
850    ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
851        use crate::Statement as S;
852
853        let mut combined_uniformity = FunctionUniformity::new();
854        for statement in statements {
855            let uniformity = match *statement {
856                S::Emit(ref range) => {
857                    let mut requirements = UniformityRequirements::empty();
858                    for expr in range.clone() {
859                        let req = self.expressions[expr.index()].uniformity.requirements;
860                        if self
861                            .flags
862                            .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
863                            && !req.is_empty()
864                        {
865                            if let Some(cause) = disruptor {
866                                let severity = DiagnosticFilterNode::search(
867                                    self.diagnostic_filter_leaf,
868                                    diagnostic_filter_arena,
869                                    StandardFilterableTriggeringRule::DerivativeUniformity,
870                                );
871                                severity.report_diag(
872                                    FunctionError::NonUniformControlFlow(req, expr, cause)
873                                        .with_span_handle(expr, expression_arena),
874                                    // TODO: Yes, this isn't contextualized with source, because
875                                    // the user is supposed to render what would normally be an
876                                    // error here. Once we actually support warning-level
877                                    // diagnostic items, then we won't need this non-compliant hack:
878                                    // <https://github.com/gfx-rs/wgpu/issues/6458>
879                                    |e, level| log::log!(level, "{e}"),
880                                )?;
881                            }
882                        }
883                        requirements |= req;
884                    }
885                    FunctionUniformity {
886                        result: Uniformity {
887                            non_uniform_result: None,
888                            requirements,
889                        },
890                        exit: ExitFlags::empty(),
891                    }
892                }
893                S::Break | S::Continue => FunctionUniformity::new(),
894                S::Kill => FunctionUniformity {
895                    result: Uniformity::new(),
896                    exit: if disruptor.is_some() {
897                        ExitFlags::MAY_KILL
898                    } else {
899                        ExitFlags::empty()
900                    },
901                },
902                S::Barrier(_) => FunctionUniformity {
903                    result: Uniformity {
904                        non_uniform_result: None,
905                        requirements: UniformityRequirements::WORK_GROUP_BARRIER,
906                    },
907                    exit: ExitFlags::empty(),
908                },
909                S::WorkGroupUniformLoad { pointer, .. } => {
910                    let _condition_nur = self.add_ref(pointer);
911
912                    // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
913                    // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
914                    // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
915                    // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
916
917                    /*
918                    if self
919                        .flags
920                        .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
921                    {
922                        let condition_nur = self.add_ref(pointer);
923                        let this_disruptor =
924                            disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
925                        if let Some(cause) = this_disruptor {
926                            return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
927                                .with_span_static(*span, "WorkGroupUniformLoad"));
928                        }
929                    } */
930                    FunctionUniformity {
931                        result: Uniformity {
932                            non_uniform_result: None,
933                            requirements: UniformityRequirements::WORK_GROUP_BARRIER,
934                        },
935                        exit: ExitFlags::empty(),
936                    }
937                }
938                S::Block(ref b) => self.process_block(
939                    b,
940                    other_functions,
941                    disruptor,
942                    expression_arena,
943                    diagnostic_filter_arena,
944                )?,
945                S::If {
946                    condition,
947                    ref accept,
948                    ref reject,
949                } => {
950                    let condition_nur = self.add_ref(condition);
951                    let branch_disruptor =
952                        disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
953                    let accept_uniformity = self.process_block(
954                        accept,
955                        other_functions,
956                        branch_disruptor,
957                        expression_arena,
958                        diagnostic_filter_arena,
959                    )?;
960                    let reject_uniformity = self.process_block(
961                        reject,
962                        other_functions,
963                        branch_disruptor,
964                        expression_arena,
965                        diagnostic_filter_arena,
966                    )?;
967                    accept_uniformity | reject_uniformity
968                }
969                S::Switch {
970                    selector,
971                    ref cases,
972                } => {
973                    let selector_nur = self.add_ref(selector);
974                    let branch_disruptor =
975                        disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
976                    let mut uniformity = FunctionUniformity::new();
977                    let mut case_disruptor = branch_disruptor;
978                    for case in cases.iter() {
979                        let case_uniformity = self.process_block(
980                            &case.body,
981                            other_functions,
982                            case_disruptor,
983                            expression_arena,
984                            diagnostic_filter_arena,
985                        )?;
986                        case_disruptor = if case.fall_through {
987                            case_disruptor.or(case_uniformity.exit_disruptor())
988                        } else {
989                            branch_disruptor
990                        };
991                        uniformity = uniformity | case_uniformity;
992                    }
993                    uniformity
994                }
995                S::Loop {
996                    ref body,
997                    ref continuing,
998                    break_if,
999                } => {
1000                    let body_uniformity = self.process_block(
1001                        body,
1002                        other_functions,
1003                        disruptor,
1004                        expression_arena,
1005                        diagnostic_filter_arena,
1006                    )?;
1007                    let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1008                    let continuing_uniformity = self.process_block(
1009                        continuing,
1010                        other_functions,
1011                        continuing_disruptor,
1012                        expression_arena,
1013                        diagnostic_filter_arena,
1014                    )?;
1015                    if let Some(expr) = break_if {
1016                        let _ = self.add_ref(expr);
1017                    }
1018                    body_uniformity | continuing_uniformity
1019                }
1020                S::Return { value } => FunctionUniformity {
1021                    result: Uniformity {
1022                        non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1023                        requirements: UniformityRequirements::empty(),
1024                    },
1025                    exit: if disruptor.is_some() {
1026                        ExitFlags::MAY_RETURN
1027                    } else {
1028                        ExitFlags::empty()
1029                    },
1030                },
1031                // Here and below, the used expressions are already emitted,
1032                // and their results do not affect the function return value,
1033                // so we can ignore their non-uniformity.
1034                S::Store { pointer, value } => {
1035                    let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1036                    let _ = self.add_ref(value);
1037                    FunctionUniformity::new()
1038                }
1039                S::ImageStore {
1040                    image,
1041                    coordinate,
1042                    array_index,
1043                    value,
1044                } => {
1045                    let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1046                    if let Some(expr) = array_index {
1047                        let _ = self.add_ref(expr);
1048                    }
1049                    let _ = self.add_ref(coordinate);
1050                    let _ = self.add_ref(value);
1051                    FunctionUniformity::new()
1052                }
1053                S::Call {
1054                    function,
1055                    ref arguments,
1056                    result: _,
1057                } => {
1058                    for &argument in arguments {
1059                        let _ = self.add_ref(argument);
1060                    }
1061                    let info = &other_functions[function.index()];
1062                    //Note: the result is validated by the Validator, not here
1063                    self.process_call(info, arguments, expression_arena)?
1064                }
1065                S::Atomic {
1066                    pointer,
1067                    ref fun,
1068                    value,
1069                    result: _,
1070                } => {
1071                    let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1072                    let _ = self.add_ref(value);
1073                    if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1074                        let _ = self.add_ref(cmp);
1075                    }
1076                    FunctionUniformity::new()
1077                }
1078                S::ImageAtomic {
1079                    image,
1080                    coordinate,
1081                    array_index,
1082                    fun: _,
1083                    value,
1084                } => {
1085                    let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1086                    let _ = self.add_ref(coordinate);
1087                    if let Some(expr) = array_index {
1088                        let _ = self.add_ref(expr);
1089                    }
1090                    let _ = self.add_ref(value);
1091                    FunctionUniformity::new()
1092                }
1093                S::RayQuery { query, ref fun } => {
1094                    let _ = self.add_ref(query);
1095                    match *fun {
1096                        crate::RayQueryFunction::Initialize {
1097                            acceleration_structure,
1098                            descriptor,
1099                        } => {
1100                            let _ = self.add_ref(acceleration_structure);
1101                            let _ = self.add_ref(descriptor);
1102                        }
1103                        crate::RayQueryFunction::Proceed { result: _ } => {}
1104                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1105                            let _ = self.add_ref(hit_t);
1106                        }
1107                        crate::RayQueryFunction::ConfirmIntersection => {}
1108                        crate::RayQueryFunction::Terminate => {}
1109                    }
1110                    FunctionUniformity::new()
1111                }
1112                S::SubgroupBallot {
1113                    result: _,
1114                    predicate,
1115                } => {
1116                    if let Some(predicate) = predicate {
1117                        let _ = self.add_ref(predicate);
1118                    }
1119                    FunctionUniformity::new()
1120                }
1121                S::SubgroupCollectiveOperation {
1122                    op: _,
1123                    collective_op: _,
1124                    argument,
1125                    result: _,
1126                } => {
1127                    let _ = self.add_ref(argument);
1128                    FunctionUniformity::new()
1129                }
1130                S::SubgroupGather {
1131                    mode,
1132                    argument,
1133                    result: _,
1134                } => {
1135                    let _ = self.add_ref(argument);
1136                    match mode {
1137                        crate::GatherMode::BroadcastFirst => {}
1138                        crate::GatherMode::Broadcast(index)
1139                        | crate::GatherMode::Shuffle(index)
1140                        | crate::GatherMode::ShuffleDown(index)
1141                        | crate::GatherMode::ShuffleUp(index)
1142                        | crate::GatherMode::ShuffleXor(index) => {
1143                            let _ = self.add_ref(index);
1144                        }
1145                    }
1146                    FunctionUniformity::new()
1147                }
1148            };
1149
1150            disruptor = disruptor.or(uniformity.exit_disruptor());
1151            combined_uniformity = combined_uniformity | uniformity;
1152        }
1153        Ok(combined_uniformity)
1154    }
1155}
1156
1157impl ModuleInfo {
1158    /// Populates `self.const_expression_types`
1159    pub(super) fn process_const_expression(
1160        &mut self,
1161        handle: Handle<crate::Expression>,
1162        resolve_context: &ResolveContext,
1163        gctx: crate::proc::GlobalCtx,
1164    ) -> Result<(), super::ConstExpressionError> {
1165        self.const_expression_types[handle.index()] =
1166            resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1167        Ok(())
1168    }
1169
1170    /// Builds the `FunctionInfo` based on the function, and validates the
1171    /// uniform control flow if required by the expressions of this function.
1172    pub(super) fn process_function(
1173        &self,
1174        fun: &crate::Function,
1175        module: &crate::Module,
1176        flags: ValidationFlags,
1177        capabilities: super::Capabilities,
1178    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1179        let mut info = FunctionInfo {
1180            flags,
1181            available_stages: ShaderStages::all(),
1182            uniformity: Uniformity::new(),
1183            may_kill: false,
1184            sampling_set: crate::FastHashSet::default(),
1185            global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1186            expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1187            sampling: crate::FastHashSet::default(),
1188            dual_source_blending: false,
1189            diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1190        };
1191        let resolve_context =
1192            ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1193
1194        for (handle, _) in fun.expressions.iter() {
1195            if let Err(source) = info.process_expression(
1196                handle,
1197                &fun.expressions,
1198                &self.functions,
1199                &resolve_context,
1200                capabilities,
1201            ) {
1202                return Err(FunctionError::Expression { handle, source }
1203                    .with_span_handle(handle, &fun.expressions));
1204            }
1205        }
1206
1207        for (_, expr) in fun.local_variables.iter() {
1208            if let Some(init) = expr.init {
1209                let _ = info.add_ref(init);
1210            }
1211        }
1212
1213        let uniformity = info.process_block(
1214            &fun.body,
1215            &self.functions,
1216            None,
1217            &fun.expressions,
1218            &module.diagnostic_filters,
1219        )?;
1220        info.uniformity = uniformity.result;
1221        info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1222
1223        Ok(info)
1224    }
1225
1226    pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1227        &self.entry_points[index]
1228    }
1229}
1230
1231#[test]
1232fn uniform_control_flow() {
1233    use crate::{Expression as E, Statement as S};
1234
1235    let mut type_arena = crate::UniqueArena::new();
1236    let ty = type_arena.insert(
1237        crate::Type {
1238            name: None,
1239            inner: crate::TypeInner::Vector {
1240                size: crate::VectorSize::Bi,
1241                scalar: crate::Scalar::F32,
1242            },
1243        },
1244        Default::default(),
1245    );
1246    let mut global_var_arena = Arena::new();
1247    let non_uniform_global = global_var_arena.append(
1248        crate::GlobalVariable {
1249            name: None,
1250            init: None,
1251            ty,
1252            space: crate::AddressSpace::Handle,
1253            binding: None,
1254        },
1255        Default::default(),
1256    );
1257    let uniform_global = global_var_arena.append(
1258        crate::GlobalVariable {
1259            name: None,
1260            init: None,
1261            ty,
1262            binding: None,
1263            space: crate::AddressSpace::Uniform,
1264        },
1265        Default::default(),
1266    );
1267
1268    let mut expressions = Arena::new();
1269    // checks the uniform control flow
1270    let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1271    // checks the non-uniform control flow
1272    let derivative_expr = expressions.append(
1273        E::Derivative {
1274            axis: crate::DerivativeAxis::X,
1275            ctrl: crate::DerivativeControl::None,
1276            expr: constant_expr,
1277        },
1278        Default::default(),
1279    );
1280    let emit_range_constant_derivative = expressions.range_from(0);
1281    let non_uniform_global_expr =
1282        expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1283    let uniform_global_expr =
1284        expressions.append(E::GlobalVariable(uniform_global), Default::default());
1285    let emit_range_globals = expressions.range_from(2);
1286
1287    // checks the QUERY flag
1288    let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1289    // checks the transitive WRITE flag
1290    let access_expr = expressions.append(
1291        E::AccessIndex {
1292            base: non_uniform_global_expr,
1293            index: 1,
1294        },
1295        Default::default(),
1296    );
1297    let emit_range_query_access_globals = expressions.range_from(2);
1298
1299    let mut info = FunctionInfo {
1300        flags: ValidationFlags::all(),
1301        available_stages: ShaderStages::all(),
1302        uniformity: Uniformity::new(),
1303        may_kill: false,
1304        sampling_set: crate::FastHashSet::default(),
1305        global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1306        expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1307        sampling: crate::FastHashSet::default(),
1308        dual_source_blending: false,
1309        diagnostic_filter_leaf: None,
1310    };
1311    let resolve_context = ResolveContext {
1312        constants: &Arena::new(),
1313        overrides: &Arena::new(),
1314        types: &type_arena,
1315        special_types: &crate::SpecialTypes::default(),
1316        global_vars: &global_var_arena,
1317        local_vars: &Arena::new(),
1318        functions: &Arena::new(),
1319        arguments: &[],
1320    };
1321    for (handle, _) in expressions.iter() {
1322        info.process_expression(
1323            handle,
1324            &expressions,
1325            &[],
1326            &resolve_context,
1327            super::Capabilities::empty(),
1328        )
1329        .unwrap();
1330    }
1331    assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1332    assert_eq!(info[uniform_global_expr].ref_count, 1);
1333    assert_eq!(info[query_expr].ref_count, 0);
1334    assert_eq!(info[access_expr].ref_count, 0);
1335    assert_eq!(info[non_uniform_global], GlobalUse::empty());
1336    assert_eq!(info[uniform_global], GlobalUse::QUERY);
1337
1338    let stmt_emit1 = S::Emit(emit_range_globals.clone());
1339    let stmt_if_uniform = S::If {
1340        condition: uniform_global_expr,
1341        accept: crate::Block::new(),
1342        reject: vec![
1343            S::Emit(emit_range_constant_derivative.clone()),
1344            S::Store {
1345                pointer: constant_expr,
1346                value: derivative_expr,
1347            },
1348        ]
1349        .into(),
1350    };
1351    assert_eq!(
1352        info.process_block(
1353            &vec![stmt_emit1, stmt_if_uniform].into(),
1354            &[],
1355            None,
1356            &expressions,
1357            &Arena::new(),
1358        ),
1359        Ok(FunctionUniformity {
1360            result: Uniformity {
1361                non_uniform_result: None,
1362                requirements: UniformityRequirements::DERIVATIVE,
1363            },
1364            exit: ExitFlags::empty(),
1365        }),
1366    );
1367    assert_eq!(info[constant_expr].ref_count, 2);
1368    assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1369
1370    let stmt_emit2 = S::Emit(emit_range_globals.clone());
1371    let stmt_if_non_uniform = S::If {
1372        condition: non_uniform_global_expr,
1373        accept: vec![
1374            S::Emit(emit_range_constant_derivative),
1375            S::Store {
1376                pointer: constant_expr,
1377                value: derivative_expr,
1378            },
1379        ]
1380        .into(),
1381        reject: crate::Block::new(),
1382    };
1383    {
1384        let block_info = info.process_block(
1385            &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1386            &[],
1387            None,
1388            &expressions,
1389            &Arena::new(),
1390        );
1391        if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1392            assert_eq!(info[derivative_expr].ref_count, 2);
1393        } else {
1394            assert_eq!(
1395                block_info,
1396                Err(FunctionError::NonUniformControlFlow(
1397                    UniformityRequirements::DERIVATIVE,
1398                    derivative_expr,
1399                    UniformityDisruptor::Expression(non_uniform_global_expr)
1400                )
1401                .with_span()),
1402            );
1403            assert_eq!(info[derivative_expr].ref_count, 1);
1404
1405            // Test that the same thing passes when we disable the `derivative_uniformity`
1406            let mut diagnostic_filters = Arena::new();
1407            let diagnostic_filter_leaf = diagnostic_filters.append(
1408                DiagnosticFilterNode {
1409                    inner: crate::diagnostic_filter::DiagnosticFilter {
1410                        new_severity: crate::diagnostic_filter::Severity::Off,
1411                        triggering_rule:
1412                            crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1413                                StandardFilterableTriggeringRule::DerivativeUniformity,
1414                            ),
1415                    },
1416                    parent: None,
1417                },
1418                crate::Span::default(),
1419            );
1420            let mut info = FunctionInfo {
1421                diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1422                ..info.clone()
1423            };
1424
1425            let block_info = info.process_block(
1426                &vec![stmt_emit2, stmt_if_non_uniform].into(),
1427                &[],
1428                None,
1429                &expressions,
1430                &diagnostic_filters,
1431            );
1432            assert_eq!(
1433                block_info,
1434                Ok(FunctionUniformity {
1435                    result: Uniformity {
1436                        non_uniform_result: None,
1437                        requirements: UniformityRequirements::DERIVATIVE,
1438                    },
1439                    exit: ExitFlags::empty()
1440                }),
1441            );
1442            assert_eq!(info[derivative_expr].ref_count, 2);
1443        }
1444    }
1445    assert_eq!(info[non_uniform_global], GlobalUse::READ);
1446
1447    let stmt_emit3 = S::Emit(emit_range_globals);
1448    let stmt_return_non_uniform = S::Return {
1449        value: Some(non_uniform_global_expr),
1450    };
1451    assert_eq!(
1452        info.process_block(
1453            &vec![stmt_emit3, stmt_return_non_uniform].into(),
1454            &[],
1455            Some(UniformityDisruptor::Return),
1456            &expressions,
1457            &Arena::new(),
1458        ),
1459        Ok(FunctionUniformity {
1460            result: Uniformity {
1461                non_uniform_result: Some(non_uniform_global_expr),
1462                requirements: UniformityRequirements::empty(),
1463            },
1464            exit: ExitFlags::MAY_RETURN,
1465        }),
1466    );
1467    assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1468
1469    // Check that uniformity requirements reach through a pointer
1470    let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1471    let stmt_assign = S::Store {
1472        pointer: access_expr,
1473        value: query_expr,
1474    };
1475    let stmt_return_pointer = S::Return {
1476        value: Some(access_expr),
1477    };
1478    let stmt_kill = S::Kill;
1479    assert_eq!(
1480        info.process_block(
1481            &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1482            &[],
1483            Some(UniformityDisruptor::Discard),
1484            &expressions,
1485            &Arena::new(),
1486        ),
1487        Ok(FunctionUniformity {
1488            result: Uniformity {
1489                non_uniform_result: Some(non_uniform_global_expr),
1490                requirements: UniformityRequirements::empty(),
1491            },
1492            exit: ExitFlags::all(),
1493        }),
1494    );
1495    assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1496}