1use alloc::{boxed::Box, vec};
9use core::ops;
10
11use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
12use crate::diagnostic_filter::{DiagnosticFilterNode, StandardFilterableTriggeringRule};
13use crate::span::{AddSpan as _, WithSpan};
14use crate::{
15 arena::{Arena, Handle},
16 proc::{ResolveContext, TypeResolution},
17};
18
19pub type NonUniformResult = Option<Handle<crate::Expression>>;
20
21const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
22
23bitflags::bitflags! {
24 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
26 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
27 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
28 pub struct UniformityRequirements: u8 {
29 const WORK_GROUP_BARRIER = 0x1;
30 const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
31 const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
32 }
33}
34
35#[derive(Clone, Debug)]
37#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
38#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
39#[cfg_attr(test, derive(PartialEq))]
40pub struct Uniformity {
41 pub non_uniform_result: NonUniformResult,
53 pub requirements: UniformityRequirements,
55}
56
57impl Uniformity {
58 const fn new() -> Self {
59 Uniformity {
60 non_uniform_result: None,
61 requirements: UniformityRequirements::empty(),
62 }
63 }
64}
65
66bitflags::bitflags! {
67 #[derive(Clone, Copy, Debug, PartialEq)]
68 struct ExitFlags: u8 {
69 const MAY_RETURN = 0x1;
73 const MAY_KILL = 0x2;
78 }
79}
80
81#[cfg_attr(test, derive(Debug, PartialEq))]
83struct FunctionUniformity {
84 result: Uniformity,
85 exit: ExitFlags,
86}
87
88impl ops::BitOr for FunctionUniformity {
89 type Output = Self;
90 fn bitor(self, other: Self) -> Self {
91 FunctionUniformity {
92 result: Uniformity {
93 non_uniform_result: self
94 .result
95 .non_uniform_result
96 .or(other.result.non_uniform_result),
97 requirements: self.result.requirements | other.result.requirements,
98 },
99 exit: self.exit | other.exit,
100 }
101 }
102}
103
104impl FunctionUniformity {
105 const fn new() -> Self {
106 FunctionUniformity {
107 result: Uniformity::new(),
108 exit: ExitFlags::empty(),
109 }
110 }
111
112 const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
114 if self.exit.contains(ExitFlags::MAY_RETURN) {
115 Some(UniformityDisruptor::Return)
116 } else if self.exit.contains(ExitFlags::MAY_KILL) {
117 Some(UniformityDisruptor::Discard)
118 } else {
119 None
120 }
121 }
122}
123
124bitflags::bitflags! {
125 #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
127 #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
128 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
129 pub struct GlobalUse: u8 {
130 const READ = 0x1;
132 const WRITE = 0x2;
134 const QUERY = 0x4;
136 const ATOMIC = 0x8;
138 }
139}
140
141#[derive(Clone, Debug, Eq, Hash, PartialEq)]
142#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
143#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
144pub struct SamplingKey {
145 pub image: Handle<crate::GlobalVariable>,
146 pub sampler: Handle<crate::GlobalVariable>,
147}
148
149#[derive(Clone, Debug)]
150#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
151#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
152pub struct ExpressionInfo {
154 pub uniformity: Uniformity,
160
161 pub ref_count: usize,
164
165 assignable_global: Option<Handle<crate::GlobalVariable>>,
179
180 pub ty: TypeResolution,
182}
183
184impl ExpressionInfo {
185 const fn new() -> Self {
186 ExpressionInfo {
187 uniformity: Uniformity::new(),
188 ref_count: 0,
189 assignable_global: None,
190 ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
192 kind: crate::ScalarKind::Bool,
193 width: 0,
194 })),
195 }
196 }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
200#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
201#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
202enum GlobalOrArgument {
203 Global(Handle<crate::GlobalVariable>),
204 Argument(u32),
205}
206
207impl GlobalOrArgument {
208 fn from_expression(
209 expression_arena: &Arena<crate::Expression>,
210 expression: Handle<crate::Expression>,
211 ) -> Result<GlobalOrArgument, ExpressionError> {
212 Ok(match expression_arena[expression] {
213 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
214 crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
215 crate::Expression::Access { base, .. }
216 | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
217 crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
218 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
219 },
220 _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
221 })
222 }
223}
224
225#[derive(Debug, Clone, PartialEq, Eq, Hash)]
226#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
227#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
228struct Sampling {
229 image: GlobalOrArgument,
230 sampler: GlobalOrArgument,
231}
232
233#[derive(Debug, Clone)]
234#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
235#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
236pub struct FunctionInfo {
237 #[allow(dead_code)]
239 flags: ValidationFlags,
240 pub available_stages: ShaderStages,
242 pub uniformity: Uniformity,
244 pub may_kill: bool,
246
247 pub sampling_set: crate::FastHashSet<SamplingKey>,
262
263 global_uses: Box<[GlobalUse]>,
270
271 expressions: Box<[ExpressionInfo]>,
278
279 sampling: crate::FastHashSet<Sampling>,
292
293 pub dual_source_blending: bool,
295
296 diagnostic_filter_leaf: Option<Handle<DiagnosticFilterNode>>,
302}
303
304impl FunctionInfo {
305 pub const fn global_variable_count(&self) -> usize {
306 self.global_uses.len()
307 }
308 pub const fn expression_count(&self) -> usize {
309 self.expressions.len()
310 }
311 pub fn dominates_global_use(&self, other: &Self) -> bool {
312 for (self_global_uses, other_global_uses) in
313 self.global_uses.iter().zip(other.global_uses.iter())
314 {
315 if !self_global_uses.contains(*other_global_uses) {
316 return false;
317 }
318 }
319 true
320 }
321}
322
323impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
324 type Output = GlobalUse;
325 fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
326 &self.global_uses[handle.index()]
327 }
328}
329
330impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
331 type Output = ExpressionInfo;
332 fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
333 &self.expressions[handle.index()]
334 }
335}
336
337#[derive(Clone, Copy, Debug, thiserror::Error)]
339#[cfg_attr(test, derive(PartialEq))]
340pub enum UniformityDisruptor {
341 #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
342 Expression(Handle<crate::Expression>),
343 #[error("There is a Return earlier in the control flow of the function")]
344 Return,
345 #[error("There is a Discard earlier in the entry point across all called functions")]
346 Discard,
347}
348
349impl FunctionInfo {
350 #[must_use]
358 fn add_ref_impl(
359 &mut self,
360 expr: Handle<crate::Expression>,
361 global_use: GlobalUse,
362 ) -> NonUniformResult {
363 let info = &mut self.expressions[expr.index()];
364 info.ref_count += 1;
365 if let Some(global) = info.assignable_global {
367 self.global_uses[global.index()] |= global_use;
368 }
369 info.uniformity.non_uniform_result
370 }
371
372 #[must_use]
379 fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult {
380 self.add_ref_impl(expr, GlobalUse::READ)
381 }
382
383 #[must_use]
402 fn add_assignable_ref(
403 &mut self,
404 expr: Handle<crate::Expression>,
405 assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
406 ) -> NonUniformResult {
407 let info = &mut self.expressions[expr.index()];
408 info.ref_count += 1;
409 if let Some(global) = info.assignable_global {
412 if let Some(_old) = assignable_global.replace(global) {
413 unreachable!()
414 }
415 }
416 info.uniformity.non_uniform_result
417 }
418
419 fn process_call(
421 &mut self,
422 callee: &Self,
423 arguments: &[Handle<crate::Expression>],
424 expression_arena: &Arena<crate::Expression>,
425 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
426 self.sampling_set
427 .extend(callee.sampling_set.iter().cloned());
428 for sampling in callee.sampling.iter() {
429 let image_storage = match sampling.image {
432 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
433 GlobalOrArgument::Argument(i) => {
434 let Some(handle) = arguments.get(i as usize).cloned() else {
435 break;
437 };
438 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
439 |source| {
440 FunctionError::Expression { handle, source }
441 .with_span_handle(handle, expression_arena)
442 },
443 )?
444 }
445 };
446
447 let sampler_storage = match sampling.sampler {
448 GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
449 GlobalOrArgument::Argument(i) => {
450 let Some(handle) = arguments.get(i as usize).cloned() else {
451 break;
453 };
454 GlobalOrArgument::from_expression(expression_arena, handle).map_err(
455 |source| {
456 FunctionError::Expression { handle, source }
457 .with_span_handle(handle, expression_arena)
458 },
459 )?
460 }
461 };
462
463 match (image_storage, sampler_storage) {
468 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
469 self.sampling_set.insert(SamplingKey { image, sampler });
470 }
471 (image, sampler) => {
472 self.sampling.insert(Sampling { image, sampler });
473 }
474 }
475 }
476
477 for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
479 *mine |= *other;
480 }
481
482 Ok(FunctionUniformity {
483 result: callee.uniformity.clone(),
484 exit: if callee.may_kill {
485 ExitFlags::MAY_KILL
486 } else {
487 ExitFlags::empty()
488 },
489 })
490 }
491
492 #[allow(clippy::or_fun_call)]
512 fn process_expression(
513 &mut self,
514 handle: Handle<crate::Expression>,
515 expression_arena: &Arena<crate::Expression>,
516 other_functions: &[FunctionInfo],
517 resolve_context: &ResolveContext,
518 capabilities: super::Capabilities,
519 ) -> Result<(), ExpressionError> {
520 use crate::{Expression as E, SampleLevel as Sl};
521
522 let expression = &expression_arena[handle];
523 let mut assignable_global = None;
524 let uniformity = match *expression {
525 E::Access { base, index } => {
526 let base_ty = self[base].ty.inner_with(resolve_context.types);
527
528 let mut needed_caps = super::Capabilities::empty();
530 let is_binding_array = match *base_ty {
531 crate::TypeInner::BindingArray {
532 base: array_element_ty_handle,
533 ..
534 } => {
535 let sto = super::Capabilities::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
537 let uni = super::Capabilities::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
538 let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
539 let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
540
541 let array_element_ty =
543 &resolve_context.types[array_element_ty_handle].inner;
544
545 needed_caps |= match *array_element_ty {
546 crate::TypeInner::Image { class, .. } => match class {
548 crate::ImageClass::Storage { .. } => sto,
549 _ => st_sb,
550 },
551 crate::TypeInner::Sampler { .. } => sampler,
552 _ => {
554 if let E::GlobalVariable(global_handle) = expression_arena[base] {
555 let global = &resolve_context.global_vars[global_handle];
556 match global.space {
557 crate::AddressSpace::Uniform => uni,
558 crate::AddressSpace::Storage { .. } => st_sb,
559 _ => unreachable!(),
560 }
561 } else {
562 unreachable!()
563 }
564 }
565 };
566
567 true
568 }
569 _ => false,
570 };
571
572 if self[index].uniformity.non_uniform_result.is_some()
573 && !capabilities.contains(needed_caps)
574 && is_binding_array
575 {
576 return Err(ExpressionError::MissingCapabilities(needed_caps));
577 }
578
579 Uniformity {
580 non_uniform_result: self
581 .add_assignable_ref(base, &mut assignable_global)
582 .or(self.add_ref(index)),
583 requirements: UniformityRequirements::empty(),
584 }
585 }
586 E::AccessIndex { base, .. } => Uniformity {
587 non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
588 requirements: UniformityRequirements::empty(),
589 },
590 E::Splat { size: _, value } => Uniformity {
592 non_uniform_result: self.add_ref(value),
593 requirements: UniformityRequirements::empty(),
594 },
595 E::Swizzle { vector, .. } => Uniformity {
596 non_uniform_result: self.add_ref(vector),
597 requirements: UniformityRequirements::empty(),
598 },
599 E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
600 E::Compose { ref components, .. } => {
601 let non_uniform_result = components
602 .iter()
603 .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
604 Uniformity {
605 non_uniform_result,
606 requirements: UniformityRequirements::empty(),
607 }
608 }
609 E::FunctionArgument(index) => {
611 let arg = &resolve_context.arguments[index as usize];
612 let uniform = match arg.binding {
613 Some(crate::Binding::BuiltIn(
614 crate::BuiltIn::WorkGroupId
616 | crate::BuiltIn::WorkGroupSize
617 | crate::BuiltIn::NumWorkGroups,
618 )) => true,
619 _ => false,
620 };
621 Uniformity {
622 non_uniform_result: if uniform { None } else { Some(handle) },
623 requirements: UniformityRequirements::empty(),
624 }
625 }
626 E::GlobalVariable(gh) => {
628 use crate::AddressSpace as As;
629 assignable_global = Some(gh);
630 let var = &resolve_context.global_vars[gh];
631 let uniform = match var.space {
632 As::Function | As::Private => false,
634 As::WorkGroup => true,
636 As::Uniform | As::PushConstant => true,
638 As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
640 As::Handle => false,
641 };
642 Uniformity {
643 non_uniform_result: if uniform { None } else { Some(handle) },
644 requirements: UniformityRequirements::empty(),
645 }
646 }
647 E::LocalVariable(_) => Uniformity {
648 non_uniform_result: Some(handle),
649 requirements: UniformityRequirements::empty(),
650 },
651 E::Load { pointer } => Uniformity {
652 non_uniform_result: self.add_ref(pointer),
653 requirements: UniformityRequirements::empty(),
654 },
655 E::ImageSample {
656 image,
657 sampler,
658 gather: _,
659 coordinate,
660 array_index,
661 offset,
662 level,
663 depth_ref,
664 } => {
665 let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
666 let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
667
668 match (image_storage, sampler_storage) {
669 (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
670 self.sampling_set.insert(SamplingKey { image, sampler });
671 }
672 _ => {
673 self.sampling.insert(Sampling {
674 image: image_storage,
675 sampler: sampler_storage,
676 });
677 }
678 }
679
680 let array_nur = array_index.and_then(|h| self.add_ref(h));
682 let level_nur = match level {
683 Sl::Auto | Sl::Zero => None,
684 Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
685 Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
686 };
687 let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
688 let offset_nur = offset.and_then(|h| self.add_ref(h));
689 Uniformity {
690 non_uniform_result: self
691 .add_ref(image)
692 .or(self.add_ref(sampler))
693 .or(self.add_ref(coordinate))
694 .or(array_nur)
695 .or(level_nur)
696 .or(dref_nur)
697 .or(offset_nur),
698 requirements: if level.implicit_derivatives() {
699 UniformityRequirements::IMPLICIT_LEVEL
700 } else {
701 UniformityRequirements::empty()
702 },
703 }
704 }
705 E::ImageLoad {
706 image,
707 coordinate,
708 array_index,
709 sample,
710 level,
711 } => {
712 let array_nur = array_index.and_then(|h| self.add_ref(h));
713 let sample_nur = sample.and_then(|h| self.add_ref(h));
714 let level_nur = level.and_then(|h| self.add_ref(h));
715 Uniformity {
716 non_uniform_result: self
717 .add_ref(image)
718 .or(self.add_ref(coordinate))
719 .or(array_nur)
720 .or(sample_nur)
721 .or(level_nur),
722 requirements: UniformityRequirements::empty(),
723 }
724 }
725 E::ImageQuery { image, query } => {
726 let query_nur = match query {
727 crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
728 _ => None,
729 };
730 Uniformity {
731 non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
732 requirements: UniformityRequirements::empty(),
733 }
734 }
735 E::Unary { expr, .. } => Uniformity {
736 non_uniform_result: self.add_ref(expr),
737 requirements: UniformityRequirements::empty(),
738 },
739 E::Binary { left, right, .. } => Uniformity {
740 non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
741 requirements: UniformityRequirements::empty(),
742 },
743 E::Select {
744 condition,
745 accept,
746 reject,
747 } => Uniformity {
748 non_uniform_result: self
749 .add_ref(condition)
750 .or(self.add_ref(accept))
751 .or(self.add_ref(reject)),
752 requirements: UniformityRequirements::empty(),
753 },
754 E::Derivative { expr, .. } => Uniformity {
756 non_uniform_result: self.add_ref(expr),
758 requirements: UniformityRequirements::DERIVATIVE,
759 },
760 E::Relational { argument, .. } => Uniformity {
761 non_uniform_result: self.add_ref(argument),
762 requirements: UniformityRequirements::empty(),
763 },
764 E::Math {
765 fun: _,
766 arg,
767 arg1,
768 arg2,
769 arg3,
770 } => {
771 let arg1_nur = arg1.and_then(|h| self.add_ref(h));
772 let arg2_nur = arg2.and_then(|h| self.add_ref(h));
773 let arg3_nur = arg3.and_then(|h| self.add_ref(h));
774 Uniformity {
775 non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
776 requirements: UniformityRequirements::empty(),
777 }
778 }
779 E::As { expr, .. } => Uniformity {
780 non_uniform_result: self.add_ref(expr),
781 requirements: UniformityRequirements::empty(),
782 },
783 E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
784 E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
785 non_uniform_result: Some(handle),
786 requirements: UniformityRequirements::empty(),
787 },
788 E::WorkGroupUniformLoadResult { .. } => Uniformity {
789 non_uniform_result: None,
791 requirements: UniformityRequirements::empty(),
794 },
795 E::ArrayLength(expr) => Uniformity {
796 non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
797 requirements: UniformityRequirements::empty(),
798 },
799 E::RayQueryGetIntersection {
800 query,
801 committed: _,
802 } => Uniformity {
803 non_uniform_result: self.add_ref(query),
804 requirements: UniformityRequirements::empty(),
805 },
806 E::SubgroupBallotResult => Uniformity {
807 non_uniform_result: Some(handle),
808 requirements: UniformityRequirements::empty(),
809 },
810 E::SubgroupOperationResult { .. } => Uniformity {
811 non_uniform_result: Some(handle),
812 requirements: UniformityRequirements::empty(),
813 },
814 E::RayQueryVertexPositions {
815 query,
816 committed: _,
817 } => Uniformity {
818 non_uniform_result: self.add_ref(query),
819 requirements: UniformityRequirements::empty(),
820 },
821 };
822
823 let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
824 self.expressions[handle.index()] = ExpressionInfo {
825 uniformity,
826 ref_count: 0,
827 assignable_global,
828 ty,
829 };
830 Ok(())
831 }
832
833 #[allow(clippy::or_fun_call)]
843 fn process_block(
844 &mut self,
845 statements: &crate::Block,
846 other_functions: &[FunctionInfo],
847 mut disruptor: Option<UniformityDisruptor>,
848 expression_arena: &Arena<crate::Expression>,
849 diagnostic_filter_arena: &Arena<DiagnosticFilterNode>,
850 ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
851 use crate::Statement as S;
852
853 let mut combined_uniformity = FunctionUniformity::new();
854 for statement in statements {
855 let uniformity = match *statement {
856 S::Emit(ref range) => {
857 let mut requirements = UniformityRequirements::empty();
858 for expr in range.clone() {
859 let req = self.expressions[expr.index()].uniformity.requirements;
860 if self
861 .flags
862 .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
863 && !req.is_empty()
864 {
865 if let Some(cause) = disruptor {
866 let severity = DiagnosticFilterNode::search(
867 self.diagnostic_filter_leaf,
868 diagnostic_filter_arena,
869 StandardFilterableTriggeringRule::DerivativeUniformity,
870 );
871 severity.report_diag(
872 FunctionError::NonUniformControlFlow(req, expr, cause)
873 .with_span_handle(expr, expression_arena),
874 |e, level| log::log!(level, "{e}"),
880 )?;
881 }
882 }
883 requirements |= req;
884 }
885 FunctionUniformity {
886 result: Uniformity {
887 non_uniform_result: None,
888 requirements,
889 },
890 exit: ExitFlags::empty(),
891 }
892 }
893 S::Break | S::Continue => FunctionUniformity::new(),
894 S::Kill => FunctionUniformity {
895 result: Uniformity::new(),
896 exit: if disruptor.is_some() {
897 ExitFlags::MAY_KILL
898 } else {
899 ExitFlags::empty()
900 },
901 },
902 S::Barrier(_) => FunctionUniformity {
903 result: Uniformity {
904 non_uniform_result: None,
905 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
906 },
907 exit: ExitFlags::empty(),
908 },
909 S::WorkGroupUniformLoad { pointer, .. } => {
910 let _condition_nur = self.add_ref(pointer);
911
912 FunctionUniformity {
931 result: Uniformity {
932 non_uniform_result: None,
933 requirements: UniformityRequirements::WORK_GROUP_BARRIER,
934 },
935 exit: ExitFlags::empty(),
936 }
937 }
938 S::Block(ref b) => self.process_block(
939 b,
940 other_functions,
941 disruptor,
942 expression_arena,
943 diagnostic_filter_arena,
944 )?,
945 S::If {
946 condition,
947 ref accept,
948 ref reject,
949 } => {
950 let condition_nur = self.add_ref(condition);
951 let branch_disruptor =
952 disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
953 let accept_uniformity = self.process_block(
954 accept,
955 other_functions,
956 branch_disruptor,
957 expression_arena,
958 diagnostic_filter_arena,
959 )?;
960 let reject_uniformity = self.process_block(
961 reject,
962 other_functions,
963 branch_disruptor,
964 expression_arena,
965 diagnostic_filter_arena,
966 )?;
967 accept_uniformity | reject_uniformity
968 }
969 S::Switch {
970 selector,
971 ref cases,
972 } => {
973 let selector_nur = self.add_ref(selector);
974 let branch_disruptor =
975 disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
976 let mut uniformity = FunctionUniformity::new();
977 let mut case_disruptor = branch_disruptor;
978 for case in cases.iter() {
979 let case_uniformity = self.process_block(
980 &case.body,
981 other_functions,
982 case_disruptor,
983 expression_arena,
984 diagnostic_filter_arena,
985 )?;
986 case_disruptor = if case.fall_through {
987 case_disruptor.or(case_uniformity.exit_disruptor())
988 } else {
989 branch_disruptor
990 };
991 uniformity = uniformity | case_uniformity;
992 }
993 uniformity
994 }
995 S::Loop {
996 ref body,
997 ref continuing,
998 break_if,
999 } => {
1000 let body_uniformity = self.process_block(
1001 body,
1002 other_functions,
1003 disruptor,
1004 expression_arena,
1005 diagnostic_filter_arena,
1006 )?;
1007 let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
1008 let continuing_uniformity = self.process_block(
1009 continuing,
1010 other_functions,
1011 continuing_disruptor,
1012 expression_arena,
1013 diagnostic_filter_arena,
1014 )?;
1015 if let Some(expr) = break_if {
1016 let _ = self.add_ref(expr);
1017 }
1018 body_uniformity | continuing_uniformity
1019 }
1020 S::Return { value } => FunctionUniformity {
1021 result: Uniformity {
1022 non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
1023 requirements: UniformityRequirements::empty(),
1024 },
1025 exit: if disruptor.is_some() {
1026 ExitFlags::MAY_RETURN
1027 } else {
1028 ExitFlags::empty()
1029 },
1030 },
1031 S::Store { pointer, value } => {
1035 let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
1036 let _ = self.add_ref(value);
1037 FunctionUniformity::new()
1038 }
1039 S::ImageStore {
1040 image,
1041 coordinate,
1042 array_index,
1043 value,
1044 } => {
1045 let _ = self.add_ref_impl(image, GlobalUse::WRITE);
1046 if let Some(expr) = array_index {
1047 let _ = self.add_ref(expr);
1048 }
1049 let _ = self.add_ref(coordinate);
1050 let _ = self.add_ref(value);
1051 FunctionUniformity::new()
1052 }
1053 S::Call {
1054 function,
1055 ref arguments,
1056 result: _,
1057 } => {
1058 for &argument in arguments {
1059 let _ = self.add_ref(argument);
1060 }
1061 let info = &other_functions[function.index()];
1062 self.process_call(info, arguments, expression_arena)?
1064 }
1065 S::Atomic {
1066 pointer,
1067 ref fun,
1068 value,
1069 result: _,
1070 } => {
1071 let _ = self.add_ref_impl(pointer, GlobalUse::READ | GlobalUse::WRITE);
1072 let _ = self.add_ref(value);
1073 if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
1074 let _ = self.add_ref(cmp);
1075 }
1076 FunctionUniformity::new()
1077 }
1078 S::ImageAtomic {
1079 image,
1080 coordinate,
1081 array_index,
1082 fun: _,
1083 value,
1084 } => {
1085 let _ = self.add_ref_impl(image, GlobalUse::ATOMIC);
1086 let _ = self.add_ref(coordinate);
1087 if let Some(expr) = array_index {
1088 let _ = self.add_ref(expr);
1089 }
1090 let _ = self.add_ref(value);
1091 FunctionUniformity::new()
1092 }
1093 S::RayQuery { query, ref fun } => {
1094 let _ = self.add_ref(query);
1095 match *fun {
1096 crate::RayQueryFunction::Initialize {
1097 acceleration_structure,
1098 descriptor,
1099 } => {
1100 let _ = self.add_ref(acceleration_structure);
1101 let _ = self.add_ref(descriptor);
1102 }
1103 crate::RayQueryFunction::Proceed { result: _ } => {}
1104 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1105 let _ = self.add_ref(hit_t);
1106 }
1107 crate::RayQueryFunction::ConfirmIntersection => {}
1108 crate::RayQueryFunction::Terminate => {}
1109 }
1110 FunctionUniformity::new()
1111 }
1112 S::SubgroupBallot {
1113 result: _,
1114 predicate,
1115 } => {
1116 if let Some(predicate) = predicate {
1117 let _ = self.add_ref(predicate);
1118 }
1119 FunctionUniformity::new()
1120 }
1121 S::SubgroupCollectiveOperation {
1122 op: _,
1123 collective_op: _,
1124 argument,
1125 result: _,
1126 } => {
1127 let _ = self.add_ref(argument);
1128 FunctionUniformity::new()
1129 }
1130 S::SubgroupGather {
1131 mode,
1132 argument,
1133 result: _,
1134 } => {
1135 let _ = self.add_ref(argument);
1136 match mode {
1137 crate::GatherMode::BroadcastFirst => {}
1138 crate::GatherMode::Broadcast(index)
1139 | crate::GatherMode::Shuffle(index)
1140 | crate::GatherMode::ShuffleDown(index)
1141 | crate::GatherMode::ShuffleUp(index)
1142 | crate::GatherMode::ShuffleXor(index) => {
1143 let _ = self.add_ref(index);
1144 }
1145 }
1146 FunctionUniformity::new()
1147 }
1148 };
1149
1150 disruptor = disruptor.or(uniformity.exit_disruptor());
1151 combined_uniformity = combined_uniformity | uniformity;
1152 }
1153 Ok(combined_uniformity)
1154 }
1155}
1156
1157impl ModuleInfo {
1158 pub(super) fn process_const_expression(
1160 &mut self,
1161 handle: Handle<crate::Expression>,
1162 resolve_context: &ResolveContext,
1163 gctx: crate::proc::GlobalCtx,
1164 ) -> Result<(), super::ConstExpressionError> {
1165 self.const_expression_types[handle.index()] =
1166 resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
1167 Ok(())
1168 }
1169
1170 pub(super) fn process_function(
1173 &self,
1174 fun: &crate::Function,
1175 module: &crate::Module,
1176 flags: ValidationFlags,
1177 capabilities: super::Capabilities,
1178 ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
1179 let mut info = FunctionInfo {
1180 flags,
1181 available_stages: ShaderStages::all(),
1182 uniformity: Uniformity::new(),
1183 may_kill: false,
1184 sampling_set: crate::FastHashSet::default(),
1185 global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
1186 expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
1187 sampling: crate::FastHashSet::default(),
1188 dual_source_blending: false,
1189 diagnostic_filter_leaf: fun.diagnostic_filter_leaf,
1190 };
1191 let resolve_context =
1192 ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
1193
1194 for (handle, _) in fun.expressions.iter() {
1195 if let Err(source) = info.process_expression(
1196 handle,
1197 &fun.expressions,
1198 &self.functions,
1199 &resolve_context,
1200 capabilities,
1201 ) {
1202 return Err(FunctionError::Expression { handle, source }
1203 .with_span_handle(handle, &fun.expressions));
1204 }
1205 }
1206
1207 for (_, expr) in fun.local_variables.iter() {
1208 if let Some(init) = expr.init {
1209 let _ = info.add_ref(init);
1210 }
1211 }
1212
1213 let uniformity = info.process_block(
1214 &fun.body,
1215 &self.functions,
1216 None,
1217 &fun.expressions,
1218 &module.diagnostic_filters,
1219 )?;
1220 info.uniformity = uniformity.result;
1221 info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
1222
1223 Ok(info)
1224 }
1225
1226 pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
1227 &self.entry_points[index]
1228 }
1229}
1230
1231#[test]
1232fn uniform_control_flow() {
1233 use crate::{Expression as E, Statement as S};
1234
1235 let mut type_arena = crate::UniqueArena::new();
1236 let ty = type_arena.insert(
1237 crate::Type {
1238 name: None,
1239 inner: crate::TypeInner::Vector {
1240 size: crate::VectorSize::Bi,
1241 scalar: crate::Scalar::F32,
1242 },
1243 },
1244 Default::default(),
1245 );
1246 let mut global_var_arena = Arena::new();
1247 let non_uniform_global = global_var_arena.append(
1248 crate::GlobalVariable {
1249 name: None,
1250 init: None,
1251 ty,
1252 space: crate::AddressSpace::Handle,
1253 binding: None,
1254 },
1255 Default::default(),
1256 );
1257 let uniform_global = global_var_arena.append(
1258 crate::GlobalVariable {
1259 name: None,
1260 init: None,
1261 ty,
1262 binding: None,
1263 space: crate::AddressSpace::Uniform,
1264 },
1265 Default::default(),
1266 );
1267
1268 let mut expressions = Arena::new();
1269 let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
1271 let derivative_expr = expressions.append(
1273 E::Derivative {
1274 axis: crate::DerivativeAxis::X,
1275 ctrl: crate::DerivativeControl::None,
1276 expr: constant_expr,
1277 },
1278 Default::default(),
1279 );
1280 let emit_range_constant_derivative = expressions.range_from(0);
1281 let non_uniform_global_expr =
1282 expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
1283 let uniform_global_expr =
1284 expressions.append(E::GlobalVariable(uniform_global), Default::default());
1285 let emit_range_globals = expressions.range_from(2);
1286
1287 let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
1289 let access_expr = expressions.append(
1291 E::AccessIndex {
1292 base: non_uniform_global_expr,
1293 index: 1,
1294 },
1295 Default::default(),
1296 );
1297 let emit_range_query_access_globals = expressions.range_from(2);
1298
1299 let mut info = FunctionInfo {
1300 flags: ValidationFlags::all(),
1301 available_stages: ShaderStages::all(),
1302 uniformity: Uniformity::new(),
1303 may_kill: false,
1304 sampling_set: crate::FastHashSet::default(),
1305 global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
1306 expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
1307 sampling: crate::FastHashSet::default(),
1308 dual_source_blending: false,
1309 diagnostic_filter_leaf: None,
1310 };
1311 let resolve_context = ResolveContext {
1312 constants: &Arena::new(),
1313 overrides: &Arena::new(),
1314 types: &type_arena,
1315 special_types: &crate::SpecialTypes::default(),
1316 global_vars: &global_var_arena,
1317 local_vars: &Arena::new(),
1318 functions: &Arena::new(),
1319 arguments: &[],
1320 };
1321 for (handle, _) in expressions.iter() {
1322 info.process_expression(
1323 handle,
1324 &expressions,
1325 &[],
1326 &resolve_context,
1327 super::Capabilities::empty(),
1328 )
1329 .unwrap();
1330 }
1331 assert_eq!(info[non_uniform_global_expr].ref_count, 1);
1332 assert_eq!(info[uniform_global_expr].ref_count, 1);
1333 assert_eq!(info[query_expr].ref_count, 0);
1334 assert_eq!(info[access_expr].ref_count, 0);
1335 assert_eq!(info[non_uniform_global], GlobalUse::empty());
1336 assert_eq!(info[uniform_global], GlobalUse::QUERY);
1337
1338 let stmt_emit1 = S::Emit(emit_range_globals.clone());
1339 let stmt_if_uniform = S::If {
1340 condition: uniform_global_expr,
1341 accept: crate::Block::new(),
1342 reject: vec![
1343 S::Emit(emit_range_constant_derivative.clone()),
1344 S::Store {
1345 pointer: constant_expr,
1346 value: derivative_expr,
1347 },
1348 ]
1349 .into(),
1350 };
1351 assert_eq!(
1352 info.process_block(
1353 &vec![stmt_emit1, stmt_if_uniform].into(),
1354 &[],
1355 None,
1356 &expressions,
1357 &Arena::new(),
1358 ),
1359 Ok(FunctionUniformity {
1360 result: Uniformity {
1361 non_uniform_result: None,
1362 requirements: UniformityRequirements::DERIVATIVE,
1363 },
1364 exit: ExitFlags::empty(),
1365 }),
1366 );
1367 assert_eq!(info[constant_expr].ref_count, 2);
1368 assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
1369
1370 let stmt_emit2 = S::Emit(emit_range_globals.clone());
1371 let stmt_if_non_uniform = S::If {
1372 condition: non_uniform_global_expr,
1373 accept: vec![
1374 S::Emit(emit_range_constant_derivative),
1375 S::Store {
1376 pointer: constant_expr,
1377 value: derivative_expr,
1378 },
1379 ]
1380 .into(),
1381 reject: crate::Block::new(),
1382 };
1383 {
1384 let block_info = info.process_block(
1385 &vec![stmt_emit2.clone(), stmt_if_non_uniform.clone()].into(),
1386 &[],
1387 None,
1388 &expressions,
1389 &Arena::new(),
1390 );
1391 if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
1392 assert_eq!(info[derivative_expr].ref_count, 2);
1393 } else {
1394 assert_eq!(
1395 block_info,
1396 Err(FunctionError::NonUniformControlFlow(
1397 UniformityRequirements::DERIVATIVE,
1398 derivative_expr,
1399 UniformityDisruptor::Expression(non_uniform_global_expr)
1400 )
1401 .with_span()),
1402 );
1403 assert_eq!(info[derivative_expr].ref_count, 1);
1404
1405 let mut diagnostic_filters = Arena::new();
1407 let diagnostic_filter_leaf = diagnostic_filters.append(
1408 DiagnosticFilterNode {
1409 inner: crate::diagnostic_filter::DiagnosticFilter {
1410 new_severity: crate::diagnostic_filter::Severity::Off,
1411 triggering_rule:
1412 crate::diagnostic_filter::FilterableTriggeringRule::Standard(
1413 StandardFilterableTriggeringRule::DerivativeUniformity,
1414 ),
1415 },
1416 parent: None,
1417 },
1418 crate::Span::default(),
1419 );
1420 let mut info = FunctionInfo {
1421 diagnostic_filter_leaf: Some(diagnostic_filter_leaf),
1422 ..info.clone()
1423 };
1424
1425 let block_info = info.process_block(
1426 &vec![stmt_emit2, stmt_if_non_uniform].into(),
1427 &[],
1428 None,
1429 &expressions,
1430 &diagnostic_filters,
1431 );
1432 assert_eq!(
1433 block_info,
1434 Ok(FunctionUniformity {
1435 result: Uniformity {
1436 non_uniform_result: None,
1437 requirements: UniformityRequirements::DERIVATIVE,
1438 },
1439 exit: ExitFlags::empty()
1440 }),
1441 );
1442 assert_eq!(info[derivative_expr].ref_count, 2);
1443 }
1444 }
1445 assert_eq!(info[non_uniform_global], GlobalUse::READ);
1446
1447 let stmt_emit3 = S::Emit(emit_range_globals);
1448 let stmt_return_non_uniform = S::Return {
1449 value: Some(non_uniform_global_expr),
1450 };
1451 assert_eq!(
1452 info.process_block(
1453 &vec![stmt_emit3, stmt_return_non_uniform].into(),
1454 &[],
1455 Some(UniformityDisruptor::Return),
1456 &expressions,
1457 &Arena::new(),
1458 ),
1459 Ok(FunctionUniformity {
1460 result: Uniformity {
1461 non_uniform_result: Some(non_uniform_global_expr),
1462 requirements: UniformityRequirements::empty(),
1463 },
1464 exit: ExitFlags::MAY_RETURN,
1465 }),
1466 );
1467 assert_eq!(info[non_uniform_global_expr].ref_count, 3);
1468
1469 let stmt_emit4 = S::Emit(emit_range_query_access_globals);
1471 let stmt_assign = S::Store {
1472 pointer: access_expr,
1473 value: query_expr,
1474 };
1475 let stmt_return_pointer = S::Return {
1476 value: Some(access_expr),
1477 };
1478 let stmt_kill = S::Kill;
1479 assert_eq!(
1480 info.process_block(
1481 &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
1482 &[],
1483 Some(UniformityDisruptor::Discard),
1484 &expressions,
1485 &Arena::new(),
1486 ),
1487 Ok(FunctionUniformity {
1488 result: Uniformity {
1489 non_uniform_result: Some(non_uniform_global_expr),
1490 requirements: UniformityRequirements::empty(),
1491 },
1492 exit: ExitFlags::all(),
1493 }),
1494 );
1495 assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
1496}