naga/valid/
function.rs

1use alloc::{format, string::String};
2
3use super::validate_atomic_compare_exchange_struct;
4use super::{
5    analyzer::{UniformityDisruptor, UniformityRequirements},
6    ExpressionError, FunctionInfo, ModuleInfo,
7};
8use crate::arena::{Arena, UniqueArena};
9use crate::arena::{Handle, HandleSet};
10use crate::proc::TypeResolution;
11use crate::span::WithSpan;
12use crate::span::{AddSpan as _, MapErrWithSpan as _};
13
14#[derive(Clone, Debug, thiserror::Error)]
15#[cfg_attr(test, derive(PartialEq))]
16pub enum CallError {
17    #[error("Argument {index} expression is invalid")]
18    Argument {
19        index: usize,
20        source: ExpressionError,
21    },
22    #[error("Result expression {0:?} has already been introduced earlier")]
23    ResultAlreadyInScope(Handle<crate::Expression>),
24    #[error("Result expression {0:?} is populated by multiple `Call` statements")]
25    ResultAlreadyPopulated(Handle<crate::Expression>),
26    #[error("Requires {required} arguments, but {seen} are provided")]
27    ArgumentCount { required: usize, seen: usize },
28    #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
29    ArgumentType {
30        index: usize,
31        required: Handle<crate::Type>,
32        seen_expression: Handle<crate::Expression>,
33    },
34    #[error("The emitted expression doesn't match the call")]
35    ExpressionMismatch(Option<Handle<crate::Expression>>),
36}
37
38#[derive(Clone, Debug, thiserror::Error)]
39#[cfg_attr(test, derive(PartialEq))]
40pub enum AtomicError {
41    #[error("Pointer {0:?} to atomic is invalid.")]
42    InvalidPointer(Handle<crate::Expression>),
43    #[error("Address space {0:?} is not supported.")]
44    InvalidAddressSpace(crate::AddressSpace),
45    #[error("Operand {0:?} has invalid type.")]
46    InvalidOperand(Handle<crate::Expression>),
47    #[error("Operator {0:?} is not supported.")]
48    InvalidOperator(crate::AtomicFunction),
49    #[error("Result expression {0:?} is not an `AtomicResult` expression")]
50    InvalidResultExpression(Handle<crate::Expression>),
51    #[error("Result expression {0:?} is marked as an `exchange`")]
52    ResultExpressionExchange(Handle<crate::Expression>),
53    #[error("Result expression {0:?} is not marked as an `exchange`")]
54    ResultExpressionNotExchange(Handle<crate::Expression>),
55    #[error("Result type for {0:?} doesn't match the statement")]
56    ResultTypeMismatch(Handle<crate::Expression>),
57    #[error("Exchange operations must return a value")]
58    MissingReturnValue,
59    #[error("Capability {0:?} is required")]
60    MissingCapability(super::Capabilities),
61    #[error("Result expression {0:?} is populated by multiple `Atomic` statements")]
62    ResultAlreadyPopulated(Handle<crate::Expression>),
63}
64
65#[derive(Clone, Debug, thiserror::Error)]
66#[cfg_attr(test, derive(PartialEq))]
67pub enum SubgroupError {
68    #[error("Operand {0:?} has invalid type.")]
69    InvalidOperand(Handle<crate::Expression>),
70    #[error("Result type for {0:?} doesn't match the statement")]
71    ResultTypeMismatch(Handle<crate::Expression>),
72    #[error("Support for subgroup operation {0:?} is required")]
73    UnsupportedOperation(super::SubgroupOperationSet),
74    #[error("Unknown operation")]
75    UnknownOperation,
76    #[error("Invocation ID must be a const-expression")]
77    InvalidInvocationIdExprType(Handle<crate::Expression>),
78}
79
80#[derive(Clone, Debug, thiserror::Error)]
81#[cfg_attr(test, derive(PartialEq))]
82pub enum LocalVariableError {
83    #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
84    InvalidType(Handle<crate::Type>),
85    #[error("Initializer doesn't match the variable type")]
86    InitializerType,
87    #[error("Initializer is not a const or override expression")]
88    NonConstOrOverrideInitializer,
89}
90
91#[derive(Clone, Debug, thiserror::Error)]
92#[cfg_attr(test, derive(PartialEq))]
93pub enum FunctionError {
94    #[error("Expression {handle:?} is invalid")]
95    Expression {
96        handle: Handle<crate::Expression>,
97        source: ExpressionError,
98    },
99    #[error("Expression {0:?} can't be introduced - it's already in scope")]
100    ExpressionAlreadyInScope(Handle<crate::Expression>),
101    #[error("Local variable {handle:?} '{name}' is invalid")]
102    LocalVariable {
103        handle: Handle<crate::LocalVariable>,
104        name: String,
105        source: LocalVariableError,
106    },
107    #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
108    InvalidArgumentType { index: usize, name: String },
109    #[error("The function's given return type cannot be returned from functions")]
110    NonConstructibleReturnType,
111    #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
112    InvalidArgumentPointerSpace {
113        index: usize,
114        name: String,
115        space: crate::AddressSpace,
116    },
117    #[error("The `break` is used outside of a `loop` or `switch` context")]
118    BreakOutsideOfLoopOrSwitch,
119    #[error("The `continue` is used outside of a `loop` context")]
120    ContinueOutsideOfLoop,
121    #[error("The `return` is called within a `continuing` block")]
122    InvalidReturnSpot,
123    #[error("The `return` expression {expression:?} does not match the declared return type {expected_ty:?}")]
124    InvalidReturnType {
125        expression: Option<Handle<crate::Expression>>,
126        expected_ty: Option<Handle<crate::Type>>,
127    },
128    #[error("The `if` condition {0:?} is not a boolean scalar")]
129    InvalidIfType(Handle<crate::Expression>),
130    #[error("The `switch` value {0:?} is not an integer scalar")]
131    InvalidSwitchType(Handle<crate::Expression>),
132    #[error("Multiple `switch` cases for {0:?} are present")]
133    ConflictingSwitchCase(crate::SwitchValue),
134    #[error("The `switch` contains cases with conflicting types")]
135    ConflictingCaseType,
136    #[error("The `switch` is missing a `default` case")]
137    MissingDefaultCase,
138    #[error("Multiple `default` cases are present")]
139    MultipleDefaultCases,
140    #[error("The last `switch` case contains a `fallthrough`")]
141    LastCaseFallTrough,
142    #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
143    InvalidStorePointer(Handle<crate::Expression>),
144    #[error("Image store texture parameter type mismatch")]
145    InvalidStoreTexture {
146        actual: Handle<crate::Expression>,
147        actual_ty: crate::TypeInner,
148    },
149    #[error("Image store value parameter type mismatch")]
150    InvalidStoreValue {
151        actual: Handle<crate::Expression>,
152        actual_ty: crate::TypeInner,
153        expected_ty: crate::TypeInner,
154    },
155    #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")]
156    InvalidStoreTypes {
157        pointer: Handle<crate::Expression>,
158        value: Handle<crate::Expression>,
159    },
160    #[error("Image store parameters are invalid")]
161    InvalidImageStore(#[source] ExpressionError),
162    #[error("Image atomic parameters are invalid")]
163    InvalidImageAtomic(#[source] ExpressionError),
164    #[error("Image atomic function is invalid")]
165    InvalidImageAtomicFunction(crate::AtomicFunction),
166    #[error("Image atomic value is invalid")]
167    InvalidImageAtomicValue(Handle<crate::Expression>),
168    #[error("Call to {function:?} is invalid")]
169    InvalidCall {
170        function: Handle<crate::Function>,
171        #[source]
172        error: CallError,
173    },
174    #[error("Atomic operation is invalid")]
175    InvalidAtomic(#[from] AtomicError),
176    #[error("Ray Query {0:?} is not a local variable")]
177    InvalidRayQueryExpression(Handle<crate::Expression>),
178    #[error("Acceleration structure {0:?} is not a matching expression")]
179    InvalidAccelerationStructure(Handle<crate::Expression>),
180    #[error(
181        "Acceleration structure {0:?} is missing flag vertex_return while Ray Query {1:?} does"
182    )]
183    MissingAccelerationStructureVertexReturn(Handle<crate::Expression>, Handle<crate::Expression>),
184    #[error("Ray Query {0:?} is missing flag vertex_return")]
185    MissingRayQueryVertexReturn(Handle<crate::Expression>),
186    #[error("Ray descriptor {0:?} is not a matching expression")]
187    InvalidRayDescriptor(Handle<crate::Expression>),
188    #[error("Ray Query {0:?} does not have a matching type")]
189    InvalidRayQueryType(Handle<crate::Type>),
190    #[error("Hit distance {0:?} must be an f32")]
191    InvalidHitDistanceType(Handle<crate::Expression>),
192    #[error("Shader requires capability {0:?}")]
193    MissingCapability(super::Capabilities),
194    #[error(
195        "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
196    )]
197    NonUniformControlFlow(
198        UniformityRequirements,
199        Handle<crate::Expression>,
200        UniformityDisruptor,
201    ),
202    #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
203    PipelineInputRegularFunction { name: String },
204    #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
205    PipelineOutputRegularFunction,
206    #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
207    // The actual load statement will be "pointed to" by the span
208    NonUniformWorkgroupUniformLoad(UniformityDisruptor),
209    // This is only possible with a misbehaving frontend
210    #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
211    WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
212    #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
213    WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
214    #[error("Subgroup operation is invalid")]
215    InvalidSubgroup(#[from] SubgroupError),
216    #[error("Emit statement should not cover \"result\" expressions like {0:?}")]
217    EmitResult(Handle<crate::Expression>),
218    #[error("Expression not visited by the appropriate statement")]
219    UnvisitedExpression(Handle<crate::Expression>),
220}
221
222bitflags::bitflags! {
223    #[repr(transparent)]
224    #[derive(Clone, Copy)]
225    struct ControlFlowAbility: u8 {
226        /// The control can return out of this block.
227        const RETURN = 0x1;
228        /// The control can break.
229        const BREAK = 0x2;
230        /// The control can continue.
231        const CONTINUE = 0x4;
232    }
233}
234
235struct BlockInfo {
236    stages: super::ShaderStages,
237}
238
239struct BlockContext<'a> {
240    abilities: ControlFlowAbility,
241    info: &'a FunctionInfo,
242    expressions: &'a Arena<crate::Expression>,
243    types: &'a UniqueArena<crate::Type>,
244    local_vars: &'a Arena<crate::LocalVariable>,
245    global_vars: &'a Arena<crate::GlobalVariable>,
246    functions: &'a Arena<crate::Function>,
247    special_types: &'a crate::SpecialTypes,
248    prev_infos: &'a [FunctionInfo],
249    return_type: Option<Handle<crate::Type>>,
250    local_expr_kind: &'a crate::proc::ExpressionKindTracker,
251}
252
253impl<'a> BlockContext<'a> {
254    fn new(
255        fun: &'a crate::Function,
256        module: &'a crate::Module,
257        info: &'a FunctionInfo,
258        prev_infos: &'a [FunctionInfo],
259        local_expr_kind: &'a crate::proc::ExpressionKindTracker,
260    ) -> Self {
261        Self {
262            abilities: ControlFlowAbility::RETURN,
263            info,
264            expressions: &fun.expressions,
265            types: &module.types,
266            local_vars: &fun.local_variables,
267            global_vars: &module.global_variables,
268            functions: &module.functions,
269            special_types: &module.special_types,
270            prev_infos,
271            return_type: fun.result.as_ref().map(|fr| fr.ty),
272            local_expr_kind,
273        }
274    }
275
276    const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
277        BlockContext { abilities, ..*self }
278    }
279
280    fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
281        &self.expressions[handle]
282    }
283
284    fn resolve_type_impl(
285        &self,
286        handle: Handle<crate::Expression>,
287        valid_expressions: &HandleSet<crate::Expression>,
288    ) -> Result<&TypeResolution, WithSpan<ExpressionError>> {
289        if !valid_expressions.contains(handle) {
290            Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
291        } else {
292            Ok(&self.info[handle].ty)
293        }
294    }
295
296    fn resolve_type(
297        &self,
298        handle: Handle<crate::Expression>,
299        valid_expressions: &HandleSet<crate::Expression>,
300    ) -> Result<&TypeResolution, WithSpan<FunctionError>> {
301        self.resolve_type_impl(handle, valid_expressions)
302            .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
303    }
304
305    fn resolve_type_inner(
306        &self,
307        handle: Handle<crate::Expression>,
308        valid_expressions: &HandleSet<crate::Expression>,
309    ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
310        self.resolve_type(handle, valid_expressions)
311            .map(|tr| tr.inner_with(self.types))
312    }
313
314    fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
315        self.info[handle].ty.inner_with(self.types)
316    }
317
318    fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
319        crate::proc::compare_types(lhs, rhs, self.types)
320    }
321}
322
323impl super::Validator {
324    fn validate_call(
325        &mut self,
326        function: Handle<crate::Function>,
327        arguments: &[Handle<crate::Expression>],
328        result: Option<Handle<crate::Expression>>,
329        context: &BlockContext,
330    ) -> Result<super::ShaderStages, WithSpan<CallError>> {
331        let fun = &context.functions[function];
332        if fun.arguments.len() != arguments.len() {
333            return Err(CallError::ArgumentCount {
334                required: fun.arguments.len(),
335                seen: arguments.len(),
336            }
337            .with_span());
338        }
339        for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
340            let ty = context
341                .resolve_type_impl(expr, &self.valid_expression_set)
342                .map_err_inner(|source| {
343                    CallError::Argument { index, source }
344                        .with_span_handle(expr, context.expressions)
345                })?;
346            if !context.compare_types(&TypeResolution::Handle(arg.ty), ty) {
347                return Err(CallError::ArgumentType {
348                    index,
349                    required: arg.ty,
350                    seen_expression: expr,
351                }
352                .with_span_handle(expr, context.expressions));
353            }
354        }
355
356        if let Some(expr) = result {
357            if self.valid_expression_set.insert(expr) {
358                self.valid_expression_list.push(expr);
359            } else {
360                return Err(CallError::ResultAlreadyInScope(expr)
361                    .with_span_handle(expr, context.expressions));
362            }
363            match context.expressions[expr] {
364                crate::Expression::CallResult(callee)
365                    if fun.result.is_some() && callee == function =>
366                {
367                    if !self.needs_visit.remove(expr) {
368                        return Err(CallError::ResultAlreadyPopulated(expr)
369                            .with_span_handle(expr, context.expressions));
370                    }
371                }
372                _ => {
373                    return Err(CallError::ExpressionMismatch(result)
374                        .with_span_handle(expr, context.expressions))
375                }
376            }
377        } else if fun.result.is_some() {
378            return Err(CallError::ExpressionMismatch(result).with_span());
379        }
380
381        let callee_info = &context.prev_infos[function.index()];
382        Ok(callee_info.available_stages)
383    }
384
385    fn emit_expression(
386        &mut self,
387        handle: Handle<crate::Expression>,
388        context: &BlockContext,
389    ) -> Result<(), WithSpan<FunctionError>> {
390        if self.valid_expression_set.insert(handle) {
391            self.valid_expression_list.push(handle);
392            Ok(())
393        } else {
394            Err(FunctionError::ExpressionAlreadyInScope(handle)
395                .with_span_handle(handle, context.expressions))
396        }
397    }
398
399    fn validate_atomic(
400        &mut self,
401        pointer: Handle<crate::Expression>,
402        fun: &crate::AtomicFunction,
403        value: Handle<crate::Expression>,
404        result: Option<Handle<crate::Expression>>,
405        span: crate::Span,
406        context: &BlockContext,
407    ) -> Result<(), WithSpan<FunctionError>> {
408        // The `pointer` operand must be a pointer to an atomic value.
409        let pointer_inner = context.resolve_type_inner(pointer, &self.valid_expression_set)?;
410        let crate::TypeInner::Pointer {
411            base: pointer_base,
412            space: pointer_space,
413        } = *pointer_inner
414        else {
415            log::error!("Atomic operation on type {:?}", *pointer_inner);
416            return Err(AtomicError::InvalidPointer(pointer)
417                .with_span_handle(pointer, context.expressions)
418                .into_other());
419        };
420        let crate::TypeInner::Atomic(pointer_scalar) = context.types[pointer_base].inner else {
421            log::error!(
422                "Atomic pointer to type {:?}",
423                context.types[pointer_base].inner
424            );
425            return Err(AtomicError::InvalidPointer(pointer)
426                .with_span_handle(pointer, context.expressions)
427                .into_other());
428        };
429
430        // The `value` operand must be a scalar of the same type as the atomic.
431        let value_inner = context.resolve_type_inner(value, &self.valid_expression_set)?;
432        let crate::TypeInner::Scalar(value_scalar) = *value_inner else {
433            log::error!("Atomic operand type {:?}", *value_inner);
434            return Err(AtomicError::InvalidOperand(value)
435                .with_span_handle(value, context.expressions)
436                .into_other());
437        };
438        if pointer_scalar != value_scalar {
439            log::error!("Atomic operand type {:?}", *value_inner);
440            return Err(AtomicError::InvalidOperand(value)
441                .with_span_handle(value, context.expressions)
442                .into_other());
443        }
444
445        match pointer_scalar {
446            // Check for the special restrictions on 64-bit atomic operations.
447            //
448            // We don't need to consider other widths here: this function has already checked
449            // that `pointer`'s type is an `Atomic`, and `validate_type` has already checked
450            // that `Atomic` type has a permitted scalar width.
451            crate::Scalar::I64 | crate::Scalar::U64 => {
452                // `Capabilities::SHADER_INT64_ATOMIC_ALL_OPS` enables all sorts of 64-bit
453                // atomic operations.
454                if self
455                    .capabilities
456                    .contains(super::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS)
457                {
458                    // okay
459                } else {
460                    // `Capabilities::SHADER_INT64_ATOMIC_MIN_MAX` allows `Min` and
461                    // `Max` on operations in `Storage`, without a return value.
462                    if matches!(
463                        *fun,
464                        crate::AtomicFunction::Min | crate::AtomicFunction::Max
465                    ) && matches!(pointer_space, crate::AddressSpace::Storage { .. })
466                        && result.is_none()
467                    {
468                        if !self
469                            .capabilities
470                            .contains(super::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX)
471                        {
472                            log::error!("Int64 min-max atomic operations are not supported");
473                            return Err(AtomicError::MissingCapability(
474                                super::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
475                            )
476                            .with_span_handle(value, context.expressions)
477                            .into_other());
478                        }
479                    } else {
480                        // Otherwise, we require the full 64-bit atomic capability.
481                        log::error!("Int64 atomic operations are not supported");
482                        return Err(AtomicError::MissingCapability(
483                            super::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
484                        )
485                        .with_span_handle(value, context.expressions)
486                        .into_other());
487                    }
488                }
489            }
490            // Check for the special restrictions on 32-bit floating-point atomic operations.
491            crate::Scalar::F32 => {
492                // `Capabilities::SHADER_FLOAT32_ATOMIC` allows 32-bit floating-point
493                // atomic operations `Add`, `Subtract`, and `Exchange`
494                // in the `Storage` address space.
495                if !self
496                    .capabilities
497                    .contains(super::Capabilities::SHADER_FLOAT32_ATOMIC)
498                {
499                    log::error!("Float32 atomic operations are not supported");
500                    return Err(AtomicError::MissingCapability(
501                        super::Capabilities::SHADER_FLOAT32_ATOMIC,
502                    )
503                    .with_span_handle(value, context.expressions)
504                    .into_other());
505                }
506                if !matches!(
507                    *fun,
508                    crate::AtomicFunction::Add
509                        | crate::AtomicFunction::Subtract
510                        | crate::AtomicFunction::Exchange { compare: None }
511                ) {
512                    log::error!("Float32 atomic operation {:?} is not supported", fun);
513                    return Err(AtomicError::InvalidOperator(*fun)
514                        .with_span_handle(value, context.expressions)
515                        .into_other());
516                }
517                if !matches!(pointer_space, crate::AddressSpace::Storage { .. }) {
518                    log::error!(
519                        "Float32 atomic operations are only supported in the Storage address space"
520                    );
521                    return Err(AtomicError::InvalidAddressSpace(pointer_space)
522                        .with_span_handle(value, context.expressions)
523                        .into_other());
524                }
525            }
526            _ => {}
527        }
528
529        // The result expression must be appropriate to the operation.
530        match result {
531            Some(result) => {
532                // The `result` handle must refer to an `AtomicResult` expression.
533                let crate::Expression::AtomicResult {
534                    ty: result_ty,
535                    comparison,
536                } = context.expressions[result]
537                else {
538                    return Err(AtomicError::InvalidResultExpression(result)
539                        .with_span_handle(result, context.expressions)
540                        .into_other());
541                };
542
543                // Note that this expression has been visited by the proper kind
544                // of statement.
545                if !self.needs_visit.remove(result) {
546                    return Err(AtomicError::ResultAlreadyPopulated(result)
547                        .with_span_handle(result, context.expressions)
548                        .into_other());
549                }
550
551                // The constraints on the result type depend on the atomic function.
552                if let crate::AtomicFunction::Exchange {
553                    compare: Some(compare),
554                } = *fun
555                {
556                    // The comparison value must be a scalar of the same type as the
557                    // atomic we're operating on.
558                    let compare_inner =
559                        context.resolve_type_inner(compare, &self.valid_expression_set)?;
560                    if !compare_inner.non_struct_equivalent(value_inner, context.types) {
561                        log::error!(
562                            "Atomic exchange comparison has a different type from the value"
563                        );
564                        return Err(AtomicError::InvalidOperand(compare)
565                            .with_span_handle(compare, context.expressions)
566                            .into_other());
567                    }
568
569                    // The result expression must be an `__atomic_compare_exchange_result`
570                    // struct whose `old_value` member is of the same type as the atomic
571                    // we're operating on.
572                    let crate::TypeInner::Struct { ref members, .. } =
573                        context.types[result_ty].inner
574                    else {
575                        return Err(AtomicError::ResultTypeMismatch(result)
576                            .with_span_handle(result, context.expressions)
577                            .into_other());
578                    };
579                    if !validate_atomic_compare_exchange_struct(
580                        context.types,
581                        members,
582                        |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(pointer_scalar),
583                    ) {
584                        return Err(AtomicError::ResultTypeMismatch(result)
585                            .with_span_handle(result, context.expressions)
586                            .into_other());
587                    }
588
589                    // The result expression must be for a comparison operation.
590                    if !comparison {
591                        return Err(AtomicError::ResultExpressionNotExchange(result)
592                            .with_span_handle(result, context.expressions)
593                            .into_other());
594                    }
595                } else {
596                    // The result expression must be a scalar of the same type as the
597                    // atomic we're operating on.
598                    let result_inner = &context.types[result_ty].inner;
599                    if !result_inner.non_struct_equivalent(value_inner, context.types) {
600                        return Err(AtomicError::ResultTypeMismatch(result)
601                            .with_span_handle(result, context.expressions)
602                            .into_other());
603                    }
604
605                    // The result expression must not be for a comparison.
606                    if comparison {
607                        return Err(AtomicError::ResultExpressionExchange(result)
608                            .with_span_handle(result, context.expressions)
609                            .into_other());
610                    }
611                }
612                self.emit_expression(result, context)?;
613            }
614
615            None => {
616                // Exchange operations must always produce a value.
617                if let crate::AtomicFunction::Exchange { compare: None } = *fun {
618                    log::error!("Atomic exchange's value is unused");
619                    return Err(AtomicError::MissingReturnValue
620                        .with_span_static(span, "atomic exchange operation")
621                        .into_other());
622                }
623            }
624        }
625
626        Ok(())
627    }
628    fn validate_subgroup_operation(
629        &mut self,
630        op: &crate::SubgroupOperation,
631        collective_op: &crate::CollectiveOperation,
632        argument: Handle<crate::Expression>,
633        result: Handle<crate::Expression>,
634        context: &BlockContext,
635    ) -> Result<(), WithSpan<FunctionError>> {
636        let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
637
638        let (is_scalar, scalar) = match *argument_inner {
639            crate::TypeInner::Scalar(scalar) => (true, scalar),
640            crate::TypeInner::Vector { scalar, .. } => (false, scalar),
641            _ => {
642                log::error!("Subgroup operand type {:?}", argument_inner);
643                return Err(SubgroupError::InvalidOperand(argument)
644                    .with_span_handle(argument, context.expressions)
645                    .into_other());
646            }
647        };
648
649        use crate::ScalarKind as sk;
650        use crate::SubgroupOperation as sg;
651        match (scalar.kind, *op) {
652            (sk::Bool, sg::All | sg::Any) if is_scalar => {}
653            (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
654            (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
655
656            (_, _) => {
657                log::error!("Subgroup operand type {:?}", argument_inner);
658                return Err(SubgroupError::InvalidOperand(argument)
659                    .with_span_handle(argument, context.expressions)
660                    .into_other());
661            }
662        };
663
664        use crate::CollectiveOperation as co;
665        match (*collective_op, *op) {
666            (
667                co::Reduce,
668                sg::All
669                | sg::Any
670                | sg::Add
671                | sg::Mul
672                | sg::Min
673                | sg::Max
674                | sg::And
675                | sg::Or
676                | sg::Xor,
677            ) => {}
678            (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
679
680            (_, _) => {
681                return Err(SubgroupError::UnknownOperation.with_span().into_other());
682            }
683        };
684
685        self.emit_expression(result, context)?;
686        match context.expressions[result] {
687            crate::Expression::SubgroupOperationResult { ty }
688                if { &context.types[ty].inner == argument_inner } => {}
689            _ => {
690                return Err(SubgroupError::ResultTypeMismatch(result)
691                    .with_span_handle(result, context.expressions)
692                    .into_other())
693            }
694        }
695        Ok(())
696    }
697    fn validate_subgroup_gather(
698        &mut self,
699        mode: &crate::GatherMode,
700        argument: Handle<crate::Expression>,
701        result: Handle<crate::Expression>,
702        context: &BlockContext,
703    ) -> Result<(), WithSpan<FunctionError>> {
704        match *mode {
705            crate::GatherMode::BroadcastFirst => {}
706            crate::GatherMode::Broadcast(index)
707            | crate::GatherMode::Shuffle(index)
708            | crate::GatherMode::ShuffleDown(index)
709            | crate::GatherMode::ShuffleUp(index)
710            | crate::GatherMode::ShuffleXor(index)
711            | crate::GatherMode::QuadBroadcast(index) => {
712                let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?;
713                match *index_ty {
714                    crate::TypeInner::Scalar(crate::Scalar::U32) => {}
715                    _ => {
716                        log::error!(
717                            "Subgroup gather index type {:?}, expected unsigned int",
718                            index_ty
719                        );
720                        return Err(SubgroupError::InvalidOperand(argument)
721                            .with_span_handle(index, context.expressions)
722                            .into_other());
723                    }
724                }
725            }
726            crate::GatherMode::QuadSwap(_) => {}
727        }
728        match *mode {
729            crate::GatherMode::Broadcast(index) | crate::GatherMode::QuadBroadcast(index) => {
730                if !context.local_expr_kind.is_const(index) {
731                    return Err(SubgroupError::InvalidInvocationIdExprType(index)
732                        .with_span_handle(index, context.expressions)
733                        .into_other());
734                }
735            }
736            _ => {}
737        }
738        let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
739        if !matches!(*argument_inner,
740            crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
741            if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
742        ) {
743            log::error!("Subgroup gather operand type {:?}", argument_inner);
744            return Err(SubgroupError::InvalidOperand(argument)
745                .with_span_handle(argument, context.expressions)
746                .into_other());
747        }
748
749        self.emit_expression(result, context)?;
750        match context.expressions[result] {
751            crate::Expression::SubgroupOperationResult { ty }
752                if { &context.types[ty].inner == argument_inner } => {}
753            _ => {
754                return Err(SubgroupError::ResultTypeMismatch(result)
755                    .with_span_handle(result, context.expressions)
756                    .into_other())
757            }
758        }
759        Ok(())
760    }
761
762    fn validate_block_impl(
763        &mut self,
764        statements: &crate::Block,
765        context: &BlockContext,
766    ) -> Result<BlockInfo, WithSpan<FunctionError>> {
767        use crate::{AddressSpace, Statement as S, TypeInner as Ti};
768        let mut stages = super::ShaderStages::all();
769        for (statement, &span) in statements.span_iter() {
770            match *statement {
771                S::Emit(ref range) => {
772                    for handle in range.clone() {
773                        use crate::Expression as Ex;
774                        match context.expressions[handle] {
775                            Ex::Literal(_)
776                            | Ex::Constant(_)
777                            | Ex::Override(_)
778                            | Ex::ZeroValue(_)
779                            | Ex::Compose { .. }
780                            | Ex::Access { .. }
781                            | Ex::AccessIndex { .. }
782                            | Ex::Splat { .. }
783                            | Ex::Swizzle { .. }
784                            | Ex::FunctionArgument(_)
785                            | Ex::GlobalVariable(_)
786                            | Ex::LocalVariable(_)
787                            | Ex::Load { .. }
788                            | Ex::ImageSample { .. }
789                            | Ex::ImageLoad { .. }
790                            | Ex::ImageQuery { .. }
791                            | Ex::Unary { .. }
792                            | Ex::Binary { .. }
793                            | Ex::Select { .. }
794                            | Ex::Derivative { .. }
795                            | Ex::Relational { .. }
796                            | Ex::Math { .. }
797                            | Ex::As { .. }
798                            | Ex::ArrayLength(_)
799                            | Ex::RayQueryGetIntersection { .. }
800                            | Ex::RayQueryVertexPositions { .. } => {
801                                self.emit_expression(handle, context)?
802                            }
803                            Ex::CallResult(_)
804                            | Ex::AtomicResult { .. }
805                            | Ex::WorkGroupUniformLoadResult { .. }
806                            | Ex::RayQueryProceedResult
807                            | Ex::SubgroupBallotResult
808                            | Ex::SubgroupOperationResult { .. } => {
809                                return Err(FunctionError::EmitResult(handle)
810                                    .with_span_handle(handle, context.expressions));
811                            }
812                        }
813                    }
814                }
815                S::Block(ref block) => {
816                    let info = self.validate_block(block, context)?;
817                    stages &= info.stages;
818                }
819                S::If {
820                    condition,
821                    ref accept,
822                    ref reject,
823                } => {
824                    match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
825                        Ti::Scalar(crate::Scalar {
826                            kind: crate::ScalarKind::Bool,
827                            width: _,
828                        }) => {}
829                        _ => {
830                            return Err(FunctionError::InvalidIfType(condition)
831                                .with_span_handle(condition, context.expressions))
832                        }
833                    }
834                    stages &= self.validate_block(accept, context)?.stages;
835                    stages &= self.validate_block(reject, context)?.stages;
836                }
837                S::Switch {
838                    selector,
839                    ref cases,
840                } => {
841                    let uint = match context
842                        .resolve_type_inner(selector, &self.valid_expression_set)?
843                        .scalar_kind()
844                    {
845                        Some(crate::ScalarKind::Uint) => true,
846                        Some(crate::ScalarKind::Sint) => false,
847                        _ => {
848                            return Err(FunctionError::InvalidSwitchType(selector)
849                                .with_span_handle(selector, context.expressions))
850                        }
851                    };
852                    self.switch_values.clear();
853                    for case in cases {
854                        match case.value {
855                            crate::SwitchValue::I32(_) if !uint => {}
856                            crate::SwitchValue::U32(_) if uint => {}
857                            crate::SwitchValue::Default => {}
858                            _ => {
859                                return Err(FunctionError::ConflictingCaseType.with_span_static(
860                                    case.body
861                                        .span_iter()
862                                        .next()
863                                        .map_or(Default::default(), |(_, s)| *s),
864                                    "conflicting switch arm here",
865                                ));
866                            }
867                        };
868                        if !self.switch_values.insert(case.value) {
869                            return Err(match case.value {
870                                crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
871                                    .with_span_static(
872                                        case.body
873                                            .span_iter()
874                                            .next()
875                                            .map_or(Default::default(), |(_, s)| *s),
876                                        "duplicated switch arm here",
877                                    ),
878                                _ => FunctionError::ConflictingSwitchCase(case.value)
879                                    .with_span_static(
880                                        case.body
881                                            .span_iter()
882                                            .next()
883                                            .map_or(Default::default(), |(_, s)| *s),
884                                        "conflicting switch arm here",
885                                    ),
886                            });
887                        }
888                    }
889                    if !self.switch_values.contains(&crate::SwitchValue::Default) {
890                        return Err(FunctionError::MissingDefaultCase
891                            .with_span_static(span, "missing default case"));
892                    }
893                    if let Some(case) = cases.last() {
894                        if case.fall_through {
895                            return Err(FunctionError::LastCaseFallTrough.with_span_static(
896                                case.body
897                                    .span_iter()
898                                    .next()
899                                    .map_or(Default::default(), |(_, s)| *s),
900                                "bad switch arm here",
901                            ));
902                        }
903                    }
904                    let pass_through_abilities = context.abilities
905                        & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
906                    let sub_context =
907                        context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
908                    for case in cases {
909                        stages &= self.validate_block(&case.body, &sub_context)?.stages;
910                    }
911                }
912                S::Loop {
913                    ref body,
914                    ref continuing,
915                    break_if,
916                } => {
917                    // special handling for block scoping is needed here,
918                    // because the continuing{} block inherits the scope
919                    let base_expression_count = self.valid_expression_list.len();
920                    let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
921                    stages &= self
922                        .validate_block_impl(
923                            body,
924                            &context.with_abilities(
925                                pass_through_abilities
926                                    | ControlFlowAbility::BREAK
927                                    | ControlFlowAbility::CONTINUE,
928                            ),
929                        )?
930                        .stages;
931                    stages &= self
932                        .validate_block_impl(
933                            continuing,
934                            &context.with_abilities(ControlFlowAbility::empty()),
935                        )?
936                        .stages;
937
938                    if let Some(condition) = break_if {
939                        match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
940                            Ti::Scalar(crate::Scalar {
941                                kind: crate::ScalarKind::Bool,
942                                width: _,
943                            }) => {}
944                            _ => {
945                                return Err(FunctionError::InvalidIfType(condition)
946                                    .with_span_handle(condition, context.expressions))
947                            }
948                        }
949                    }
950
951                    for handle in self.valid_expression_list.drain(base_expression_count..) {
952                        self.valid_expression_set.remove(handle);
953                    }
954                }
955                S::Break => {
956                    if !context.abilities.contains(ControlFlowAbility::BREAK) {
957                        return Err(FunctionError::BreakOutsideOfLoopOrSwitch
958                            .with_span_static(span, "invalid break"));
959                    }
960                }
961                S::Continue => {
962                    if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
963                        return Err(FunctionError::ContinueOutsideOfLoop
964                            .with_span_static(span, "invalid continue"));
965                    }
966                }
967                S::Return { value } => {
968                    if !context.abilities.contains(ControlFlowAbility::RETURN) {
969                        return Err(FunctionError::InvalidReturnSpot
970                            .with_span_static(span, "invalid return"));
971                    }
972                    let value_ty = value
973                        .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
974                        .transpose()?;
975                    // We can't return pointers, but it seems best not to embed that
976                    // assumption here, so use `TypeInner::equivalent` for comparison.
977                    let okay = match (value_ty, context.return_type) {
978                        (None, None) => true,
979                        (Some(value_inner), Some(expected_ty)) => {
980                            context.compare_types(value_inner, &TypeResolution::Handle(expected_ty))
981                        }
982                        (_, _) => false,
983                    };
984
985                    if !okay {
986                        log::error!(
987                            "Returning {:?} where {:?} is expected",
988                            value_ty,
989                            context.return_type,
990                        );
991                        if let Some(handle) = value {
992                            return Err(FunctionError::InvalidReturnType {
993                                expression: value,
994                                expected_ty: context.return_type,
995                            }
996                            .with_span_handle(handle, context.expressions));
997                        } else {
998                            return Err(FunctionError::InvalidReturnType {
999                                expression: value,
1000                                expected_ty: context.return_type,
1001                            }
1002                            .with_span_static(span, "invalid return"));
1003                        }
1004                    }
1005                }
1006                S::Kill => {
1007                    stages &= super::ShaderStages::FRAGMENT;
1008                }
1009                S::ControlBarrier(barrier) | S::MemoryBarrier(barrier) => {
1010                    stages &= super::ShaderStages::COMPUTE;
1011                    if barrier.contains(crate::Barrier::SUB_GROUP) {
1012                        if !self.capabilities.contains(
1013                            super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
1014                        ) {
1015                            return Err(FunctionError::MissingCapability(
1016                                super::Capabilities::SUBGROUP
1017                                    | super::Capabilities::SUBGROUP_BARRIER,
1018                            )
1019                            .with_span_static(span, "missing capability for this operation"));
1020                        }
1021                        if !self
1022                            .subgroup_operations
1023                            .contains(super::SubgroupOperationSet::BASIC)
1024                        {
1025                            return Err(FunctionError::InvalidSubgroup(
1026                                SubgroupError::UnsupportedOperation(
1027                                    super::SubgroupOperationSet::BASIC,
1028                                ),
1029                            )
1030                            .with_span_static(span, "support for this operation is not present"));
1031                        }
1032                    }
1033                }
1034                S::Store { pointer, value } => {
1035                    let mut current = pointer;
1036                    loop {
1037                        match context.expressions[current] {
1038                            crate::Expression::Access { base, .. }
1039                            | crate::Expression::AccessIndex { base, .. } => current = base,
1040                            crate::Expression::LocalVariable(_)
1041                            | crate::Expression::GlobalVariable(_)
1042                            | crate::Expression::FunctionArgument(_) => break,
1043                            _ => {
1044                                return Err(FunctionError::InvalidStorePointer(current)
1045                                    .with_span_handle(pointer, context.expressions))
1046                            }
1047                        }
1048                    }
1049
1050                    let value_tr = context.resolve_type(value, &self.valid_expression_set)?;
1051                    let value_ty = value_tr.inner_with(context.types);
1052                    match *value_ty {
1053                        Ti::Image { .. } | Ti::Sampler { .. } => {
1054                            return Err(FunctionError::InvalidStoreTexture {
1055                                actual: value,
1056                                actual_ty: value_ty.clone(),
1057                            }
1058                            .with_span_context((
1059                                context.expressions.get_span(value),
1060                                format!("this value is of type {value_ty:?}"),
1061                            ))
1062                            .with_span(span, "expects a texture argument"));
1063                        }
1064                        _ => {}
1065                    }
1066
1067                    let pointer_ty = context.resolve_pointer_type(pointer);
1068                    let pointer_base_tr = pointer_ty.pointer_base_type();
1069                    let pointer_base_ty = pointer_base_tr
1070                        .as_ref()
1071                        .map(|ty| ty.inner_with(context.types));
1072                    let good = if let Some(&Ti::Atomic(ref scalar)) = pointer_base_ty {
1073                        // The Naga IR allows storing a scalar to an atomic.
1074                        *value_ty == Ti::Scalar(*scalar)
1075                    } else if let Some(tr) = pointer_base_tr {
1076                        context.compare_types(value_tr, &tr)
1077                    } else {
1078                        false
1079                    };
1080
1081                    if !good {
1082                        return Err(FunctionError::InvalidStoreTypes { pointer, value }
1083                            .with_span()
1084                            .with_handle(pointer, context.expressions)
1085                            .with_handle(value, context.expressions));
1086                    }
1087
1088                    if let Some(space) = pointer_ty.pointer_space() {
1089                        if !space.access().contains(crate::StorageAccess::STORE) {
1090                            return Err(FunctionError::InvalidStorePointer(pointer)
1091                                .with_span_static(
1092                                    context.expressions.get_span(pointer),
1093                                    "writing to this location is not permitted",
1094                                ));
1095                        }
1096                    }
1097                }
1098                S::ImageStore {
1099                    image,
1100                    coordinate,
1101                    array_index,
1102                    value,
1103                } => {
1104                    //Note: this code uses a lot of `FunctionError::InvalidImageStore`,
1105                    // and could probably be refactored.
1106                    let global_var;
1107                    let image_ty;
1108                    match *context.get_expression(image) {
1109                        crate::Expression::GlobalVariable(var_handle) => {
1110                            global_var = &context.global_vars[var_handle];
1111                            image_ty = global_var.ty;
1112                        }
1113                        // The `image` operand is indexing into a binding array,
1114                        // so punch through the `Access`* expression and look at
1115                        // the global behind it.
1116                        crate::Expression::Access { base, .. }
1117                        | crate::Expression::AccessIndex { base, .. } => {
1118                            let crate::Expression::GlobalVariable(var_handle) =
1119                                *context.get_expression(base)
1120                            else {
1121                                return Err(FunctionError::InvalidImageStore(
1122                                    ExpressionError::ExpectedGlobalVariable,
1123                                )
1124                                .with_span_handle(image, context.expressions));
1125                            };
1126                            global_var = &context.global_vars[var_handle];
1127
1128                            // The global variable must be a binding array.
1129                            let Ti::BindingArray { base, .. } = context.types[global_var.ty].inner
1130                            else {
1131                                return Err(FunctionError::InvalidImageStore(
1132                                    ExpressionError::ExpectedBindingArrayType(global_var.ty),
1133                                )
1134                                .with_span_handle(global_var.ty, context.types));
1135                            };
1136
1137                            image_ty = base;
1138                        }
1139                        _ => {
1140                            return Err(FunctionError::InvalidImageStore(
1141                                ExpressionError::ExpectedGlobalVariable,
1142                            )
1143                            .with_span_handle(image, context.expressions))
1144                        }
1145                    };
1146
1147                    // The `image` operand must be an `Image`.
1148                    let Ti::Image {
1149                        class,
1150                        arrayed,
1151                        dim,
1152                    } = context.types[image_ty].inner
1153                    else {
1154                        return Err(FunctionError::InvalidImageStore(
1155                            ExpressionError::ExpectedImageType(global_var.ty),
1156                        )
1157                        .with_span()
1158                        .with_handle(global_var.ty, context.types)
1159                        .with_handle(image, context.expressions));
1160                    };
1161
1162                    // It had better be a storage image, since we're writing to it.
1163                    let crate::ImageClass::Storage { format, .. } = class else {
1164                        return Err(FunctionError::InvalidImageStore(
1165                            ExpressionError::InvalidImageClass(class),
1166                        )
1167                        .with_span_handle(image, context.expressions));
1168                    };
1169
1170                    // The `coordinate` operand must be a vector of the appropriate size.
1171                    if context
1172                        .resolve_type_inner(coordinate, &self.valid_expression_set)?
1173                        .image_storage_coordinates()
1174                        .is_none_or(|coord_dim| coord_dim != dim)
1175                    {
1176                        return Err(FunctionError::InvalidImageStore(
1177                            ExpressionError::InvalidImageCoordinateType(dim, coordinate),
1178                        )
1179                        .with_span_handle(coordinate, context.expressions));
1180                    }
1181
1182                    // The `array_index` operand should be present if and only if
1183                    // the image itself is arrayed.
1184                    if arrayed != array_index.is_some() {
1185                        return Err(FunctionError::InvalidImageStore(
1186                            ExpressionError::InvalidImageArrayIndex,
1187                        )
1188                        .with_span_handle(coordinate, context.expressions));
1189                    }
1190
1191                    // If present, `array_index` must be a scalar integer type.
1192                    if let Some(expr) = array_index {
1193                        if !matches!(
1194                            *context.resolve_type_inner(expr, &self.valid_expression_set)?,
1195                            Ti::Scalar(crate::Scalar {
1196                                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
1197                                width: _,
1198                            })
1199                        ) {
1200                            return Err(FunctionError::InvalidImageStore(
1201                                ExpressionError::InvalidImageArrayIndexType(expr),
1202                            )
1203                            .with_span_handle(expr, context.expressions));
1204                        }
1205                    }
1206
1207                    let value_ty = crate::TypeInner::Vector {
1208                        size: crate::VectorSize::Quad,
1209                        scalar: format.into(),
1210                    };
1211
1212                    // The value we're writing had better match the scalar type
1213                    // for `image`'s format.
1214                    let actual_value_ty =
1215                        context.resolve_type_inner(value, &self.valid_expression_set)?;
1216                    if actual_value_ty != &value_ty {
1217                        return Err(FunctionError::InvalidStoreValue {
1218                            actual: value,
1219                            actual_ty: actual_value_ty.clone(),
1220                            expected_ty: value_ty.clone(),
1221                        }
1222                        .with_span_context((
1223                            context.expressions.get_span(value),
1224                            format!("this value is of type {actual_value_ty:?}"),
1225                        ))
1226                        .with_span(
1227                            span,
1228                            format!("expects a value argument of type {value_ty:?}"),
1229                        ));
1230                    }
1231                }
1232                S::Call {
1233                    function,
1234                    ref arguments,
1235                    result,
1236                } => match self.validate_call(function, arguments, result, context) {
1237                    Ok(callee_stages) => stages &= callee_stages,
1238                    Err(error) => {
1239                        return Err(error.and_then(|error| {
1240                            FunctionError::InvalidCall { function, error }
1241                                .with_span_static(span, "invalid function call")
1242                        }))
1243                    }
1244                },
1245                S::Atomic {
1246                    pointer,
1247                    ref fun,
1248                    value,
1249                    result,
1250                } => {
1251                    self.validate_atomic(pointer, fun, value, result, span, context)?;
1252                }
1253                S::ImageAtomic {
1254                    image,
1255                    coordinate,
1256                    array_index,
1257                    fun,
1258                    value,
1259                } => {
1260                    let var = match *context.get_expression(image) {
1261                        crate::Expression::GlobalVariable(var_handle) => {
1262                            &context.global_vars[var_handle]
1263                        }
1264                        // We're looking at a binding index situation, so punch through the index and look at the global behind it.
1265                        crate::Expression::Access { base, .. }
1266                        | crate::Expression::AccessIndex { base, .. } => {
1267                            match *context.get_expression(base) {
1268                                crate::Expression::GlobalVariable(var_handle) => {
1269                                    &context.global_vars[var_handle]
1270                                }
1271                                _ => {
1272                                    return Err(FunctionError::InvalidImageAtomic(
1273                                        ExpressionError::ExpectedGlobalVariable,
1274                                    )
1275                                    .with_span_handle(image, context.expressions))
1276                                }
1277                            }
1278                        }
1279                        _ => {
1280                            return Err(FunctionError::InvalidImageAtomic(
1281                                ExpressionError::ExpectedGlobalVariable,
1282                            )
1283                            .with_span_handle(image, context.expressions))
1284                        }
1285                    };
1286
1287                    // Punch through a binding array to get the underlying type
1288                    let global_ty = match context.types[var.ty].inner {
1289                        Ti::BindingArray { base, .. } => &context.types[base].inner,
1290                        ref inner => inner,
1291                    };
1292
1293                    let value_ty = match *global_ty {
1294                        Ti::Image {
1295                            class,
1296                            arrayed,
1297                            dim,
1298                        } => {
1299                            match context
1300                                .resolve_type_inner(coordinate, &self.valid_expression_set)?
1301                                .image_storage_coordinates()
1302                            {
1303                                Some(coord_dim) if coord_dim == dim => {}
1304                                _ => {
1305                                    return Err(FunctionError::InvalidImageAtomic(
1306                                        ExpressionError::InvalidImageCoordinateType(
1307                                            dim, coordinate,
1308                                        ),
1309                                    )
1310                                    .with_span_handle(coordinate, context.expressions));
1311                                }
1312                            };
1313                            if arrayed != array_index.is_some() {
1314                                return Err(FunctionError::InvalidImageAtomic(
1315                                    ExpressionError::InvalidImageArrayIndex,
1316                                )
1317                                .with_span_handle(coordinate, context.expressions));
1318                            }
1319                            if let Some(expr) = array_index {
1320                                match *context
1321                                    .resolve_type_inner(expr, &self.valid_expression_set)?
1322                                {
1323                                    Ti::Scalar(crate::Scalar {
1324                                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
1325                                        width: _,
1326                                    }) => {}
1327                                    _ => {
1328                                        return Err(FunctionError::InvalidImageAtomic(
1329                                            ExpressionError::InvalidImageArrayIndexType(expr),
1330                                        )
1331                                        .with_span_handle(expr, context.expressions));
1332                                    }
1333                                }
1334                            }
1335                            match class {
1336                                crate::ImageClass::Storage { format, access } => {
1337                                    if !access.contains(crate::StorageAccess::ATOMIC) {
1338                                        return Err(FunctionError::InvalidImageAtomic(
1339                                            ExpressionError::InvalidImageStorageAccess(access),
1340                                        )
1341                                        .with_span_handle(image, context.expressions));
1342                                    }
1343                                    match format {
1344                                        crate::StorageFormat::R64Uint => {
1345                                            if !self.capabilities.intersects(
1346                                                super::Capabilities::TEXTURE_INT64_ATOMIC,
1347                                            ) {
1348                                                return Err(FunctionError::MissingCapability(
1349                                                    super::Capabilities::TEXTURE_INT64_ATOMIC,
1350                                                )
1351                                                .with_span_static(
1352                                                    span,
1353                                                    "missing capability for this operation",
1354                                                ));
1355                                            }
1356                                            match fun {
1357                                                crate::AtomicFunction::Min
1358                                                | crate::AtomicFunction::Max => {}
1359                                                _ => {
1360                                                    return Err(
1361                                                        FunctionError::InvalidImageAtomicFunction(
1362                                                            fun,
1363                                                        )
1364                                                        .with_span_handle(
1365                                                            image,
1366                                                            context.expressions,
1367                                                        ),
1368                                                    );
1369                                                }
1370                                            }
1371                                        }
1372                                        crate::StorageFormat::R32Sint
1373                                        | crate::StorageFormat::R32Uint => {
1374                                            if !self
1375                                                .capabilities
1376                                                .intersects(super::Capabilities::TEXTURE_ATOMIC)
1377                                            {
1378                                                return Err(FunctionError::MissingCapability(
1379                                                    super::Capabilities::TEXTURE_ATOMIC,
1380                                                )
1381                                                .with_span_static(
1382                                                    span,
1383                                                    "missing capability for this operation",
1384                                                ));
1385                                            }
1386                                            match fun {
1387                                                crate::AtomicFunction::Add
1388                                                | crate::AtomicFunction::And
1389                                                | crate::AtomicFunction::ExclusiveOr
1390                                                | crate::AtomicFunction::InclusiveOr
1391                                                | crate::AtomicFunction::Min
1392                                                | crate::AtomicFunction::Max => {}
1393                                                _ => {
1394                                                    return Err(
1395                                                        FunctionError::InvalidImageAtomicFunction(
1396                                                            fun,
1397                                                        )
1398                                                        .with_span_handle(
1399                                                            image,
1400                                                            context.expressions,
1401                                                        ),
1402                                                    );
1403                                                }
1404                                            }
1405                                        }
1406                                        _ => {
1407                                            return Err(FunctionError::InvalidImageAtomic(
1408                                                ExpressionError::InvalidImageFormat(format),
1409                                            )
1410                                            .with_span_handle(image, context.expressions));
1411                                        }
1412                                    }
1413                                    crate::TypeInner::Scalar(format.into())
1414                                }
1415                                _ => {
1416                                    return Err(FunctionError::InvalidImageAtomic(
1417                                        ExpressionError::InvalidImageClass(class),
1418                                    )
1419                                    .with_span_handle(image, context.expressions));
1420                                }
1421                            }
1422                        }
1423                        _ => {
1424                            return Err(FunctionError::InvalidImageAtomic(
1425                                ExpressionError::ExpectedImageType(var.ty),
1426                            )
1427                            .with_span()
1428                            .with_handle(var.ty, context.types)
1429                            .with_handle(image, context.expressions))
1430                        }
1431                    };
1432
1433                    if *context.resolve_type_inner(value, &self.valid_expression_set)? != value_ty {
1434                        return Err(FunctionError::InvalidImageAtomicValue(value)
1435                            .with_span_handle(value, context.expressions));
1436                    }
1437                }
1438                S::WorkGroupUniformLoad { pointer, result } => {
1439                    stages &= super::ShaderStages::COMPUTE;
1440                    let pointer_inner =
1441                        context.resolve_type_inner(pointer, &self.valid_expression_set)?;
1442                    match *pointer_inner {
1443                        Ti::Pointer {
1444                            space: AddressSpace::WorkGroup,
1445                            ..
1446                        } => {}
1447                        Ti::ValuePointer {
1448                            space: AddressSpace::WorkGroup,
1449                            ..
1450                        } => {}
1451                        _ => {
1452                            return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1453                                .with_span_static(span, "WorkGroupUniformLoad"))
1454                        }
1455                    }
1456                    self.emit_expression(result, context)?;
1457                    let ty = match &context.expressions[result] {
1458                        &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
1459                        _ => {
1460                            return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
1461                                result,
1462                            )
1463                            .with_span_static(span, "WorkGroupUniformLoad"));
1464                        }
1465                    };
1466                    let expected_pointer_inner = Ti::Pointer {
1467                        base: ty,
1468                        space: AddressSpace::WorkGroup,
1469                    };
1470                    if !expected_pointer_inner.non_struct_equivalent(pointer_inner, context.types) {
1471                        return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1472                            .with_span_static(span, "WorkGroupUniformLoad"));
1473                    }
1474                }
1475                S::RayQuery { query, ref fun } => {
1476                    let query_var = match *context.get_expression(query) {
1477                        crate::Expression::LocalVariable(var) => &context.local_vars[var],
1478                        ref other => {
1479                            log::error!("Unexpected ray query expression {other:?}");
1480                            return Err(FunctionError::InvalidRayQueryExpression(query)
1481                                .with_span_static(span, "invalid query expression"));
1482                        }
1483                    };
1484                    let rq_vertex_return = match context.types[query_var.ty].inner {
1485                        Ti::RayQuery { vertex_return } => vertex_return,
1486                        ref other => {
1487                            log::error!("Unexpected ray query type {other:?}");
1488                            return Err(FunctionError::InvalidRayQueryType(query_var.ty)
1489                                .with_span_static(span, "invalid query type"));
1490                        }
1491                    };
1492                    match *fun {
1493                        crate::RayQueryFunction::Initialize {
1494                            acceleration_structure,
1495                            descriptor,
1496                        } => {
1497                            match *context.resolve_type_inner(
1498                                acceleration_structure,
1499                                &self.valid_expression_set,
1500                            )? {
1501                                Ti::AccelerationStructure { vertex_return } => {
1502                                    if (!vertex_return) && rq_vertex_return {
1503                                        return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure"));
1504                                    }
1505                                }
1506                                _ => {
1507                                    return Err(FunctionError::InvalidAccelerationStructure(
1508                                        acceleration_structure,
1509                                    )
1510                                    .with_span_static(span, "invalid acceleration structure"))
1511                                }
1512                            }
1513                            let desc_ty_given = context
1514                                .resolve_type_inner(descriptor, &self.valid_expression_set)?;
1515                            let desc_ty_expected = context
1516                                .special_types
1517                                .ray_desc
1518                                .map(|handle| &context.types[handle].inner);
1519                            if Some(desc_ty_given) != desc_ty_expected {
1520                                return Err(FunctionError::InvalidRayDescriptor(descriptor)
1521                                    .with_span_static(span, "invalid ray descriptor"));
1522                            }
1523                        }
1524                        crate::RayQueryFunction::Proceed { result } => {
1525                            self.emit_expression(result, context)?;
1526                        }
1527                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1528                            match *context.resolve_type_inner(hit_t, &self.valid_expression_set)? {
1529                                Ti::Scalar(crate::Scalar {
1530                                    kind: crate::ScalarKind::Float,
1531                                    width: _,
1532                                }) => {}
1533                                _ => {
1534                                    return Err(FunctionError::InvalidHitDistanceType(hit_t)
1535                                        .with_span_static(span, "invalid hit_t"))
1536                                }
1537                            }
1538                        }
1539                        crate::RayQueryFunction::ConfirmIntersection => {}
1540                        crate::RayQueryFunction::Terminate => {}
1541                    }
1542                }
1543                S::SubgroupBallot { result, predicate } => {
1544                    stages &= self.subgroup_stages;
1545                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1546                        return Err(FunctionError::MissingCapability(
1547                            super::Capabilities::SUBGROUP,
1548                        )
1549                        .with_span_static(span, "missing capability for this operation"));
1550                    }
1551                    if !self
1552                        .subgroup_operations
1553                        .contains(super::SubgroupOperationSet::BALLOT)
1554                    {
1555                        return Err(FunctionError::InvalidSubgroup(
1556                            SubgroupError::UnsupportedOperation(
1557                                super::SubgroupOperationSet::BALLOT,
1558                            ),
1559                        )
1560                        .with_span_static(span, "support for this operation is not present"));
1561                    }
1562                    if let Some(predicate) = predicate {
1563                        let predicate_inner =
1564                            context.resolve_type_inner(predicate, &self.valid_expression_set)?;
1565                        if !matches!(
1566                            *predicate_inner,
1567                            crate::TypeInner::Scalar(crate::Scalar::BOOL,)
1568                        ) {
1569                            log::error!(
1570                                "Subgroup ballot predicate type {:?} expected bool",
1571                                predicate_inner
1572                            );
1573                            return Err(SubgroupError::InvalidOperand(predicate)
1574                                .with_span_handle(predicate, context.expressions)
1575                                .into_other());
1576                        }
1577                    }
1578                    self.emit_expression(result, context)?;
1579                }
1580                S::SubgroupCollectiveOperation {
1581                    ref op,
1582                    ref collective_op,
1583                    argument,
1584                    result,
1585                } => {
1586                    stages &= self.subgroup_stages;
1587                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1588                        return Err(FunctionError::MissingCapability(
1589                            super::Capabilities::SUBGROUP,
1590                        )
1591                        .with_span_static(span, "missing capability for this operation"));
1592                    }
1593                    let operation = op.required_operations();
1594                    if !self.subgroup_operations.contains(operation) {
1595                        return Err(FunctionError::InvalidSubgroup(
1596                            SubgroupError::UnsupportedOperation(operation),
1597                        )
1598                        .with_span_static(span, "support for this operation is not present"));
1599                    }
1600                    self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
1601                }
1602                S::SubgroupGather {
1603                    ref mode,
1604                    argument,
1605                    result,
1606                } => {
1607                    stages &= self.subgroup_stages;
1608                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1609                        return Err(FunctionError::MissingCapability(
1610                            super::Capabilities::SUBGROUP,
1611                        )
1612                        .with_span_static(span, "missing capability for this operation"));
1613                    }
1614                    let operation = mode.required_operations();
1615                    if !self.subgroup_operations.contains(operation) {
1616                        return Err(FunctionError::InvalidSubgroup(
1617                            SubgroupError::UnsupportedOperation(operation),
1618                        )
1619                        .with_span_static(span, "support for this operation is not present"));
1620                    }
1621                    self.validate_subgroup_gather(mode, argument, result, context)?;
1622                }
1623            }
1624        }
1625        Ok(BlockInfo { stages })
1626    }
1627
1628    fn validate_block(
1629        &mut self,
1630        statements: &crate::Block,
1631        context: &BlockContext,
1632    ) -> Result<BlockInfo, WithSpan<FunctionError>> {
1633        let base_expression_count = self.valid_expression_list.len();
1634        let info = self.validate_block_impl(statements, context)?;
1635        for handle in self.valid_expression_list.drain(base_expression_count..) {
1636            self.valid_expression_set.remove(handle);
1637        }
1638        Ok(info)
1639    }
1640
1641    fn validate_local_var(
1642        &self,
1643        var: &crate::LocalVariable,
1644        gctx: crate::proc::GlobalCtx,
1645        fun_info: &FunctionInfo,
1646        local_expr_kind: &crate::proc::ExpressionKindTracker,
1647    ) -> Result<(), LocalVariableError> {
1648        log::debug!("var {:?}", var);
1649        let type_info = self
1650            .types
1651            .get(var.ty.index())
1652            .ok_or(LocalVariableError::InvalidType(var.ty))?;
1653        if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) {
1654            return Err(LocalVariableError::InvalidType(var.ty));
1655        }
1656
1657        if let Some(init) = var.init {
1658            if !gctx.compare_types(&TypeResolution::Handle(var.ty), &fun_info[init].ty) {
1659                return Err(LocalVariableError::InitializerType);
1660            }
1661
1662            if !local_expr_kind.is_const_or_override(init) {
1663                return Err(LocalVariableError::NonConstOrOverrideInitializer);
1664            }
1665        }
1666
1667        Ok(())
1668    }
1669
1670    pub(super) fn validate_function(
1671        &mut self,
1672        fun: &crate::Function,
1673        module: &crate::Module,
1674        mod_info: &ModuleInfo,
1675        entry_point: bool,
1676    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1677        let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
1678
1679        let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
1680
1681        for (var_handle, var) in fun.local_variables.iter() {
1682            self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
1683                .map_err(|source| {
1684                    FunctionError::LocalVariable {
1685                        handle: var_handle,
1686                        name: var.name.clone().unwrap_or_default(),
1687                        source,
1688                    }
1689                    .with_span_handle(var.ty, &module.types)
1690                    .with_handle(var_handle, &fun.local_variables)
1691                })?;
1692        }
1693
1694        for (index, argument) in fun.arguments.iter().enumerate() {
1695            match module.types[argument.ty].inner.pointer_space() {
1696                Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
1697                Some(other) => {
1698                    return Err(FunctionError::InvalidArgumentPointerSpace {
1699                        index,
1700                        name: argument.name.clone().unwrap_or_default(),
1701                        space: other,
1702                    }
1703                    .with_span_handle(argument.ty, &module.types))
1704                }
1705            }
1706            // Check for the least informative error last.
1707            if !self.types[argument.ty.index()]
1708                .flags
1709                .contains(super::TypeFlags::ARGUMENT)
1710            {
1711                return Err(FunctionError::InvalidArgumentType {
1712                    index,
1713                    name: argument.name.clone().unwrap_or_default(),
1714                }
1715                .with_span_handle(argument.ty, &module.types));
1716            }
1717
1718            if !entry_point && argument.binding.is_some() {
1719                return Err(FunctionError::PipelineInputRegularFunction {
1720                    name: argument.name.clone().unwrap_or_default(),
1721                }
1722                .with_span_handle(argument.ty, &module.types));
1723            }
1724        }
1725
1726        if let Some(ref result) = fun.result {
1727            if !self.types[result.ty.index()]
1728                .flags
1729                .contains(super::TypeFlags::CONSTRUCTIBLE)
1730            {
1731                return Err(FunctionError::NonConstructibleReturnType
1732                    .with_span_handle(result.ty, &module.types));
1733            }
1734
1735            if !entry_point && result.binding.is_some() {
1736                return Err(FunctionError::PipelineOutputRegularFunction
1737                    .with_span_handle(result.ty, &module.types));
1738            }
1739        }
1740
1741        self.valid_expression_set.clear_for_arena(&fun.expressions);
1742        self.valid_expression_list.clear();
1743        self.needs_visit.clear_for_arena(&fun.expressions);
1744        for (handle, expr) in fun.expressions.iter() {
1745            if expr.needs_pre_emit() {
1746                self.valid_expression_set.insert(handle);
1747            }
1748            if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1749                // Mark expressions that need to be visited by a particular kind of
1750                // statement.
1751                if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } =
1752                    *expr
1753                {
1754                    self.needs_visit.insert(handle);
1755                }
1756
1757                match self.validate_expression(
1758                    handle,
1759                    expr,
1760                    fun,
1761                    module,
1762                    &info,
1763                    mod_info,
1764                    &local_expr_kind,
1765                ) {
1766                    Ok(stages) => info.available_stages &= stages,
1767                    Err(source) => {
1768                        return Err(FunctionError::Expression { handle, source }
1769                            .with_span_handle(handle, &fun.expressions))
1770                    }
1771                }
1772            }
1773        }
1774
1775        if self.flags.contains(super::ValidationFlags::BLOCKS) {
1776            let stages = self
1777                .validate_block(
1778                    &fun.body,
1779                    &BlockContext::new(fun, module, &info, &mod_info.functions, &local_expr_kind),
1780                )?
1781                .stages;
1782            info.available_stages &= stages;
1783
1784            if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1785                if let Some(handle) = self.needs_visit.iter().next() {
1786                    return Err(FunctionError::UnvisitedExpression(handle)
1787                        .with_span_handle(handle, &fun.expressions));
1788                }
1789            }
1790        }
1791        Ok(info)
1792    }
1793}