1use alloc::{format, string::String};
2
3use thiserror::Error;
4
5use crate::arena::{Arena, Handle, UniqueArena};
6use crate::common::ForDebugWithTypes;
7
8#[derive(Debug, PartialEq)]
88#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
89#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
90pub enum TypeResolution {
91 Handle(Handle<crate::Type>),
93
94 Value(crate::TypeInner),
107}
108
109impl TypeResolution {
110 pub const fn handle(&self) -> Option<Handle<crate::Type>> {
111 match *self {
112 Self::Handle(handle) => Some(handle),
113 Self::Value(_) => None,
114 }
115 }
116
117 pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
118 match *self {
119 Self::Handle(handle) => &arena[handle].inner,
120 Self::Value(ref inner) => inner,
121 }
122 }
123}
124
125impl Clone for TypeResolution {
127 fn clone(&self) -> Self {
128 use crate::TypeInner as Ti;
129 match *self {
130 Self::Handle(handle) => Self::Handle(handle),
131 Self::Value(ref v) => Self::Value(match *v {
132 Ti::Scalar(scalar) => Ti::Scalar(scalar),
133 Ti::Vector { size, scalar } => Ti::Vector { size, scalar },
134 Ti::Matrix {
135 rows,
136 columns,
137 scalar,
138 } => Ti::Matrix {
139 rows,
140 columns,
141 scalar,
142 },
143 Ti::Pointer { base, space } => Ti::Pointer { base, space },
144 Ti::ValuePointer {
145 size,
146 scalar,
147 space,
148 } => Ti::ValuePointer {
149 size,
150 scalar,
151 space,
152 },
153 Ti::Array { base, size, stride } => Ti::Array { base, size, stride },
154 _ => unreachable!("Unexpected clone type: {:?}", v),
155 }),
156 }
157 }
158}
159
160#[derive(Clone, Debug, Error, PartialEq)]
161pub enum ResolveError {
162 #[error("Index {index} is out of bounds for expression {expr:?}")]
163 OutOfBoundsIndex {
164 expr: Handle<crate::Expression>,
165 index: u32,
166 },
167 #[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
168 InvalidAccess {
169 expr: Handle<crate::Expression>,
170 indexed: bool,
171 },
172 #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
173 InvalidSubAccess {
174 ty: Handle<crate::Type>,
175 indexed: bool,
176 },
177 #[error("Invalid scalar {0:?}")]
178 InvalidScalar(Handle<crate::Expression>),
179 #[error("Invalid vector {0:?}")]
180 InvalidVector(Handle<crate::Expression>),
181 #[error("Invalid pointer {0:?}")]
182 InvalidPointer(Handle<crate::Expression>),
183 #[error("Invalid image {0:?}")]
184 InvalidImage(Handle<crate::Expression>),
185 #[error("Function {name} not defined")]
186 FunctionNotDefined { name: String },
187 #[error("Function without return type")]
188 FunctionReturnsVoid,
189 #[error("Incompatible operands: {0}")]
190 IncompatibleOperands(String),
191 #[error("Function argument {0} doesn't exist")]
192 FunctionArgumentNotFound(u32),
193 #[error("Special type is not registered within the module")]
194 MissingSpecialType,
195 #[error("Call to builtin {0} has incorrect or ambiguous arguments")]
196 BuiltinArgumentsInvalid(String),
197}
198
199impl From<crate::proc::MissingSpecialType> for ResolveError {
200 fn from(_unit_struct: crate::proc::MissingSpecialType) -> Self {
201 ResolveError::MissingSpecialType
202 }
203}
204
205pub struct ResolveContext<'a> {
206 pub constants: &'a Arena<crate::Constant>,
207 pub overrides: &'a Arena<crate::Override>,
208 pub types: &'a UniqueArena<crate::Type>,
209 pub special_types: &'a crate::SpecialTypes,
210 pub global_vars: &'a Arena<crate::GlobalVariable>,
211 pub local_vars: &'a Arena<crate::LocalVariable>,
212 pub functions: &'a Arena<crate::Function>,
213 pub arguments: &'a [crate::FunctionArgument],
214}
215
216impl<'a> ResolveContext<'a> {
217 pub const fn with_locals(
219 module: &'a crate::Module,
220 local_vars: &'a Arena<crate::LocalVariable>,
221 arguments: &'a [crate::FunctionArgument],
222 ) -> Self {
223 Self {
224 constants: &module.constants,
225 overrides: &module.overrides,
226 types: &module.types,
227 special_types: &module.special_types,
228 global_vars: &module.global_variables,
229 local_vars,
230 functions: &module.functions,
231 arguments,
232 }
233 }
234
235 pub fn resolve(
251 &self,
252 expr: &crate::Expression,
253 past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
254 ) -> Result<TypeResolution, ResolveError> {
255 use crate::TypeInner as Ti;
256 let types = self.types;
257 Ok(match *expr {
258 crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
259 Ti::Array { base, .. } => TypeResolution::Handle(base),
263 Ti::Matrix { rows, scalar, .. } => {
264 TypeResolution::Value(Ti::Vector { size: rows, scalar })
265 }
266 Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
267 Ti::ValuePointer {
268 size: Some(_),
269 scalar,
270 space,
271 } => TypeResolution::Value(Ti::ValuePointer {
272 size: None,
273 scalar,
274 space,
275 }),
276 Ti::Pointer { base, space } => {
277 TypeResolution::Value(match types[base].inner {
278 Ti::Array { base, .. } => Ti::Pointer { base, space },
279 Ti::Vector { size: _, scalar } => Ti::ValuePointer {
280 size: None,
281 scalar,
282 space,
283 },
284 Ti::Matrix {
286 columns: _,
287 rows,
288 scalar,
289 } => Ti::ValuePointer {
290 size: Some(rows),
291 scalar,
292 space,
293 },
294 Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
295 ref other => {
296 log::error!("Access sub-type {:?}", other);
297 return Err(ResolveError::InvalidSubAccess {
298 ty: base,
299 indexed: false,
300 });
301 }
302 })
303 }
304 Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
305 ref other => {
306 log::error!("Access type {:?}", other);
307 return Err(ResolveError::InvalidAccess {
308 expr: base,
309 indexed: false,
310 });
311 }
312 },
313 crate::Expression::AccessIndex { base, index } => {
314 match *past(base)?.inner_with(types) {
315 Ti::Vector { size, scalar } => {
316 if index >= size as u32 {
317 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
318 }
319 TypeResolution::Value(Ti::Scalar(scalar))
320 }
321 Ti::Matrix {
322 columns,
323 rows,
324 scalar,
325 } => {
326 if index >= columns as u32 {
327 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
328 }
329 TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
330 }
331 Ti::Array { base, .. } => TypeResolution::Handle(base),
332 Ti::Struct { ref members, .. } => {
333 let member = members
334 .get(index as usize)
335 .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
336 TypeResolution::Handle(member.ty)
337 }
338 Ti::ValuePointer {
339 size: Some(size),
340 scalar,
341 space,
342 } => {
343 if index >= size as u32 {
344 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
345 }
346 TypeResolution::Value(Ti::ValuePointer {
347 size: None,
348 scalar,
349 space,
350 })
351 }
352 Ti::Pointer {
353 base: ty_base,
354 space,
355 } => TypeResolution::Value(match types[ty_base].inner {
356 Ti::Array { base, .. } => Ti::Pointer { base, space },
357 Ti::Vector { size, scalar } => {
358 if index >= size as u32 {
359 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
360 }
361 Ti::ValuePointer {
362 size: None,
363 scalar,
364 space,
365 }
366 }
367 Ti::Matrix {
368 rows,
369 columns,
370 scalar,
371 } => {
372 if index >= columns as u32 {
373 return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
374 }
375 Ti::ValuePointer {
376 size: Some(rows),
377 scalar,
378 space,
379 }
380 }
381 Ti::Struct { ref members, .. } => {
382 let member = members
383 .get(index as usize)
384 .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
385 Ti::Pointer {
386 base: member.ty,
387 space,
388 }
389 }
390 Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
391 ref other => {
392 log::error!("Access index sub-type {:?}", other);
393 return Err(ResolveError::InvalidSubAccess {
394 ty: ty_base,
395 indexed: true,
396 });
397 }
398 }),
399 Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
400 ref other => {
401 log::error!("Access index type {:?}", other);
402 return Err(ResolveError::InvalidAccess {
403 expr: base,
404 indexed: true,
405 });
406 }
407 }
408 }
409 crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
410 Ti::Scalar(scalar) => TypeResolution::Value(Ti::Vector { size, scalar }),
411 ref other => {
412 log::error!("Scalar type {:?}", other);
413 return Err(ResolveError::InvalidScalar(value));
414 }
415 },
416 crate::Expression::Swizzle {
417 size,
418 vector,
419 pattern: _,
420 } => match *past(vector)?.inner_with(types) {
421 Ti::Vector { size: _, scalar } => {
422 TypeResolution::Value(Ti::Vector { size, scalar })
423 }
424 ref other => {
425 log::error!("Vector type {:?}", other);
426 return Err(ResolveError::InvalidVector(vector));
427 }
428 },
429 crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
430 crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
431 crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty),
432 crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
433 crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
434 crate::Expression::FunctionArgument(index) => {
435 let arg = self
436 .arguments
437 .get(index as usize)
438 .ok_or(ResolveError::FunctionArgumentNotFound(index))?;
439 TypeResolution::Handle(arg.ty)
440 }
441 crate::Expression::GlobalVariable(h) => {
442 let var = &self.global_vars[h];
443 if var.space == crate::AddressSpace::Handle {
444 TypeResolution::Handle(var.ty)
445 } else {
446 TypeResolution::Value(Ti::Pointer {
447 base: var.ty,
448 space: var.space,
449 })
450 }
451 }
452 crate::Expression::LocalVariable(h) => {
453 let var = &self.local_vars[h];
454 TypeResolution::Value(Ti::Pointer {
455 base: var.ty,
456 space: crate::AddressSpace::Function,
457 })
458 }
459 crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
460 Ti::Pointer { base, space: _ } => {
461 if let Ti::Atomic(scalar) = types[base].inner {
462 TypeResolution::Value(Ti::Scalar(scalar))
463 } else {
464 TypeResolution::Handle(base)
465 }
466 }
467 Ti::ValuePointer {
468 size,
469 scalar,
470 space: _,
471 } => TypeResolution::Value(match size {
472 Some(size) => Ti::Vector { size, scalar },
473 None => Ti::Scalar(scalar),
474 }),
475 ref other => {
476 log::error!("Pointer type {:?}", other);
477 return Err(ResolveError::InvalidPointer(pointer));
478 }
479 },
480 crate::Expression::ImageSample {
481 image,
482 gather: Some(_),
483 ..
484 } => match *past(image)?.inner_with(types) {
485 Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
486 scalar: crate::Scalar {
487 kind: match class {
488 crate::ImageClass::Sampled { kind, multi: _ } => kind,
489 _ => crate::ScalarKind::Float,
490 },
491 width: 4,
492 },
493 size: crate::VectorSize::Quad,
494 }),
495 ref other => {
496 log::error!("Image type {:?}", other);
497 return Err(ResolveError::InvalidImage(image));
498 }
499 },
500 crate::Expression::ImageSample { image, .. }
501 | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
502 Ti::Image { class, .. } => TypeResolution::Value(match class {
503 crate::ImageClass::Depth { multi: _ } => Ti::Scalar(crate::Scalar::F32),
504 crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
505 scalar: crate::Scalar { kind, width: 4 },
506 size: crate::VectorSize::Quad,
507 },
508 crate::ImageClass::Storage { format, .. } => Ti::Vector {
509 scalar: format.into(),
510 size: crate::VectorSize::Quad,
511 },
512 }),
513 ref other => {
514 log::error!("Image type {:?}", other);
515 return Err(ResolveError::InvalidImage(image));
516 }
517 },
518 crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
519 crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
520 Ti::Image { dim, .. } => match dim {
521 crate::ImageDimension::D1 => Ti::Scalar(crate::Scalar::U32),
522 crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
523 size: crate::VectorSize::Bi,
524 scalar: crate::Scalar::U32,
525 },
526 crate::ImageDimension::D3 => Ti::Vector {
527 size: crate::VectorSize::Tri,
528 scalar: crate::Scalar::U32,
529 },
530 },
531 ref other => {
532 log::error!("Image type {:?}", other);
533 return Err(ResolveError::InvalidImage(image));
534 }
535 },
536 crate::ImageQuery::NumLevels
537 | crate::ImageQuery::NumLayers
538 | crate::ImageQuery::NumSamples => Ti::Scalar(crate::Scalar::U32),
539 }),
540 crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
541 crate::Expression::Binary { op, left, right } => match op {
542 crate::BinaryOperator::Add
543 | crate::BinaryOperator::Subtract
544 | crate::BinaryOperator::Divide
545 | crate::BinaryOperator::Modulo => past(left)?.clone(),
546 crate::BinaryOperator::Multiply => {
547 let (res_left, res_right) = (past(left)?, past(right)?);
548 match (res_left.inner_with(types), res_right.inner_with(types)) {
549 (
550 &Ti::Matrix {
551 columns: _,
552 rows,
553 scalar,
554 },
555 &Ti::Matrix { columns, .. },
556 ) => TypeResolution::Value(Ti::Matrix {
557 columns,
558 rows,
559 scalar,
560 }),
561 (
562 &Ti::Matrix {
563 columns: _,
564 rows,
565 scalar,
566 },
567 &Ti::Vector { .. },
568 ) => TypeResolution::Value(Ti::Vector { size: rows, scalar }),
569 (
570 &Ti::Vector { .. },
571 &Ti::Matrix {
572 columns,
573 rows: _,
574 scalar,
575 },
576 ) => TypeResolution::Value(Ti::Vector {
577 size: columns,
578 scalar,
579 }),
580 (&Ti::Scalar { .. }, _) => res_right.clone(),
581 (_, &Ti::Scalar { .. }) => res_left.clone(),
582 (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
583 (tl, tr) => {
584 return Err(ResolveError::IncompatibleOperands(format!(
585 "{tl:?} * {tr:?}"
586 )))
587 }
588 }
589 }
590 crate::BinaryOperator::Equal
591 | crate::BinaryOperator::NotEqual
592 | crate::BinaryOperator::Less
593 | crate::BinaryOperator::LessEqual
594 | crate::BinaryOperator::Greater
595 | crate::BinaryOperator::GreaterEqual => {
596 let scalar = crate::Scalar::BOOL;
598 let inner = match *past(left)?.inner_with(types) {
599 Ti::Scalar { .. } => Ti::Scalar(scalar),
600 Ti::Vector { size, .. } => Ti::Vector { size, scalar },
601 ref other => {
602 return Err(ResolveError::IncompatibleOperands(format!(
603 "{op:?}({other:?}, _)"
604 )))
605 }
606 };
607 TypeResolution::Value(inner)
608 }
609 crate::BinaryOperator::LogicalAnd | crate::BinaryOperator::LogicalOr => {
610 let bool = Ti::Scalar(crate::Scalar::BOOL);
612 let ty = past(left)?.inner_with(types);
613 if *ty == bool {
614 TypeResolution::Value(bool)
615 } else {
616 return Err(ResolveError::IncompatibleOperands(format!(
617 "{op:?}({:?}, _)",
618 ty.for_debug(types),
619 )));
620 }
621 }
622 crate::BinaryOperator::And
623 | crate::BinaryOperator::ExclusiveOr
624 | crate::BinaryOperator::InclusiveOr
625 | crate::BinaryOperator::ShiftLeft
626 | crate::BinaryOperator::ShiftRight => past(left)?.clone(),
627 },
628 crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
629 crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
630 crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
631 crate::Expression::Select { accept, .. } => past(accept)?.clone(),
632 crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
633 crate::Expression::Relational { fun, argument } => match fun {
634 crate::RelationalFunction::All | crate::RelationalFunction::Any => {
635 TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
636 }
637 crate::RelationalFunction::IsNan | crate::RelationalFunction::IsInf => {
638 match *past(argument)?.inner_with(types) {
639 Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)),
640 Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
641 scalar: crate::Scalar::BOOL,
642 size,
643 }),
644 ref other => {
645 return Err(ResolveError::IncompatibleOperands(format!(
646 "{fun:?}({other:?})"
647 )))
648 }
649 }
650 }
651 },
652 crate::Expression::Math {
653 fun,
654 arg,
655 arg1,
656 arg2: _,
657 arg3: _,
658 } => {
659 use crate::proc::OverloadSet as _;
660
661 let mut overloads = fun.overloads();
662 log::debug!(
663 "initial overloads for {fun:?}, {:#?}",
664 overloads.for_debug(types)
665 );
666
667 let res_arg = past(arg)?;
675 overloads = overloads.arg(0, res_arg.inner_with(types), types);
676 log::debug!(
677 "overloads after arg 0 of type {:?}: {:#?}",
678 res_arg.for_debug(types),
679 overloads.for_debug(types)
680 );
681
682 if let Some(arg1) = arg1 {
683 let res_arg1 = past(arg1)?;
684 overloads = overloads.arg(1, res_arg1.inner_with(types), types);
685 log::debug!(
686 "overloads after arg 1 of type {:?}: {:#?}",
687 res_arg1.for_debug(types),
688 overloads.for_debug(types)
689 );
690 }
691
692 if overloads.is_empty() {
693 return Err(ResolveError::BuiltinArgumentsInvalid(format!("{fun:?}")));
694 }
695
696 let rule = overloads.most_preferred();
697
698 rule.conclusion.into_resolution(self.special_types)?
699 }
700 crate::Expression::As {
701 expr,
702 kind,
703 convert,
704 } => match *past(expr)?.inner_with(types) {
705 Ti::Scalar(crate::Scalar { width, .. }) => {
706 TypeResolution::Value(Ti::Scalar(crate::Scalar {
707 kind,
708 width: convert.unwrap_or(width),
709 }))
710 }
711 Ti::Vector {
712 size,
713 scalar: crate::Scalar { kind: _, width },
714 } => TypeResolution::Value(Ti::Vector {
715 size,
716 scalar: crate::Scalar {
717 kind,
718 width: convert.unwrap_or(width),
719 },
720 }),
721 Ti::Matrix {
722 columns,
723 rows,
724 mut scalar,
725 } => {
726 if let Some(width) = convert {
727 scalar.width = width;
728 }
729 TypeResolution::Value(Ti::Matrix {
730 columns,
731 rows,
732 scalar,
733 })
734 }
735 ref other => {
736 return Err(ResolveError::IncompatibleOperands(format!(
737 "{other:?} as {kind:?}"
738 )))
739 }
740 },
741 crate::Expression::CallResult(function) => {
742 let result = self.functions[function]
743 .result
744 .as_ref()
745 .ok_or(ResolveError::FunctionReturnsVoid)?;
746 TypeResolution::Handle(result.ty)
747 }
748 crate::Expression::ArrayLength(_) => {
749 TypeResolution::Value(Ti::Scalar(crate::Scalar::U32))
750 }
751 crate::Expression::RayQueryProceedResult => {
752 TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
753 }
754 crate::Expression::RayQueryGetIntersection { .. } => {
755 let result = self
756 .special_types
757 .ray_intersection
758 .ok_or(ResolveError::MissingSpecialType)?;
759 TypeResolution::Handle(result)
760 }
761 crate::Expression::RayQueryVertexPositions { .. } => {
762 let result = self
763 .special_types
764 .ray_vertex_return
765 .ok_or(ResolveError::MissingSpecialType)?;
766 TypeResolution::Handle(result)
767 }
768 crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
769 scalar: crate::Scalar::U32,
770 size: crate::VectorSize::Quad,
771 }),
772 })
773 }
774}
775
776#[test]
777fn test_error_size() {
778 assert_eq!(size_of::<ResolveError>(), 32);
779}