1use alloc::{format, string::String};
2
3use super::validate_atomic_compare_exchange_struct;
4use super::{
5 analyzer::{UniformityDisruptor, UniformityRequirements},
6 ExpressionError, FunctionInfo, ModuleInfo,
7};
8use crate::arena::{Arena, UniqueArena};
9use crate::arena::{Handle, HandleSet};
10use crate::proc::TypeResolution;
11use crate::span::WithSpan;
12use crate::span::{AddSpan as _, MapErrWithSpan as _};
13
14#[derive(Clone, Debug, thiserror::Error)]
15#[cfg_attr(test, derive(PartialEq))]
16pub enum CallError {
17 #[error("Argument {index} expression is invalid")]
18 Argument {
19 index: usize,
20 source: ExpressionError,
21 },
22 #[error("Result expression {0:?} has already been introduced earlier")]
23 ResultAlreadyInScope(Handle<crate::Expression>),
24 #[error("Result expression {0:?} is populated by multiple `Call` statements")]
25 ResultAlreadyPopulated(Handle<crate::Expression>),
26 #[error("Requires {required} arguments, but {seen} are provided")]
27 ArgumentCount { required: usize, seen: usize },
28 #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
29 ArgumentType {
30 index: usize,
31 required: Handle<crate::Type>,
32 seen_expression: Handle<crate::Expression>,
33 },
34 #[error("The emitted expression doesn't match the call")]
35 ExpressionMismatch(Option<Handle<crate::Expression>>),
36}
37
38#[derive(Clone, Debug, thiserror::Error)]
39#[cfg_attr(test, derive(PartialEq))]
40pub enum AtomicError {
41 #[error("Pointer {0:?} to atomic is invalid.")]
42 InvalidPointer(Handle<crate::Expression>),
43 #[error("Address space {0:?} is not supported.")]
44 InvalidAddressSpace(crate::AddressSpace),
45 #[error("Operand {0:?} has invalid type.")]
46 InvalidOperand(Handle<crate::Expression>),
47 #[error("Operator {0:?} is not supported.")]
48 InvalidOperator(crate::AtomicFunction),
49 #[error("Result expression {0:?} is not an `AtomicResult` expression")]
50 InvalidResultExpression(Handle<crate::Expression>),
51 #[error("Result expression {0:?} is marked as an `exchange`")]
52 ResultExpressionExchange(Handle<crate::Expression>),
53 #[error("Result expression {0:?} is not marked as an `exchange`")]
54 ResultExpressionNotExchange(Handle<crate::Expression>),
55 #[error("Result type for {0:?} doesn't match the statement")]
56 ResultTypeMismatch(Handle<crate::Expression>),
57 #[error("Exchange operations must return a value")]
58 MissingReturnValue,
59 #[error("Capability {0:?} is required")]
60 MissingCapability(super::Capabilities),
61 #[error("Result expression {0:?} is populated by multiple `Atomic` statements")]
62 ResultAlreadyPopulated(Handle<crate::Expression>),
63}
64
65#[derive(Clone, Debug, thiserror::Error)]
66#[cfg_attr(test, derive(PartialEq))]
67pub enum SubgroupError {
68 #[error("Operand {0:?} has invalid type.")]
69 InvalidOperand(Handle<crate::Expression>),
70 #[error("Result type for {0:?} doesn't match the statement")]
71 ResultTypeMismatch(Handle<crate::Expression>),
72 #[error("Support for subgroup operation {0:?} is required")]
73 UnsupportedOperation(super::SubgroupOperationSet),
74 #[error("Unknown operation")]
75 UnknownOperation,
76 #[error("Invocation ID must be a const-expression")]
77 InvalidInvocationIdExprType(Handle<crate::Expression>),
78}
79
80#[derive(Clone, Debug, thiserror::Error)]
81#[cfg_attr(test, derive(PartialEq))]
82pub enum LocalVariableError {
83 #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
84 InvalidType(Handle<crate::Type>),
85 #[error("Initializer doesn't match the variable type")]
86 InitializerType,
87 #[error("Initializer is not a const or override expression")]
88 NonConstOrOverrideInitializer,
89}
90
91#[derive(Clone, Debug, thiserror::Error)]
92#[cfg_attr(test, derive(PartialEq))]
93pub enum FunctionError {
94 #[error("Expression {handle:?} is invalid")]
95 Expression {
96 handle: Handle<crate::Expression>,
97 source: ExpressionError,
98 },
99 #[error("Expression {0:?} can't be introduced - it's already in scope")]
100 ExpressionAlreadyInScope(Handle<crate::Expression>),
101 #[error("Local variable {handle:?} '{name}' is invalid")]
102 LocalVariable {
103 handle: Handle<crate::LocalVariable>,
104 name: String,
105 source: LocalVariableError,
106 },
107 #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
108 InvalidArgumentType { index: usize, name: String },
109 #[error("The function's given return type cannot be returned from functions")]
110 NonConstructibleReturnType,
111 #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
112 InvalidArgumentPointerSpace {
113 index: usize,
114 name: String,
115 space: crate::AddressSpace,
116 },
117 #[error("The `break` is used outside of a `loop` or `switch` context")]
118 BreakOutsideOfLoopOrSwitch,
119 #[error("The `continue` is used outside of a `loop` context")]
120 ContinueOutsideOfLoop,
121 #[error("The `return` is called within a `continuing` block")]
122 InvalidReturnSpot,
123 #[error("The `return` expression {expression:?} does not match the declared return type {expected_ty:?}")]
124 InvalidReturnType {
125 expression: Option<Handle<crate::Expression>>,
126 expected_ty: Option<Handle<crate::Type>>,
127 },
128 #[error("The `if` condition {0:?} is not a boolean scalar")]
129 InvalidIfType(Handle<crate::Expression>),
130 #[error("The `switch` value {0:?} is not an integer scalar")]
131 InvalidSwitchType(Handle<crate::Expression>),
132 #[error("Multiple `switch` cases for {0:?} are present")]
133 ConflictingSwitchCase(crate::SwitchValue),
134 #[error("The `switch` contains cases with conflicting types")]
135 ConflictingCaseType,
136 #[error("The `switch` is missing a `default` case")]
137 MissingDefaultCase,
138 #[error("Multiple `default` cases are present")]
139 MultipleDefaultCases,
140 #[error("The last `switch` case contains a `fallthrough`")]
141 LastCaseFallTrough,
142 #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
143 InvalidStorePointer(Handle<crate::Expression>),
144 #[error("Image store texture parameter type mismatch")]
145 InvalidStoreTexture {
146 actual: Handle<crate::Expression>,
147 actual_ty: crate::TypeInner,
148 },
149 #[error("Image store value parameter type mismatch")]
150 InvalidStoreValue {
151 actual: Handle<crate::Expression>,
152 actual_ty: crate::TypeInner,
153 expected_ty: crate::TypeInner,
154 },
155 #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")]
156 InvalidStoreTypes {
157 pointer: Handle<crate::Expression>,
158 value: Handle<crate::Expression>,
159 },
160 #[error("Image store parameters are invalid")]
161 InvalidImageStore(#[source] ExpressionError),
162 #[error("Image atomic parameters are invalid")]
163 InvalidImageAtomic(#[source] ExpressionError),
164 #[error("Image atomic function is invalid")]
165 InvalidImageAtomicFunction(crate::AtomicFunction),
166 #[error("Image atomic value is invalid")]
167 InvalidImageAtomicValue(Handle<crate::Expression>),
168 #[error("Call to {function:?} is invalid")]
169 InvalidCall {
170 function: Handle<crate::Function>,
171 #[source]
172 error: CallError,
173 },
174 #[error("Atomic operation is invalid")]
175 InvalidAtomic(#[from] AtomicError),
176 #[error("Ray Query {0:?} is not a local variable")]
177 InvalidRayQueryExpression(Handle<crate::Expression>),
178 #[error("Acceleration structure {0:?} is not a matching expression")]
179 InvalidAccelerationStructure(Handle<crate::Expression>),
180 #[error(
181 "Acceleration structure {0:?} is missing flag vertex_return while Ray Query {1:?} does"
182 )]
183 MissingAccelerationStructureVertexReturn(Handle<crate::Expression>, Handle<crate::Expression>),
184 #[error("Ray Query {0:?} is missing flag vertex_return")]
185 MissingRayQueryVertexReturn(Handle<crate::Expression>),
186 #[error("Ray descriptor {0:?} is not a matching expression")]
187 InvalidRayDescriptor(Handle<crate::Expression>),
188 #[error("Ray Query {0:?} does not have a matching type")]
189 InvalidRayQueryType(Handle<crate::Type>),
190 #[error("Hit distance {0:?} must be an f32")]
191 InvalidHitDistanceType(Handle<crate::Expression>),
192 #[error("Shader requires capability {0:?}")]
193 MissingCapability(super::Capabilities),
194 #[error(
195 "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
196 )]
197 NonUniformControlFlow(
198 UniformityRequirements,
199 Handle<crate::Expression>,
200 UniformityDisruptor,
201 ),
202 #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
203 PipelineInputRegularFunction { name: String },
204 #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
205 PipelineOutputRegularFunction,
206 #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
207 NonUniformWorkgroupUniformLoad(UniformityDisruptor),
209 #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
211 WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
212 #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
213 WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
214 #[error("Subgroup operation is invalid")]
215 InvalidSubgroup(#[from] SubgroupError),
216 #[error("Emit statement should not cover \"result\" expressions like {0:?}")]
217 EmitResult(Handle<crate::Expression>),
218 #[error("Expression not visited by the appropriate statement")]
219 UnvisitedExpression(Handle<crate::Expression>),
220}
221
222bitflags::bitflags! {
223 #[repr(transparent)]
224 #[derive(Clone, Copy)]
225 struct ControlFlowAbility: u8 {
226 const RETURN = 0x1;
228 const BREAK = 0x2;
230 const CONTINUE = 0x4;
232 }
233}
234
235struct BlockInfo {
236 stages: super::ShaderStages,
237}
238
239struct BlockContext<'a> {
240 abilities: ControlFlowAbility,
241 info: &'a FunctionInfo,
242 expressions: &'a Arena<crate::Expression>,
243 types: &'a UniqueArena<crate::Type>,
244 local_vars: &'a Arena<crate::LocalVariable>,
245 global_vars: &'a Arena<crate::GlobalVariable>,
246 functions: &'a Arena<crate::Function>,
247 special_types: &'a crate::SpecialTypes,
248 prev_infos: &'a [FunctionInfo],
249 return_type: Option<Handle<crate::Type>>,
250 local_expr_kind: &'a crate::proc::ExpressionKindTracker,
251}
252
253impl<'a> BlockContext<'a> {
254 fn new(
255 fun: &'a crate::Function,
256 module: &'a crate::Module,
257 info: &'a FunctionInfo,
258 prev_infos: &'a [FunctionInfo],
259 local_expr_kind: &'a crate::proc::ExpressionKindTracker,
260 ) -> Self {
261 Self {
262 abilities: ControlFlowAbility::RETURN,
263 info,
264 expressions: &fun.expressions,
265 types: &module.types,
266 local_vars: &fun.local_variables,
267 global_vars: &module.global_variables,
268 functions: &module.functions,
269 special_types: &module.special_types,
270 prev_infos,
271 return_type: fun.result.as_ref().map(|fr| fr.ty),
272 local_expr_kind,
273 }
274 }
275
276 const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
277 BlockContext { abilities, ..*self }
278 }
279
280 fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
281 &self.expressions[handle]
282 }
283
284 fn resolve_type_impl(
285 &self,
286 handle: Handle<crate::Expression>,
287 valid_expressions: &HandleSet<crate::Expression>,
288 ) -> Result<&TypeResolution, WithSpan<ExpressionError>> {
289 if !valid_expressions.contains(handle) {
290 Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
291 } else {
292 Ok(&self.info[handle].ty)
293 }
294 }
295
296 fn resolve_type(
297 &self,
298 handle: Handle<crate::Expression>,
299 valid_expressions: &HandleSet<crate::Expression>,
300 ) -> Result<&TypeResolution, WithSpan<FunctionError>> {
301 self.resolve_type_impl(handle, valid_expressions)
302 .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
303 }
304
305 fn resolve_type_inner(
306 &self,
307 handle: Handle<crate::Expression>,
308 valid_expressions: &HandleSet<crate::Expression>,
309 ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
310 self.resolve_type(handle, valid_expressions)
311 .map(|tr| tr.inner_with(self.types))
312 }
313
314 fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
315 self.info[handle].ty.inner_with(self.types)
316 }
317
318 fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
319 crate::proc::compare_types(lhs, rhs, self.types)
320 }
321}
322
323impl super::Validator {
324 fn validate_call(
325 &mut self,
326 function: Handle<crate::Function>,
327 arguments: &[Handle<crate::Expression>],
328 result: Option<Handle<crate::Expression>>,
329 context: &BlockContext,
330 ) -> Result<super::ShaderStages, WithSpan<CallError>> {
331 let fun = &context.functions[function];
332 if fun.arguments.len() != arguments.len() {
333 return Err(CallError::ArgumentCount {
334 required: fun.arguments.len(),
335 seen: arguments.len(),
336 }
337 .with_span());
338 }
339 for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
340 let ty = context
341 .resolve_type_impl(expr, &self.valid_expression_set)
342 .map_err_inner(|source| {
343 CallError::Argument { index, source }
344 .with_span_handle(expr, context.expressions)
345 })?;
346 if !context.compare_types(&TypeResolution::Handle(arg.ty), ty) {
347 return Err(CallError::ArgumentType {
348 index,
349 required: arg.ty,
350 seen_expression: expr,
351 }
352 .with_span_handle(expr, context.expressions));
353 }
354 }
355
356 if let Some(expr) = result {
357 if self.valid_expression_set.insert(expr) {
358 self.valid_expression_list.push(expr);
359 } else {
360 return Err(CallError::ResultAlreadyInScope(expr)
361 .with_span_handle(expr, context.expressions));
362 }
363 match context.expressions[expr] {
364 crate::Expression::CallResult(callee)
365 if fun.result.is_some() && callee == function =>
366 {
367 if !self.needs_visit.remove(expr) {
368 return Err(CallError::ResultAlreadyPopulated(expr)
369 .with_span_handle(expr, context.expressions));
370 }
371 }
372 _ => {
373 return Err(CallError::ExpressionMismatch(result)
374 .with_span_handle(expr, context.expressions))
375 }
376 }
377 } else if fun.result.is_some() {
378 return Err(CallError::ExpressionMismatch(result).with_span());
379 }
380
381 let callee_info = &context.prev_infos[function.index()];
382 Ok(callee_info.available_stages)
383 }
384
385 fn emit_expression(
386 &mut self,
387 handle: Handle<crate::Expression>,
388 context: &BlockContext,
389 ) -> Result<(), WithSpan<FunctionError>> {
390 if self.valid_expression_set.insert(handle) {
391 self.valid_expression_list.push(handle);
392 Ok(())
393 } else {
394 Err(FunctionError::ExpressionAlreadyInScope(handle)
395 .with_span_handle(handle, context.expressions))
396 }
397 }
398
399 fn validate_atomic(
400 &mut self,
401 pointer: Handle<crate::Expression>,
402 fun: &crate::AtomicFunction,
403 value: Handle<crate::Expression>,
404 result: Option<Handle<crate::Expression>>,
405 span: crate::Span,
406 context: &BlockContext,
407 ) -> Result<(), WithSpan<FunctionError>> {
408 let pointer_inner = context.resolve_type_inner(pointer, &self.valid_expression_set)?;
410 let crate::TypeInner::Pointer {
411 base: pointer_base,
412 space: pointer_space,
413 } = *pointer_inner
414 else {
415 log::error!("Atomic operation on type {:?}", *pointer_inner);
416 return Err(AtomicError::InvalidPointer(pointer)
417 .with_span_handle(pointer, context.expressions)
418 .into_other());
419 };
420 let crate::TypeInner::Atomic(pointer_scalar) = context.types[pointer_base].inner else {
421 log::error!(
422 "Atomic pointer to type {:?}",
423 context.types[pointer_base].inner
424 );
425 return Err(AtomicError::InvalidPointer(pointer)
426 .with_span_handle(pointer, context.expressions)
427 .into_other());
428 };
429
430 let value_inner = context.resolve_type_inner(value, &self.valid_expression_set)?;
432 let crate::TypeInner::Scalar(value_scalar) = *value_inner else {
433 log::error!("Atomic operand type {:?}", *value_inner);
434 return Err(AtomicError::InvalidOperand(value)
435 .with_span_handle(value, context.expressions)
436 .into_other());
437 };
438 if pointer_scalar != value_scalar {
439 log::error!("Atomic operand type {:?}", *value_inner);
440 return Err(AtomicError::InvalidOperand(value)
441 .with_span_handle(value, context.expressions)
442 .into_other());
443 }
444
445 match pointer_scalar {
446 crate::Scalar::I64 | crate::Scalar::U64 => {
452 if self
455 .capabilities
456 .contains(super::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS)
457 {
458 } else {
460 if matches!(
463 *fun,
464 crate::AtomicFunction::Min | crate::AtomicFunction::Max
465 ) && matches!(pointer_space, crate::AddressSpace::Storage { .. })
466 && result.is_none()
467 {
468 if !self
469 .capabilities
470 .contains(super::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX)
471 {
472 log::error!("Int64 min-max atomic operations are not supported");
473 return Err(AtomicError::MissingCapability(
474 super::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX,
475 )
476 .with_span_handle(value, context.expressions)
477 .into_other());
478 }
479 } else {
480 log::error!("Int64 atomic operations are not supported");
482 return Err(AtomicError::MissingCapability(
483 super::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS,
484 )
485 .with_span_handle(value, context.expressions)
486 .into_other());
487 }
488 }
489 }
490 crate::Scalar::F32 => {
492 if !self
496 .capabilities
497 .contains(super::Capabilities::SHADER_FLOAT32_ATOMIC)
498 {
499 log::error!("Float32 atomic operations are not supported");
500 return Err(AtomicError::MissingCapability(
501 super::Capabilities::SHADER_FLOAT32_ATOMIC,
502 )
503 .with_span_handle(value, context.expressions)
504 .into_other());
505 }
506 if !matches!(
507 *fun,
508 crate::AtomicFunction::Add
509 | crate::AtomicFunction::Subtract
510 | crate::AtomicFunction::Exchange { compare: None }
511 ) {
512 log::error!("Float32 atomic operation {:?} is not supported", fun);
513 return Err(AtomicError::InvalidOperator(*fun)
514 .with_span_handle(value, context.expressions)
515 .into_other());
516 }
517 if !matches!(pointer_space, crate::AddressSpace::Storage { .. }) {
518 log::error!(
519 "Float32 atomic operations are only supported in the Storage address space"
520 );
521 return Err(AtomicError::InvalidAddressSpace(pointer_space)
522 .with_span_handle(value, context.expressions)
523 .into_other());
524 }
525 }
526 _ => {}
527 }
528
529 match result {
531 Some(result) => {
532 let crate::Expression::AtomicResult {
534 ty: result_ty,
535 comparison,
536 } = context.expressions[result]
537 else {
538 return Err(AtomicError::InvalidResultExpression(result)
539 .with_span_handle(result, context.expressions)
540 .into_other());
541 };
542
543 if !self.needs_visit.remove(result) {
546 return Err(AtomicError::ResultAlreadyPopulated(result)
547 .with_span_handle(result, context.expressions)
548 .into_other());
549 }
550
551 if let crate::AtomicFunction::Exchange {
553 compare: Some(compare),
554 } = *fun
555 {
556 let compare_inner =
559 context.resolve_type_inner(compare, &self.valid_expression_set)?;
560 if !compare_inner.non_struct_equivalent(value_inner, context.types) {
561 log::error!(
562 "Atomic exchange comparison has a different type from the value"
563 );
564 return Err(AtomicError::InvalidOperand(compare)
565 .with_span_handle(compare, context.expressions)
566 .into_other());
567 }
568
569 let crate::TypeInner::Struct { ref members, .. } =
573 context.types[result_ty].inner
574 else {
575 return Err(AtomicError::ResultTypeMismatch(result)
576 .with_span_handle(result, context.expressions)
577 .into_other());
578 };
579 if !validate_atomic_compare_exchange_struct(
580 context.types,
581 members,
582 |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(pointer_scalar),
583 ) {
584 return Err(AtomicError::ResultTypeMismatch(result)
585 .with_span_handle(result, context.expressions)
586 .into_other());
587 }
588
589 if !comparison {
591 return Err(AtomicError::ResultExpressionNotExchange(result)
592 .with_span_handle(result, context.expressions)
593 .into_other());
594 }
595 } else {
596 let result_inner = &context.types[result_ty].inner;
599 if !result_inner.non_struct_equivalent(value_inner, context.types) {
600 return Err(AtomicError::ResultTypeMismatch(result)
601 .with_span_handle(result, context.expressions)
602 .into_other());
603 }
604
605 if comparison {
607 return Err(AtomicError::ResultExpressionExchange(result)
608 .with_span_handle(result, context.expressions)
609 .into_other());
610 }
611 }
612 self.emit_expression(result, context)?;
613 }
614
615 None => {
616 if let crate::AtomicFunction::Exchange { compare: None } = *fun {
618 log::error!("Atomic exchange's value is unused");
619 return Err(AtomicError::MissingReturnValue
620 .with_span_static(span, "atomic exchange operation")
621 .into_other());
622 }
623 }
624 }
625
626 Ok(())
627 }
628 fn validate_subgroup_operation(
629 &mut self,
630 op: &crate::SubgroupOperation,
631 collective_op: &crate::CollectiveOperation,
632 argument: Handle<crate::Expression>,
633 result: Handle<crate::Expression>,
634 context: &BlockContext,
635 ) -> Result<(), WithSpan<FunctionError>> {
636 let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
637
638 let (is_scalar, scalar) = match *argument_inner {
639 crate::TypeInner::Scalar(scalar) => (true, scalar),
640 crate::TypeInner::Vector { scalar, .. } => (false, scalar),
641 _ => {
642 log::error!("Subgroup operand type {:?}", argument_inner);
643 return Err(SubgroupError::InvalidOperand(argument)
644 .with_span_handle(argument, context.expressions)
645 .into_other());
646 }
647 };
648
649 use crate::ScalarKind as sk;
650 use crate::SubgroupOperation as sg;
651 match (scalar.kind, *op) {
652 (sk::Bool, sg::All | sg::Any) if is_scalar => {}
653 (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
654 (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
655
656 (_, _) => {
657 log::error!("Subgroup operand type {:?}", argument_inner);
658 return Err(SubgroupError::InvalidOperand(argument)
659 .with_span_handle(argument, context.expressions)
660 .into_other());
661 }
662 };
663
664 use crate::CollectiveOperation as co;
665 match (*collective_op, *op) {
666 (
667 co::Reduce,
668 sg::All
669 | sg::Any
670 | sg::Add
671 | sg::Mul
672 | sg::Min
673 | sg::Max
674 | sg::And
675 | sg::Or
676 | sg::Xor,
677 ) => {}
678 (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
679
680 (_, _) => {
681 return Err(SubgroupError::UnknownOperation.with_span().into_other());
682 }
683 };
684
685 self.emit_expression(result, context)?;
686 match context.expressions[result] {
687 crate::Expression::SubgroupOperationResult { ty }
688 if { &context.types[ty].inner == argument_inner } => {}
689 _ => {
690 return Err(SubgroupError::ResultTypeMismatch(result)
691 .with_span_handle(result, context.expressions)
692 .into_other())
693 }
694 }
695 Ok(())
696 }
697 fn validate_subgroup_gather(
698 &mut self,
699 mode: &crate::GatherMode,
700 argument: Handle<crate::Expression>,
701 result: Handle<crate::Expression>,
702 context: &BlockContext,
703 ) -> Result<(), WithSpan<FunctionError>> {
704 match *mode {
705 crate::GatherMode::BroadcastFirst => {}
706 crate::GatherMode::Broadcast(index)
707 | crate::GatherMode::Shuffle(index)
708 | crate::GatherMode::ShuffleDown(index)
709 | crate::GatherMode::ShuffleUp(index)
710 | crate::GatherMode::ShuffleXor(index)
711 | crate::GatherMode::QuadBroadcast(index) => {
712 let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?;
713 match *index_ty {
714 crate::TypeInner::Scalar(crate::Scalar::U32) => {}
715 _ => {
716 log::error!(
717 "Subgroup gather index type {:?}, expected unsigned int",
718 index_ty
719 );
720 return Err(SubgroupError::InvalidOperand(argument)
721 .with_span_handle(index, context.expressions)
722 .into_other());
723 }
724 }
725 }
726 crate::GatherMode::QuadSwap(_) => {}
727 }
728 match *mode {
729 crate::GatherMode::Broadcast(index) | crate::GatherMode::QuadBroadcast(index) => {
730 if !context.local_expr_kind.is_const(index) {
731 return Err(SubgroupError::InvalidInvocationIdExprType(index)
732 .with_span_handle(index, context.expressions)
733 .into_other());
734 }
735 }
736 _ => {}
737 }
738 let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
739 if !matches!(*argument_inner,
740 crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
741 if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
742 ) {
743 log::error!("Subgroup gather operand type {:?}", argument_inner);
744 return Err(SubgroupError::InvalidOperand(argument)
745 .with_span_handle(argument, context.expressions)
746 .into_other());
747 }
748
749 self.emit_expression(result, context)?;
750 match context.expressions[result] {
751 crate::Expression::SubgroupOperationResult { ty }
752 if { &context.types[ty].inner == argument_inner } => {}
753 _ => {
754 return Err(SubgroupError::ResultTypeMismatch(result)
755 .with_span_handle(result, context.expressions)
756 .into_other())
757 }
758 }
759 Ok(())
760 }
761
762 fn validate_block_impl(
763 &mut self,
764 statements: &crate::Block,
765 context: &BlockContext,
766 ) -> Result<BlockInfo, WithSpan<FunctionError>> {
767 use crate::{AddressSpace, Statement as S, TypeInner as Ti};
768 let mut stages = super::ShaderStages::all();
769 for (statement, &span) in statements.span_iter() {
770 match *statement {
771 S::Emit(ref range) => {
772 for handle in range.clone() {
773 use crate::Expression as Ex;
774 match context.expressions[handle] {
775 Ex::Literal(_)
776 | Ex::Constant(_)
777 | Ex::Override(_)
778 | Ex::ZeroValue(_)
779 | Ex::Compose { .. }
780 | Ex::Access { .. }
781 | Ex::AccessIndex { .. }
782 | Ex::Splat { .. }
783 | Ex::Swizzle { .. }
784 | Ex::FunctionArgument(_)
785 | Ex::GlobalVariable(_)
786 | Ex::LocalVariable(_)
787 | Ex::Load { .. }
788 | Ex::ImageSample { .. }
789 | Ex::ImageLoad { .. }
790 | Ex::ImageQuery { .. }
791 | Ex::Unary { .. }
792 | Ex::Binary { .. }
793 | Ex::Select { .. }
794 | Ex::Derivative { .. }
795 | Ex::Relational { .. }
796 | Ex::Math { .. }
797 | Ex::As { .. }
798 | Ex::ArrayLength(_)
799 | Ex::RayQueryGetIntersection { .. }
800 | Ex::RayQueryVertexPositions { .. } => {
801 self.emit_expression(handle, context)?
802 }
803 Ex::CallResult(_)
804 | Ex::AtomicResult { .. }
805 | Ex::WorkGroupUniformLoadResult { .. }
806 | Ex::RayQueryProceedResult
807 | Ex::SubgroupBallotResult
808 | Ex::SubgroupOperationResult { .. } => {
809 return Err(FunctionError::EmitResult(handle)
810 .with_span_handle(handle, context.expressions));
811 }
812 }
813 }
814 }
815 S::Block(ref block) => {
816 let info = self.validate_block(block, context)?;
817 stages &= info.stages;
818 }
819 S::If {
820 condition,
821 ref accept,
822 ref reject,
823 } => {
824 match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
825 Ti::Scalar(crate::Scalar {
826 kind: crate::ScalarKind::Bool,
827 width: _,
828 }) => {}
829 _ => {
830 return Err(FunctionError::InvalidIfType(condition)
831 .with_span_handle(condition, context.expressions))
832 }
833 }
834 stages &= self.validate_block(accept, context)?.stages;
835 stages &= self.validate_block(reject, context)?.stages;
836 }
837 S::Switch {
838 selector,
839 ref cases,
840 } => {
841 let uint = match context
842 .resolve_type_inner(selector, &self.valid_expression_set)?
843 .scalar_kind()
844 {
845 Some(crate::ScalarKind::Uint) => true,
846 Some(crate::ScalarKind::Sint) => false,
847 _ => {
848 return Err(FunctionError::InvalidSwitchType(selector)
849 .with_span_handle(selector, context.expressions))
850 }
851 };
852 self.switch_values.clear();
853 for case in cases {
854 match case.value {
855 crate::SwitchValue::I32(_) if !uint => {}
856 crate::SwitchValue::U32(_) if uint => {}
857 crate::SwitchValue::Default => {}
858 _ => {
859 return Err(FunctionError::ConflictingCaseType.with_span_static(
860 case.body
861 .span_iter()
862 .next()
863 .map_or(Default::default(), |(_, s)| *s),
864 "conflicting switch arm here",
865 ));
866 }
867 };
868 if !self.switch_values.insert(case.value) {
869 return Err(match case.value {
870 crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
871 .with_span_static(
872 case.body
873 .span_iter()
874 .next()
875 .map_or(Default::default(), |(_, s)| *s),
876 "duplicated switch arm here",
877 ),
878 _ => FunctionError::ConflictingSwitchCase(case.value)
879 .with_span_static(
880 case.body
881 .span_iter()
882 .next()
883 .map_or(Default::default(), |(_, s)| *s),
884 "conflicting switch arm here",
885 ),
886 });
887 }
888 }
889 if !self.switch_values.contains(&crate::SwitchValue::Default) {
890 return Err(FunctionError::MissingDefaultCase
891 .with_span_static(span, "missing default case"));
892 }
893 if let Some(case) = cases.last() {
894 if case.fall_through {
895 return Err(FunctionError::LastCaseFallTrough.with_span_static(
896 case.body
897 .span_iter()
898 .next()
899 .map_or(Default::default(), |(_, s)| *s),
900 "bad switch arm here",
901 ));
902 }
903 }
904 let pass_through_abilities = context.abilities
905 & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
906 let sub_context =
907 context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
908 for case in cases {
909 stages &= self.validate_block(&case.body, &sub_context)?.stages;
910 }
911 }
912 S::Loop {
913 ref body,
914 ref continuing,
915 break_if,
916 } => {
917 let base_expression_count = self.valid_expression_list.len();
920 let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
921 stages &= self
922 .validate_block_impl(
923 body,
924 &context.with_abilities(
925 pass_through_abilities
926 | ControlFlowAbility::BREAK
927 | ControlFlowAbility::CONTINUE,
928 ),
929 )?
930 .stages;
931 stages &= self
932 .validate_block_impl(
933 continuing,
934 &context.with_abilities(ControlFlowAbility::empty()),
935 )?
936 .stages;
937
938 if let Some(condition) = break_if {
939 match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
940 Ti::Scalar(crate::Scalar {
941 kind: crate::ScalarKind::Bool,
942 width: _,
943 }) => {}
944 _ => {
945 return Err(FunctionError::InvalidIfType(condition)
946 .with_span_handle(condition, context.expressions))
947 }
948 }
949 }
950
951 for handle in self.valid_expression_list.drain(base_expression_count..) {
952 self.valid_expression_set.remove(handle);
953 }
954 }
955 S::Break => {
956 if !context.abilities.contains(ControlFlowAbility::BREAK) {
957 return Err(FunctionError::BreakOutsideOfLoopOrSwitch
958 .with_span_static(span, "invalid break"));
959 }
960 }
961 S::Continue => {
962 if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
963 return Err(FunctionError::ContinueOutsideOfLoop
964 .with_span_static(span, "invalid continue"));
965 }
966 }
967 S::Return { value } => {
968 if !context.abilities.contains(ControlFlowAbility::RETURN) {
969 return Err(FunctionError::InvalidReturnSpot
970 .with_span_static(span, "invalid return"));
971 }
972 let value_ty = value
973 .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
974 .transpose()?;
975 let okay = match (value_ty, context.return_type) {
978 (None, None) => true,
979 (Some(value_inner), Some(expected_ty)) => {
980 context.compare_types(value_inner, &TypeResolution::Handle(expected_ty))
981 }
982 (_, _) => false,
983 };
984
985 if !okay {
986 log::error!(
987 "Returning {:?} where {:?} is expected",
988 value_ty,
989 context.return_type,
990 );
991 if let Some(handle) = value {
992 return Err(FunctionError::InvalidReturnType {
993 expression: value,
994 expected_ty: context.return_type,
995 }
996 .with_span_handle(handle, context.expressions));
997 } else {
998 return Err(FunctionError::InvalidReturnType {
999 expression: value,
1000 expected_ty: context.return_type,
1001 }
1002 .with_span_static(span, "invalid return"));
1003 }
1004 }
1005 }
1006 S::Kill => {
1007 stages &= super::ShaderStages::FRAGMENT;
1008 }
1009 S::ControlBarrier(barrier) | S::MemoryBarrier(barrier) => {
1010 stages &= super::ShaderStages::COMPUTE;
1011 if barrier.contains(crate::Barrier::SUB_GROUP) {
1012 if !self.capabilities.contains(
1013 super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
1014 ) {
1015 return Err(FunctionError::MissingCapability(
1016 super::Capabilities::SUBGROUP
1017 | super::Capabilities::SUBGROUP_BARRIER,
1018 )
1019 .with_span_static(span, "missing capability for this operation"));
1020 }
1021 if !self
1022 .subgroup_operations
1023 .contains(super::SubgroupOperationSet::BASIC)
1024 {
1025 return Err(FunctionError::InvalidSubgroup(
1026 SubgroupError::UnsupportedOperation(
1027 super::SubgroupOperationSet::BASIC,
1028 ),
1029 )
1030 .with_span_static(span, "support for this operation is not present"));
1031 }
1032 }
1033 }
1034 S::Store { pointer, value } => {
1035 let mut current = pointer;
1036 loop {
1037 match context.expressions[current] {
1038 crate::Expression::Access { base, .. }
1039 | crate::Expression::AccessIndex { base, .. } => current = base,
1040 crate::Expression::LocalVariable(_)
1041 | crate::Expression::GlobalVariable(_)
1042 | crate::Expression::FunctionArgument(_) => break,
1043 _ => {
1044 return Err(FunctionError::InvalidStorePointer(current)
1045 .with_span_handle(pointer, context.expressions))
1046 }
1047 }
1048 }
1049
1050 let value_tr = context.resolve_type(value, &self.valid_expression_set)?;
1051 let value_ty = value_tr.inner_with(context.types);
1052 match *value_ty {
1053 Ti::Image { .. } | Ti::Sampler { .. } => {
1054 return Err(FunctionError::InvalidStoreTexture {
1055 actual: value,
1056 actual_ty: value_ty.clone(),
1057 }
1058 .with_span_context((
1059 context.expressions.get_span(value),
1060 format!("this value is of type {value_ty:?}"),
1061 ))
1062 .with_span(span, "expects a texture argument"));
1063 }
1064 _ => {}
1065 }
1066
1067 let pointer_ty = context.resolve_pointer_type(pointer);
1068 let pointer_base_tr = pointer_ty.pointer_base_type();
1069 let pointer_base_ty = pointer_base_tr
1070 .as_ref()
1071 .map(|ty| ty.inner_with(context.types));
1072 let good = if let Some(&Ti::Atomic(ref scalar)) = pointer_base_ty {
1073 *value_ty == Ti::Scalar(*scalar)
1075 } else if let Some(tr) = pointer_base_tr {
1076 context.compare_types(value_tr, &tr)
1077 } else {
1078 false
1079 };
1080
1081 if !good {
1082 return Err(FunctionError::InvalidStoreTypes { pointer, value }
1083 .with_span()
1084 .with_handle(pointer, context.expressions)
1085 .with_handle(value, context.expressions));
1086 }
1087
1088 if let Some(space) = pointer_ty.pointer_space() {
1089 if !space.access().contains(crate::StorageAccess::STORE) {
1090 return Err(FunctionError::InvalidStorePointer(pointer)
1091 .with_span_static(
1092 context.expressions.get_span(pointer),
1093 "writing to this location is not permitted",
1094 ));
1095 }
1096 }
1097 }
1098 S::ImageStore {
1099 image,
1100 coordinate,
1101 array_index,
1102 value,
1103 } => {
1104 let global_var;
1107 let image_ty;
1108 match *context.get_expression(image) {
1109 crate::Expression::GlobalVariable(var_handle) => {
1110 global_var = &context.global_vars[var_handle];
1111 image_ty = global_var.ty;
1112 }
1113 crate::Expression::Access { base, .. }
1117 | crate::Expression::AccessIndex { base, .. } => {
1118 let crate::Expression::GlobalVariable(var_handle) =
1119 *context.get_expression(base)
1120 else {
1121 return Err(FunctionError::InvalidImageStore(
1122 ExpressionError::ExpectedGlobalVariable,
1123 )
1124 .with_span_handle(image, context.expressions));
1125 };
1126 global_var = &context.global_vars[var_handle];
1127
1128 let Ti::BindingArray { base, .. } = context.types[global_var.ty].inner
1130 else {
1131 return Err(FunctionError::InvalidImageStore(
1132 ExpressionError::ExpectedBindingArrayType(global_var.ty),
1133 )
1134 .with_span_handle(global_var.ty, context.types));
1135 };
1136
1137 image_ty = base;
1138 }
1139 _ => {
1140 return Err(FunctionError::InvalidImageStore(
1141 ExpressionError::ExpectedGlobalVariable,
1142 )
1143 .with_span_handle(image, context.expressions))
1144 }
1145 };
1146
1147 let Ti::Image {
1149 class,
1150 arrayed,
1151 dim,
1152 } = context.types[image_ty].inner
1153 else {
1154 return Err(FunctionError::InvalidImageStore(
1155 ExpressionError::ExpectedImageType(global_var.ty),
1156 )
1157 .with_span()
1158 .with_handle(global_var.ty, context.types)
1159 .with_handle(image, context.expressions));
1160 };
1161
1162 let crate::ImageClass::Storage { format, .. } = class else {
1164 return Err(FunctionError::InvalidImageStore(
1165 ExpressionError::InvalidImageClass(class),
1166 )
1167 .with_span_handle(image, context.expressions));
1168 };
1169
1170 if context
1172 .resolve_type_inner(coordinate, &self.valid_expression_set)?
1173 .image_storage_coordinates()
1174 .is_none_or(|coord_dim| coord_dim != dim)
1175 {
1176 return Err(FunctionError::InvalidImageStore(
1177 ExpressionError::InvalidImageCoordinateType(dim, coordinate),
1178 )
1179 .with_span_handle(coordinate, context.expressions));
1180 }
1181
1182 if arrayed != array_index.is_some() {
1185 return Err(FunctionError::InvalidImageStore(
1186 ExpressionError::InvalidImageArrayIndex,
1187 )
1188 .with_span_handle(coordinate, context.expressions));
1189 }
1190
1191 if let Some(expr) = array_index {
1193 if !matches!(
1194 *context.resolve_type_inner(expr, &self.valid_expression_set)?,
1195 Ti::Scalar(crate::Scalar {
1196 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
1197 width: _,
1198 })
1199 ) {
1200 return Err(FunctionError::InvalidImageStore(
1201 ExpressionError::InvalidImageArrayIndexType(expr),
1202 )
1203 .with_span_handle(expr, context.expressions));
1204 }
1205 }
1206
1207 let value_ty = crate::TypeInner::Vector {
1208 size: crate::VectorSize::Quad,
1209 scalar: format.into(),
1210 };
1211
1212 let actual_value_ty =
1215 context.resolve_type_inner(value, &self.valid_expression_set)?;
1216 if actual_value_ty != &value_ty {
1217 return Err(FunctionError::InvalidStoreValue {
1218 actual: value,
1219 actual_ty: actual_value_ty.clone(),
1220 expected_ty: value_ty.clone(),
1221 }
1222 .with_span_context((
1223 context.expressions.get_span(value),
1224 format!("this value is of type {actual_value_ty:?}"),
1225 ))
1226 .with_span(
1227 span,
1228 format!("expects a value argument of type {value_ty:?}"),
1229 ));
1230 }
1231 }
1232 S::Call {
1233 function,
1234 ref arguments,
1235 result,
1236 } => match self.validate_call(function, arguments, result, context) {
1237 Ok(callee_stages) => stages &= callee_stages,
1238 Err(error) => {
1239 return Err(error.and_then(|error| {
1240 FunctionError::InvalidCall { function, error }
1241 .with_span_static(span, "invalid function call")
1242 }))
1243 }
1244 },
1245 S::Atomic {
1246 pointer,
1247 ref fun,
1248 value,
1249 result,
1250 } => {
1251 self.validate_atomic(pointer, fun, value, result, span, context)?;
1252 }
1253 S::ImageAtomic {
1254 image,
1255 coordinate,
1256 array_index,
1257 fun,
1258 value,
1259 } => {
1260 let var = match *context.get_expression(image) {
1261 crate::Expression::GlobalVariable(var_handle) => {
1262 &context.global_vars[var_handle]
1263 }
1264 crate::Expression::Access { base, .. }
1266 | crate::Expression::AccessIndex { base, .. } => {
1267 match *context.get_expression(base) {
1268 crate::Expression::GlobalVariable(var_handle) => {
1269 &context.global_vars[var_handle]
1270 }
1271 _ => {
1272 return Err(FunctionError::InvalidImageAtomic(
1273 ExpressionError::ExpectedGlobalVariable,
1274 )
1275 .with_span_handle(image, context.expressions))
1276 }
1277 }
1278 }
1279 _ => {
1280 return Err(FunctionError::InvalidImageAtomic(
1281 ExpressionError::ExpectedGlobalVariable,
1282 )
1283 .with_span_handle(image, context.expressions))
1284 }
1285 };
1286
1287 let global_ty = match context.types[var.ty].inner {
1289 Ti::BindingArray { base, .. } => &context.types[base].inner,
1290 ref inner => inner,
1291 };
1292
1293 let value_ty = match *global_ty {
1294 Ti::Image {
1295 class,
1296 arrayed,
1297 dim,
1298 } => {
1299 match context
1300 .resolve_type_inner(coordinate, &self.valid_expression_set)?
1301 .image_storage_coordinates()
1302 {
1303 Some(coord_dim) if coord_dim == dim => {}
1304 _ => {
1305 return Err(FunctionError::InvalidImageAtomic(
1306 ExpressionError::InvalidImageCoordinateType(
1307 dim, coordinate,
1308 ),
1309 )
1310 .with_span_handle(coordinate, context.expressions));
1311 }
1312 };
1313 if arrayed != array_index.is_some() {
1314 return Err(FunctionError::InvalidImageAtomic(
1315 ExpressionError::InvalidImageArrayIndex,
1316 )
1317 .with_span_handle(coordinate, context.expressions));
1318 }
1319 if let Some(expr) = array_index {
1320 match *context
1321 .resolve_type_inner(expr, &self.valid_expression_set)?
1322 {
1323 Ti::Scalar(crate::Scalar {
1324 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
1325 width: _,
1326 }) => {}
1327 _ => {
1328 return Err(FunctionError::InvalidImageAtomic(
1329 ExpressionError::InvalidImageArrayIndexType(expr),
1330 )
1331 .with_span_handle(expr, context.expressions));
1332 }
1333 }
1334 }
1335 match class {
1336 crate::ImageClass::Storage { format, access } => {
1337 if !access.contains(crate::StorageAccess::ATOMIC) {
1338 return Err(FunctionError::InvalidImageAtomic(
1339 ExpressionError::InvalidImageStorageAccess(access),
1340 )
1341 .with_span_handle(image, context.expressions));
1342 }
1343 match format {
1344 crate::StorageFormat::R64Uint => {
1345 if !self.capabilities.intersects(
1346 super::Capabilities::TEXTURE_INT64_ATOMIC,
1347 ) {
1348 return Err(FunctionError::MissingCapability(
1349 super::Capabilities::TEXTURE_INT64_ATOMIC,
1350 )
1351 .with_span_static(
1352 span,
1353 "missing capability for this operation",
1354 ));
1355 }
1356 match fun {
1357 crate::AtomicFunction::Min
1358 | crate::AtomicFunction::Max => {}
1359 _ => {
1360 return Err(
1361 FunctionError::InvalidImageAtomicFunction(
1362 fun,
1363 )
1364 .with_span_handle(
1365 image,
1366 context.expressions,
1367 ),
1368 );
1369 }
1370 }
1371 }
1372 crate::StorageFormat::R32Sint
1373 | crate::StorageFormat::R32Uint => {
1374 if !self
1375 .capabilities
1376 .intersects(super::Capabilities::TEXTURE_ATOMIC)
1377 {
1378 return Err(FunctionError::MissingCapability(
1379 super::Capabilities::TEXTURE_ATOMIC,
1380 )
1381 .with_span_static(
1382 span,
1383 "missing capability for this operation",
1384 ));
1385 }
1386 match fun {
1387 crate::AtomicFunction::Add
1388 | crate::AtomicFunction::And
1389 | crate::AtomicFunction::ExclusiveOr
1390 | crate::AtomicFunction::InclusiveOr
1391 | crate::AtomicFunction::Min
1392 | crate::AtomicFunction::Max => {}
1393 _ => {
1394 return Err(
1395 FunctionError::InvalidImageAtomicFunction(
1396 fun,
1397 )
1398 .with_span_handle(
1399 image,
1400 context.expressions,
1401 ),
1402 );
1403 }
1404 }
1405 }
1406 _ => {
1407 return Err(FunctionError::InvalidImageAtomic(
1408 ExpressionError::InvalidImageFormat(format),
1409 )
1410 .with_span_handle(image, context.expressions));
1411 }
1412 }
1413 crate::TypeInner::Scalar(format.into())
1414 }
1415 _ => {
1416 return Err(FunctionError::InvalidImageAtomic(
1417 ExpressionError::InvalidImageClass(class),
1418 )
1419 .with_span_handle(image, context.expressions));
1420 }
1421 }
1422 }
1423 _ => {
1424 return Err(FunctionError::InvalidImageAtomic(
1425 ExpressionError::ExpectedImageType(var.ty),
1426 )
1427 .with_span()
1428 .with_handle(var.ty, context.types)
1429 .with_handle(image, context.expressions))
1430 }
1431 };
1432
1433 if *context.resolve_type_inner(value, &self.valid_expression_set)? != value_ty {
1434 return Err(FunctionError::InvalidImageAtomicValue(value)
1435 .with_span_handle(value, context.expressions));
1436 }
1437 }
1438 S::WorkGroupUniformLoad { pointer, result } => {
1439 stages &= super::ShaderStages::COMPUTE;
1440 let pointer_inner =
1441 context.resolve_type_inner(pointer, &self.valid_expression_set)?;
1442 match *pointer_inner {
1443 Ti::Pointer {
1444 space: AddressSpace::WorkGroup,
1445 ..
1446 } => {}
1447 Ti::ValuePointer {
1448 space: AddressSpace::WorkGroup,
1449 ..
1450 } => {}
1451 _ => {
1452 return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1453 .with_span_static(span, "WorkGroupUniformLoad"))
1454 }
1455 }
1456 self.emit_expression(result, context)?;
1457 let ty = match &context.expressions[result] {
1458 &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
1459 _ => {
1460 return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
1461 result,
1462 )
1463 .with_span_static(span, "WorkGroupUniformLoad"));
1464 }
1465 };
1466 let expected_pointer_inner = Ti::Pointer {
1467 base: ty,
1468 space: AddressSpace::WorkGroup,
1469 };
1470 if !expected_pointer_inner.non_struct_equivalent(pointer_inner, context.types) {
1471 return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
1472 .with_span_static(span, "WorkGroupUniformLoad"));
1473 }
1474 }
1475 S::RayQuery { query, ref fun } => {
1476 let query_var = match *context.get_expression(query) {
1477 crate::Expression::LocalVariable(var) => &context.local_vars[var],
1478 ref other => {
1479 log::error!("Unexpected ray query expression {other:?}");
1480 return Err(FunctionError::InvalidRayQueryExpression(query)
1481 .with_span_static(span, "invalid query expression"));
1482 }
1483 };
1484 let rq_vertex_return = match context.types[query_var.ty].inner {
1485 Ti::RayQuery { vertex_return } => vertex_return,
1486 ref other => {
1487 log::error!("Unexpected ray query type {other:?}");
1488 return Err(FunctionError::InvalidRayQueryType(query_var.ty)
1489 .with_span_static(span, "invalid query type"));
1490 }
1491 };
1492 match *fun {
1493 crate::RayQueryFunction::Initialize {
1494 acceleration_structure,
1495 descriptor,
1496 } => {
1497 match *context.resolve_type_inner(
1498 acceleration_structure,
1499 &self.valid_expression_set,
1500 )? {
1501 Ti::AccelerationStructure { vertex_return } => {
1502 if (!vertex_return) && rq_vertex_return {
1503 return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure"));
1504 }
1505 }
1506 _ => {
1507 return Err(FunctionError::InvalidAccelerationStructure(
1508 acceleration_structure,
1509 )
1510 .with_span_static(span, "invalid acceleration structure"))
1511 }
1512 }
1513 let desc_ty_given = context
1514 .resolve_type_inner(descriptor, &self.valid_expression_set)?;
1515 let desc_ty_expected = context
1516 .special_types
1517 .ray_desc
1518 .map(|handle| &context.types[handle].inner);
1519 if Some(desc_ty_given) != desc_ty_expected {
1520 return Err(FunctionError::InvalidRayDescriptor(descriptor)
1521 .with_span_static(span, "invalid ray descriptor"));
1522 }
1523 }
1524 crate::RayQueryFunction::Proceed { result } => {
1525 self.emit_expression(result, context)?;
1526 }
1527 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1528 match *context.resolve_type_inner(hit_t, &self.valid_expression_set)? {
1529 Ti::Scalar(crate::Scalar {
1530 kind: crate::ScalarKind::Float,
1531 width: _,
1532 }) => {}
1533 _ => {
1534 return Err(FunctionError::InvalidHitDistanceType(hit_t)
1535 .with_span_static(span, "invalid hit_t"))
1536 }
1537 }
1538 }
1539 crate::RayQueryFunction::ConfirmIntersection => {}
1540 crate::RayQueryFunction::Terminate => {}
1541 }
1542 }
1543 S::SubgroupBallot { result, predicate } => {
1544 stages &= self.subgroup_stages;
1545 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1546 return Err(FunctionError::MissingCapability(
1547 super::Capabilities::SUBGROUP,
1548 )
1549 .with_span_static(span, "missing capability for this operation"));
1550 }
1551 if !self
1552 .subgroup_operations
1553 .contains(super::SubgroupOperationSet::BALLOT)
1554 {
1555 return Err(FunctionError::InvalidSubgroup(
1556 SubgroupError::UnsupportedOperation(
1557 super::SubgroupOperationSet::BALLOT,
1558 ),
1559 )
1560 .with_span_static(span, "support for this operation is not present"));
1561 }
1562 if let Some(predicate) = predicate {
1563 let predicate_inner =
1564 context.resolve_type_inner(predicate, &self.valid_expression_set)?;
1565 if !matches!(
1566 *predicate_inner,
1567 crate::TypeInner::Scalar(crate::Scalar::BOOL,)
1568 ) {
1569 log::error!(
1570 "Subgroup ballot predicate type {:?} expected bool",
1571 predicate_inner
1572 );
1573 return Err(SubgroupError::InvalidOperand(predicate)
1574 .with_span_handle(predicate, context.expressions)
1575 .into_other());
1576 }
1577 }
1578 self.emit_expression(result, context)?;
1579 }
1580 S::SubgroupCollectiveOperation {
1581 ref op,
1582 ref collective_op,
1583 argument,
1584 result,
1585 } => {
1586 stages &= self.subgroup_stages;
1587 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1588 return Err(FunctionError::MissingCapability(
1589 super::Capabilities::SUBGROUP,
1590 )
1591 .with_span_static(span, "missing capability for this operation"));
1592 }
1593 let operation = op.required_operations();
1594 if !self.subgroup_operations.contains(operation) {
1595 return Err(FunctionError::InvalidSubgroup(
1596 SubgroupError::UnsupportedOperation(operation),
1597 )
1598 .with_span_static(span, "support for this operation is not present"));
1599 }
1600 self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
1601 }
1602 S::SubgroupGather {
1603 ref mode,
1604 argument,
1605 result,
1606 } => {
1607 stages &= self.subgroup_stages;
1608 if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
1609 return Err(FunctionError::MissingCapability(
1610 super::Capabilities::SUBGROUP,
1611 )
1612 .with_span_static(span, "missing capability for this operation"));
1613 }
1614 let operation = mode.required_operations();
1615 if !self.subgroup_operations.contains(operation) {
1616 return Err(FunctionError::InvalidSubgroup(
1617 SubgroupError::UnsupportedOperation(operation),
1618 )
1619 .with_span_static(span, "support for this operation is not present"));
1620 }
1621 self.validate_subgroup_gather(mode, argument, result, context)?;
1622 }
1623 }
1624 }
1625 Ok(BlockInfo { stages })
1626 }
1627
1628 fn validate_block(
1629 &mut self,
1630 statements: &crate::Block,
1631 context: &BlockContext,
1632 ) -> Result<BlockInfo, WithSpan<FunctionError>> {
1633 let base_expression_count = self.valid_expression_list.len();
1634 let info = self.validate_block_impl(statements, context)?;
1635 for handle in self.valid_expression_list.drain(base_expression_count..) {
1636 self.valid_expression_set.remove(handle);
1637 }
1638 Ok(info)
1639 }
1640
1641 fn validate_local_var(
1642 &self,
1643 var: &crate::LocalVariable,
1644 gctx: crate::proc::GlobalCtx,
1645 fun_info: &FunctionInfo,
1646 local_expr_kind: &crate::proc::ExpressionKindTracker,
1647 ) -> Result<(), LocalVariableError> {
1648 log::debug!("var {:?}", var);
1649 let type_info = self
1650 .types
1651 .get(var.ty.index())
1652 .ok_or(LocalVariableError::InvalidType(var.ty))?;
1653 if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) {
1654 return Err(LocalVariableError::InvalidType(var.ty));
1655 }
1656
1657 if let Some(init) = var.init {
1658 if !gctx.compare_types(&TypeResolution::Handle(var.ty), &fun_info[init].ty) {
1659 return Err(LocalVariableError::InitializerType);
1660 }
1661
1662 if !local_expr_kind.is_const_or_override(init) {
1663 return Err(LocalVariableError::NonConstOrOverrideInitializer);
1664 }
1665 }
1666
1667 Ok(())
1668 }
1669
1670 pub(super) fn validate_function(
1671 &mut self,
1672 fun: &crate::Function,
1673 module: &crate::Module,
1674 mod_info: &ModuleInfo,
1675 entry_point: bool,
1676 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1677 let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
1678
1679 let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
1680
1681 for (var_handle, var) in fun.local_variables.iter() {
1682 self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
1683 .map_err(|source| {
1684 FunctionError::LocalVariable {
1685 handle: var_handle,
1686 name: var.name.clone().unwrap_or_default(),
1687 source,
1688 }
1689 .with_span_handle(var.ty, &module.types)
1690 .with_handle(var_handle, &fun.local_variables)
1691 })?;
1692 }
1693
1694 for (index, argument) in fun.arguments.iter().enumerate() {
1695 match module.types[argument.ty].inner.pointer_space() {
1696 Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
1697 Some(other) => {
1698 return Err(FunctionError::InvalidArgumentPointerSpace {
1699 index,
1700 name: argument.name.clone().unwrap_or_default(),
1701 space: other,
1702 }
1703 .with_span_handle(argument.ty, &module.types))
1704 }
1705 }
1706 if !self.types[argument.ty.index()]
1708 .flags
1709 .contains(super::TypeFlags::ARGUMENT)
1710 {
1711 return Err(FunctionError::InvalidArgumentType {
1712 index,
1713 name: argument.name.clone().unwrap_or_default(),
1714 }
1715 .with_span_handle(argument.ty, &module.types));
1716 }
1717
1718 if !entry_point && argument.binding.is_some() {
1719 return Err(FunctionError::PipelineInputRegularFunction {
1720 name: argument.name.clone().unwrap_or_default(),
1721 }
1722 .with_span_handle(argument.ty, &module.types));
1723 }
1724 }
1725
1726 if let Some(ref result) = fun.result {
1727 if !self.types[result.ty.index()]
1728 .flags
1729 .contains(super::TypeFlags::CONSTRUCTIBLE)
1730 {
1731 return Err(FunctionError::NonConstructibleReturnType
1732 .with_span_handle(result.ty, &module.types));
1733 }
1734
1735 if !entry_point && result.binding.is_some() {
1736 return Err(FunctionError::PipelineOutputRegularFunction
1737 .with_span_handle(result.ty, &module.types));
1738 }
1739 }
1740
1741 self.valid_expression_set.clear_for_arena(&fun.expressions);
1742 self.valid_expression_list.clear();
1743 self.needs_visit.clear_for_arena(&fun.expressions);
1744 for (handle, expr) in fun.expressions.iter() {
1745 if expr.needs_pre_emit() {
1746 self.valid_expression_set.insert(handle);
1747 }
1748 if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1749 if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } =
1752 *expr
1753 {
1754 self.needs_visit.insert(handle);
1755 }
1756
1757 match self.validate_expression(
1758 handle,
1759 expr,
1760 fun,
1761 module,
1762 &info,
1763 mod_info,
1764 &local_expr_kind,
1765 ) {
1766 Ok(stages) => info.available_stages &= stages,
1767 Err(source) => {
1768 return Err(FunctionError::Expression { handle, source }
1769 .with_span_handle(handle, &fun.expressions))
1770 }
1771 }
1772 }
1773 }
1774
1775 if self.flags.contains(super::ValidationFlags::BLOCKS) {
1776 let stages = self
1777 .validate_block(
1778 &fun.body,
1779 &BlockContext::new(fun, module, &info, &mod_info.functions, &local_expr_kind),
1780 )?
1781 .stages;
1782 info.available_stages &= stages;
1783
1784 if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
1785 if let Some(handle) = self.needs_visit.iter().next() {
1786 return Err(FunctionError::UnvisitedExpression(handle)
1787 .with_span_handle(handle, &fun.expressions));
1788 }
1789 }
1790 }
1791 Ok(info)
1792 }
1793}