naga/valid/
function.rs

1use alloc::{format, string::String};
2
3use super::validate_atomic_compare_exchange_struct;
4use super::{
5    analyzer::{UniformityDisruptor, UniformityRequirements},
6    ExpressionError, FunctionInfo, ModuleInfo,
7};
8use crate::arena::{Arena, UniqueArena};
9use crate::arena::{Handle, HandleSet};
10use crate::proc::TypeResolution;
11use crate::span::WithSpan;
12use crate::span::{AddSpan as _, MapErrWithSpan as _};
13
14#[derive(Clone, Debug, thiserror::Error)]
15#[cfg_attr(test, derive(PartialEq))]
16pub enum CallError {
17    #[error("Argument {index} expression is invalid")]
18    Argument {
19        index: usize,
20        source: ExpressionError,
21    },
22    #[error("Result expression {0:?} has already been introduced earlier")]
23    ResultAlreadyInScope(Handle<crate::Expression>),
24    #[error("Result expression {0:?} is populated by multiple `Call` statements")]
25    ResultAlreadyPopulated(Handle<crate::Expression>),
26    #[error("Requires {required} arguments, but {seen} are provided")]
27    ArgumentCount { required: usize, seen: usize },
28    #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
29    ArgumentType {
30        index: usize,
31        required: Handle<crate::Type>,
32        seen_expression: Handle<crate::Expression>,
33    },
34    #[error("The emitted expression doesn't match the call")]
35    ExpressionMismatch(Option<Handle<crate::Expression>>),
36}
37
38#[derive(Clone, Debug, thiserror::Error)]
39#[cfg_attr(test, derive(PartialEq))]
40pub enum AtomicError {
41    #[error("Pointer {0:?} to atomic is invalid.")]
42    InvalidPointer(Handle<crate::Expression>),
43    #[error("Address space {0:?} is not supported.")]
44    InvalidAddressSpace(crate::AddressSpace),
45    #[error("Operand {0:?} has invalid type.")]
46    InvalidOperand(Handle<crate::Expression>),
47    #[error("Operator {0:?} is not supported.")]
48    InvalidOperator(crate::AtomicFunction),
49    #[error("Result expression {0:?} is not an `AtomicResult` expression")]
50    InvalidResultExpression(Handle<crate::Expression>),
51    #[error("Result expression {0:?} is marked as an `exchange`")]
52    ResultExpressionExchange(Handle<crate::Expression>),
53    #[error("Result expression {0:?} is not marked as an `exchange`")]
54    ResultExpressionNotExchange(Handle<crate::Expression>),
55    #[error("Result type for {0:?} doesn't match the statement")]
56    ResultTypeMismatch(Handle<crate::Expression>),
57    #[error("Exchange operations must return a value")]
58    MissingReturnValue,
59    #[error("Capability {0:?} is required")]
60    MissingCapability(super::Capabilities),
61    #[error("Result expression {0:?} is populated by multiple `Atomic` statements")]
62    ResultAlreadyPopulated(Handle<crate::Expression>),
63}
64
65#[derive(Clone, Debug, thiserror::Error)]
66#[cfg_attr(test, derive(PartialEq))]
67pub enum SubgroupError {
68    #[error("Operand {0:?} has invalid type.")]
69    InvalidOperand(Handle<crate::Expression>),
70    #[error("Result type for {0:?} doesn't match the statement")]
71    ResultTypeMismatch(Handle<crate::Expression>),
72    #[error("Support for subgroup operation {0:?} is required")]
73    UnsupportedOperation(super::SubgroupOperationSet),
74    #[error("Unknown operation")]
75    UnknownOperation,
76}
77
78#[derive(Clone, Debug, thiserror::Error)]
79#[cfg_attr(test, derive(PartialEq))]
80pub enum LocalVariableError {
81    #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
82    InvalidType(Handle<crate::Type>),
83    #[error("Initializer doesn't match the variable type")]
84    InitializerType,
85    #[error("Initializer is not a const or override expression")]
86    NonConstOrOverrideInitializer,
87}
88
89#[derive(Clone, Debug, thiserror::Error)]
90#[cfg_attr(test, derive(PartialEq))]
91pub enum FunctionError {
92    #[error("Expression {handle:?} is invalid")]
93    Expression {
94        handle: Handle<crate::Expression>,
95        source: ExpressionError,
96    },
97    #[error("Expression {0:?} can't be introduced - it's already in scope")]
98    ExpressionAlreadyInScope(Handle<crate::Expression>),
99    #[error("Local variable {handle:?} '{name}' is invalid")]
100    LocalVariable {
101        handle: Handle<crate::LocalVariable>,
102        name: String,
103        source: LocalVariableError,
104    },
105    #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
106    InvalidArgumentType { index: usize, name: String },
107    #[error("The function's given return type cannot be returned from functions")]
108    NonConstructibleReturnType,
109    #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
110    InvalidArgumentPointerSpace {
111        index: usize,
112        name: String,
113        space: crate::AddressSpace,
114    },
115    #[error("There are instructions after `return`/`break`/`continue`")]
116    InstructionsAfterReturn,
117    #[error("The `break` is used outside of a `loop` or `switch` context")]
118    BreakOutsideOfLoopOrSwitch,
119    #[error("The `continue` is used outside of a `loop` context")]
120    ContinueOutsideOfLoop,
121    #[error("The `return` is called within a `continuing` block")]
122    InvalidReturnSpot,
123    #[error("The `return` value {0:?} does not match the function return value")]
124    InvalidReturnType(Option<Handle<crate::Expression>>),
125    #[error("The `if` condition {0:?} is not a boolean scalar")]
126    InvalidIfType(Handle<crate::Expression>),
127    #[error("The `switch` value {0:?} is not an integer scalar")]
128    InvalidSwitchType(Handle<crate::Expression>),
129    #[error("Multiple `switch` cases for {0:?} are present")]
130    ConflictingSwitchCase(crate::SwitchValue),
131    #[error("The `switch` contains cases with conflicting types")]
132    ConflictingCaseType,
133    #[error("The `switch` is missing a `default` case")]
134    MissingDefaultCase,
135    #[error("Multiple `default` cases are present")]
136    MultipleDefaultCases,
137    #[error("The last `switch` case contains a `fallthrough`")]
138    LastCaseFallTrough,
139    #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
140    InvalidStorePointer(Handle<crate::Expression>),
141    #[error("Image store texture parameter type mismatch")]
142    InvalidStoreTexture {
143        actual: Handle<crate::Expression>,
144        actual_ty: crate::TypeInner,
145    },
146    #[error("Image store value parameter type mismatch")]
147    InvalidStoreValue {
148        actual: Handle<crate::Expression>,
149        actual_ty: crate::TypeInner,
150        expected_ty: crate::TypeInner,
151    },
152    #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")]
153    InvalidStoreTypes {
154        pointer: Handle<crate::Expression>,
155        value: Handle<crate::Expression>,
156    },
157    #[error("Image store parameters are invalid")]
158    InvalidImageStore(#[source] ExpressionError),
159    #[error("Image atomic parameters are invalid")]
160    InvalidImageAtomic(#[source] ExpressionError),
161    #[error("Image atomic function is invalid")]
162    InvalidImageAtomicFunction(crate::AtomicFunction),
163    #[error("Image atomic value is invalid")]
164    InvalidImageAtomicValue(Handle<crate::Expression>),
165    #[error("Call to {function:?} is invalid")]
166    InvalidCall {
167        function: Handle<crate::Function>,
168        #[source]
169        error: CallError,
170    },
171    #[error("Atomic operation is invalid")]
172    InvalidAtomic(#[from] AtomicError),
173    #[error("Ray Query {0:?} is not a local variable")]
174    InvalidRayQueryExpression(Handle<crate::Expression>),
175    #[error("Acceleration structure {0:?} is not a matching expression")]
176    InvalidAccelerationStructure(Handle<crate::Expression>),
177    #[error(
178        "Acceleration structure {0:?} is missing flag vertex_return while Ray Query {1:?} does"
179    )]
180    MissingAccelerationStructureVertexReturn(Handle<crate::Expression>, Handle<crate::Expression>),
181    #[error("Ray Query {0:?} is missing flag vertex_return")]
182    MissingRayQueryVertexReturn(Handle<crate::Expression>),
183    #[error("Ray descriptor {0:?} is not a matching expression")]
184    InvalidRayDescriptor(Handle<crate::Expression>),
185    #[error("Ray Query {0:?} does not have a matching type")]
186    InvalidRayQueryType(Handle<crate::Type>),
187    #[error("Hit distance {0:?} must be an f32")]
188    InvalidHitDistanceType(Handle<crate::Expression>),
189    #[error("Shader requires capability {0:?}")]
190    MissingCapability(super::Capabilities),
191    #[error(
192        "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
193    )]
194    NonUniformControlFlow(
195        UniformityRequirements,
196        Handle<crate::Expression>,
197        UniformityDisruptor,
198    ),
199    #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
200    PipelineInputRegularFunction { name: String },
201    #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
202    PipelineOutputRegularFunction,
203    #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
204    // The actual load statement will be "pointed to" by the span
205    NonUniformWorkgroupUniformLoad(UniformityDisruptor),
206    // This is only possible with a misbehaving frontend
207    #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
208    WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
209    #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
210    WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
211    #[error("Subgroup operation is invalid")]
212    InvalidSubgroup(#[from] SubgroupError),
213    #[error("Emit statement should not cover \"result\" expressions like {0:?}")]
214    EmitResult(Handle<crate::Expression>),
215    #[error("Expression not visited by the appropriate statement")]
216    UnvisitedExpression(Handle<crate::Expression>),
217}
218
219bitflags::bitflags! {
220    #[repr(transparent)]
221    #[derive(Clone, Copy)]
222    struct ControlFlowAbility: u8 {
223        /// The control can return out of this block.
224        const RETURN = 0x1;
225        /// The control can break.
226        const BREAK = 0x2;
227        /// The control can continue.
228        const CONTINUE = 0x4;
229    }
230}
231
232struct BlockInfo {
233    stages: super::ShaderStages,
234    finished: bool,
235}
236
237struct BlockContext<'a> {
238    abilities: ControlFlowAbility,
239    info: &'a FunctionInfo,
240    expressions: &'a Arena<crate::Expression>,
241    types: &'a UniqueArena<crate::Type>,
242    local_vars: &'a Arena<crate::LocalVariable>,
243    global_vars: &'a Arena<crate::GlobalVariable>,
244    functions: &'a Arena<crate::Function>,
245    special_types: &'a crate::SpecialTypes,
246    prev_infos: &'a [FunctionInfo],
247    return_type: Option<Handle<crate::Type>>,
248}
249
250impl<'a> BlockContext<'a> {
251    fn new(
252        fun: &'a crate::Function,
253        module: &'a crate::Module,
254        info: &'a FunctionInfo,
255        prev_infos: &'a [FunctionInfo],
256    ) -> Self {
257        Self {
258            abilities: ControlFlowAbility::RETURN,
259            info,
260            expressions: &fun.expressions,
261            types: &module.types,
262            local_vars: &fun.local_variables,
263            global_vars: &module.global_variables,
264            functions: &module.functions,
265            special_types: &module.special_types,
266            prev_infos,
267            return_type: fun.result.as_ref().map(|fr| fr.ty),
268        }
269    }
270
271    const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
272        BlockContext { abilities, ..*self }
273    }
274
275    fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
276        &self.expressions[handle]
277    }
278
279    fn resolve_type_impl(
280        &self,
281        handle: Handle<crate::Expression>,
282        valid_expressions: &HandleSet<crate::Expression>,
283    ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
284        if !valid_expressions.contains(handle) {
285            Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
286        } else {
287            Ok(self.info[handle].ty.inner_with(self.types))
288        }
289    }
290
291    fn resolve_type(
292        &self,
293        handle: Handle<crate::Expression>,
294        valid_expressions: &HandleSet<crate::Expression>,
295    ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
296        self.resolve_type_impl(handle, valid_expressions)
297            .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
298    }
299
300    fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
301        self.info[handle].ty.inner_with(self.types)
302    }
303
304    fn inner_type<'t>(&'t self, ty: &'t TypeResolution) -> &'t crate::TypeInner {
305        ty.inner_with(self.types)
306    }
307}
308
309impl super::Validator {
310    fn validate_call(
311        &mut self,
312        function: Handle<crate::Function>,
313        arguments: &[Handle<crate::Expression>],
314        result: Option<Handle<crate::Expression>>,
315        context: &BlockContext,
316    ) -> Result<super::ShaderStages, WithSpan<CallError>> {
317        let fun = &context.functions[function];
318        if fun.arguments.len() != arguments.len() {
319            return Err(CallError::ArgumentCount {
320                required: fun.arguments.len(),
321                seen: arguments.len(),
322            }
323            .with_span());
324        }
325        for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
326            let ty = context
327                .resolve_type_impl(expr, &self.valid_expression_set)
328                .map_err_inner(|source| {
329                    CallError::Argument { index, source }
330                        .with_span_handle(expr, context.expressions)
331                })?;
332            let arg_inner = &context.types[arg.ty].inner;
333            if !ty.equivalent(arg_inner, context.types) {
334                return Err(CallError::ArgumentType {
335                    index,
336                    required: arg.ty,
337                    seen_expression: expr,
338                }
339                .with_span_handle(expr, context.expressions));
340            }
341        }
342
343        if let Some(expr) = result {
344            if self.valid_expression_set.insert(expr) {
345                self.valid_expression_list.push(expr);
346            } else {
347                return Err(CallError::ResultAlreadyInScope(expr)
348                    .with_span_handle(expr, context.expressions));
349            }
350            match context.expressions[expr] {
351                crate::Expression::CallResult(callee)
352                    if fun.result.is_some() && callee == function =>
353                {
354                    if !self.needs_visit.remove(expr) {
355                        return Err(CallError::ResultAlreadyPopulated(expr)
356                            .with_span_handle(expr, context.expressions));
357                    }
358                }
359                _ => {
360                    return Err(CallError::ExpressionMismatch(result)
361                        .with_span_handle(expr, context.expressions))
362                }
363            }
364        } else if fun.result.is_some() {
365            return Err(CallError::ExpressionMismatch(result).with_span());
366        }
367
368        let callee_info = &context.prev_infos[function.index()];
369        Ok(callee_info.available_stages)
370    }
371
372    fn emit_expression(
373        &mut self,
374        handle: Handle<crate::Expression>,
375        context: &BlockContext,
376    ) -> Result<(), WithSpan<FunctionError>> {
377        if self.valid_expression_set.insert(handle) {
378            self.valid_expression_list.push(handle);
379            Ok(())
380        } else {
381            Err(FunctionError::ExpressionAlreadyInScope(handle)
382                .with_span_handle(handle, context.expressions))
383        }
384    }
385
386    fn validate_atomic(
387        &mut self,
388        pointer: Handle<crate::Expression>,
389        fun: &crate::AtomicFunction,
390        value: Handle<crate::Expression>,
391        result: Option<Handle<crate::Expression>>,
392        span: crate::Span,
393        context: &BlockContext,
394    ) -> Result<(), WithSpan<FunctionError>> {
395        // The `pointer` operand must be a pointer to an atomic value.
396        let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
397        let crate::TypeInner::Pointer {
398            base: pointer_base,
399            space: pointer_space,
400        } = *pointer_inner
401        else {
402            log::error!("Atomic operation on type {:?}", *pointer_inner);
403            return Err(AtomicError::InvalidPointer(pointer)
404                .with_span_handle(pointer, context.expressions)
405                .into_other());
406        };
407        let crate::TypeInner::Atomic(pointer_scalar) = context.types[pointer_base].inner else {
408            log::error!(
409                "Atomic pointer to type {:?}",
410                context.types[pointer_base].inner
411            );
412            return Err(AtomicError::InvalidPointer(pointer)
413                .with_span_handle(pointer, context.expressions)
414                .into_other());
415        };
416
417        // The `value` operand must be a scalar of the same type as the atomic.
418        let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
419        let crate::TypeInner::Scalar(value_scalar) = *value_inner else {
420            log::error!("Atomic operand type {:?}", *value_inner);
421            return Err(AtomicError::InvalidOperand(value)
422                .with_span_handle(value, context.expressions)
423                .into_other());
424        };
425        if pointer_scalar != value_scalar {
426            log::error!("Atomic operand type {:?}", *value_inner);
427            return Err(AtomicError::InvalidOperand(value)
428                .with_span_handle(value, context.expressions)
429                .into_other());
430        }
431
432        match pointer_scalar {
433            // Check for the special restrictions on 64-bit atomic operations.
434            //
435            // We don't need to consider other widths here: this function has already checked
436            // that `pointer`'s type is an `Atomic`, and `validate_type` has already checked
437            // that `Atomic` type has a permitted scalar width.
438            crate::Scalar::I64 | crate::Scalar::U64 => {
439                // `Capabilities::SHADER_INT64_ATOMIC_ALL_OPS` enables all sorts of 64-bit
440                // atomic operations.
441                if self
442                    .capabilities
443                    .contains(super::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS)
444                {
445                    // okay
446                } else {
447                    // `Capabilities::SHADER_INT64_ATOMIC_MIN_MAX` allows `Min` and
448                    // `Max` on operations in `Storage`, without a return value.
449                    if matches!(
450                        *fun,
451                        crate::AtomicFunction::Min | crate::AtomicFunction::Max
452                    ) && matches!(pointer_space, crate::AddressSpace::Storage { .. })
453                        && result.is_none()
454                    {
455                        if !self
456                            .capabilities
457                            .contains(super::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX)
458                        {
459                            log::error!("Int64 min-max atomic operations are not supported");
460                            return Err(AtomicError::MissingCapability(
461                                super::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
462                            )
463                            .with_span_handle(value, context.expressions)
464                            .into_other());
465                        }
466                    } else {
467                        // Otherwise, we require the full 64-bit atomic capability.
468                        log::error!("Int64 atomic operations are not supported");
469                        return Err(AtomicError::MissingCapability(
470                            super::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
471                        )
472                        .with_span_handle(value, context.expressions)
473                        .into_other());
474                    }
475                }
476            }
477            // Check for the special restrictions on 32-bit floating-point atomic operations.
478            crate::Scalar::F32 => {
479                // `Capabilities::SHADER_FLOAT32_ATOMIC` allows 32-bit floating-point
480                // atomic operations `Add`, `Subtract`, and `Exchange`
481                // in the `Storage` address space.
482                if !self
483                    .capabilities
484                    .contains(super::Capabilities::SHADER_FLOAT32_ATOMIC)
485                {
486                    log::error!("Float32 atomic operations are not supported");
487                    return Err(AtomicError::MissingCapability(
488                        super::Capabilities::SHADER_FLOAT32_ATOMIC,
489                    )
490                    .with_span_handle(value, context.expressions)
491                    .into_other());
492                }
493                if !matches!(
494                    *fun,
495                    crate::AtomicFunction::Add
496                        | crate::AtomicFunction::Subtract
497                        | crate::AtomicFunction::Exchange { compare: None }
498                ) {
499                    log::error!("Float32 atomic operation {:?} is not supported", fun);
500                    return Err(AtomicError::InvalidOperator(*fun)
501                        .with_span_handle(value, context.expressions)
502                        .into_other());
503                }
504                if !matches!(pointer_space, crate::AddressSpace::Storage { .. }) {
505                    log::error!(
506                        "Float32 atomic operations are only supported in the Storage address space"
507                    );
508                    return Err(AtomicError::InvalidAddressSpace(pointer_space)
509                        .with_span_handle(value, context.expressions)
510                        .into_other());
511                }
512            }
513            _ => {}
514        }
515
516        // The result expression must be appropriate to the operation.
517        match result {
518            Some(result) => {
519                // The `result` handle must refer to an `AtomicResult` expression.
520                let crate::Expression::AtomicResult {
521                    ty: result_ty,
522                    comparison,
523                } = context.expressions[result]
524                else {
525                    return Err(AtomicError::InvalidResultExpression(result)
526                        .with_span_handle(result, context.expressions)
527                        .into_other());
528                };
529
530                // Note that this expression has been visited by the proper kind
531                // of statement.
532                if !self.needs_visit.remove(result) {
533                    return Err(AtomicError::ResultAlreadyPopulated(result)
534                        .with_span_handle(result, context.expressions)
535                        .into_other());
536                }
537
538                // The constraints on the result type depend on the atomic function.
539                if let crate::AtomicFunction::Exchange {
540                    compare: Some(compare),
541                } = *fun
542                {
543                    // The comparison value must be a scalar of the same type as the
544                    // atomic we're operating on.
545                    let compare_inner =
546                        context.resolve_type(compare, &self.valid_expression_set)?;
547                    if !compare_inner.equivalent(value_inner, context.types) {
548                        log::error!(
549                            "Atomic exchange comparison has a different type from the value"
550                        );
551                        return Err(AtomicError::InvalidOperand(compare)
552                            .with_span_handle(compare, context.expressions)
553                            .into_other());
554                    }
555
556                    // The result expression must be an `__atomic_compare_exchange_result`
557                    // struct whose `old_value` member is of the same type as the atomic
558                    // we're operating on.
559                    let crate::TypeInner::Struct { ref members, .. } =
560                        context.types[result_ty].inner
561                    else {
562                        return Err(AtomicError::ResultTypeMismatch(result)
563                            .with_span_handle(result, context.expressions)
564                            .into_other());
565                    };
566                    if !validate_atomic_compare_exchange_struct(
567                        context.types,
568                        members,
569                        |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(pointer_scalar),
570                    ) {
571                        return Err(AtomicError::ResultTypeMismatch(result)
572                            .with_span_handle(result, context.expressions)
573                            .into_other());
574                    }
575
576                    // The result expression must be for a comparison operation.
577                    if !comparison {
578                        return Err(AtomicError::ResultExpressionNotExchange(result)
579                            .with_span_handle(result, context.expressions)
580                            .into_other());
581                    }
582                } else {
583                    // The result expression must be a scalar of the same type as the
584                    // atomic we're operating on.
585                    let result_inner = &context.types[result_ty].inner;
586                    if !result_inner.equivalent(value_inner, context.types) {
587                        return Err(AtomicError::ResultTypeMismatch(result)
588                            .with_span_handle(result, context.expressions)
589                            .into_other());
590                    }
591
592                    // The result expression must not be for a comparison.
593                    if comparison {
594                        return Err(AtomicError::ResultExpressionExchange(result)
595                            .with_span_handle(result, context.expressions)
596                            .into_other());
597                    }
598                }
599                self.emit_expression(result, context)?;
600            }
601
602            None => {
603                // Exchange operations must always produce a value.
604                if let crate::AtomicFunction::Exchange { compare: None } = *fun {
605                    log::error!("Atomic exchange's value is unused");
606                    return Err(AtomicError::MissingReturnValue
607                        .with_span_static(span, "atomic exchange operation")
608                        .into_other());
609                }
610            }
611        }
612
613        Ok(())
614    }
615    fn validate_subgroup_operation(
616        &mut self,
617        op: &crate::SubgroupOperation,
618        collective_op: &crate::CollectiveOperation,
619        argument: Handle<crate::Expression>,
620        result: Handle<crate::Expression>,
621        context: &BlockContext,
622    ) -> Result<(), WithSpan<FunctionError>> {
623        let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
624
625        let (is_scalar, scalar) = match *argument_inner {
626            crate::TypeInner::Scalar(scalar) => (true, scalar),
627            crate::TypeInner::Vector { scalar, .. } => (false, scalar),
628            _ => {
629                log::error!("Subgroup operand type {:?}", argument_inner);
630                return Err(SubgroupError::InvalidOperand(argument)
631                    .with_span_handle(argument, context.expressions)
632                    .into_other());
633            }
634        };
635
636        use crate::ScalarKind as sk;
637        use crate::SubgroupOperation as sg;
638        match (scalar.kind, *op) {
639            (sk::Bool, sg::All | sg::Any) if is_scalar => {}
640            (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
641            (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
642
643            (_, _) => {
644                log::error!("Subgroup operand type {:?}", argument_inner);
645                return Err(SubgroupError::InvalidOperand(argument)
646                    .with_span_handle(argument, context.expressions)
647                    .into_other());
648            }
649        };
650
651        use crate::CollectiveOperation as co;
652        match (*collective_op, *op) {
653            (
654                co::Reduce,
655                sg::All
656                | sg::Any
657                | sg::Add
658                | sg::Mul
659                | sg::Min
660                | sg::Max
661                | sg::And
662                | sg::Or
663                | sg::Xor,
664            ) => {}
665            (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
666
667            (_, _) => {
668                return Err(SubgroupError::UnknownOperation.with_span().into_other());
669            }
670        };
671
672        self.emit_expression(result, context)?;
673        match context.expressions[result] {
674            crate::Expression::SubgroupOperationResult { ty }
675                if { &context.types[ty].inner == argument_inner } => {}
676            _ => {
677                return Err(SubgroupError::ResultTypeMismatch(result)
678                    .with_span_handle(result, context.expressions)
679                    .into_other())
680            }
681        }
682        Ok(())
683    }
684    fn validate_subgroup_gather(
685        &mut self,
686        mode: &crate::GatherMode,
687        argument: Handle<crate::Expression>,
688        result: Handle<crate::Expression>,
689        context: &BlockContext,
690    ) -> Result<(), WithSpan<FunctionError>> {
691        match *mode {
692            crate::GatherMode::BroadcastFirst => {}
693            crate::GatherMode::Broadcast(index)
694            | crate::GatherMode::Shuffle(index)
695            | crate::GatherMode::ShuffleDown(index)
696            | crate::GatherMode::ShuffleUp(index)
697            | crate::GatherMode::ShuffleXor(index) => {
698                let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
699                match *index_ty {
700                    crate::TypeInner::Scalar(crate::Scalar::U32) => {}
701                    _ => {
702                        log::error!(
703                            "Subgroup gather index type {:?}, expected unsigned int",
704                            index_ty
705                        );
706                        return Err(SubgroupError::InvalidOperand(argument)
707                            .with_span_handle(index, context.expressions)
708                            .into_other());
709                    }
710                }
711            }
712        }
713        let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
714        if !matches!(*argument_inner,
715            crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
716            if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
717        ) {
718            log::error!("Subgroup gather operand type {:?}", argument_inner);
719            return Err(SubgroupError::InvalidOperand(argument)
720                .with_span_handle(argument, context.expressions)
721                .into_other());
722        }
723
724        self.emit_expression(result, context)?;
725        match context.expressions[result] {
726            crate::Expression::SubgroupOperationResult { ty }
727                if { &context.types[ty].inner == argument_inner } => {}
728            _ => {
729                return Err(SubgroupError::ResultTypeMismatch(result)
730                    .with_span_handle(result, context.expressions)
731                    .into_other())
732            }
733        }
734        Ok(())
735    }
736
737    fn validate_block_impl(
738        &mut self,
739        statements: &crate::Block,
740        context: &BlockContext,
741    ) -> Result<BlockInfo, WithSpan<FunctionError>> {
742        use crate::{AddressSpace, Statement as S, TypeInner as Ti};
743        let mut finished = false;
744        let mut stages = super::ShaderStages::all();
745        for (statement, &span) in statements.span_iter() {
746            if finished {
747                return Err(FunctionError::InstructionsAfterReturn
748                    .with_span_static(span, "instructions after return"));
749            }
750            match *statement {
751                S::Emit(ref range) => {
752                    for handle in range.clone() {
753                        use crate::Expression as Ex;
754                        match context.expressions[handle] {
755                            Ex::Literal(_)
756                            | Ex::Constant(_)
757                            | Ex::Override(_)
758                            | Ex::ZeroValue(_)
759                            | Ex::Compose { .. }
760                            | Ex::Access { .. }
761                            | Ex::AccessIndex { .. }
762                            | Ex::Splat { .. }
763                            | Ex::Swizzle { .. }
764                            | Ex::FunctionArgument(_)
765                            | Ex::GlobalVariable(_)
766                            | Ex::LocalVariable(_)
767                            | Ex::Load { .. }
768                            | Ex::ImageSample { .. }
769                            | Ex::ImageLoad { .. }
770                            | Ex::ImageQuery { .. }
771                            | Ex::Unary { .. }
772                            | Ex::Binary { .. }
773                            | Ex::Select { .. }
774                            | Ex::Derivative { .. }
775                            | Ex::Relational { .. }
776                            | Ex::Math { .. }
777                            | Ex::As { .. }
778                            | Ex::ArrayLength(_)
779                            | Ex::RayQueryGetIntersection { .. }
780                            | Ex::RayQueryVertexPositions { .. } => {
781                                self.emit_expression(handle, context)?
782                            }
783                            Ex::CallResult(_)
784                            | Ex::AtomicResult { .. }
785                            | Ex::WorkGroupUniformLoadResult { .. }
786                            | Ex::RayQueryProceedResult
787                            | Ex::SubgroupBallotResult
788                            | Ex::SubgroupOperationResult { .. } => {
789                                return Err(FunctionError::EmitResult(handle)
790                                    .with_span_handle(handle, context.expressions));
791                            }
792                        }
793                    }
794                }
795                S::Block(ref block) => {
796                    let info = self.validate_block(block, context)?;
797                    stages &= info.stages;
798                    finished = info.finished;
799                }
800                S::If {
801                    condition,
802                    ref accept,
803                    ref reject,
804                } => {
805                    match *context.resolve_type(condition, &self.valid_expression_set)? {
806                        Ti::Scalar(crate::Scalar {
807                            kind: crate::ScalarKind::Bool,
808                            width: _,
809                        }) => {}
810                        _ => {
811                            return Err(FunctionError::InvalidIfType(condition)
812                                .with_span_handle(condition, context.expressions))
813                        }
814                    }
815                    stages &= self.validate_block(accept, context)?.stages;
816                    stages &= self.validate_block(reject, context)?.stages;
817                }
818                S::Switch {
819                    selector,
820                    ref cases,
821                } => {
822                    let uint = match context
823                        .resolve_type(selector, &self.valid_expression_set)?
824                        .scalar_kind()
825                    {
826                        Some(crate::ScalarKind::Uint) => true,
827                        Some(crate::ScalarKind::Sint) => false,
828                        _ => {
829                            return Err(FunctionError::InvalidSwitchType(selector)
830                                .with_span_handle(selector, context.expressions))
831                        }
832                    };
833                    self.switch_values.clear();
834                    for case in cases {
835                        match case.value {
836                            crate::SwitchValue::I32(_) if !uint => {}
837                            crate::SwitchValue::U32(_) if uint => {}
838                            crate::SwitchValue::Default => {}
839                            _ => {
840                                return Err(FunctionError::ConflictingCaseType.with_span_static(
841                                    case.body
842                                        .span_iter()
843                                        .next()
844                                        .map_or(Default::default(), |(_, s)| *s),
845                                    "conflicting switch arm here",
846                                ));
847                            }
848                        };
849                        if !self.switch_values.insert(case.value) {
850                            return Err(match case.value {
851                                crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
852                                    .with_span_static(
853                                        case.body
854                                            .span_iter()
855                                            .next()
856                                            .map_or(Default::default(), |(_, s)| *s),
857                                        "duplicated switch arm here",
858                                    ),
859                                _ => FunctionError::ConflictingSwitchCase(case.value)
860                                    .with_span_static(
861                                        case.body
862                                            .span_iter()
863                                            .next()
864                                            .map_or(Default::default(), |(_, s)| *s),
865                                        "conflicting switch arm here",
866                                    ),
867                            });
868                        }
869                    }
870                    if !self.switch_values.contains(&crate::SwitchValue::Default) {
871                        return Err(FunctionError::MissingDefaultCase
872                            .with_span_static(span, "missing default case"));
873                    }
874                    if let Some(case) = cases.last() {
875                        if case.fall_through {
876                            return Err(FunctionError::LastCaseFallTrough.with_span_static(
877                                case.body
878                                    .span_iter()
879                                    .next()
880                                    .map_or(Default::default(), |(_, s)| *s),
881                                "bad switch arm here",
882                            ));
883                        }
884                    }
885                    let pass_through_abilities = context.abilities
886                        & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
887                    let sub_context =
888                        context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
889                    for case in cases {
890                        stages &= self.validate_block(&case.body, &sub_context)?.stages;
891                    }
892                }
893                S::Loop {
894                    ref body,
895                    ref continuing,
896                    break_if,
897                } => {
898                    // special handling for block scoping is needed here,
899                    // because the continuing{} block inherits the scope
900                    let base_expression_count = self.valid_expression_list.len();
901                    let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
902                    stages &= self
903                        .validate_block_impl(
904                            body,
905                            &context.with_abilities(
906                                pass_through_abilities
907                                    | ControlFlowAbility::BREAK
908                                    | ControlFlowAbility::CONTINUE,
909                            ),
910                        )?
911                        .stages;
912                    stages &= self
913                        .validate_block_impl(
914                            continuing,
915                            &context.with_abilities(ControlFlowAbility::empty()),
916                        )?
917                        .stages;
918
919                    if let Some(condition) = break_if {
920                        match *context.resolve_type(condition, &self.valid_expression_set)? {
921                            Ti::Scalar(crate::Scalar {
922                                kind: crate::ScalarKind::Bool,
923                                width: _,
924                            }) => {}
925                            _ => {
926                                return Err(FunctionError::InvalidIfType(condition)
927                                    .with_span_handle(condition, context.expressions))
928                            }
929                        }
930                    }
931
932                    for handle in self.valid_expression_list.drain(base_expression_count..) {
933                        self.valid_expression_set.remove(handle);
934                    }
935                }
936                S::Break => {
937                    if !context.abilities.contains(ControlFlowAbility::BREAK) {
938                        return Err(FunctionError::BreakOutsideOfLoopOrSwitch
939                            .with_span_static(span, "invalid break"));
940                    }
941                    finished = true;
942                }
943                S::Continue => {
944                    if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
945                        return Err(FunctionError::ContinueOutsideOfLoop
946                            .with_span_static(span, "invalid continue"));
947                    }
948                    finished = true;
949                }
950                S::Return { value } => {
951                    if !context.abilities.contains(ControlFlowAbility::RETURN) {
952                        return Err(FunctionError::InvalidReturnSpot
953                            .with_span_static(span, "invalid return"));
954                    }
955                    let value_ty = value
956                        .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
957                        .transpose()?;
958                    let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
959                    // We can't return pointers, but it seems best not to embed that
960                    // assumption here, so use `TypeInner::equivalent` for comparison.
961                    let okay = match (value_ty, expected_ty) {
962                        (None, None) => true,
963                        (Some(value_inner), Some(expected_inner)) => {
964                            value_inner.equivalent(expected_inner, context.types)
965                        }
966                        (_, _) => false,
967                    };
968
969                    if !okay {
970                        log::error!(
971                            "Returning {:?} where {:?} is expected",
972                            value_ty,
973                            expected_ty
974                        );
975                        if let Some(handle) = value {
976                            return Err(FunctionError::InvalidReturnType(value)
977                                .with_span_handle(handle, context.expressions));
978                        } else {
979                            return Err(FunctionError::InvalidReturnType(value)
980                                .with_span_static(span, "invalid return"));
981                        }
982                    }
983                    finished = true;
984                }
985                S::Kill => {
986                    stages &= super::ShaderStages::FRAGMENT;
987                    finished = true;
988                }
989                S::Barrier(barrier) => {
990                    stages &= super::ShaderStages::COMPUTE;
991                    if barrier.contains(crate::Barrier::SUB_GROUP) {
992                        if !self.capabilities.contains(
993                            super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
994                        ) {
995                            return Err(FunctionError::MissingCapability(
996                                super::Capabilities::SUBGROUP
997                                    | super::Capabilities::SUBGROUP_BARRIER,
998                            )
999                            .with_span_static(span, "missing capability for this operation"));
1000                        }
1001                        if !self
1002                            .subgroup_operations
1003                            .contains(super::SubgroupOperationSet::BASIC)
1004                        {
1005                            return Err(FunctionError::InvalidSubgroup(
1006                                SubgroupError::UnsupportedOperation(
1007                                    super::SubgroupOperationSet::BASIC,
1008                                ),
1009                            )
1010                            .with_span_static(span, "support for this operation is not present"));
1011                        }
1012                    }
1013                }
1014                S::Store { pointer, value } => {
1015                    let mut current = pointer;
1016                    loop {
1017                        match context.expressions[current] {
1018                            crate::Expression::Access { base, .. }
1019                            | crate::Expression::AccessIndex { base, .. } => current = base,
1020                            crate::Expression::LocalVariable(_)
1021                            | crate::Expression::GlobalVariable(_)
1022                            | crate::Expression::FunctionArgument(_) => break,
1023                            _ => {
1024                                return Err(FunctionError::InvalidStorePointer(current)
1025                                    .with_span_handle(pointer, context.expressions))
1026                            }
1027                        }
1028                    }
1029
1030                    let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
1031                    match *value_ty {
1032                        Ti::Image { .. } | Ti::Sampler { .. } => {
1033                            return Err(FunctionError::InvalidStoreTexture {
1034                                actual: value,
1035                                actual_ty: value_ty.clone(),
1036                            }
1037                            .with_span_context((
1038                                context.expressions.get_span(value),
1039                                format!("this value is of type {value_ty:?}"),
1040                            ))
1041                            .with_span(span, "expects a texture argument"));
1042                        }
1043                        _ => {}
1044                    }
1045
1046                    let pointer_ty = context.resolve_pointer_type(pointer);
1047                    let good = match pointer_ty
1048                        .pointer_base_type()
1049                        .as_ref()
1050                        .map(|ty| context.inner_type(ty))
1051                    {
1052                        // The Naga IR allows storing a scalar to an atomic.
1053                        Some(&Ti::Atomic(ref scalar)) => *value_ty == Ti::Scalar(*scalar),
1054                        Some(other) => *value_ty == *other,
1055                        None => false,
1056                    };
1057                    if !good {
1058                        return Err(FunctionError::InvalidStoreTypes { pointer, value }
1059                            .with_span()
1060                            .with_handle(pointer, context.expressions)
1061                            .with_handle(value, context.expressions));
1062                    }
1063
1064                    if let Some(space) = pointer_ty.pointer_space() {
1065                        if !space.access().contains(crate::StorageAccess::STORE) {
1066                            return Err(FunctionError::InvalidStorePointer(pointer)
1067                                .with_span_static(
1068                                    context.expressions.get_span(pointer),
1069                                    "writing to this location is not permitted",
1070                                ));
1071                        }
1072                    }
1073                }
1074                S::ImageStore {
1075                    image,
1076                    coordinate,
1077                    array_index,
1078                    value,
1079                } => {
1080                    //Note: this code uses a lot of `FunctionError::InvalidImageStore`,
1081                    // and could probably be refactored.
1082                    let global_var;
1083                    let image_ty;
1084                    match *context.get_expression(image) {
1085                        crate::Expression::GlobalVariable(var_handle) => {
1086                            global_var = &context.global_vars[var_handle];
1087                            image_ty = global_var.ty;
1088                        }
1089                        // The `image` operand is indexing into a binding array,
1090                        // so punch through the `Access`* expression and look at
1091                        // the global behind it.
1092                        crate::Expression::Access { base, .. }
1093                        | crate::Expression::AccessIndex { base, .. } => {
1094                            let crate::Expression::GlobalVariable(var_handle) =
1095                                *context.get_expression(base)
1096                            else {
1097                                return Err(FunctionError::InvalidImageStore(
1098                                    ExpressionError::ExpectedGlobalVariable,
1099                                )
1100                                .with_span_handle(image, context.expressions));
1101                            };
1102                            global_var = &context.global_vars[var_handle];
1103
1104                            // The global variable must be a binding array.
1105                            let Ti::BindingArray { base, .. } = context.types[global_var.ty].inner
1106                            else {
1107                                return Err(FunctionError::InvalidImageStore(
1108                                    ExpressionError::ExpectedBindingArrayType(global_var.ty),
1109                                )
1110                                .with_span_handle(global_var.ty, context.types));
1111                            };
1112
1113                            image_ty = base;
1114                        }
1115                        _ => {
1116                            return Err(FunctionError::InvalidImageStore(
1117                                ExpressionError::ExpectedGlobalVariable,
1118                            )
1119                            .with_span_handle(image, context.expressions))
1120                        }
1121                    };
1122
1123                    // The `image` operand must be an `Image`.
1124                    let Ti::Image {
1125                        class,
1126                        arrayed,
1127                        dim,
1128                    } = context.types[image_ty].inner
1129                    else {
1130                        return Err(FunctionError::InvalidImageStore(
1131                            ExpressionError::ExpectedImageType(global_var.ty),
1132                        )
1133                        .with_span()
1134                        .with_handle(global_var.ty, context.types)
1135                        .with_handle(image, context.expressions));
1136                    };
1137
1138                    // It had better be a storage image, since we're writing to it.
1139                    let crate::ImageClass::Storage { format, .. } = class else {
1140                        return Err(FunctionError::InvalidImageStore(
1141                            ExpressionError::InvalidImageClass(class),
1142                        )
1143                        .with_span_handle(image, context.expressions));
1144                    };
1145
1146                    // The `coordinate` operand must be a vector of the appropriate size.
1147                    if context
1148                        .resolve_type(coordinate, &self.valid_expression_set)?
1149                        .image_storage_coordinates()
1150                        .is_none_or(|coord_dim| coord_dim != dim)
1151                    {
1152                        return Err(FunctionError::InvalidImageStore(
1153                            ExpressionError::InvalidImageCoordinateType(dim, coordinate),
1154                        )
1155                        .with_span_handle(coordinate, context.expressions));
1156                    }
1157
1158                    // The `array_index` operand should be present if and only if
1159                    // the image itself is arrayed.
1160                    if arrayed != array_index.is_some() {
1161                        return Err(FunctionError::InvalidImageStore(
1162                            ExpressionError::InvalidImageArrayIndex,
1163                        )
1164                        .with_span_handle(coordinate, context.expressions));
1165                    }
1166
1167                    // If present, `array_index` must be a scalar integer type.
1168                    if let Some(expr) = array_index {
1169                        if !matches!(
1170                            *context.resolve_type(expr, &self.valid_expression_set)?,
1171                            Ti::Scalar(crate::Scalar {
1172                                kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
1173                                width: _,
1174                            })
1175                        ) {
1176                            return Err(FunctionError::InvalidImageStore(
1177                                ExpressionError::InvalidImageArrayIndexType(expr),
1178                            )
1179                            .with_span_handle(expr, context.expressions));
1180                        }
1181                    }
1182
1183                    let value_ty = crate::TypeInner::Vector {
1184                        size: crate::VectorSize::Quad,
1185                        scalar: format.into(),
1186                    };
1187
1188                    // The value we're writing had better match the scalar type
1189                    // for `image`'s format.
1190                    let actual_value_ty =
1191                        context.resolve_type(value, &self.valid_expression_set)?;
1192                    if actual_value_ty != &value_ty {
1193                        return Err(FunctionError::InvalidStoreValue {
1194                            actual: value,
1195                            actual_ty: actual_value_ty.clone(),
1196                            expected_ty: value_ty.clone(),
1197                        }
1198                        .with_span_context((
1199                            context.expressions.get_span(value),
1200                            format!("this value is of type {actual_value_ty:?}"),
1201                        ))
1202                        .with_span(
1203                            span,
1204                            format!("expects a value argument of type {value_ty:?}"),
1205                        ));
1206                    }
1207                }
1208                S::Call {
1209                    function,
1210                    ref arguments,
1211                    result,
1212                } => match self.validate_call(function, arguments, result, context) {
1213                    Ok(callee_stages) => stages &= callee_stages,
1214                    Err(error) => {
1215                        return Err(error.and_then(|error| {
1216                            FunctionError::InvalidCall { function, error }
1217                                .with_span_static(span, "invalid function call")
1218                        }))
1219                    }
1220                },
1221                S::Atomic {
1222                    pointer,
1223                    ref fun,
1224                    value,
1225                    result,
1226                } => {
1227                    self.validate_atomic(pointer, fun, value, result, span, context)?;
1228                }
1229                S::ImageAtomic {
1230                    image,
1231                    coordinate,
1232                    array_index,
1233                    fun,
1234                    value,
1235                } => {
1236                    let var = match *context.get_expression(image) {
1237                        crate::Expression::GlobalVariable(var_handle) => {
1238                            &context.global_vars[var_handle]
1239                        }
1240                        // We're looking at a binding index situation, so punch through the index and look at the global behind it.
1241                        crate::Expression::Access { base, .. }
1242                        | crate::Expression::AccessIndex { base, .. } => {
1243                            match *context.get_expression(base) {
1244                                crate::Expression::GlobalVariable(var_handle) => {
1245                                    &context.global_vars[var_handle]
1246                                }
1247                                _ => {
1248                                    return Err(FunctionError::InvalidImageAtomic(
1249                                        ExpressionError::ExpectedGlobalVariable,
1250                                    )
1251                                    .with_span_handle(image, context.expressions))
1252                                }
1253                            }
1254                        }
1255                        _ => {
1256                            return Err(FunctionError::InvalidImageAtomic(
1257                                ExpressionError::ExpectedGlobalVariable,
1258                            )
1259                            .with_span_handle(image, context.expressions))
1260                        }
1261                    };
1262
1263                    // Punch through a binding array to get the underlying type
1264                    let global_ty = match context.types[var.ty].inner {
1265                        Ti::BindingArray { base, .. } => &context.types[base].inner,
1266                        ref inner => inner,
1267                    };
1268
1269                    let value_ty = match *global_ty {
1270                        Ti::Image {
1271                            class,
1272                            arrayed,
1273                            dim,
1274                        } => {
1275                            match context
1276                                .resolve_type(coordinate, &self.valid_expression_set)?
1277                                .image_storage_coordinates()
1278                            {
1279                                Some(coord_dim) if coord_dim == dim => {}
1280                                _ => {
1281                                    return Err(FunctionError::InvalidImageAtomic(
1282                                        ExpressionError::InvalidImageCoordinateType(
1283                                            dim, coordinate,
1284                                        ),
1285                                    )
1286                                    .with_span_handle(coordinate, context.expressions));
1287                                }
1288                            };
1289                            if arrayed != array_index.is_some() {
1290                                return Err(FunctionError::InvalidImageAtomic(
1291                                    ExpressionError::InvalidImageArrayIndex,
1292                                )
1293                                .with_span_handle(coordinate, context.expressions));
1294                            }
1295                            if let Some(expr) = array_index {
1296                                match *context.resolve_type(expr, &self.valid_expression_set)? {
1297                                    Ti::Scalar(crate::Scalar {
1298                                        kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
1299                                        width: _,
1300                                    }) => {}
1301                                    _ => {
1302                                        return Err(FunctionError::InvalidImageAtomic(
1303                                            ExpressionError::InvalidImageArrayIndexType(expr),
1304                                        )
1305                                        .with_span_handle(expr, context.expressions));
1306                                    }
1307                                }
1308                            }
1309                            match class {
1310                                crate::ImageClass::Storage { format, access } => {
1311                                    if !access.contains(crate::StorageAccess::ATOMIC) {
1312                                        return Err(FunctionError::InvalidImageAtomic(
1313                                            ExpressionError::InvalidImageStorageAccess(access),
1314                                        )
1315                                        .with_span_handle(image, context.expressions));
1316                                    }
1317                                    match format {
1318                                        crate::StorageFormat::R64Uint => {
1319                                            if !self.capabilities.intersects(
1320                                                super::Capabilities::TEXTURE_INT64_ATOMIC,
1321                                            ) {
1322                                                return Err(FunctionError::MissingCapability(
1323                                                    super::Capabilities::TEXTURE_INT64_ATOMIC,
1324                                                )
1325                                                .with_span_static(
1326                                                    span,
1327                                                    "missing capability for this operation",
1328                                                ));
1329                                            }
1330                                            match fun {
1331                                                crate::AtomicFunction::Min
1332                                                | crate::AtomicFunction::Max => {}
1333                                                _ => {
1334                                                    return Err(
1335                                                        FunctionError::InvalidImageAtomicFunction(
1336                                                            fun,
1337                                                        )
1338                                                        .with_span_handle(
1339                                                            image,
1340                                                            context.expressions,
1341                                                        ),
1342                                                    );
1343                                                }
1344                                            }
1345                                        }
1346                                        crate::StorageFormat::R32Sint
1347                                        | crate::StorageFormat::R32Uint => {
1348                                            if !self
1349                                                .capabilities
1350                                                .intersects(super::Capabilities::TEXTURE_ATOMIC)
1351                                            {
1352                                                return Err(FunctionError::MissingCapability(
1353                                                    super::Capabilities::TEXTURE_ATOMIC,
1354                                                )
1355                                                .with_span_static(
1356                                                    span,
1357                                                    "missing capability for this operation",
1358                                                ));
1359                                            }
1360                                            match fun {
1361                                                crate::AtomicFunction::Add
1362                                                | crate::AtomicFunction::And
1363                                                | crate::AtomicFunction::ExclusiveOr
1364                                                | crate::AtomicFunction::InclusiveOr
1365                                                | crate::AtomicFunction::Min
1366                                                | crate::AtomicFunction::Max => {}
1367                                                _ => {
1368                                                    return Err(
1369                                                        FunctionError::InvalidImageAtomicFunction(
1370                                                            fun,
1371                                                        )
1372                                                        .with_span_handle(
1373                                                            image,
1374                                                            context.expressions,
1375                                                        ),
1376                                                    );
1377                                                }
1378                                            }
1379                                        }
1380                                        _ => {
1381                                            return Err(FunctionError::InvalidImageAtomic(
1382                                                ExpressionError::InvalidImageFormat(format),
1383                                            )
1384                                            .with_span_handle(image, context.expressions));
1385                                        }
1386                                    }
1387                                    crate::TypeInner::Scalar(format.into())
1388                                }
1389                                _ => {
1390                                    return Err(FunctionError::InvalidImageAtomic(
1391                                        ExpressionError::InvalidImageClass(class),
1392                                    )
1393                                    .with_span_handle(image, context.expressions));
1394                                }
1395                            }
1396                        }
1397                        _ => {
1398                            return Err(FunctionError::InvalidImageAtomic(
1399                                ExpressionError::ExpectedImageType(var.ty),
1400                            )
1401                            .with_span()
1402                            .with_handle(var.ty, context.types)
1403                            .with_handle(image, context.expressions))
1404                        }
1405                    };
1406
1407                    if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
1408                        return Err(FunctionError::InvalidImageAtomicValue(value)
1409                            .with_span_handle(value, context.expressions));
1410                    }
1411                }
1412                S::WorkGroupUniformLoad { pointer, result } => {
1413                    stages &= super::ShaderStages::COMPUTE;
1414                    let pointer_inner =
1415                        context.resolve_type(pointer, &self.valid_expression_set)?;
1416                    match *pointer_inner {
1417                        Ti::Pointer {
1418                            space: AddressSpace::WorkGroup,
1419                            ..
1420                        } => {}
1421                        Ti::ValuePointer {
1422                            space: AddressSpace::WorkGroup,
1423                            ..
1424                        } => {}
1425                        _ => {
1426                            return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1427                                .with_span_static(span, "WorkGroupUniformLoad"))
1428                        }
1429                    }
1430                    self.emit_expression(result, context)?;
1431                    let ty = match &context.expressions[result] {
1432                        &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
1433                        _ => {
1434                            return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
1435                                result,
1436                            )
1437                            .with_span_static(span, "WorkGroupUniformLoad"));
1438                        }
1439                    };
1440                    let expected_pointer_inner = Ti::Pointer {
1441                        base: ty,
1442                        space: AddressSpace::WorkGroup,
1443                    };
1444                    if !expected_pointer_inner.equivalent(pointer_inner, context.types) {
1445                        return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1446                            .with_span_static(span, "WorkGroupUniformLoad"));
1447                    }
1448                }
1449                S::RayQuery { query, ref fun } => {
1450                    let query_var = match *context.get_expression(query) {
1451                        crate::Expression::LocalVariable(var) => &context.local_vars[var],
1452                        ref other => {
1453                            log::error!("Unexpected ray query expression {other:?}");
1454                            return Err(FunctionError::InvalidRayQueryExpression(query)
1455                                .with_span_static(span, "invalid query expression"));
1456                        }
1457                    };
1458                    let rq_vertex_return = match context.types[query_var.ty].inner {
1459                        Ti::RayQuery { vertex_return } => vertex_return,
1460                        ref other => {
1461                            log::error!("Unexpected ray query type {other:?}");
1462                            return Err(FunctionError::InvalidRayQueryType(query_var.ty)
1463                                .with_span_static(span, "invalid query type"));
1464                        }
1465                    };
1466                    match *fun {
1467                        crate::RayQueryFunction::Initialize {
1468                            acceleration_structure,
1469                            descriptor,
1470                        } => {
1471                            match *context
1472                                .resolve_type(acceleration_structure, &self.valid_expression_set)?
1473                            {
1474                                Ti::AccelerationStructure { vertex_return } => {
1475                                    if (!vertex_return) && rq_vertex_return {
1476                                        return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure"));
1477                                    }
1478                                }
1479                                _ => {
1480                                    return Err(FunctionError::InvalidAccelerationStructure(
1481                                        acceleration_structure,
1482                                    )
1483                                    .with_span_static(span, "invalid acceleration structure"))
1484                                }
1485                            }
1486                            let desc_ty_given =
1487                                context.resolve_type(descriptor, &self.valid_expression_set)?;
1488                            let desc_ty_expected = context
1489                                .special_types
1490                                .ray_desc
1491                                .map(|handle| &context.types[handle].inner);
1492                            if Some(desc_ty_given) != desc_ty_expected {
1493                                return Err(FunctionError::InvalidRayDescriptor(descriptor)
1494                                    .with_span_static(span, "invalid ray descriptor"));
1495                            }
1496                        }
1497                        crate::RayQueryFunction::Proceed { result } => {
1498                            self.emit_expression(result, context)?;
1499                        }
1500                        crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1501                            match *context.resolve_type(hit_t, &self.valid_expression_set)? {
1502                                Ti::Scalar(crate::Scalar {
1503                                    kind: crate::ScalarKind::Float,
1504                                    width: _,
1505                                }) => {}
1506                                _ => {
1507                                    return Err(FunctionError::InvalidHitDistanceType(hit_t)
1508                                        .with_span_static(span, "invalid hit_t"))
1509                                }
1510                            }
1511                        }
1512                        crate::RayQueryFunction::ConfirmIntersection => {}
1513                        crate::RayQueryFunction::Terminate => {}
1514                    }
1515                }
1516                S::SubgroupBallot { result, predicate } => {
1517                    stages &= self.subgroup_stages;
1518                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1519                        return Err(FunctionError::MissingCapability(
1520                            super::Capabilities::SUBGROUP,
1521                        )
1522                        .with_span_static(span, "missing capability for this operation"));
1523                    }
1524                    if !self
1525                        .subgroup_operations
1526                        .contains(super::SubgroupOperationSet::BALLOT)
1527                    {
1528                        return Err(FunctionError::InvalidSubgroup(
1529                            SubgroupError::UnsupportedOperation(
1530                                super::SubgroupOperationSet::BALLOT,
1531                            ),
1532                        )
1533                        .with_span_static(span, "support for this operation is not present"));
1534                    }
1535                    if let Some(predicate) = predicate {
1536                        let predicate_inner =
1537                            context.resolve_type(predicate, &self.valid_expression_set)?;
1538                        if !matches!(
1539                            *predicate_inner,
1540                            crate::TypeInner::Scalar(crate::Scalar::BOOL,)
1541                        ) {
1542                            log::error!(
1543                                "Subgroup ballot predicate type {:?} expected bool",
1544                                predicate_inner
1545                            );
1546                            return Err(SubgroupError::InvalidOperand(predicate)
1547                                .with_span_handle(predicate, context.expressions)
1548                                .into_other());
1549                        }
1550                    }
1551                    self.emit_expression(result, context)?;
1552                }
1553                S::SubgroupCollectiveOperation {
1554                    ref op,
1555                    ref collective_op,
1556                    argument,
1557                    result,
1558                } => {
1559                    stages &= self.subgroup_stages;
1560                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1561                        return Err(FunctionError::MissingCapability(
1562                            super::Capabilities::SUBGROUP,
1563                        )
1564                        .with_span_static(span, "missing capability for this operation"));
1565                    }
1566                    let operation = op.required_operations();
1567                    if !self.subgroup_operations.contains(operation) {
1568                        return Err(FunctionError::InvalidSubgroup(
1569                            SubgroupError::UnsupportedOperation(operation),
1570                        )
1571                        .with_span_static(span, "support for this operation is not present"));
1572                    }
1573                    self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
1574                }
1575                S::SubgroupGather {
1576                    ref mode,
1577                    argument,
1578                    result,
1579                } => {
1580                    stages &= self.subgroup_stages;
1581                    if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1582                        return Err(FunctionError::MissingCapability(
1583                            super::Capabilities::SUBGROUP,
1584                        )
1585                        .with_span_static(span, "missing capability for this operation"));
1586                    }
1587                    let operation = mode.required_operations();
1588                    if !self.subgroup_operations.contains(operation) {
1589                        return Err(FunctionError::InvalidSubgroup(
1590                            SubgroupError::UnsupportedOperation(operation),
1591                        )
1592                        .with_span_static(span, "support for this operation is not present"));
1593                    }
1594                    self.validate_subgroup_gather(mode, argument, result, context)?;
1595                }
1596            }
1597        }
1598        Ok(BlockInfo { stages, finished })
1599    }
1600
1601    fn validate_block(
1602        &mut self,
1603        statements: &crate::Block,
1604        context: &BlockContext,
1605    ) -> Result<BlockInfo, WithSpan<FunctionError>> {
1606        let base_expression_count = self.valid_expression_list.len();
1607        let info = self.validate_block_impl(statements, context)?;
1608        for handle in self.valid_expression_list.drain(base_expression_count..) {
1609            self.valid_expression_set.remove(handle);
1610        }
1611        Ok(info)
1612    }
1613
1614    fn validate_local_var(
1615        &self,
1616        var: &crate::LocalVariable,
1617        gctx: crate::proc::GlobalCtx,
1618        fun_info: &FunctionInfo,
1619        local_expr_kind: &crate::proc::ExpressionKindTracker,
1620    ) -> Result<(), LocalVariableError> {
1621        log::debug!("var {:?}", var);
1622        let type_info = self
1623            .types
1624            .get(var.ty.index())
1625            .ok_or(LocalVariableError::InvalidType(var.ty))?;
1626        if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) {
1627            return Err(LocalVariableError::InvalidType(var.ty));
1628        }
1629
1630        if let Some(init) = var.init {
1631            let decl_ty = &gctx.types[var.ty].inner;
1632            let init_ty = fun_info[init].ty.inner_with(gctx.types);
1633            if !decl_ty.equivalent(init_ty, gctx.types) {
1634                return Err(LocalVariableError::InitializerType);
1635            }
1636
1637            if !local_expr_kind.is_const_or_override(init) {
1638                return Err(LocalVariableError::NonConstOrOverrideInitializer);
1639            }
1640        }
1641
1642        Ok(())
1643    }
1644
1645    pub(super) fn validate_function(
1646        &mut self,
1647        fun: &crate::Function,
1648        module: &crate::Module,
1649        mod_info: &ModuleInfo,
1650        entry_point: bool,
1651    ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1652        let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
1653
1654        let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
1655
1656        for (var_handle, var) in fun.local_variables.iter() {
1657            self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
1658                .map_err(|source| {
1659                    FunctionError::LocalVariable {
1660                        handle: var_handle,
1661                        name: var.name.clone().unwrap_or_default(),
1662                        source,
1663                    }
1664                    .with_span_handle(var.ty, &module.types)
1665                    .with_handle(var_handle, &fun.local_variables)
1666                })?;
1667        }
1668
1669        for (index, argument) in fun.arguments.iter().enumerate() {
1670            match module.types[argument.ty].inner.pointer_space() {
1671                Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
1672                Some(other) => {
1673                    return Err(FunctionError::InvalidArgumentPointerSpace {
1674                        index,
1675                        name: argument.name.clone().unwrap_or_default(),
1676                        space: other,
1677                    }
1678                    .with_span_handle(argument.ty, &module.types))
1679                }
1680            }
1681            // Check for the least informative error last.
1682            if !self.types[argument.ty.index()]
1683                .flags
1684                .contains(super::TypeFlags::ARGUMENT)
1685            {
1686                return Err(FunctionError::InvalidArgumentType {
1687                    index,
1688                    name: argument.name.clone().unwrap_or_default(),
1689                }
1690                .with_span_handle(argument.ty, &module.types));
1691            }
1692
1693            if !entry_point && argument.binding.is_some() {
1694                return Err(FunctionError::PipelineInputRegularFunction {
1695                    name: argument.name.clone().unwrap_or_default(),
1696                }
1697                .with_span_handle(argument.ty, &module.types));
1698            }
1699        }
1700
1701        if let Some(ref result) = fun.result {
1702            if !self.types[result.ty.index()]
1703                .flags
1704                .contains(super::TypeFlags::CONSTRUCTIBLE)
1705            {
1706                return Err(FunctionError::NonConstructibleReturnType
1707                    .with_span_handle(result.ty, &module.types));
1708            }
1709
1710            if !entry_point && result.binding.is_some() {
1711                return Err(FunctionError::PipelineOutputRegularFunction
1712                    .with_span_handle(result.ty, &module.types));
1713            }
1714        }
1715
1716        self.valid_expression_set.clear_for_arena(&fun.expressions);
1717        self.valid_expression_list.clear();
1718        self.needs_visit.clear_for_arena(&fun.expressions);
1719        for (handle, expr) in fun.expressions.iter() {
1720            if expr.needs_pre_emit() {
1721                self.valid_expression_set.insert(handle);
1722            }
1723            if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1724                // Mark expressions that need to be visited by a particular kind of
1725                // statement.
1726                if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } =
1727                    *expr
1728                {
1729                    self.needs_visit.insert(handle);
1730                }
1731
1732                match self.validate_expression(
1733                    handle,
1734                    expr,
1735                    fun,
1736                    module,
1737                    &info,
1738                    mod_info,
1739                    &local_expr_kind,
1740                ) {
1741                    Ok(stages) => info.available_stages &= stages,
1742                    Err(source) => {
1743                        return Err(FunctionError::Expression { handle, source }
1744                            .with_span_handle(handle, &fun.expressions))
1745                    }
1746                }
1747            }
1748        }
1749
1750        if self.flags.contains(super::ValidationFlags::BLOCKS) {
1751            let stages = self
1752                .validate_block(
1753                    &fun.body,
1754                    &BlockContext::new(fun, module, &info, &mod_info.functions),
1755                )?
1756                .stages;
1757            info.available_stages &= stages;
1758
1759            if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1760                if let Some(handle) = self.needs_visit.iter().next() {
1761                    return Err(FunctionError::UnvisitedExpression(handle)
1762                        .with_span_handle(handle, &fun.expressions));
1763                }
1764            }
1765        }
1766        Ok(info)
1767    }
1768}