1mod analyzer;
6mod compose;
7mod expression;
8mod function;
9mod handles;
10mod interface;
11mod r#type;
12
13use alloc::{boxed::Box, string::String, vec, vec::Vec};
14use core::ops;
15
16use bit_set::BitSet;
17
18use crate::{
19 arena::{Handle, HandleSet},
20 proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
21 FastHashSet,
22};
23
24use crate::span::{AddSpan as _, WithSpan};
28pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
29pub use compose::ComposeError;
30pub use expression::{check_literal_value, LiteralError};
31pub use expression::{ConstExpressionError, ExpressionError};
32pub use function::{CallError, FunctionError, LocalVariableError};
33pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
34pub use r#type::{Disalignment, PushConstantError, TypeError, TypeFlags, WidthError};
35
36use self::handles::InvalidHandleError;
37
38bitflags::bitflags! {
39 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
53 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
54 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
55 pub struct ValidationFlags: u8 {
56 const EXPRESSIONS = 0x1;
58 const BLOCKS = 0x2;
60 const CONTROL_FLOW_UNIFORMITY = 0x4;
62 const STRUCT_LAYOUTS = 0x8;
64 const CONSTANTS = 0x10;
66 const BINDINGS = 0x20;
68 }
69}
70
71impl Default for ValidationFlags {
72 fn default() -> Self {
73 Self::all()
74 }
75}
76
77bitflags::bitflags! {
78 #[must_use]
80 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
81 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
82 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
83 pub struct Capabilities: u32 {
84 const PUSH_CONSTANT = 1 << 0;
88 const FLOAT64 = 1 << 1;
90 const PRIMITIVE_INDEX = 1 << 2;
94 const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 3;
96 const STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 1 << 4;
98 const UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 1 << 5;
100 const SAMPLER_NON_UNIFORM_INDEXING = 1 << 6;
102 const CLIP_DISTANCE = 1 << 7;
106 const CULL_DISTANCE = 1 << 8;
110 const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 1 << 9;
112 const MULTIVIEW = 1 << 10;
116 const EARLY_DEPTH_TEST = 1 << 11;
118 const MULTISAMPLED_SHADING = 1 << 12;
123 const RAY_QUERY = 1 << 13;
125 const DUAL_SOURCE_BLENDING = 1 << 14;
127 const CUBE_ARRAY_TEXTURES = 1 << 15;
129 const SHADER_INT64 = 1 << 16;
131 const SUBGROUP = 1 << 17;
135 const SUBGROUP_BARRIER = 1 << 18;
137 const SUBGROUP_VERTEX_STAGE = 1 << 19;
139 const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 20;
149 const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 21;
151 const SHADER_FLOAT32_ATOMIC = 1 << 22;
160 const TEXTURE_ATOMIC = 1 << 23;
162 const TEXTURE_INT64_ATOMIC = 1 << 24;
164 const RAY_HIT_VERTEX_POSITION = 1 << 25;
166 const SHADER_FLOAT16 = 1 << 26;
168 }
169}
170
171impl Default for Capabilities {
172 fn default() -> Self {
173 Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
174 }
175}
176
177bitflags::bitflags! {
178 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
180 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
181 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
182 pub struct SubgroupOperationSet: u8 {
183 const BASIC = 1 << 0;
185 const VOTE = 1 << 1;
187 const ARITHMETIC = 1 << 2;
189 const BALLOT = 1 << 3;
191 const SHUFFLE = 1 << 4;
193 const SHUFFLE_RELATIVE = 1 << 5;
195 }
203}
204
205impl super::SubgroupOperation {
206 const fn required_operations(&self) -> SubgroupOperationSet {
207 use SubgroupOperationSet as S;
208 match *self {
209 Self::All | Self::Any => S::VOTE,
210 Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
211 S::ARITHMETIC
212 }
213 }
214 }
215}
216
217impl super::GatherMode {
218 const fn required_operations(&self) -> SubgroupOperationSet {
219 use SubgroupOperationSet as S;
220 match *self {
221 Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
222 Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
223 Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
224 }
225 }
226}
227
228bitflags::bitflags! {
229 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
231 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
232 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
233 pub struct ShaderStages: u8 {
234 const VERTEX = 0x1;
235 const FRAGMENT = 0x2;
236 const COMPUTE = 0x4;
237 }
238}
239
240#[derive(Debug, Clone, Default)]
241#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
242#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
243pub struct ModuleInfo {
244 type_flags: Vec<TypeFlags>,
245 functions: Vec<FunctionInfo>,
246 entry_points: Vec<FunctionInfo>,
247 const_expression_types: Box<[TypeResolution]>,
248}
249
250impl ops::Index<Handle<crate::Type>> for ModuleInfo {
251 type Output = TypeFlags;
252 fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
253 &self.type_flags[handle.index()]
254 }
255}
256
257impl ops::Index<Handle<crate::Function>> for ModuleInfo {
258 type Output = FunctionInfo;
259 fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
260 &self.functions[handle.index()]
261 }
262}
263
264impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
265 type Output = TypeResolution;
266 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
267 &self.const_expression_types[handle.index()]
268 }
269}
270
271#[derive(Debug)]
272pub struct Validator {
273 flags: ValidationFlags,
274 capabilities: Capabilities,
275 subgroup_stages: ShaderStages,
276 subgroup_operations: SubgroupOperationSet,
277 types: Vec<r#type::TypeInfo>,
278 layouter: Layouter,
279 location_mask: BitSet,
280 blend_src_mask: BitSet,
281 ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
282 #[allow(dead_code)]
283 switch_values: FastHashSet<crate::SwitchValue>,
284 valid_expression_list: Vec<Handle<crate::Expression>>,
285 valid_expression_set: HandleSet<crate::Expression>,
286 override_ids: FastHashSet<u16>,
287
288 overrides_resolved: bool,
291
292 needs_visit: HandleSet<crate::Expression>,
311}
312
313#[derive(Clone, Debug, thiserror::Error)]
314#[cfg_attr(test, derive(PartialEq))]
315pub enum ConstantError {
316 #[error("Initializer must be a const-expression")]
317 InitializerExprType,
318 #[error("The type doesn't match the constant")]
319 InvalidType,
320 #[error("The type is not constructible")]
321 NonConstructibleType,
322}
323
324#[derive(Clone, Debug, thiserror::Error)]
325#[cfg_attr(test, derive(PartialEq))]
326pub enum OverrideError {
327 #[error("Override name and ID are missing")]
328 MissingNameAndID,
329 #[error("Override ID must be unique")]
330 DuplicateID,
331 #[error("Initializer must be a const-expression or override-expression")]
332 InitializerExprType,
333 #[error("The type doesn't match the override")]
334 InvalidType,
335 #[error("The type is not constructible")]
336 NonConstructibleType,
337 #[error("The type is not a scalar")]
338 TypeNotScalar,
339 #[error("Override declarations are not allowed")]
340 NotAllowed,
341 #[error("Override is uninitialized")]
342 UninitializedOverride,
343 #[error("Constant expression {handle:?} is invalid")]
344 ConstExpression {
345 handle: Handle<crate::Expression>,
346 source: ConstExpressionError,
347 },
348}
349
350#[derive(Clone, Debug, thiserror::Error)]
351#[cfg_attr(test, derive(PartialEq))]
352pub enum ValidationError {
353 #[error(transparent)]
354 InvalidHandle(#[from] InvalidHandleError),
355 #[error(transparent)]
356 Layouter(#[from] LayoutError),
357 #[error("Type {handle:?} '{name}' is invalid")]
358 Type {
359 handle: Handle<crate::Type>,
360 name: String,
361 source: TypeError,
362 },
363 #[error("Constant expression {handle:?} is invalid")]
364 ConstExpression {
365 handle: Handle<crate::Expression>,
366 source: ConstExpressionError,
367 },
368 #[error("Array size expression {handle:?} is not strictly positive")]
369 ArraySizeError { handle: Handle<crate::Expression> },
370 #[error("Constant {handle:?} '{name}' is invalid")]
371 Constant {
372 handle: Handle<crate::Constant>,
373 name: String,
374 source: ConstantError,
375 },
376 #[error("Override {handle:?} '{name}' is invalid")]
377 Override {
378 handle: Handle<crate::Override>,
379 name: String,
380 source: OverrideError,
381 },
382 #[error("Global variable {handle:?} '{name}' is invalid")]
383 GlobalVariable {
384 handle: Handle<crate::GlobalVariable>,
385 name: String,
386 source: GlobalVariableError,
387 },
388 #[error("Function {handle:?} '{name}' is invalid")]
389 Function {
390 handle: Handle<crate::Function>,
391 name: String,
392 source: FunctionError,
393 },
394 #[error("Entry point {name} at {stage:?} is invalid")]
395 EntryPoint {
396 stage: crate::ShaderStage,
397 name: String,
398 source: EntryPointError,
399 },
400 #[error("Module is corrupted")]
401 Corrupted,
402}
403
404impl crate::TypeInner {
405 const fn is_sized(&self) -> bool {
406 match *self {
407 Self::Scalar { .. }
408 | Self::Vector { .. }
409 | Self::Matrix { .. }
410 | Self::Array {
411 size: crate::ArraySize::Constant(_),
412 ..
413 }
414 | Self::Atomic { .. }
415 | Self::Pointer { .. }
416 | Self::ValuePointer { .. }
417 | Self::Struct { .. } => true,
418 Self::Array { .. }
419 | Self::Image { .. }
420 | Self::Sampler { .. }
421 | Self::AccelerationStructure { .. }
422 | Self::RayQuery { .. }
423 | Self::BindingArray { .. } => false,
424 }
425 }
426
427 const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
429 match *self {
430 Self::Scalar(crate::Scalar {
431 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
432 ..
433 }) => Some(crate::ImageDimension::D1),
434 Self::Vector {
435 size: crate::VectorSize::Bi,
436 scalar:
437 crate::Scalar {
438 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
439 ..
440 },
441 } => Some(crate::ImageDimension::D2),
442 Self::Vector {
443 size: crate::VectorSize::Tri,
444 scalar:
445 crate::Scalar {
446 kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
447 ..
448 },
449 } => Some(crate::ImageDimension::D3),
450 _ => None,
451 }
452 }
453}
454
455impl Validator {
456 pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
458 let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) {
459 use SubgroupOperationSet as S;
460 S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE
461 } else {
462 SubgroupOperationSet::empty()
463 };
464 let subgroup_stages = {
465 let mut stages = ShaderStages::empty();
466 if capabilities.contains(Capabilities::SUBGROUP_VERTEX_STAGE) {
467 stages |= ShaderStages::VERTEX;
468 }
469 if capabilities.contains(Capabilities::SUBGROUP) {
470 stages |= ShaderStages::FRAGMENT | ShaderStages::COMPUTE;
471 }
472 stages
473 };
474
475 Validator {
476 flags,
477 capabilities,
478 subgroup_stages,
479 subgroup_operations,
480 types: Vec::new(),
481 layouter: Layouter::default(),
482 location_mask: BitSet::new(),
483 blend_src_mask: BitSet::new(),
484 ep_resource_bindings: FastHashSet::default(),
485 switch_values: FastHashSet::default(),
486 valid_expression_list: Vec::new(),
487 valid_expression_set: HandleSet::new(),
488 override_ids: FastHashSet::default(),
489 overrides_resolved: false,
490 needs_visit: HandleSet::new(),
491 }
492 }
493
494 pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
495 self.subgroup_stages = stages;
496 self
497 }
498
499 pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
500 self.subgroup_operations = operations;
501 self
502 }
503
504 pub fn reset(&mut self) {
506 self.types.clear();
507 self.layouter.clear();
508 self.location_mask.clear();
509 self.blend_src_mask.clear();
510 self.ep_resource_bindings.clear();
511 self.switch_values.clear();
512 self.valid_expression_list.clear();
513 self.valid_expression_set.clear();
514 self.override_ids.clear();
515 }
516
517 fn validate_constant(
518 &self,
519 handle: Handle<crate::Constant>,
520 gctx: crate::proc::GlobalCtx,
521 mod_info: &ModuleInfo,
522 global_expr_kind: &ExpressionKindTracker,
523 ) -> Result<(), ConstantError> {
524 let con = &gctx.constants[handle];
525
526 let type_info = &self.types[con.ty.index()];
527 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
528 return Err(ConstantError::NonConstructibleType);
529 }
530
531 if !global_expr_kind.is_const(con.init) {
532 return Err(ConstantError::InitializerExprType);
533 }
534
535 let decl_ty = &gctx.types[con.ty].inner;
536 let init_ty = mod_info[con.init].inner_with(gctx.types);
537 if !decl_ty.equivalent(init_ty, gctx.types) {
538 return Err(ConstantError::InvalidType);
539 }
540
541 Ok(())
542 }
543
544 fn validate_override(
545 &mut self,
546 handle: Handle<crate::Override>,
547 gctx: crate::proc::GlobalCtx,
548 mod_info: &ModuleInfo,
549 ) -> Result<(), OverrideError> {
550 let o = &gctx.overrides[handle];
551
552 if let Some(id) = o.id {
553 if !self.override_ids.insert(id) {
554 return Err(OverrideError::DuplicateID);
555 }
556 }
557
558 let type_info = &self.types[o.ty.index()];
559 if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
560 return Err(OverrideError::NonConstructibleType);
561 }
562
563 let decl_ty = &gctx.types[o.ty].inner;
564 match decl_ty {
565 &crate::TypeInner::Scalar(
566 crate::Scalar::BOOL
567 | crate::Scalar::I32
568 | crate::Scalar::U32
569 | crate::Scalar::F16
570 | crate::Scalar::F32
571 | crate::Scalar::F64,
572 ) => {}
573 _ => return Err(OverrideError::TypeNotScalar),
574 }
575
576 if let Some(init) = o.init {
577 let init_ty = mod_info[init].inner_with(gctx.types);
578 if !decl_ty.equivalent(init_ty, gctx.types) {
579 return Err(OverrideError::InvalidType);
580 }
581 } else if self.overrides_resolved {
582 return Err(OverrideError::UninitializedOverride);
583 }
584
585 Ok(())
586 }
587
588 pub fn validate(
590 &mut self,
591 module: &crate::Module,
592 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
593 self.overrides_resolved = false;
594 self.validate_impl(module)
595 }
596
597 pub fn validate_resolved_overrides(
605 &mut self,
606 module: &crate::Module,
607 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
608 self.overrides_resolved = true;
609 self.validate_impl(module)
610 }
611
612 fn validate_impl(
613 &mut self,
614 module: &crate::Module,
615 ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
616 self.reset();
617 self.reset_types(module.types.len());
618
619 Self::validate_module_handles(module).map_err(|e| e.with_span())?;
620
621 self.layouter.update(module.to_ctx()).map_err(|e| {
622 let handle = e.ty;
623 ValidationError::from(e).with_span_handle(handle, &module.types)
624 })?;
625
626 let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
628 kind: crate::ScalarKind::Bool,
629 width: 0,
630 }));
631
632 let mut mod_info = ModuleInfo {
633 type_flags: Vec::with_capacity(module.types.len()),
634 functions: Vec::with_capacity(module.functions.len()),
635 entry_points: Vec::with_capacity(module.entry_points.len()),
636 const_expression_types: vec![placeholder; module.global_expressions.len()]
637 .into_boxed_slice(),
638 };
639
640 for (handle, ty) in module.types.iter() {
641 let ty_info = self
642 .validate_type(handle, module.to_ctx())
643 .map_err(|source| {
644 ValidationError::Type {
645 handle,
646 name: ty.name.clone().unwrap_or_default(),
647 source,
648 }
649 .with_span_handle(handle, &module.types)
650 })?;
651 mod_info.type_flags.push(ty_info.flags);
652 self.types[handle.index()] = ty_info;
653 }
654
655 {
656 let t = crate::Arena::new();
657 let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
658 for (handle, _) in module.global_expressions.iter() {
659 mod_info
660 .process_const_expression(handle, &resolve_context, module.to_ctx())
661 .map_err(|source| {
662 ValidationError::ConstExpression { handle, source }
663 .with_span_handle(handle, &module.global_expressions)
664 })?
665 }
666 }
667
668 let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
669
670 if self.flags.contains(ValidationFlags::CONSTANTS) {
671 for (handle, _) in module.global_expressions.iter() {
672 self.validate_const_expression(
673 handle,
674 module.to_ctx(),
675 &mod_info,
676 &global_expr_kind,
677 )
678 .map_err(|source| {
679 ValidationError::ConstExpression { handle, source }
680 .with_span_handle(handle, &module.global_expressions)
681 })?
682 }
683
684 for (handle, constant) in module.constants.iter() {
685 self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
686 .map_err(|source| {
687 ValidationError::Constant {
688 handle,
689 name: constant.name.clone().unwrap_or_default(),
690 source,
691 }
692 .with_span_handle(handle, &module.constants)
693 })?
694 }
695
696 for (handle, r#override) in module.overrides.iter() {
697 self.validate_override(handle, module.to_ctx(), &mod_info)
698 .map_err(|source| {
699 ValidationError::Override {
700 handle,
701 name: r#override.name.clone().unwrap_or_default(),
702 source,
703 }
704 .with_span_handle(handle, &module.overrides)
705 })?;
706 }
707 }
708
709 for (var_handle, var) in module.global_variables.iter() {
710 self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
711 .map_err(|source| {
712 ValidationError::GlobalVariable {
713 handle: var_handle,
714 name: var.name.clone().unwrap_or_default(),
715 source,
716 }
717 .with_span_handle(var_handle, &module.global_variables)
718 })?;
719 }
720
721 for (handle, fun) in module.functions.iter() {
722 match self.validate_function(fun, module, &mod_info, false) {
723 Ok(info) => mod_info.functions.push(info),
724 Err(error) => {
725 return Err(error.and_then(|source| {
726 ValidationError::Function {
727 handle,
728 name: fun.name.clone().unwrap_or_default(),
729 source,
730 }
731 .with_span_handle(handle, &module.functions)
732 }))
733 }
734 }
735 }
736
737 let mut ep_map = FastHashSet::default();
738 for ep in module.entry_points.iter() {
739 if !ep_map.insert((ep.stage, &ep.name)) {
740 return Err(ValidationError::EntryPoint {
741 stage: ep.stage,
742 name: ep.name.clone(),
743 source: EntryPointError::Conflict,
744 }
745 .with_span()); }
747
748 match self.validate_entry_point(ep, module, &mod_info) {
749 Ok(info) => mod_info.entry_points.push(info),
750 Err(error) => {
751 return Err(error.and_then(|source| {
752 ValidationError::EntryPoint {
753 stage: ep.stage,
754 name: ep.name.clone(),
755 source,
756 }
757 .with_span()
758 }));
759 }
760 }
761 }
762
763 Ok(mod_info)
764 }
765}
766
767fn validate_atomic_compare_exchange_struct(
768 types: &crate::UniqueArena<crate::Type>,
769 members: &[crate::StructMember],
770 scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
771) -> bool {
772 members.len() == 2
773 && members[0].name.as_deref() == Some("old_value")
774 && scalar_predicate(&types[members[0].ty].inner)
775 && members[1].name.as_deref() == Some("exchanged")
776 && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
777}