1use super::{compose::validate_compose, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags};
2use crate::arena::UniqueArena;
3use crate::{
4 arena::Handle,
5 proc::OverloadSet as _,
6 proc::{IndexableLengthError, ResolveError},
7};
8
9#[derive(Clone, Debug, thiserror::Error)]
10#[cfg_attr(test, derive(PartialEq))]
11pub enum ExpressionError {
12 #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
13 NotInScope,
14 #[error("Base type {0:?} is not compatible with this expression")]
15 InvalidBaseType(Handle<crate::Expression>),
16 #[error("Accessing with index {0:?} can't be done")]
17 InvalidIndexType(Handle<crate::Expression>),
18 #[error("Accessing {0:?} via a negative index is invalid")]
19 NegativeIndex(Handle<crate::Expression>),
20 #[error("Accessing index {1} is out of {0:?} bounds")]
21 IndexOutOfBounds(Handle<crate::Expression>, u32),
22 #[error("Function argument {0:?} doesn't exist")]
23 FunctionArgumentDoesntExist(u32),
24 #[error("Loading of {0:?} can't be done")]
25 InvalidPointerType(Handle<crate::Expression>),
26 #[error("Array length of {0:?} can't be done")]
27 InvalidArrayType(Handle<crate::Expression>),
28 #[error("Get intersection of {0:?} can't be done")]
29 InvalidRayQueryType(Handle<crate::Expression>),
30 #[error("Splatting {0:?} can't be done")]
31 InvalidSplatType(Handle<crate::Expression>),
32 #[error("Swizzling {0:?} can't be done")]
33 InvalidVectorType(Handle<crate::Expression>),
34 #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
35 InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
36 #[error(transparent)]
37 Compose(#[from] super::ComposeError),
38 #[error(transparent)]
39 IndexableLength(#[from] IndexableLengthError),
40 #[error("Operation {0:?} can't work with {1:?}")]
41 InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
42 #[error(
43 "Operation {:?} can't work with {:?} (of type {:?}) and {:?} (of type {:?})",
44 op,
45 lhs_expr,
46 lhs_type,
47 rhs_expr,
48 rhs_type
49 )]
50 InvalidBinaryOperandTypes {
51 op: crate::BinaryOperator,
52 lhs_expr: Handle<crate::Expression>,
53 lhs_type: crate::TypeInner,
54 rhs_expr: Handle<crate::Expression>,
55 rhs_type: crate::TypeInner,
56 },
57 #[error("Expected selection argument types to match, but reject value of type {reject:?} does not match accept value of value {accept:?}")]
58 SelectValuesTypeMismatch {
59 accept: crate::TypeInner,
60 reject: crate::TypeInner,
61 },
62 #[error("Expected selection condition to be a boolean value, got {actual:?}")]
63 SelectConditionNotABool { actual: crate::TypeInner },
64 #[error("Relational argument {0:?} is not a boolean vector")]
65 InvalidBooleanVector(Handle<crate::Expression>),
66 #[error("Relational argument {0:?} is not a float")]
67 InvalidFloatArgument(Handle<crate::Expression>),
68 #[error("Type resolution failed")]
69 Type(#[from] ResolveError),
70 #[error("Not a global variable")]
71 ExpectedGlobalVariable,
72 #[error("Not a global variable or a function argument")]
73 ExpectedGlobalOrArgument,
74 #[error("Needs to be an binding array instead of {0:?}")]
75 ExpectedBindingArrayType(Handle<crate::Type>),
76 #[error("Needs to be an image instead of {0:?}")]
77 ExpectedImageType(Handle<crate::Type>),
78 #[error("Needs to be an image instead of {0:?}")]
79 ExpectedSamplerType(Handle<crate::Type>),
80 #[error("Unable to operate on image class {0:?}")]
81 InvalidImageClass(crate::ImageClass),
82 #[error("Image atomics are not supported for storage format {0:?}")]
83 InvalidImageFormat(crate::StorageFormat),
84 #[error("Image atomics require atomic storage access, {0:?} is insufficient")]
85 InvalidImageStorageAccess(crate::StorageAccess),
86 #[error("Derivatives can only be taken from scalar and vector floats")]
87 InvalidDerivative,
88 #[error("Image array index parameter is misplaced")]
89 InvalidImageArrayIndex,
90 #[error("Inappropriate sample or level-of-detail index for texel access")]
91 InvalidImageOtherIndex,
92 #[error("Image array index type of {0:?} is not an integer scalar")]
93 InvalidImageArrayIndexType(Handle<crate::Expression>),
94 #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
95 InvalidImageOtherIndexType(Handle<crate::Expression>),
96 #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
97 InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
98 #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
99 ComparisonSamplingMismatch {
100 image: crate::ImageClass,
101 sampler: bool,
102 has_ref: bool,
103 },
104 #[error("Sample offset must be a const-expression")]
105 InvalidSampleOffsetExprType,
106 #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
107 InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
108 #[error("Depth reference {0:?} is not a scalar float")]
109 InvalidDepthReference(Handle<crate::Expression>),
110 #[error("Depth sample level can only be Auto or Zero")]
111 InvalidDepthSampleLevel,
112 #[error("Gather level can only be Zero")]
113 InvalidGatherLevel,
114 #[error("Gather component {0:?} doesn't exist in the image")]
115 InvalidGatherComponent(crate::SwizzleComponent),
116 #[error("Gather can't be done for image dimension {0:?}")]
117 InvalidGatherDimension(crate::ImageDimension),
118 #[error("Sample level (exact) type {0:?} has an invalid type")]
119 InvalidSampleLevelExactType(Handle<crate::Expression>),
120 #[error("Sample level (bias) type {0:?} is not a scalar float")]
121 InvalidSampleLevelBiasType(Handle<crate::Expression>),
122 #[error("Bias can't be done for image dimension {0:?}")]
123 InvalidSampleLevelBiasDimension(crate::ImageDimension),
124 #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
125 InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
126 #[error("Unable to cast")]
127 InvalidCastArgument,
128 #[error("Invalid argument count for {0:?}")]
129 WrongArgumentCount(crate::MathFunction),
130 #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
131 InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
132 #[error(
133 "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
134 )]
135 InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
136 #[error("Shader requires capability {0:?}")]
137 MissingCapabilities(super::Capabilities),
138 #[error(transparent)]
139 Literal(#[from] LiteralError),
140 #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")]
141 UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes),
142}
143
144#[derive(Clone, Debug, thiserror::Error)]
145#[cfg_attr(test, derive(PartialEq))]
146pub enum ConstExpressionError {
147 #[error("The expression is not a constant or override expression")]
148 NonConstOrOverride,
149 #[error("The expression is not a fully evaluated constant expression")]
150 NonFullyEvaluatedConst,
151 #[error(transparent)]
152 Compose(#[from] super::ComposeError),
153 #[error("Splatting {0:?} can't be done")]
154 InvalidSplatType(Handle<crate::Expression>),
155 #[error("Type resolution failed")]
156 Type(#[from] ResolveError),
157 #[error(transparent)]
158 Literal(#[from] LiteralError),
159 #[error(transparent)]
160 Width(#[from] super::r#type::WidthError),
161}
162
163#[derive(Clone, Debug, thiserror::Error)]
164#[cfg_attr(test, derive(PartialEq))]
165pub enum LiteralError {
166 #[error("Float literal is NaN")]
167 NaN,
168 #[error("Float literal is infinite")]
169 Infinity,
170 #[error(transparent)]
171 Width(#[from] super::r#type::WidthError),
172}
173
174struct ExpressionTypeResolver<'a> {
175 root: Handle<crate::Expression>,
176 types: &'a UniqueArena<crate::Type>,
177 info: &'a FunctionInfo,
178}
179
180impl core::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'_> {
181 type Output = crate::TypeInner;
182
183 #[allow(clippy::panic)]
184 fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
185 if handle < self.root {
186 self.info[handle].ty.inner_with(self.types)
187 } else {
188 panic!(
190 "Depends on {:?}, which has not been processed yet",
191 self.root
192 )
193 }
194 }
195}
196
197impl super::Validator {
198 pub(super) fn validate_const_expression(
199 &self,
200 handle: Handle<crate::Expression>,
201 gctx: crate::proc::GlobalCtx,
202 mod_info: &ModuleInfo,
203 global_expr_kind: &crate::proc::ExpressionKindTracker,
204 ) -> Result<(), ConstExpressionError> {
205 use crate::Expression as E;
206
207 if !global_expr_kind.is_const_or_override(handle) {
208 return Err(ConstExpressionError::NonConstOrOverride);
209 }
210
211 match gctx.global_expressions[handle] {
212 E::Literal(literal) => {
213 self.validate_literal(literal)?;
214 }
215 E::Constant(_) | E::ZeroValue(_) => {}
216 E::Compose { ref components, ty } => {
217 validate_compose(
218 ty,
219 gctx,
220 components.iter().map(|&handle| mod_info[handle].clone()),
221 )?;
222 }
223 E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
224 crate::TypeInner::Scalar { .. } => {}
225 _ => return Err(ConstExpressionError::InvalidSplatType(value)),
226 },
227 _ if global_expr_kind.is_const(handle) || self.overrides_resolved => {
228 return Err(ConstExpressionError::NonFullyEvaluatedConst)
229 }
230 _ => {}
232 }
233
234 Ok(())
235 }
236
237 #[allow(clippy::too_many_arguments)]
238 pub(super) fn validate_expression(
239 &self,
240 root: Handle<crate::Expression>,
241 expression: &crate::Expression,
242 function: &crate::Function,
243 module: &crate::Module,
244 info: &FunctionInfo,
245 mod_info: &ModuleInfo,
246 expr_kind: &crate::proc::ExpressionKindTracker,
247 ) -> Result<ShaderStages, ExpressionError> {
248 use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
249
250 let resolver = ExpressionTypeResolver {
251 root,
252 types: &module.types,
253 info,
254 };
255
256 let stages = match *expression {
257 E::Access { base, index } => {
258 let base_type = &resolver[base];
259 match *base_type {
260 Ti::Matrix { .. }
261 | Ti::Vector { .. }
262 | Ti::Array { .. }
263 | Ti::Pointer { .. }
264 | Ti::ValuePointer { size: Some(_), .. }
265 | Ti::BindingArray { .. } => {}
266 ref other => {
267 log::error!("Indexing of {:?}", other);
268 return Err(ExpressionError::InvalidBaseType(base));
269 }
270 };
271 match resolver[index] {
272 Ti::Scalar(Sc {
274 kind: Sk::Sint | Sk::Uint,
275 ..
276 }) => {}
277 ref other => {
278 log::error!("Indexing by {:?}", other);
279 return Err(ExpressionError::InvalidIndexType(index));
280 }
281 }
282
283 match module
285 .to_ctx()
286 .eval_expr_to_u32_from(index, &function.expressions)
287 {
288 Ok(value) => {
289 let length = if self.overrides_resolved {
290 base_type.indexable_length_resolved(module)
291 } else {
292 base_type.indexable_length_pending(module)
293 }?;
294 if let crate::proc::IndexableLength::Known(known_length) = length {
297 if value >= known_length {
298 return Err(ExpressionError::IndexOutOfBounds(base, value));
299 }
300 }
301 }
302 Err(crate::proc::U32EvalError::Negative) => {
303 return Err(ExpressionError::NegativeIndex(base))
304 }
305 Err(crate::proc::U32EvalError::NonConst) => {}
306 }
307
308 ShaderStages::all()
309 }
310 E::AccessIndex { base, index } => {
311 fn resolve_index_limit(
312 module: &crate::Module,
313 top: Handle<crate::Expression>,
314 ty: &crate::TypeInner,
315 top_level: bool,
316 ) -> Result<u32, ExpressionError> {
317 let limit = match *ty {
318 Ti::Vector { size, .. }
319 | Ti::ValuePointer {
320 size: Some(size), ..
321 } => size as u32,
322 Ti::Matrix { columns, .. } => columns as u32,
323 Ti::Array {
324 size: crate::ArraySize::Constant(len),
325 ..
326 } => len.get(),
327 Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, Ti::Pointer { base, .. } if top_level => {
329 resolve_index_limit(module, top, &module.types[base].inner, false)?
330 }
331 Ti::Struct { ref members, .. } => members.len() as u32,
332 ref other => {
333 log::error!("Indexing of {:?}", other);
334 return Err(ExpressionError::InvalidBaseType(top));
335 }
336 };
337 Ok(limit)
338 }
339
340 let limit = resolve_index_limit(module, base, &resolver[base], true)?;
341 if index >= limit {
342 return Err(ExpressionError::IndexOutOfBounds(base, limit));
343 }
344 ShaderStages::all()
345 }
346 E::Splat { size: _, value } => match resolver[value] {
347 Ti::Scalar { .. } => ShaderStages::all(),
348 ref other => {
349 log::error!("Splat scalar type {:?}", other);
350 return Err(ExpressionError::InvalidSplatType(value));
351 }
352 },
353 E::Swizzle {
354 size,
355 vector,
356 pattern,
357 } => {
358 let vec_size = match resolver[vector] {
359 Ti::Vector { size: vec_size, .. } => vec_size,
360 ref other => {
361 log::error!("Swizzle vector type {:?}", other);
362 return Err(ExpressionError::InvalidVectorType(vector));
363 }
364 };
365 for &sc in pattern[..size as usize].iter() {
366 if sc as u8 >= vec_size as u8 {
367 return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
368 }
369 }
370 ShaderStages::all()
371 }
372 E::Literal(literal) => {
373 self.validate_literal(literal)?;
374 ShaderStages::all()
375 }
376 E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
377 E::Compose { ref components, ty } => {
378 validate_compose(
379 ty,
380 module.to_ctx(),
381 components.iter().map(|&handle| info[handle].ty.clone()),
382 )?;
383 ShaderStages::all()
384 }
385 E::FunctionArgument(index) => {
386 if index >= function.arguments.len() as u32 {
387 return Err(ExpressionError::FunctionArgumentDoesntExist(index));
388 }
389 ShaderStages::all()
390 }
391 E::GlobalVariable(_handle) => ShaderStages::all(),
392 E::LocalVariable(_handle) => ShaderStages::all(),
393 E::Load { pointer } => {
394 match resolver[pointer] {
395 Ti::Pointer { base, .. }
396 if self.types[base.index()]
397 .flags
398 .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
399 Ti::ValuePointer { .. } => {}
400 ref other => {
401 log::error!("Loading {:?}", other);
402 return Err(ExpressionError::InvalidPointerType(pointer));
403 }
404 }
405 ShaderStages::all()
406 }
407 E::ImageSample {
408 image,
409 sampler,
410 gather,
411 coordinate,
412 array_index,
413 offset,
414 level,
415 depth_ref,
416 } => {
417 let image_ty = Self::global_var_ty(module, function, image)?;
419 let sampler_ty = Self::global_var_ty(module, function, sampler)?;
420
421 let comparison = match module.types[sampler_ty].inner {
422 Ti::Sampler { comparison } => comparison,
423 _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
424 };
425
426 let (class, dim) = match module.types[image_ty].inner {
427 Ti::Image {
428 class,
429 arrayed,
430 dim,
431 } => {
432 if arrayed != array_index.is_some() {
434 return Err(ExpressionError::InvalidImageArrayIndex);
435 }
436 if let Some(expr) = array_index {
437 match resolver[expr] {
438 Ti::Scalar(Sc {
439 kind: Sk::Sint | Sk::Uint,
440 ..
441 }) => {}
442 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
443 }
444 }
445 (class, dim)
446 }
447 _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
448 };
449
450 let image_depth = match class {
452 crate::ImageClass::Sampled {
453 kind: crate::ScalarKind::Float,
454 multi: false,
455 } => false,
456 crate::ImageClass::Sampled {
457 kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
458 multi: false,
459 } if gather.is_some() => false,
460 crate::ImageClass::Depth { multi: false } => true,
461 _ => return Err(ExpressionError::InvalidImageClass(class)),
462 };
463 if comparison != depth_ref.is_some() || (comparison && !image_depth) {
464 return Err(ExpressionError::ComparisonSamplingMismatch {
465 image: class,
466 sampler: comparison,
467 has_ref: depth_ref.is_some(),
468 });
469 }
470
471 let num_components = match dim {
473 crate::ImageDimension::D1 => 1,
474 crate::ImageDimension::D2 => 2,
475 crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
476 };
477 match resolver[coordinate] {
478 Ti::Scalar(Sc {
479 kind: Sk::Float, ..
480 }) if num_components == 1 => {}
481 Ti::Vector {
482 size,
483 scalar:
484 Sc {
485 kind: Sk::Float, ..
486 },
487 } if size as u32 == num_components => {}
488 _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
489 }
490
491 if let Some(const_expr) = offset {
493 if !expr_kind.is_const(const_expr) {
494 return Err(ExpressionError::InvalidSampleOffsetExprType);
495 }
496
497 match resolver[const_expr] {
498 Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
499 Ti::Vector {
500 size,
501 scalar: Sc { kind: Sk::Sint, .. },
502 } if size as u32 == num_components => {}
503 _ => {
504 return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
505 }
506 }
507 }
508
509 if let Some(expr) = depth_ref {
511 match resolver[expr] {
512 Ti::Scalar(Sc {
513 kind: Sk::Float, ..
514 }) => {}
515 _ => return Err(ExpressionError::InvalidDepthReference(expr)),
516 }
517 match level {
518 crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
519 _ => return Err(ExpressionError::InvalidDepthSampleLevel),
520 }
521 }
522
523 if let Some(component) = gather {
524 match dim {
525 crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
526 crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
527 return Err(ExpressionError::InvalidGatherDimension(dim))
528 }
529 };
530 let max_component = match class {
531 crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
532 _ => crate::SwizzleComponent::W,
533 };
534 if component > max_component {
535 return Err(ExpressionError::InvalidGatherComponent(component));
536 }
537 match level {
538 crate::SampleLevel::Zero => {}
539 _ => return Err(ExpressionError::InvalidGatherLevel),
540 }
541 }
542
543 match level {
545 crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
546 crate::SampleLevel::Zero => ShaderStages::all(),
547 crate::SampleLevel::Exact(expr) => {
548 match class {
549 crate::ImageClass::Depth { .. } => match resolver[expr] {
550 Ti::Scalar(Sc {
551 kind: Sk::Sint | Sk::Uint,
552 ..
553 }) => {}
554 _ => {
555 return Err(ExpressionError::InvalidSampleLevelExactType(expr))
556 }
557 },
558 _ => match resolver[expr] {
559 Ti::Scalar(Sc {
560 kind: Sk::Float, ..
561 }) => {}
562 _ => {
563 return Err(ExpressionError::InvalidSampleLevelExactType(expr))
564 }
565 },
566 }
567 ShaderStages::all()
568 }
569 crate::SampleLevel::Bias(expr) => {
570 match resolver[expr] {
571 Ti::Scalar(Sc {
572 kind: Sk::Float, ..
573 }) => {}
574 _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
575 }
576 match class {
577 crate::ImageClass::Sampled {
578 kind: Sk::Float,
579 multi: false,
580 } => {
581 if dim == crate::ImageDimension::D1 {
582 return Err(ExpressionError::InvalidSampleLevelBiasDimension(
583 dim,
584 ));
585 }
586 }
587 _ => return Err(ExpressionError::InvalidImageClass(class)),
588 }
589 ShaderStages::FRAGMENT
590 }
591 crate::SampleLevel::Gradient { x, y } => {
592 match resolver[x] {
593 Ti::Scalar(Sc {
594 kind: Sk::Float, ..
595 }) if num_components == 1 => {}
596 Ti::Vector {
597 size,
598 scalar:
599 Sc {
600 kind: Sk::Float, ..
601 },
602 } if size as u32 == num_components => {}
603 _ => {
604 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
605 }
606 }
607 match resolver[y] {
608 Ti::Scalar(Sc {
609 kind: Sk::Float, ..
610 }) if num_components == 1 => {}
611 Ti::Vector {
612 size,
613 scalar:
614 Sc {
615 kind: Sk::Float, ..
616 },
617 } if size as u32 == num_components => {}
618 _ => {
619 return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
620 }
621 }
622 ShaderStages::all()
623 }
624 }
625 }
626 E::ImageLoad {
627 image,
628 coordinate,
629 array_index,
630 sample,
631 level,
632 } => {
633 let ty = Self::global_var_ty(module, function, image)?;
634 match module.types[ty].inner {
635 Ti::Image {
636 class,
637 arrayed,
638 dim,
639 } => {
640 match resolver[coordinate].image_storage_coordinates() {
641 Some(coord_dim) if coord_dim == dim => {}
642 _ => {
643 return Err(ExpressionError::InvalidImageCoordinateType(
644 dim, coordinate,
645 ))
646 }
647 };
648 if arrayed != array_index.is_some() {
649 return Err(ExpressionError::InvalidImageArrayIndex);
650 }
651 if let Some(expr) = array_index {
652 match resolver[expr] {
653 Ti::Scalar(Sc {
654 kind: Sk::Sint | Sk::Uint,
655 width: _,
656 }) => {}
657 _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
658 }
659 }
660
661 match (sample, class.is_multisampled()) {
662 (None, false) => {}
663 (Some(sample), true) => {
664 if resolver[sample].scalar_kind() != Some(Sk::Sint) {
665 return Err(ExpressionError::InvalidImageOtherIndexType(
666 sample,
667 ));
668 }
669 }
670 _ => {
671 return Err(ExpressionError::InvalidImageOtherIndex);
672 }
673 }
674
675 match (level, class.is_mipmapped()) {
676 (None, false) => {}
677 (Some(level), true) => match resolver[level] {
678 Ti::Scalar(Sc {
679 kind: Sk::Sint | Sk::Uint,
680 width: _,
681 }) => {}
682 _ => {
683 return Err(ExpressionError::InvalidImageArrayIndexType(level))
684 }
685 },
686 _ => {
687 return Err(ExpressionError::InvalidImageOtherIndex);
688 }
689 }
690 }
691 _ => return Err(ExpressionError::ExpectedImageType(ty)),
692 }
693 ShaderStages::all()
694 }
695 E::ImageQuery { image, query } => {
696 let ty = Self::global_var_ty(module, function, image)?;
697 match module.types[ty].inner {
698 Ti::Image { class, arrayed, .. } => {
699 let good = match query {
700 crate::ImageQuery::NumLayers => arrayed,
701 crate::ImageQuery::Size { level: None } => true,
702 crate::ImageQuery::Size { level: Some(level) } => {
703 match resolver[level] {
704 Ti::Scalar(Sc::I32 | Sc::U32) => {}
705 _ => {
706 return Err(ExpressionError::InvalidImageOtherIndexType(
707 level,
708 ))
709 }
710 }
711 class.is_mipmapped()
712 }
713 crate::ImageQuery::NumLevels => class.is_mipmapped(),
714 crate::ImageQuery::NumSamples => class.is_multisampled(),
715 };
716 if !good {
717 return Err(ExpressionError::InvalidImageClass(class));
718 }
719 }
720 _ => return Err(ExpressionError::ExpectedImageType(ty)),
721 }
722 ShaderStages::all()
723 }
724 E::Unary { op, expr } => {
725 use crate::UnaryOperator as Uo;
726 let inner = &resolver[expr];
727 match (op, inner.scalar_kind()) {
728 (Uo::Negate, Some(Sk::Float | Sk::Sint))
729 | (Uo::LogicalNot, Some(Sk::Bool))
730 | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
731 other => {
732 log::error!("Op {:?} kind {:?}", op, other);
733 return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
734 }
735 }
736 ShaderStages::all()
737 }
738 E::Binary { op, left, right } => {
739 use crate::BinaryOperator as Bo;
740 let left_inner = &resolver[left];
741 let right_inner = &resolver[right];
742 let good = match op {
743 Bo::Add | Bo::Subtract => match *left_inner {
744 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
745 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
746 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
747 },
748 Ti::Matrix { .. } => left_inner == right_inner,
749 _ => false,
750 },
751 Bo::Divide | Bo::Modulo => match *left_inner {
752 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
753 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
754 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
755 },
756 _ => false,
757 },
758 Bo::Multiply => {
759 let kind_allowed = match left_inner.scalar_kind() {
760 Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
761 Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
762 };
763 let types_match = match (left_inner, right_inner) {
764 (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
766 | (
767 &Ti::Vector {
768 scalar: scalar1, ..
769 },
770 &Ti::Scalar(scalar2),
771 )
772 | (
773 &Ti::Scalar(scalar1),
774 &Ti::Vector {
775 scalar: scalar2, ..
776 },
777 ) => scalar1 == scalar2,
778 (
780 &Ti::Scalar(Sc {
781 kind: Sk::Float, ..
782 }),
783 &Ti::Matrix { .. },
784 )
785 | (
786 &Ti::Matrix { .. },
787 &Ti::Scalar(Sc {
788 kind: Sk::Float, ..
789 }),
790 ) => true,
791 (
793 &Ti::Vector {
794 size: size1,
795 scalar: scalar1,
796 },
797 &Ti::Vector {
798 size: size2,
799 scalar: scalar2,
800 },
801 ) => scalar1 == scalar2 && size1 == size2,
802 (
804 &Ti::Matrix { columns, .. },
805 &Ti::Vector {
806 size,
807 scalar:
808 Sc {
809 kind: Sk::Float, ..
810 },
811 },
812 ) => columns == size,
813 (
815 &Ti::Vector {
816 size,
817 scalar:
818 Sc {
819 kind: Sk::Float, ..
820 },
821 },
822 &Ti::Matrix { rows, .. },
823 ) => size == rows,
824 (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
825 columns == rows
826 }
827 _ => false,
828 };
829 let left_width = left_inner.scalar_width().unwrap_or(0);
830 let right_width = right_inner.scalar_width().unwrap_or(0);
831 kind_allowed && types_match && left_width == right_width
832 }
833 Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
834 Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
835 match *left_inner {
836 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
837 Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
838 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
839 },
840 ref other => {
841 log::error!("Op {:?} left type {:?}", op, other);
842 false
843 }
844 }
845 }
846 Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
847 Ti::Scalar(Sc { kind: Sk::Bool, .. })
848 | Ti::Vector {
849 scalar: Sc { kind: Sk::Bool, .. },
850 ..
851 } => left_inner == right_inner,
852 ref other => {
853 log::error!("Op {:?} left type {:?}", op, other);
854 false
855 }
856 },
857 Bo::And | Bo::InclusiveOr => match *left_inner {
858 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
859 Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
860 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
861 },
862 ref other => {
863 log::error!("Op {:?} left type {:?}", op, other);
864 false
865 }
866 },
867 Bo::ExclusiveOr => match *left_inner {
868 Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
869 Sk::Sint | Sk::Uint => left_inner == right_inner,
870 Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
871 },
872 ref other => {
873 log::error!("Op {:?} left type {:?}", op, other);
874 false
875 }
876 },
877 Bo::ShiftLeft | Bo::ShiftRight => {
878 let (base_size, base_scalar) = match *left_inner {
879 Ti::Scalar(scalar) => (Ok(None), scalar),
880 Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
881 ref other => {
882 log::error!("Op {:?} base type {:?}", op, other);
883 (Err(()), Sc::BOOL)
884 }
885 };
886 let shift_size = match *right_inner {
887 Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
888 Ti::Vector {
889 size,
890 scalar: Sc { kind: Sk::Uint, .. },
891 } => Ok(Some(size)),
892 ref other => {
893 log::error!("Op {:?} shift type {:?}", op, other);
894 Err(())
895 }
896 };
897 match base_scalar.kind {
898 Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
899 Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
900 }
901 }
902 };
903 if !good {
904 log::error!(
905 "Left: {:?} of type {:?}",
906 function.expressions[left],
907 left_inner
908 );
909 log::error!(
910 "Right: {:?} of type {:?}",
911 function.expressions[right],
912 right_inner
913 );
914 return Err(ExpressionError::InvalidBinaryOperandTypes {
915 op,
916 lhs_expr: left,
917 lhs_type: left_inner.clone(),
918 rhs_expr: right,
919 rhs_type: right_inner.clone(),
920 });
921 }
922 ShaderStages::all()
923 }
924 E::Select {
925 condition,
926 accept,
927 reject,
928 } => {
929 let accept_inner = &resolver[accept];
930 let reject_inner = &resolver[reject];
931 let condition_ty = &resolver[condition];
932 let condition_good = match *condition_ty {
933 Ti::Scalar(Sc {
934 kind: Sk::Bool,
935 width: _,
936 }) => {
937 match *accept_inner {
940 Ti::Scalar { .. } | Ti::Vector { .. } => true,
941 _ => false,
942 }
943 }
944 Ti::Vector {
945 size,
946 scalar:
947 Sc {
948 kind: Sk::Bool,
949 width: _,
950 },
951 } => match *accept_inner {
952 Ti::Vector {
953 size: other_size, ..
954 } => size == other_size,
955 _ => false,
956 },
957 _ => false,
958 };
959 if accept_inner != reject_inner {
960 return Err(ExpressionError::SelectValuesTypeMismatch {
961 accept: accept_inner.clone(),
962 reject: reject_inner.clone(),
963 });
964 }
965 if !condition_good {
966 return Err(ExpressionError::SelectConditionNotABool {
967 actual: condition_ty.clone(),
968 });
969 }
970 ShaderStages::all()
971 }
972 E::Derivative { expr, .. } => {
973 match resolver[expr] {
974 Ti::Scalar(Sc {
975 kind: Sk::Float, ..
976 })
977 | Ti::Vector {
978 scalar:
979 Sc {
980 kind: Sk::Float, ..
981 },
982 ..
983 } => {}
984 _ => return Err(ExpressionError::InvalidDerivative),
985 }
986 ShaderStages::FRAGMENT
987 }
988 E::Relational { fun, argument } => {
989 use crate::RelationalFunction as Rf;
990 let argument_inner = &resolver[argument];
991 match fun {
992 Rf::All | Rf::Any => match *argument_inner {
993 Ti::Vector {
994 scalar: Sc { kind: Sk::Bool, .. },
995 ..
996 } => {}
997 ref other => {
998 log::error!("All/Any of type {:?}", other);
999 return Err(ExpressionError::InvalidBooleanVector(argument));
1000 }
1001 },
1002 Rf::IsNan | Rf::IsInf => match *argument_inner {
1003 Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
1004 if scalar.kind == Sk::Float => {}
1005 ref other => {
1006 log::error!("Float test of type {:?}", other);
1007 return Err(ExpressionError::InvalidFloatArgument(argument));
1008 }
1009 },
1010 }
1011 ShaderStages::all()
1012 }
1013 E::Math {
1014 fun,
1015 arg,
1016 arg1,
1017 arg2,
1018 arg3,
1019 } => {
1020 let actuals: &[_] = match (arg1, arg2, arg3) {
1021 (None, None, None) => &[arg],
1022 (Some(arg1), None, None) => &[arg, arg1],
1023 (Some(arg1), Some(arg2), None) => &[arg, arg1, arg2],
1024 (Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3],
1025 _ => return Err(ExpressionError::WrongArgumentCount(fun)),
1026 };
1027
1028 let resolve = |arg| &resolver[arg];
1029 let actual_types: &[_] = match *actuals {
1030 [arg0] => &[resolve(arg0)],
1031 [arg0, arg1] => &[resolve(arg0), resolve(arg1)],
1032 [arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)],
1033 [arg0, arg1, arg2, arg3] => {
1034 &[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)]
1035 }
1036 _ => unreachable!(),
1037 };
1038
1039 let mut overloads = fun.overloads();
1041 log::debug!(
1042 "initial overloads for {:?}: {:#?}",
1043 fun,
1044 overloads.for_debug(&module.types)
1045 );
1046
1047 for (i, (&expr, &ty)) in actuals.iter().zip(actual_types).enumerate() {
1055 overloads = overloads.arg(i, ty, &module.types);
1058 log::debug!(
1059 "overloads after arg {i}: {:#?}",
1060 overloads.for_debug(&module.types)
1061 );
1062
1063 if overloads.is_empty() {
1064 log::debug!("all overloads eliminated");
1065 return Err(ExpressionError::InvalidArgumentType(fun, i as u32, expr));
1066 }
1067 }
1068
1069 if actuals.len() < overloads.min_arguments() {
1070 return Err(ExpressionError::WrongArgumentCount(fun));
1071 }
1072
1073 ShaderStages::all()
1074 }
1075 E::As {
1076 expr,
1077 kind,
1078 convert,
1079 } => {
1080 let mut base_scalar = match resolver[expr] {
1081 crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
1082 scalar
1083 }
1084 crate::TypeInner::Matrix { scalar, .. } => scalar,
1085 _ => return Err(ExpressionError::InvalidCastArgument),
1086 };
1087 base_scalar.kind = kind;
1088 if let Some(width) = convert {
1089 base_scalar.width = width;
1090 }
1091 if self.check_width(base_scalar).is_err() {
1092 return Err(ExpressionError::InvalidCastArgument);
1093 }
1094 ShaderStages::all()
1095 }
1096 E::CallResult(function) => mod_info.functions[function.index()].available_stages,
1097 E::AtomicResult { .. } => {
1098 ShaderStages::all()
1103 }
1104 E::WorkGroupUniformLoadResult { ty } => {
1105 if self.types[ty.index()]
1106 .flags
1107 .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
1110 {
1111 ShaderStages::COMPUTE
1112 } else {
1113 return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
1114 }
1115 }
1116 E::ArrayLength(expr) => match resolver[expr] {
1117 Ti::Pointer { base, .. } => {
1118 let base_ty = &resolver.types[base];
1119 if let Ti::Array {
1120 size: crate::ArraySize::Dynamic,
1121 ..
1122 } = base_ty.inner
1123 {
1124 ShaderStages::all()
1125 } else {
1126 return Err(ExpressionError::InvalidArrayType(expr));
1127 }
1128 }
1129 ref other => {
1130 log::error!("Array length of {:?}", other);
1131 return Err(ExpressionError::InvalidArrayType(expr));
1132 }
1133 },
1134 E::RayQueryProceedResult => ShaderStages::all(),
1135 E::RayQueryGetIntersection {
1136 query,
1137 committed: _,
1138 } => match resolver[query] {
1139 Ti::Pointer {
1140 base,
1141 space: crate::AddressSpace::Function,
1142 } => match resolver.types[base].inner {
1143 Ti::RayQuery { .. } => ShaderStages::all(),
1144 ref other => {
1145 log::error!("Intersection result of a pointer to {:?}", other);
1146 return Err(ExpressionError::InvalidRayQueryType(query));
1147 }
1148 },
1149 ref other => {
1150 log::error!("Intersection result of {:?}", other);
1151 return Err(ExpressionError::InvalidRayQueryType(query));
1152 }
1153 },
1154 E::RayQueryVertexPositions {
1155 query,
1156 committed: _,
1157 } => match resolver[query] {
1158 Ti::Pointer {
1159 base,
1160 space: crate::AddressSpace::Function,
1161 } => match resolver.types[base].inner {
1162 Ti::RayQuery {
1163 vertex_return: true,
1164 } => ShaderStages::all(),
1165 ref other => {
1166 log::error!("Intersection result of a pointer to {:?}", other);
1167 return Err(ExpressionError::InvalidRayQueryType(query));
1168 }
1169 },
1170 ref other => {
1171 log::error!("Intersection result of {:?}", other);
1172 return Err(ExpressionError::InvalidRayQueryType(query));
1173 }
1174 },
1175 E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
1176 };
1177 Ok(stages)
1178 }
1179
1180 fn global_var_ty(
1181 module: &crate::Module,
1182 function: &crate::Function,
1183 expr: Handle<crate::Expression>,
1184 ) -> Result<Handle<crate::Type>, ExpressionError> {
1185 use crate::Expression as Ex;
1186
1187 match function.expressions[expr] {
1188 Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
1189 Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
1190 Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
1191 match function.expressions[base] {
1192 Ex::GlobalVariable(var_handle) => {
1193 let array_ty = module.global_variables[var_handle].ty;
1194
1195 match module.types[array_ty].inner {
1196 crate::TypeInner::BindingArray { base, .. } => Ok(base),
1197 _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
1198 }
1199 }
1200 _ => Err(ExpressionError::ExpectedGlobalVariable),
1201 }
1202 }
1203 _ => Err(ExpressionError::ExpectedGlobalVariable),
1204 }
1205 }
1206
1207 pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
1208 let _ = self.check_width(literal.scalar())?;
1209 check_literal_value(literal)?;
1210
1211 Ok(())
1212 }
1213}
1214
1215pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
1216 let is_nan = match literal {
1217 crate::Literal::F64(v) => v.is_nan(),
1218 crate::Literal::F32(v) => v.is_nan(),
1219 _ => false,
1220 };
1221 if is_nan {
1222 return Err(LiteralError::NaN);
1223 }
1224
1225 let is_infinite = match literal {
1226 crate::Literal::F64(v) => v.is_infinite(),
1227 crate::Literal::F32(v) => v.is_infinite(),
1228 _ => false,
1229 };
1230 if is_infinite {
1231 return Err(LiteralError::Infinity);
1232 }
1233
1234 Ok(())
1235}
1236
1237#[cfg(test)]
1238fn validate_with_expression(
1240 expr: crate::Expression,
1241 caps: super::Capabilities,
1242) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1243 use crate::span::Span;
1244
1245 let mut function = crate::Function::default();
1246 function.expressions.append(expr, Span::default());
1247 function.body.push(
1248 crate::Statement::Emit(function.expressions.range_from(0)),
1249 Span::default(),
1250 );
1251
1252 let mut module = crate::Module::default();
1253 module.functions.append(function, Span::default());
1254
1255 let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
1256
1257 validator.validate(&module)
1258}
1259
1260#[cfg(test)]
1261fn validate_with_const_expression(
1263 expr: crate::Expression,
1264 caps: super::Capabilities,
1265) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
1266 use crate::span::Span;
1267
1268 let mut module = crate::Module::default();
1269 module.global_expressions.append(expr, Span::default());
1270
1271 let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
1272
1273 validator.validate(&module)
1274}
1275
1276#[test]
1278fn f64_runtime_literals() {
1279 let result = validate_with_expression(
1280 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1281 super::Capabilities::default(),
1282 );
1283 let error = result.unwrap_err().into_inner();
1284 assert!(matches!(
1285 error,
1286 crate::valid::ValidationError::Function {
1287 source: super::FunctionError::Expression {
1288 source: ExpressionError::Literal(LiteralError::Width(
1289 super::r#type::WidthError::MissingCapability {
1290 name: "f64",
1291 flag: "FLOAT64",
1292 }
1293 ),),
1294 ..
1295 },
1296 ..
1297 }
1298 ));
1299
1300 let result = validate_with_expression(
1301 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1302 super::Capabilities::default() | super::Capabilities::FLOAT64,
1303 );
1304 assert!(result.is_ok());
1305}
1306
1307#[test]
1309fn f64_const_literals() {
1310 let result = validate_with_const_expression(
1311 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1312 super::Capabilities::default(),
1313 );
1314 let error = result.unwrap_err().into_inner();
1315 assert!(matches!(
1316 error,
1317 crate::valid::ValidationError::ConstExpression {
1318 source: ConstExpressionError::Literal(LiteralError::Width(
1319 super::r#type::WidthError::MissingCapability {
1320 name: "f64",
1321 flag: "FLOAT64",
1322 }
1323 )),
1324 ..
1325 }
1326 ));
1327
1328 let result = validate_with_const_expression(
1329 crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
1330 super::Capabilities::default() | super::Capabilities::FLOAT64,
1331 );
1332 assert!(result.is_ok());
1333}