1use alloc::vec::Vec;
6
7use arrayvec::ArrayVec;
8use spirv::Word;
9
10use super::{
11 index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error,
12 Instruction, LocalType, LookupType, NumericType, ResultMember, WrappedFunction, Writer,
13 WriterFlags,
14};
15use crate::{arena::Handle, proc::index::GuardedIndex, Statement};
16
17fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
18 match *type_inner {
19 crate::TypeInner::Scalar(_) => Dimension::Scalar,
20 crate::TypeInner::Vector { .. } => Dimension::Vector,
21 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
22 _ => unreachable!(),
23 }
24}
25
26enum AccessTypeAdjustment {
38 None,
47
48 IntroducePointer(spirv::StorageClass),
72}
73
74enum ExpressionPointer {
78 Ready { pointer_id: Word },
81
82 Conditional {
88 condition: Word,
89 access: Instruction,
90 },
91}
92
93enum BlockExit {
95 Return,
97 Branch {
99 target: Word,
101 },
102 BreakIf {
108 condition: Handle<crate::Expression>,
110 preamble_id: Word,
112 },
113}
114
115#[must_use]
126enum BlockExitDisposition {
127 Used,
131
132 Discarded,
137}
138
139#[derive(Clone, Copy, Default)]
140struct LoopContext {
141 continuing_id: Option<Word>,
142 break_id: Option<Word>,
143}
144
145#[derive(Debug)]
146pub(crate) struct DebugInfoInner<'a> {
147 pub source_code: &'a str,
148 pub source_file_id: Word,
149}
150
151impl Writer {
152 fn write_epilogue_position_y_flip(
157 &mut self,
158 position_id: Word,
159 body: &mut Vec<Instruction>,
160 ) -> Result<(), Error> {
161 let float_ptr_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Output);
162 let index_y_id = self.get_index_constant(1);
163 let access_id = self.id_gen.next();
164 body.push(Instruction::access_chain(
165 float_ptr_type_id,
166 access_id,
167 position_id,
168 &[index_y_id],
169 ));
170
171 let float_type_id = self.get_f32_type_id();
172 let load_id = self.id_gen.next();
173 body.push(Instruction::load(float_type_id, load_id, access_id, None));
174
175 let neg_id = self.id_gen.next();
176 body.push(Instruction::unary(
177 spirv::Op::FNegate,
178 float_type_id,
179 neg_id,
180 load_id,
181 ));
182
183 body.push(Instruction::store(access_id, neg_id, None));
184 Ok(())
185 }
186
187 fn write_epilogue_frag_depth_clamp(
189 &mut self,
190 frag_depth_id: Word,
191 body: &mut Vec<Instruction>,
192 ) -> Result<(), Error> {
193 let float_type_id = self.get_f32_type_id();
194 let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
195 let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
196
197 let original_id = self.id_gen.next();
198 body.push(Instruction::load(
199 float_type_id,
200 original_id,
201 frag_depth_id,
202 None,
203 ));
204
205 let clamp_id = self.id_gen.next();
206 body.push(Instruction::ext_inst(
207 self.gl450_ext_inst_id,
208 spirv::GLOp::FClamp,
209 float_type_id,
210 clamp_id,
211 &[original_id, zero_scalar_id, one_scalar_id],
212 ));
213
214 body.push(Instruction::store(frag_depth_id, clamp_id, None));
215 Ok(())
216 }
217
218 fn write_entry_point_return(
219 &mut self,
220 value_id: Word,
221 ir_result: &crate::FunctionResult,
222 result_members: &[ResultMember],
223 body: &mut Vec<Instruction>,
224 ) -> Result<(), Error> {
225 for (index, res_member) in result_members.iter().enumerate() {
226 let member_value_id = match ir_result.binding {
227 Some(_) => value_id,
228 None => {
229 let member_value_id = self.id_gen.next();
230 body.push(Instruction::composite_extract(
231 res_member.type_id,
232 member_value_id,
233 value_id,
234 &[index as u32],
235 ));
236 member_value_id
237 }
238 };
239
240 body.push(Instruction::store(res_member.id, member_value_id, None));
241
242 match res_member.built_in {
243 Some(crate::BuiltIn::Position { .. })
244 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
245 {
246 self.write_epilogue_position_y_flip(res_member.id, body)?;
247 }
248 Some(crate::BuiltIn::FragDepth)
249 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
250 {
251 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
252 }
253 _ => {}
254 }
255 }
256 Ok(())
257 }
258}
259
260impl BlockContext<'_> {
261 fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
274 let uint_type_id = self.writer.get_u32_type_id();
275 let uint2_type_id = self.writer.get_vec2u_type_id();
276 let uint2_ptr_type_id = self
277 .writer
278 .get_vec2u_pointer_type_id(spirv::StorageClass::Function);
279 let bool_type_id = self.writer.get_bool_type_id();
280 let bool2_type_id = self.writer.get_vec2_bool_type_id();
281 let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
282 let zero_uint2_const_id = self.writer.get_constant_composite(
283 LookupType::Local(LocalType::Numeric(NumericType::Vector {
284 size: crate::VectorSize::Bi,
285 scalar: crate::Scalar::U32,
286 })),
287 &[zero_uint_const_id, zero_uint_const_id],
288 );
289 let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
290 let max_uint_const_id = self
291 .writer
292 .get_constant_scalar(crate::Literal::U32(u32::MAX));
293 let max_uint2_const_id = self.writer.get_constant_composite(
294 LookupType::Local(LocalType::Numeric(NumericType::Vector {
295 size: crate::VectorSize::Bi,
296 scalar: crate::Scalar::U32,
297 })),
298 &[max_uint_const_id, max_uint_const_id],
299 );
300
301 let loop_counter_var_id = self.gen_id();
302 if self.writer.flags.contains(WriterFlags::DEBUG) {
303 self.writer
304 .debugs
305 .push(Instruction::name(loop_counter_var_id, "loop_bound"));
306 }
307 let var = super::LocalVariable {
308 id: loop_counter_var_id,
309 instruction: Instruction::variable(
310 uint2_ptr_type_id,
311 loop_counter_var_id,
312 spirv::StorageClass::Function,
313 Some(max_uint2_const_id),
314 ),
315 };
316 self.function.force_loop_bounding_vars.push(var);
317
318 let break_if_block = self.gen_id();
319
320 self.function
321 .consume(block, Instruction::branch(break_if_block));
322 block = Block::new(break_if_block);
323
324 let load_id = self.gen_id();
327 block.body.push(Instruction::load(
328 uint2_type_id,
329 load_id,
330 loop_counter_var_id,
331 None,
332 ));
333
334 let eq_id = self.gen_id();
337 block.body.push(Instruction::binary(
338 spirv::Op::IEqual,
339 bool2_type_id,
340 eq_id,
341 zero_uint2_const_id,
342 load_id,
343 ));
344 let all_eq_id = self.gen_id();
345 block.body.push(Instruction::relational(
346 spirv::Op::All,
347 bool_type_id,
348 all_eq_id,
349 eq_id,
350 ));
351
352 let inc_counter_block_id = self.gen_id();
353 block.body.push(Instruction::selection_merge(
354 inc_counter_block_id,
355 spirv::SelectionControl::empty(),
356 ));
357 self.function.consume(
358 block,
359 Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
360 );
361 block = Block::new(inc_counter_block_id);
362
363 let low_id = self.gen_id();
369 block.body.push(Instruction::composite_extract(
370 uint_type_id,
371 low_id,
372 load_id,
373 &[1],
374 ));
375 let low_overflow_id = self.gen_id();
376 block.body.push(Instruction::binary(
377 spirv::Op::IEqual,
378 bool_type_id,
379 low_overflow_id,
380 low_id,
381 zero_uint_const_id,
382 ));
383 let carry_bit_id = self.gen_id();
384 block.body.push(Instruction::select(
385 uint_type_id,
386 carry_bit_id,
387 low_overflow_id,
388 one_uint_const_id,
389 zero_uint_const_id,
390 ));
391 let decrement_id = self.gen_id();
392 block.body.push(Instruction::composite_construct(
393 uint2_type_id,
394 decrement_id,
395 &[carry_bit_id, one_uint_const_id],
396 ));
397 let result_id = self.gen_id();
398 block.body.push(Instruction::binary(
399 spirv::Op::ISub,
400 uint2_type_id,
401 result_id,
402 load_id,
403 decrement_id,
404 ));
405 block
406 .body
407 .push(Instruction::store(loop_counter_var_id, result_id, None));
408
409 block
410 }
411
412 pub(super) fn cache_expression_value(
414 &mut self,
415 expr_handle: Handle<crate::Expression>,
416 block: &mut Block,
417 ) -> Result<(), Error> {
418 let is_named_expression = self
419 .ir_function
420 .named_expressions
421 .contains_key(&expr_handle);
422
423 if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
424 return Ok(());
425 }
426
427 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
428 let id = match self.ir_function.expressions[expr_handle] {
429 crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
430 crate::Expression::Constant(handle) => {
431 let init = self.ir_module.constants[handle].init;
432 self.writer.constant_ids[init]
433 }
434 crate::Expression::Override(_) => return Err(Error::Override),
435 crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
436 crate::Expression::Compose { ty, ref components } => {
437 self.temp_list.clear();
438 if self.expression_constness.is_const(expr_handle) {
439 self.temp_list.extend(
440 crate::proc::flatten_compose(
441 ty,
442 components,
443 &self.ir_function.expressions,
444 &self.ir_module.types,
445 )
446 .map(|component| self.cached[component]),
447 );
448 self.writer
449 .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
450 } else {
451 self.temp_list
452 .extend(components.iter().map(|&component| self.cached[component]));
453
454 let id = self.gen_id();
455 block.body.push(Instruction::composite_construct(
456 result_type_id,
457 id,
458 &self.temp_list,
459 ));
460 id
461 }
462 }
463 crate::Expression::Splat { size, value } => {
464 let value_id = self.cached[value];
465 let components = &[value_id; 4][..size as usize];
466
467 if self.expression_constness.is_const(expr_handle) {
468 let ty = self
469 .writer
470 .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
471 self.writer.get_constant_composite(ty, components)
472 } else {
473 let id = self.gen_id();
474 block.body.push(Instruction::composite_construct(
475 result_type_id,
476 id,
477 components,
478 ));
479 id
480 }
481 }
482 crate::Expression::Access { base, index } => {
483 let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
484 match *base_ty_inner {
485 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
486 0
492 }
493 _ if self.function.spilled_accesses.contains(base) => {
494 self.function.spilled_accesses.insert(expr_handle);
502 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
503 }
504 crate::TypeInner::Vector { .. } => {
505 self.write_vector_access(expr_handle, base, index, block)?
506 }
507 crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
508 match GuardedIndex::from_expression(
510 index,
511 &self.ir_function.expressions,
512 self.ir_module,
513 ) {
514 GuardedIndex::Known(value) => {
515 let id = self.gen_id();
525 let base_id = self.cached[base];
526 block.body.push(Instruction::composite_extract(
527 result_type_id,
528 id,
529 base_id,
530 &[value],
531 ));
532 id
533 }
534 GuardedIndex::Expression(_) => {
535 self.spill_to_internal_variable(base, block);
542
543 self.function.spilled_accesses.insert(expr_handle);
546 self.maybe_access_spilled_composite(
547 expr_handle,
548 block,
549 result_type_id,
550 )?
551 }
552 }
553 }
554 crate::TypeInner::BindingArray {
555 base: binding_type, ..
556 } => {
557 let result_id = match self.write_access_chain(
560 expr_handle,
561 block,
562 AccessTypeAdjustment::IntroducePointer(
563 spirv::StorageClass::UniformConstant,
564 ),
565 )? {
566 ExpressionPointer::Ready { pointer_id } => pointer_id,
567 ExpressionPointer::Conditional { .. } => {
568 return Err(Error::FeatureNotImplemented(
569 "Texture array out-of-bounds handling",
570 ));
571 }
572 };
573
574 let binding_type_id = self.get_handle_type_id(binding_type);
575
576 let load_id = self.gen_id();
577 block.body.push(Instruction::load(
578 binding_type_id,
579 load_id,
580 result_id,
581 None,
582 ));
583
584 if self.fun_info[index].uniformity.non_uniform_result.is_some() {
588 self.writer
589 .decorate_non_uniform_binding_array_access(load_id)?;
590 }
591
592 load_id
593 }
594 ref other => {
595 log::error!(
596 "Unable to access base {:?} of type {:?}",
597 self.ir_function.expressions[base],
598 other
599 );
600 return Err(Error::Validation(
601 "only vectors and arrays may be dynamically indexed by value",
602 ));
603 }
604 }
605 }
606 crate::Expression::AccessIndex { base, index } => {
607 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
608 crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
609 0
615 }
616 _ if self.function.spilled_accesses.contains(base) => {
617 self.function.spilled_accesses.insert(expr_handle);
625 self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
626 }
627 crate::TypeInner::Vector { .. }
628 | crate::TypeInner::Matrix { .. }
629 | crate::TypeInner::Array { .. }
630 | crate::TypeInner::Struct { .. } => {
631 let id = self.gen_id();
636 let base_id = self.cached[base];
637 block.body.push(Instruction::composite_extract(
638 result_type_id,
639 id,
640 base_id,
641 &[index],
642 ));
643 id
644 }
645 crate::TypeInner::BindingArray {
646 base: binding_type, ..
647 } => {
648 let result_id = match self.write_access_chain(
651 expr_handle,
652 block,
653 AccessTypeAdjustment::IntroducePointer(
654 spirv::StorageClass::UniformConstant,
655 ),
656 )? {
657 ExpressionPointer::Ready { pointer_id } => pointer_id,
658 ExpressionPointer::Conditional { .. } => {
659 return Err(Error::FeatureNotImplemented(
660 "Texture array out-of-bounds handling",
661 ));
662 }
663 };
664
665 let binding_type_id = self.get_handle_type_id(binding_type);
666
667 let load_id = self.gen_id();
668 block.body.push(Instruction::load(
669 binding_type_id,
670 load_id,
671 result_id,
672 None,
673 ));
674
675 load_id
676 }
677 ref other => {
678 log::error!("Unable to access index of {:?}", other);
679 return Err(Error::FeatureNotImplemented("access index for type"));
680 }
681 }
682 }
683 crate::Expression::GlobalVariable(handle) => {
684 self.writer.global_variables[handle].access_id
685 }
686 crate::Expression::Swizzle {
687 size,
688 vector,
689 pattern,
690 } => {
691 let vector_id = self.cached[vector];
692 self.temp_list.clear();
693 for &sc in pattern[..size as usize].iter() {
694 self.temp_list.push(sc as Word);
695 }
696 let id = self.gen_id();
697 block.body.push(Instruction::vector_shuffle(
698 result_type_id,
699 id,
700 vector_id,
701 vector_id,
702 &self.temp_list,
703 ));
704 id
705 }
706 crate::Expression::Unary { op, expr } => {
707 let id = self.gen_id();
708 let expr_id = self.cached[expr];
709 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
710
711 let spirv_op = match op {
712 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
713 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
714 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
715 _ => return Err(Error::Validation("Unexpected kind for negation")),
716 },
717 crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
718 crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
719 };
720
721 block
722 .body
723 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
724 id
725 }
726 crate::Expression::Binary { op, left, right } => {
727 let id = self.gen_id();
728 let left_id = self.cached[left];
729 let right_id = self.cached[right];
730 let left_type_id = self.get_expression_type_id(&self.fun_info[left].ty);
731 let right_type_id = self.get_expression_type_id(&self.fun_info[right].ty);
732
733 if let Some(function_id) =
734 self.writer
735 .wrapped_functions
736 .get(&WrappedFunction::BinaryOp {
737 op,
738 left_type_id,
739 right_type_id,
740 })
741 {
742 block.body.push(Instruction::function_call(
743 result_type_id,
744 id,
745 *function_id,
746 &[left_id, right_id],
747 ));
748 } else {
749 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
750 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
751
752 let left_dimension = get_dimension(left_ty_inner);
753 let right_dimension = get_dimension(right_ty_inner);
754
755 let mut reverse_operands = false;
756
757 let spirv_op = match op {
758 crate::BinaryOperator::Add => match *left_ty_inner {
759 crate::TypeInner::Scalar(scalar)
760 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
761 crate::ScalarKind::Float => spirv::Op::FAdd,
762 _ => spirv::Op::IAdd,
763 },
764 crate::TypeInner::Matrix {
765 columns,
766 rows,
767 scalar,
768 } => {
769 self.write_matrix_matrix_column_op(
770 block,
771 id,
772 result_type_id,
773 left_id,
774 right_id,
775 columns,
776 rows,
777 scalar.width,
778 spirv::Op::FAdd,
779 );
780
781 self.cached[expr_handle] = id;
782 return Ok(());
783 }
784 _ => unimplemented!(),
785 },
786 crate::BinaryOperator::Subtract => match *left_ty_inner {
787 crate::TypeInner::Scalar(scalar)
788 | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
789 crate::ScalarKind::Float => spirv::Op::FSub,
790 _ => spirv::Op::ISub,
791 },
792 crate::TypeInner::Matrix {
793 columns,
794 rows,
795 scalar,
796 } => {
797 self.write_matrix_matrix_column_op(
798 block,
799 id,
800 result_type_id,
801 left_id,
802 right_id,
803 columns,
804 rows,
805 scalar.width,
806 spirv::Op::FSub,
807 );
808
809 self.cached[expr_handle] = id;
810 return Ok(());
811 }
812 _ => unimplemented!(),
813 },
814 crate::BinaryOperator::Multiply => {
815 match (left_dimension, right_dimension) {
816 (Dimension::Scalar, Dimension::Vector) => {
817 self.write_vector_scalar_mult(
818 block,
819 id,
820 result_type_id,
821 right_id,
822 left_id,
823 right_ty_inner,
824 );
825
826 self.cached[expr_handle] = id;
827 return Ok(());
828 }
829 (Dimension::Vector, Dimension::Scalar) => {
830 self.write_vector_scalar_mult(
831 block,
832 id,
833 result_type_id,
834 left_id,
835 right_id,
836 left_ty_inner,
837 );
838
839 self.cached[expr_handle] = id;
840 return Ok(());
841 }
842 (Dimension::Vector, Dimension::Matrix) => {
843 spirv::Op::VectorTimesMatrix
844 }
845 (Dimension::Matrix, Dimension::Scalar) => {
846 spirv::Op::MatrixTimesScalar
847 }
848 (Dimension::Scalar, Dimension::Matrix) => {
849 reverse_operands = true;
850 spirv::Op::MatrixTimesScalar
851 }
852 (Dimension::Matrix, Dimension::Vector) => {
853 spirv::Op::MatrixTimesVector
854 }
855 (Dimension::Matrix, Dimension::Matrix) => {
856 spirv::Op::MatrixTimesMatrix
857 }
858 (Dimension::Vector, Dimension::Vector)
859 | (Dimension::Scalar, Dimension::Scalar)
860 if left_ty_inner.scalar_kind()
861 == Some(crate::ScalarKind::Float) =>
862 {
863 spirv::Op::FMul
864 }
865 (Dimension::Vector, Dimension::Vector)
866 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
867 }
868 }
869 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
870 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
871 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
872 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
873 _ => unimplemented!(),
874 },
875 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
876 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
879 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
880 unreachable!("Should have been handled by wrapped function")
881 }
882 _ => unimplemented!(),
883 },
884 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
885 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
886 spirv::Op::IEqual
887 }
888 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
889 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
890 _ => unimplemented!(),
891 },
892 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
893 Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
894 spirv::Op::INotEqual
895 }
896 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
897 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
898 _ => unimplemented!(),
899 },
900 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
901 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
902 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
903 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
904 _ => unimplemented!(),
905 },
906 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
907 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
908 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
909 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
910 _ => unimplemented!(),
911 },
912 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
913 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
914 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
915 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
916 _ => unimplemented!(),
917 },
918 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
919 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
920 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
921 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
922 _ => unimplemented!(),
923 },
924 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
925 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
926 _ => spirv::Op::BitwiseAnd,
927 },
928 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
929 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
930 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
931 _ => spirv::Op::BitwiseOr,
932 },
933 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
934 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
935 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
936 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
937 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
938 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
939 _ => unimplemented!(),
940 },
941 };
942
943 block.body.push(Instruction::binary(
944 spirv_op,
945 result_type_id,
946 id,
947 if reverse_operands { right_id } else { left_id },
948 if reverse_operands { left_id } else { right_id },
949 ));
950 }
951 id
952 }
953 crate::Expression::Math {
954 fun,
955 arg,
956 arg1,
957 arg2,
958 arg3,
959 } => {
960 use crate::MathFunction as Mf;
961 enum MathOp {
962 Ext(spirv::GLOp),
963 Custom(Instruction),
964 }
965
966 let arg0_id = self.cached[arg];
967 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
968 let arg_scalar_kind = arg_ty.scalar_kind();
969 let arg1_id = match arg1 {
970 Some(handle) => self.cached[handle],
971 None => 0,
972 };
973 let arg2_id = match arg2 {
974 Some(handle) => self.cached[handle],
975 None => 0,
976 };
977 let arg3_id = match arg3 {
978 Some(handle) => self.cached[handle],
979 None => 0,
980 };
981
982 let id = self.gen_id();
983 let math_op = match fun {
984 Mf::Abs => {
986 match arg_scalar_kind {
987 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
988 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
989 Some(crate::ScalarKind::Uint) => {
990 MathOp::Custom(Instruction::unary(
991 spirv::Op::CopyObject, result_type_id,
993 id,
994 arg0_id,
995 ))
996 }
997 other => unimplemented!("Unexpected abs({:?})", other),
998 }
999 }
1000 Mf::Min => MathOp::Ext(match arg_scalar_kind {
1001 Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
1002 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
1003 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
1004 other => unimplemented!("Unexpected min({:?})", other),
1005 }),
1006 Mf::Max => MathOp::Ext(match arg_scalar_kind {
1007 Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
1008 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
1009 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
1010 other => unimplemented!("Unexpected max({:?})", other),
1011 }),
1012 Mf::Clamp => match arg_scalar_kind {
1013 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FClamp),
1017 Some(_) => {
1018 let (min_op, max_op) = match arg_scalar_kind {
1019 Some(crate::ScalarKind::Sint) => {
1020 (spirv::GLOp::SMin, spirv::GLOp::SMax)
1021 }
1022 Some(crate::ScalarKind::Uint) => {
1023 (spirv::GLOp::UMin, spirv::GLOp::UMax)
1024 }
1025 _ => unreachable!(),
1026 };
1027
1028 let max_id = self.gen_id();
1029 block.body.push(Instruction::ext_inst(
1030 self.writer.gl450_ext_inst_id,
1031 max_op,
1032 result_type_id,
1033 max_id,
1034 &[arg0_id, arg1_id],
1035 ));
1036
1037 MathOp::Custom(Instruction::ext_inst(
1038 self.writer.gl450_ext_inst_id,
1039 min_op,
1040 result_type_id,
1041 id,
1042 &[max_id, arg2_id],
1043 ))
1044 }
1045 other => unimplemented!("Unexpected max({:?})", other),
1046 },
1047 Mf::Saturate => {
1048 let (maybe_size, scalar) = match *arg_ty {
1049 crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
1050 crate::TypeInner::Scalar(scalar) => (None, scalar),
1051 ref other => unimplemented!("Unexpected saturate({:?})", other),
1052 };
1053 let scalar = crate::Scalar::float(scalar.width);
1054 let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
1055 let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
1056
1057 if let Some(size) = maybe_size {
1058 let ty =
1059 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1060
1061 self.temp_list.clear();
1062 self.temp_list.resize(size as _, arg1_id);
1063
1064 arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1065
1066 self.temp_list.fill(arg2_id);
1067
1068 arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
1069 }
1070
1071 MathOp::Custom(Instruction::ext_inst(
1072 self.writer.gl450_ext_inst_id,
1073 spirv::GLOp::FClamp,
1074 result_type_id,
1075 id,
1076 &[arg0_id, arg1_id, arg2_id],
1077 ))
1078 }
1079 Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
1081 Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
1082 Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
1083 Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
1084 Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
1085 Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
1086 Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
1087 Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
1088 Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
1089 Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
1090 Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
1091 Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
1092 Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
1093 Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
1094 Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
1095 Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
1097 Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
1098 Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
1099 Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
1100 Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
1101 Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
1102 Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
1103 Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
1104 Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
1106 crate::TypeInner::Vector {
1107 scalar:
1108 crate::Scalar {
1109 kind: crate::ScalarKind::Float,
1110 ..
1111 },
1112 ..
1113 } => MathOp::Custom(Instruction::binary(
1114 spirv::Op::Dot,
1115 result_type_id,
1116 id,
1117 arg0_id,
1118 arg1_id,
1119 )),
1120 crate::TypeInner::Vector { size, .. } => {
1122 self.write_dot_product(
1123 id,
1124 result_type_id,
1125 arg0_id,
1126 arg1_id,
1127 size as u32,
1128 block,
1129 );
1130 self.cached[expr_handle] = id;
1131 return Ok(());
1132 }
1133 _ => unreachable!(
1134 "Correct TypeInner for dot product should be already validated"
1135 ),
1136 },
1137 Mf::Outer => MathOp::Custom(Instruction::binary(
1138 spirv::Op::OuterProduct,
1139 result_type_id,
1140 id,
1141 arg0_id,
1142 arg1_id,
1143 )),
1144 Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
1145 Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
1146 Mf::Length => MathOp::Ext(spirv::GLOp::Length),
1147 Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
1148 Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
1149 Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
1150 Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
1151 Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
1153 Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
1154 Mf::Log => MathOp::Ext(spirv::GLOp::Log),
1155 Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
1156 Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
1157 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
1159 Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
1160 Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
1161 other => unimplemented!("Unexpected sign({:?})", other),
1162 }),
1163 Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
1164 Mf::Mix => {
1165 let selector = arg2.unwrap();
1166 let selector_ty =
1167 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
1168 match (arg_ty, selector_ty) {
1169 (
1171 &crate::TypeInner::Vector { size, .. },
1172 &crate::TypeInner::Scalar(scalar),
1173 ) => {
1174 let selector_type_id =
1175 self.get_numeric_type_id(NumericType::Vector { size, scalar });
1176 self.temp_list.clear();
1177 self.temp_list.resize(size as usize, arg2_id);
1178
1179 let selector_id = self.gen_id();
1180 block.body.push(Instruction::composite_construct(
1181 selector_type_id,
1182 selector_id,
1183 &self.temp_list,
1184 ));
1185
1186 MathOp::Custom(Instruction::ext_inst(
1187 self.writer.gl450_ext_inst_id,
1188 spirv::GLOp::FMix,
1189 result_type_id,
1190 id,
1191 &[arg0_id, arg1_id, selector_id],
1192 ))
1193 }
1194 _ => MathOp::Ext(spirv::GLOp::FMix),
1195 }
1196 }
1197 Mf::Step => MathOp::Ext(spirv::GLOp::Step),
1198 Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
1199 Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
1200 Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
1201 Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
1202 Mf::Transpose => MathOp::Custom(Instruction::unary(
1203 spirv::Op::Transpose,
1204 result_type_id,
1205 id,
1206 arg0_id,
1207 )),
1208 Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
1209 Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
1210 spirv::Op::QuantizeToF16,
1211 result_type_id,
1212 id,
1213 arg0_id,
1214 )),
1215 Mf::ReverseBits => MathOp::Custom(Instruction::unary(
1216 spirv::Op::BitReverse,
1217 result_type_id,
1218 id,
1219 arg0_id,
1220 )),
1221 Mf::CountTrailingZeros => {
1222 let uint_id = match *arg_ty {
1223 crate::TypeInner::Vector { size, scalar } => {
1224 let ty =
1225 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1226
1227 self.temp_list.clear();
1228 self.temp_list.resize(
1229 size as _,
1230 self.writer
1231 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1232 );
1233
1234 self.writer.get_constant_composite(ty, &self.temp_list)
1235 }
1236 crate::TypeInner::Scalar(scalar) => self
1237 .writer
1238 .get_constant_scalar_with(scalar.width * 8, scalar)?,
1239 _ => unreachable!(),
1240 };
1241
1242 let lsb_id = self.gen_id();
1243 block.body.push(Instruction::ext_inst(
1244 self.writer.gl450_ext_inst_id,
1245 spirv::GLOp::FindILsb,
1246 result_type_id,
1247 lsb_id,
1248 &[arg0_id],
1249 ));
1250
1251 MathOp::Custom(Instruction::ext_inst(
1252 self.writer.gl450_ext_inst_id,
1253 spirv::GLOp::UMin,
1254 result_type_id,
1255 id,
1256 &[uint_id, lsb_id],
1257 ))
1258 }
1259 Mf::CountLeadingZeros => {
1260 let (int_type_id, int_id, width) = match *arg_ty {
1261 crate::TypeInner::Vector { size, scalar } => {
1262 let ty =
1263 LocalType::Numeric(NumericType::Vector { size, scalar }).into();
1264
1265 self.temp_list.clear();
1266 self.temp_list.resize(
1267 size as _,
1268 self.writer
1269 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1270 );
1271
1272 (
1273 self.get_type_id(ty),
1274 self.writer.get_constant_composite(ty, &self.temp_list),
1275 scalar.width,
1276 )
1277 }
1278 crate::TypeInner::Scalar(scalar) => (
1279 self.get_numeric_type_id(NumericType::Scalar(scalar)),
1280 self.writer
1281 .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?,
1282 scalar.width,
1283 ),
1284 _ => unreachable!(),
1285 };
1286
1287 if width != 4 {
1288 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1289 };
1290
1291 let msb_id = self.gen_id();
1292 block.body.push(Instruction::ext_inst(
1293 self.writer.gl450_ext_inst_id,
1294 if width != 4 {
1295 spirv::GLOp::FindILsb
1296 } else {
1297 spirv::GLOp::FindUMsb
1298 },
1299 int_type_id,
1300 msb_id,
1301 &[arg0_id],
1302 ));
1303
1304 MathOp::Custom(Instruction::binary(
1305 spirv::Op::ISub,
1306 result_type_id,
1307 id,
1308 int_id,
1309 msb_id,
1310 ))
1311 }
1312 Mf::CountOneBits => MathOp::Custom(Instruction::unary(
1313 spirv::Op::BitCount,
1314 result_type_id,
1315 id,
1316 arg0_id,
1317 )),
1318 Mf::ExtractBits => {
1319 let op = match arg_scalar_kind {
1320 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
1321 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
1322 other => unimplemented!("Unexpected sign({:?})", other),
1323 };
1324
1325 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1340 let width_constant = self
1341 .writer
1342 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1343
1344 let u32_type =
1345 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1346
1347 let offset_id = self.gen_id();
1349 block.body.push(Instruction::ext_inst(
1350 self.writer.gl450_ext_inst_id,
1351 spirv::GLOp::UMin,
1352 u32_type,
1353 offset_id,
1354 &[arg1_id, width_constant],
1355 ));
1356
1357 let max_count_id = self.gen_id();
1359 block.body.push(Instruction::binary(
1360 spirv::Op::ISub,
1361 u32_type,
1362 max_count_id,
1363 width_constant,
1364 offset_id,
1365 ));
1366
1367 let count_id = self.gen_id();
1369 block.body.push(Instruction::ext_inst(
1370 self.writer.gl450_ext_inst_id,
1371 spirv::GLOp::UMin,
1372 u32_type,
1373 count_id,
1374 &[arg2_id, max_count_id],
1375 ));
1376
1377 MathOp::Custom(Instruction::ternary(
1378 op,
1379 result_type_id,
1380 id,
1381 arg0_id,
1382 offset_id,
1383 count_id,
1384 ))
1385 }
1386 Mf::InsertBits => {
1387 let bit_width = arg_ty.scalar_width().unwrap() * 8;
1390 let width_constant = self
1391 .writer
1392 .get_constant_scalar(crate::Literal::U32(bit_width as u32));
1393
1394 let u32_type =
1395 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1396
1397 let offset_id = self.gen_id();
1399 block.body.push(Instruction::ext_inst(
1400 self.writer.gl450_ext_inst_id,
1401 spirv::GLOp::UMin,
1402 u32_type,
1403 offset_id,
1404 &[arg2_id, width_constant],
1405 ));
1406
1407 let max_count_id = self.gen_id();
1409 block.body.push(Instruction::binary(
1410 spirv::Op::ISub,
1411 u32_type,
1412 max_count_id,
1413 width_constant,
1414 offset_id,
1415 ));
1416
1417 let count_id = self.gen_id();
1419 block.body.push(Instruction::ext_inst(
1420 self.writer.gl450_ext_inst_id,
1421 spirv::GLOp::UMin,
1422 u32_type,
1423 count_id,
1424 &[arg3_id, max_count_id],
1425 ));
1426
1427 MathOp::Custom(Instruction::quaternary(
1428 spirv::Op::BitFieldInsert,
1429 result_type_id,
1430 id,
1431 arg0_id,
1432 arg1_id,
1433 offset_id,
1434 count_id,
1435 ))
1436 }
1437 Mf::FirstTrailingBit => MathOp::Ext(spirv::GLOp::FindILsb),
1438 Mf::FirstLeadingBit => {
1439 if arg_ty.scalar_width() == Some(4) {
1440 let thing = match arg_scalar_kind {
1441 Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
1442 Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
1443 other => unimplemented!("Unexpected firstLeadingBit({:?})", other),
1444 };
1445 MathOp::Ext(thing)
1446 } else {
1447 unreachable!("This is validated out until a polyfill is implemented. https://github.com/gfx-rs/wgpu/issues/5276");
1448 }
1449 }
1450 Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
1451 Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
1452 Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
1453 Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
1454 Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
1455 fun @ (Mf::Pack4xI8 | Mf::Pack4xU8) => {
1456 let (int_type, is_signed) = match fun {
1457 Mf::Pack4xI8 => (crate::ScalarKind::Sint, true),
1458 Mf::Pack4xU8 => (crate::ScalarKind::Uint, false),
1459 _ => unreachable!(),
1460 };
1461 let uint_type_id =
1462 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
1463
1464 let int_type_id =
1465 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
1466 kind: int_type,
1467 width: 4,
1468 }));
1469
1470 let mut last_instruction = Instruction::new(spirv::Op::Nop);
1471
1472 let zero = self.writer.get_constant_scalar(crate::Literal::U32(0));
1473 let mut preresult = zero;
1474 block
1475 .body
1476 .reserve(usize::from(VEC_LENGTH) * (2 + usize::from(is_signed)));
1477
1478 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1479 const VEC_LENGTH: u8 = 4;
1480 for i in 0..u32::from(VEC_LENGTH) {
1481 let offset =
1482 self.writer.get_constant_scalar(crate::Literal::U32(i * 8));
1483 let mut extracted = self.gen_id();
1484 block.body.push(Instruction::binary(
1485 spirv::Op::CompositeExtract,
1486 int_type_id,
1487 extracted,
1488 arg0_id,
1489 i,
1490 ));
1491 if is_signed {
1492 let casted = self.gen_id();
1493 block.body.push(Instruction::unary(
1494 spirv::Op::Bitcast,
1495 uint_type_id,
1496 casted,
1497 extracted,
1498 ));
1499 extracted = casted;
1500 }
1501 let is_last = i == u32::from(VEC_LENGTH - 1);
1502 if is_last {
1503 last_instruction = Instruction::quaternary(
1504 spirv::Op::BitFieldInsert,
1505 result_type_id,
1506 id,
1507 preresult,
1508 extracted,
1509 offset,
1510 eight,
1511 )
1512 } else {
1513 let new_preresult = self.gen_id();
1514 block.body.push(Instruction::quaternary(
1515 spirv::Op::BitFieldInsert,
1516 result_type_id,
1517 new_preresult,
1518 preresult,
1519 extracted,
1520 offset,
1521 eight,
1522 ));
1523 preresult = new_preresult;
1524 }
1525 }
1526
1527 MathOp::Custom(last_instruction)
1528 }
1529 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
1530 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
1531 Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
1532 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
1533 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
1534 fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
1535 let (int_type, extract_op, is_signed) = match fun {
1536 Mf::Unpack4xI8 => {
1537 (crate::ScalarKind::Sint, spirv::Op::BitFieldSExtract, true)
1538 }
1539 Mf::Unpack4xU8 => {
1540 (crate::ScalarKind::Uint, spirv::Op::BitFieldUExtract, false)
1541 }
1542 _ => unreachable!(),
1543 };
1544
1545 let sint_type_id =
1546 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::I32));
1547
1548 let eight = self.writer.get_constant_scalar(crate::Literal::U32(8));
1549 let int_type_id =
1550 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar {
1551 kind: int_type,
1552 width: 4,
1553 }));
1554 block
1555 .body
1556 .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed));
1557 let arg_id = if is_signed {
1558 let new_arg_id = self.gen_id();
1559 block.body.push(Instruction::unary(
1560 spirv::Op::Bitcast,
1561 sint_type_id,
1562 new_arg_id,
1563 arg0_id,
1564 ));
1565 new_arg_id
1566 } else {
1567 arg0_id
1568 };
1569
1570 const VEC_LENGTH: u8 = 4;
1571 let parts: [_; VEC_LENGTH as usize] =
1572 core::array::from_fn(|_| self.gen_id());
1573 for (i, part_id) in parts.into_iter().enumerate() {
1574 let index = self
1575 .writer
1576 .get_constant_scalar(crate::Literal::U32(i as u32 * 8));
1577 block.body.push(Instruction::ternary(
1578 extract_op,
1579 int_type_id,
1580 part_id,
1581 arg_id,
1582 index,
1583 eight,
1584 ));
1585 }
1586
1587 MathOp::Custom(Instruction::composite_construct(result_type_id, id, &parts))
1588 }
1589 };
1590
1591 block.body.push(match math_op {
1592 MathOp::Ext(op) => Instruction::ext_inst(
1593 self.writer.gl450_ext_inst_id,
1594 op,
1595 result_type_id,
1596 id,
1597 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
1598 ),
1599 MathOp::Custom(inst) => inst,
1600 });
1601 id
1602 }
1603 crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
1604 crate::Expression::Load { pointer } => {
1605 self.write_checked_load(pointer, block, AccessTypeAdjustment::None, result_type_id)?
1606 }
1607 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
1608 crate::Expression::CallResult(_)
1609 | crate::Expression::AtomicResult { .. }
1610 | crate::Expression::WorkGroupUniformLoadResult { .. }
1611 | crate::Expression::RayQueryProceedResult
1612 | crate::Expression::SubgroupBallotResult
1613 | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
1614 crate::Expression::As {
1615 expr,
1616 kind,
1617 convert,
1618 } => self.write_as_expression(expr, convert, kind, block, result_type_id)?,
1619 crate::Expression::ImageLoad {
1620 image,
1621 coordinate,
1622 array_index,
1623 sample,
1624 level,
1625 } => self.write_image_load(
1626 result_type_id,
1627 image,
1628 coordinate,
1629 array_index,
1630 level,
1631 sample,
1632 block,
1633 )?,
1634 crate::Expression::ImageSample {
1635 image,
1636 sampler,
1637 gather,
1638 coordinate,
1639 array_index,
1640 offset,
1641 level,
1642 depth_ref,
1643 } => self.write_image_sample(
1644 result_type_id,
1645 image,
1646 sampler,
1647 gather,
1648 coordinate,
1649 array_index,
1650 offset,
1651 level,
1652 depth_ref,
1653 block,
1654 )?,
1655 crate::Expression::Select {
1656 condition,
1657 accept,
1658 reject,
1659 } => {
1660 let id = self.gen_id();
1661 let mut condition_id = self.cached[condition];
1662 let accept_id = self.cached[accept];
1663 let reject_id = self.cached[reject];
1664
1665 let condition_ty = self.fun_info[condition]
1666 .ty
1667 .inner_with(&self.ir_module.types);
1668 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
1669
1670 if let (
1671 &crate::TypeInner::Scalar(
1672 condition_scalar @ crate::Scalar {
1673 kind: crate::ScalarKind::Bool,
1674 ..
1675 },
1676 ),
1677 &crate::TypeInner::Vector { size, .. },
1678 ) = (condition_ty, object_ty)
1679 {
1680 self.temp_list.clear();
1681 self.temp_list.resize(size as usize, condition_id);
1682
1683 let bool_vector_type_id = self.get_numeric_type_id(NumericType::Vector {
1684 size,
1685 scalar: condition_scalar,
1686 });
1687
1688 let id = self.gen_id();
1689 block.body.push(Instruction::composite_construct(
1690 bool_vector_type_id,
1691 id,
1692 &self.temp_list,
1693 ));
1694 condition_id = id
1695 }
1696
1697 let instruction =
1698 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
1699 block.body.push(instruction);
1700 id
1701 }
1702 crate::Expression::Derivative { axis, ctrl, expr } => {
1703 use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
1704 match ctrl {
1705 Ctrl::Coarse | Ctrl::Fine => {
1706 self.writer.require_any(
1707 "DerivativeControl",
1708 &[spirv::Capability::DerivativeControl],
1709 )?;
1710 }
1711 Ctrl::None => {}
1712 }
1713 let id = self.gen_id();
1714 let expr_id = self.cached[expr];
1715 let op = match (axis, ctrl) {
1716 (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
1717 (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
1718 (Axis::X, Ctrl::None) => spirv::Op::DPdx,
1719 (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
1720 (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
1721 (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
1722 (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
1723 (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
1724 (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
1725 };
1726 block
1727 .body
1728 .push(Instruction::derivative(op, result_type_id, id, expr_id));
1729 id
1730 }
1731 crate::Expression::ImageQuery { image, query } => {
1732 self.write_image_query(result_type_id, image, query, block)?
1733 }
1734 crate::Expression::Relational { fun, argument } => {
1735 use crate::RelationalFunction as Rf;
1736 let arg_id = self.cached[argument];
1737 let op = match fun {
1738 Rf::All => spirv::Op::All,
1739 Rf::Any => spirv::Op::Any,
1740 Rf::IsNan => spirv::Op::IsNan,
1741 Rf::IsInf => spirv::Op::IsInf,
1742 };
1743 let id = self.gen_id();
1744 block
1745 .body
1746 .push(Instruction::relational(op, result_type_id, id, arg_id));
1747 id
1748 }
1749 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
1750 crate::Expression::RayQueryGetIntersection { query, committed } => {
1751 let query_id = self.cached[query];
1752 let func_id = self
1753 .writer
1754 .write_ray_query_get_intersection_function(committed, self.ir_module);
1755 let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
1756 let intersection_type_id = self.get_handle_type_id(ray_intersection);
1757 let id = self.gen_id();
1758 block.body.push(Instruction::function_call(
1759 intersection_type_id,
1760 id,
1761 func_id,
1762 &[query_id],
1763 ));
1764 id
1765 }
1766 crate::Expression::RayQueryVertexPositions { query, committed } => {
1767 self.writer.require_any(
1768 "RayQueryVertexPositions",
1769 &[spirv::Capability::RayQueryPositionFetchKHR],
1770 )?;
1771 self.write_ray_query_return_vertex_position(query, block, committed)
1772 }
1773 };
1774
1775 self.cached[expr_handle] = id;
1776 Ok(())
1777 }
1778
1779 fn write_as_expression(
1782 &mut self,
1783 expr: Handle<crate::Expression>,
1784 convert: Option<u8>,
1785 kind: crate::ScalarKind,
1786
1787 block: &mut Block,
1788 result_type_id: u32,
1789 ) -> Result<u32, Error> {
1790 use crate::ScalarKind as Sk;
1791 let expr_id = self.cached[expr];
1792 let ty = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
1793
1794 if let crate::TypeInner::Matrix {
1799 columns,
1800 rows,
1801 scalar,
1802 } = *ty
1803 {
1804 let Some(convert) = convert else {
1805 return Ok(expr_id);
1807 };
1808
1809 if convert == scalar.width {
1810 return Ok(expr_id);
1812 }
1813
1814 if kind != Sk::Float {
1815 return Err(Error::Validation("Matrices must be floats"));
1817 }
1818
1819 let column_src_ty =
1821 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
1822 size: rows,
1823 scalar,
1824 })));
1825
1826 let column_dst_ty =
1828 self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
1829 size: rows,
1830 scalar: crate::Scalar {
1831 kind,
1832 width: convert,
1833 },
1834 })));
1835
1836 let mut components = ArrayVec::<Word, 4>::new();
1837
1838 for column in 0..columns as usize {
1839 let column_id = self.gen_id();
1840 block.body.push(Instruction::composite_extract(
1841 column_src_ty,
1842 column_id,
1843 expr_id,
1844 &[column as u32],
1845 ));
1846
1847 let column_conv_id = self.gen_id();
1848 block.body.push(Instruction::unary(
1849 spirv::Op::FConvert,
1850 column_dst_ty,
1851 column_conv_id,
1852 column_id,
1853 ));
1854
1855 components.push(column_conv_id);
1856 }
1857
1858 let construct_id = self.gen_id();
1859
1860 block.body.push(Instruction::composite_construct(
1861 result_type_id,
1862 construct_id,
1863 &components,
1864 ));
1865
1866 return Ok(construct_id);
1867 }
1868
1869 let (src_scalar, src_size) = match *ty {
1870 crate::TypeInner::Scalar(scalar) => (scalar, None),
1871 crate::TypeInner::Vector { scalar, size } => (scalar, Some(size)),
1872 ref other => {
1873 log::error!("As source {:?}", other);
1874 return Err(Error::Validation("Unexpected Expression::As source"));
1875 }
1876 };
1877
1878 enum Cast {
1879 Identity(Word),
1880 Unary(spirv::Op, Word),
1881 Binary(spirv::Op, Word, Word),
1882 Ternary(spirv::Op, Word, Word, Word),
1883 }
1884 let cast = match (src_scalar.kind, kind, convert) {
1885 (src_kind, kind, convert)
1888 if src_kind == kind
1889 && convert.filter(|&width| width != src_scalar.width).is_none() =>
1890 {
1891 Cast::Identity(expr_id)
1892 }
1893 (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject, expr_id),
1894 (_, _, None) => Cast::Unary(spirv::Op::Bitcast, expr_id),
1895 (_, Sk::Bool, Some(_)) => {
1897 let op = match src_scalar.kind {
1898 Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
1899 Sk::Float => spirv::Op::FUnordNotEqual,
1900 Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
1901 };
1902 let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?;
1903 let zero_id = match src_size {
1904 Some(size) => {
1905 let ty = LocalType::Numeric(NumericType::Vector {
1906 size,
1907 scalar: src_scalar,
1908 })
1909 .into();
1910
1911 self.temp_list.clear();
1912 self.temp_list.resize(size as _, zero_scalar_id);
1913
1914 self.writer.get_constant_composite(ty, &self.temp_list)
1915 }
1916 None => zero_scalar_id,
1917 };
1918
1919 Cast::Binary(op, expr_id, zero_id)
1920 }
1921 (Sk::Bool, _, Some(dst_width)) => {
1923 let dst_scalar = crate::Scalar {
1924 kind,
1925 width: dst_width,
1926 };
1927 let zero_scalar_id = self.writer.get_constant_scalar_with(0, dst_scalar)?;
1928 let one_scalar_id = self.writer.get_constant_scalar_with(1, dst_scalar)?;
1929 let (accept_id, reject_id) = match src_size {
1930 Some(size) => {
1931 let ty = LocalType::Numeric(NumericType::Vector {
1932 size,
1933 scalar: dst_scalar,
1934 })
1935 .into();
1936
1937 self.temp_list.clear();
1938 self.temp_list.resize(size as _, zero_scalar_id);
1939
1940 let vec0_id = self.writer.get_constant_composite(ty, &self.temp_list);
1941
1942 self.temp_list.fill(one_scalar_id);
1943
1944 let vec1_id = self.writer.get_constant_composite(ty, &self.temp_list);
1945
1946 (vec1_id, vec0_id)
1947 }
1948 None => (one_scalar_id, zero_scalar_id),
1949 };
1950
1951 Cast::Ternary(spirv::Op::Select, expr_id, accept_id, reject_id)
1952 }
1953 (Sk::Float, Sk::Sint | Sk::Uint, Some(width)) => {
1964 let dst_scalar = crate::Scalar { kind, width };
1965 let (min, max) =
1966 crate::proc::min_max_float_representable_by(src_scalar, dst_scalar);
1967 let expr_type_id = self.get_expression_type_id(&self.fun_info[expr].ty);
1968
1969 let maybe_splat_const = |writer: &mut Writer, const_id| match src_size {
1970 None => const_id,
1971 Some(size) => {
1972 let constituent_ids = [const_id; crate::VectorSize::MAX];
1973 writer.get_constant_composite(
1974 LookupType::Local(LocalType::Numeric(NumericType::Vector {
1975 size,
1976 scalar: src_scalar,
1977 })),
1978 &constituent_ids[..size as usize],
1979 )
1980 }
1981 };
1982 let min_const_id = self.writer.get_constant_scalar(min);
1983 let min_const_id = maybe_splat_const(self.writer, min_const_id);
1984 let max_const_id = self.writer.get_constant_scalar(max);
1985 let max_const_id = maybe_splat_const(self.writer, max_const_id);
1986
1987 let clamp_id = self.gen_id();
1988 block.body.push(Instruction::ext_inst(
1989 self.writer.gl450_ext_inst_id,
1990 spirv::GLOp::FClamp,
1991 expr_type_id,
1992 clamp_id,
1993 &[expr_id, min_const_id, max_const_id],
1994 ));
1995
1996 let op = match dst_scalar.kind {
1997 crate::ScalarKind::Sint => spirv::Op::ConvertFToS,
1998 crate::ScalarKind::Uint => spirv::Op::ConvertFToU,
1999 _ => unreachable!(),
2000 };
2001 Cast::Unary(op, clamp_id)
2002 }
2003 (Sk::Float, Sk::Float, Some(dst_width)) if src_scalar.width != dst_width => {
2004 Cast::Unary(spirv::Op::FConvert, expr_id)
2005 }
2006 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF, expr_id),
2007 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2008 Cast::Unary(spirv::Op::SConvert, expr_id)
2009 }
2010 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF, expr_id),
2011 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2012 Cast::Unary(spirv::Op::UConvert, expr_id)
2013 }
2014 (Sk::Uint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
2015 Cast::Unary(spirv::Op::SConvert, expr_id)
2016 }
2017 (Sk::Sint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
2018 Cast::Unary(spirv::Op::UConvert, expr_id)
2019 }
2020 _ => Cast::Unary(spirv::Op::Bitcast, expr_id),
2022 };
2023 Ok(match cast {
2024 Cast::Identity(expr) => expr,
2025 Cast::Unary(op, op1) => {
2026 let id = self.gen_id();
2027 block
2028 .body
2029 .push(Instruction::unary(op, result_type_id, id, op1));
2030 id
2031 }
2032 Cast::Binary(op, op1, op2) => {
2033 let id = self.gen_id();
2034 block
2035 .body
2036 .push(Instruction::binary(op, result_type_id, id, op1, op2));
2037 id
2038 }
2039 Cast::Ternary(op, op1, op2, op3) => {
2040 let id = self.gen_id();
2041 block
2042 .body
2043 .push(Instruction::ternary(op, result_type_id, id, op1, op2, op3));
2044 id
2045 }
2046 })
2047 }
2048
2049 fn write_access_chain(
2060 &mut self,
2061 mut expr_handle: Handle<crate::Expression>,
2062 block: &mut Block,
2063 type_adjustment: AccessTypeAdjustment,
2064 ) -> Result<ExpressionPointer, Error> {
2065 let result_type_id = {
2066 let resolution = &self.fun_info[expr_handle].ty;
2067 match type_adjustment {
2068 AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
2069 AccessTypeAdjustment::IntroducePointer(class) => {
2070 self.writer.get_resolution_pointer_id(resolution, class)
2071 }
2072 }
2073 };
2074
2075 let mut accumulated_checks = None;
2079
2080 let mut is_non_uniform_binding_array = false;
2082
2083 self.temp_list.clear();
2084 let root_id = loop {
2085 if let Some(spilled) = self.function.spilled_composites.get(&expr_handle) {
2088 break spilled.id;
2091 }
2092
2093 expr_handle = match self.ir_function.expressions[expr_handle] {
2094 crate::Expression::Access { base, index } => {
2095 is_non_uniform_binding_array |=
2096 self.is_nonuniform_binding_array_access(base, index);
2097
2098 let index = GuardedIndex::Expression(index);
2099 let index_id =
2100 self.write_access_chain_index(base, index, &mut accumulated_checks, block)?;
2101 self.temp_list.push(index_id);
2102
2103 base
2104 }
2105 crate::Expression::AccessIndex { base, index } => {
2106 let mut base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
2109 if let crate::TypeInner::Pointer { base, .. } = *base_ty {
2110 base_ty = &self.ir_module.types[base].inner;
2111 }
2112 let index_id = if let crate::TypeInner::Struct { .. } = *base_ty {
2113 self.get_index_constant(index)
2114 } else {
2115 self.write_access_chain_index(
2122 base,
2123 GuardedIndex::Known(index),
2124 &mut accumulated_checks,
2125 block,
2126 )?
2127 };
2128
2129 self.temp_list.push(index_id);
2130 base
2131 }
2132 crate::Expression::GlobalVariable(handle) => {
2133 let gv = &self.writer.global_variables[handle];
2134 break gv.access_id;
2135 }
2136 crate::Expression::LocalVariable(variable) => {
2137 let local_var = &self.function.variables[&variable];
2138 break local_var.id;
2139 }
2140 crate::Expression::FunctionArgument(index) => {
2141 break self.function.parameter_id(index);
2142 }
2143 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
2144 }
2145 };
2146
2147 let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
2148 (
2149 root_id,
2150 ExpressionPointer::Ready {
2151 pointer_id: root_id,
2152 },
2153 )
2154 } else {
2155 self.temp_list.reverse();
2156 let pointer_id = self.gen_id();
2157 let access =
2158 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
2159
2160 let expr_pointer = match accumulated_checks {
2165 Some(condition) => ExpressionPointer::Conditional { condition, access },
2166 None => {
2167 block.body.push(access);
2168 ExpressionPointer::Ready { pointer_id }
2169 }
2170 };
2171 (pointer_id, expr_pointer)
2172 };
2173 if is_non_uniform_binding_array {
2177 self.writer
2178 .decorate_non_uniform_binding_array_access(pointer_id)?;
2179 }
2180
2181 Ok(expr_pointer)
2182 }
2183
2184 fn is_nonuniform_binding_array_access(
2185 &mut self,
2186 base: Handle<crate::Expression>,
2187 index: Handle<crate::Expression>,
2188 ) -> bool {
2189 let crate::Expression::GlobalVariable(var_handle) = self.ir_function.expressions[base]
2190 else {
2191 return false;
2192 };
2193
2194 let gvar = &self.ir_module.global_variables[var_handle];
2197 let crate::TypeInner::BindingArray { .. } = self.ir_module.types[gvar.ty].inner else {
2198 return false;
2199 };
2200
2201 self.fun_info[index].uniformity.non_uniform_result.is_some()
2202 }
2203
2204 fn write_access_chain_index(
2214 &mut self,
2215 base: Handle<crate::Expression>,
2216 index: GuardedIndex,
2217 accumulated_checks: &mut Option<Word>,
2218 block: &mut Block,
2219 ) -> Result<Word, Error> {
2220 match self.write_bounds_check(base, index, block)? {
2221 BoundsCheckResult::KnownInBounds(known_index) => {
2222 let scalar = crate::Literal::U32(known_index);
2225 Ok(self.writer.get_constant_scalar(scalar))
2226 }
2227 BoundsCheckResult::Computed(computed_index_id) => Ok(computed_index_id),
2228 BoundsCheckResult::Conditional {
2229 condition_id: condition,
2230 index_id: index,
2231 } => {
2232 self.extend_bounds_check_condition_chain(accumulated_checks, condition, block);
2233
2234 Ok(index)
2236 }
2237 }
2238 }
2239
2240 fn extend_bounds_check_condition_chain(
2259 &mut self,
2260 chain: &mut Option<Word>,
2261 comparison_id: Word,
2262 block: &mut Block,
2263 ) {
2264 match *chain {
2265 Some(ref mut prior_checks) => {
2266 let combined = self.gen_id();
2267 block.body.push(Instruction::binary(
2268 spirv::Op::LogicalAnd,
2269 self.writer.get_bool_type_id(),
2270 combined,
2271 *prior_checks,
2272 comparison_id,
2273 ));
2274 *prior_checks = combined;
2275 }
2276 None => {
2277 *chain = Some(comparison_id);
2279 }
2280 }
2281 }
2282
2283 fn write_checked_load(
2284 &mut self,
2285 pointer: Handle<crate::Expression>,
2286 block: &mut Block,
2287 access_type_adjustment: AccessTypeAdjustment,
2288 result_type_id: Word,
2289 ) -> Result<Word, Error> {
2290 match self.write_access_chain(pointer, block, access_type_adjustment)? {
2291 ExpressionPointer::Ready { pointer_id } => {
2292 let id = self.gen_id();
2293 let atomic_space =
2294 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
2295 crate::TypeInner::Pointer { base, space } => {
2296 match self.ir_module.types[base].inner {
2297 crate::TypeInner::Atomic { .. } => Some(space),
2298 _ => None,
2299 }
2300 }
2301 _ => None,
2302 };
2303 let instruction = if let Some(space) = atomic_space {
2304 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2305 let scope_constant_id = self.get_scope_constant(scope as u32);
2306 let semantics_id = self.get_index_constant(semantics.bits());
2307 Instruction::atomic_load(
2308 result_type_id,
2309 id,
2310 pointer_id,
2311 scope_constant_id,
2312 semantics_id,
2313 )
2314 } else {
2315 Instruction::load(result_type_id, id, pointer_id, None)
2316 };
2317 block.body.push(instruction);
2318 Ok(id)
2319 }
2320 ExpressionPointer::Conditional { condition, access } => {
2321 let value = self.write_conditional_indexed_load(
2323 result_type_id,
2324 condition,
2325 block,
2326 move |id_gen, block| {
2327 let pointer_id = access.result_id.unwrap();
2329 let value_id = id_gen.next();
2330 block.body.push(access);
2331 block.body.push(Instruction::load(
2332 result_type_id,
2333 value_id,
2334 pointer_id,
2335 None,
2336 ));
2337 value_id
2338 },
2339 );
2340 Ok(value)
2341 }
2342 }
2343 }
2344
2345 fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
2346 use indexmap::map::Entry;
2347
2348 let spill_variable_id = match self.function.spilled_composites.entry(base) {
2350 Entry::Occupied(preexisting) => preexisting.get().id,
2351 Entry::Vacant(vacant) => {
2352 let pointer_type_id = self.writer.get_resolution_pointer_id(
2355 &self.fun_info[base].ty,
2356 spirv::StorageClass::Function,
2357 );
2358 let id = self.writer.id_gen.next();
2359 vacant.insert(super::LocalVariable {
2360 id,
2361 instruction: Instruction::variable(
2362 pointer_type_id,
2363 id,
2364 spirv::StorageClass::Function,
2365 None,
2366 ),
2367 });
2368 id
2369 }
2370 };
2371
2372 let base_id = self.cached[base];
2397 block
2398 .body
2399 .push(Instruction::store(spill_variable_id, base_id, None));
2400 }
2401
2402 fn maybe_access_spilled_composite(
2419 &mut self,
2420 access: Handle<crate::Expression>,
2421 block: &mut Block,
2422 result_type_id: Word,
2423 ) -> Result<Word, Error> {
2424 let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
2425 if access_uses == self.fun_info[access].ref_count {
2426 Ok(0)
2430 } else {
2431 self.write_checked_load(
2436 access,
2437 block,
2438 AccessTypeAdjustment::IntroducePointer(spirv::StorageClass::Function),
2439 result_type_id,
2440 )
2441 }
2442 }
2443
2444 #[allow(clippy::too_many_arguments)]
2446 fn write_matrix_matrix_column_op(
2447 &mut self,
2448 block: &mut Block,
2449 result_id: Word,
2450 result_type_id: Word,
2451 left_id: Word,
2452 right_id: Word,
2453 columns: crate::VectorSize,
2454 rows: crate::VectorSize,
2455 width: u8,
2456 op: spirv::Op,
2457 ) {
2458 self.temp_list.clear();
2459
2460 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
2461 size: rows,
2462 scalar: crate::Scalar::float(width),
2463 });
2464
2465 for index in 0..columns as u32 {
2466 let column_id_left = self.gen_id();
2467 let column_id_right = self.gen_id();
2468 let column_id_res = self.gen_id();
2469
2470 block.body.push(Instruction::composite_extract(
2471 vector_type_id,
2472 column_id_left,
2473 left_id,
2474 &[index],
2475 ));
2476 block.body.push(Instruction::composite_extract(
2477 vector_type_id,
2478 column_id_right,
2479 right_id,
2480 &[index],
2481 ));
2482 block.body.push(Instruction::binary(
2483 op,
2484 vector_type_id,
2485 column_id_res,
2486 column_id_left,
2487 column_id_right,
2488 ));
2489
2490 self.temp_list.push(column_id_res);
2491 }
2492
2493 block.body.push(Instruction::composite_construct(
2494 result_type_id,
2495 result_id,
2496 &self.temp_list,
2497 ));
2498 }
2499
2500 fn write_vector_scalar_mult(
2502 &mut self,
2503 block: &mut Block,
2504 result_id: Word,
2505 result_type_id: Word,
2506 vector_id: Word,
2507 scalar_id: Word,
2508 vector: &crate::TypeInner,
2509 ) {
2510 let (size, kind) = match *vector {
2511 crate::TypeInner::Vector {
2512 size,
2513 scalar: crate::Scalar { kind, .. },
2514 } => (size, kind),
2515 _ => unreachable!(),
2516 };
2517
2518 let (op, operand_id) = match kind {
2519 crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
2520 _ => {
2521 let operand_id = self.gen_id();
2522 self.temp_list.clear();
2523 self.temp_list.resize(size as usize, scalar_id);
2524 block.body.push(Instruction::composite_construct(
2525 result_type_id,
2526 operand_id,
2527 &self.temp_list,
2528 ));
2529 (spirv::Op::IMul, operand_id)
2530 }
2531 };
2532
2533 block.body.push(Instruction::binary(
2534 op,
2535 result_type_id,
2536 result_id,
2537 vector_id,
2538 operand_id,
2539 ));
2540 }
2541
2542 fn write_dot_product(
2544 &mut self,
2545 result_id: Word,
2546 result_type_id: Word,
2547 arg0_id: Word,
2548 arg1_id: Word,
2549 size: u32,
2550 block: &mut Block,
2551 ) {
2552 let mut partial_sum = self.writer.get_constant_null(result_type_id);
2553 let last_component = size - 1;
2554 for index in 0..=last_component {
2555 let a_id = self.gen_id();
2557 block.body.push(Instruction::composite_extract(
2558 result_type_id,
2559 a_id,
2560 arg0_id,
2561 &[index],
2562 ));
2563 let b_id = self.gen_id();
2564 block.body.push(Instruction::composite_extract(
2565 result_type_id,
2566 b_id,
2567 arg1_id,
2568 &[index],
2569 ));
2570 let prod_id = self.gen_id();
2571 block.body.push(Instruction::binary(
2572 spirv::Op::IMul,
2573 result_type_id,
2574 prod_id,
2575 a_id,
2576 b_id,
2577 ));
2578
2579 let id = if index == last_component {
2581 result_id
2582 } else {
2583 self.gen_id()
2584 };
2585
2586 block.body.push(Instruction::binary(
2588 spirv::Op::IAdd,
2589 result_type_id,
2590 id,
2591 partial_sum,
2592 prod_id,
2593 ));
2594 partial_sum = id;
2596 }
2597 }
2598
2599 fn write_block(
2616 &mut self,
2617 label_id: Word,
2618 naga_block: &crate::Block,
2619 exit: BlockExit,
2620 loop_context: LoopContext,
2621 debug_info: Option<&DebugInfoInner>,
2622 ) -> Result<BlockExitDisposition, Error> {
2623 let mut block = Block::new(label_id);
2624 for (statement, span) in naga_block.span_iter() {
2625 if let (Some(debug_info), false) = (
2626 debug_info,
2627 matches!(
2628 statement,
2629 &(Statement::Block(..)
2630 | Statement::Break
2631 | Statement::Continue
2632 | Statement::Kill
2633 | Statement::Return { .. }
2634 | Statement::Loop { .. })
2635 ),
2636 ) {
2637 let loc: crate::SourceLocation = span.location(debug_info.source_code);
2638 block.body.push(Instruction::line(
2639 debug_info.source_file_id,
2640 loc.line_number,
2641 loc.line_position,
2642 ));
2643 };
2644 match *statement {
2645 Statement::Emit(ref range) => {
2646 for handle in range.clone() {
2647 if !self.expression_constness.is_const(handle) {
2649 self.cache_expression_value(handle, &mut block)?;
2650 }
2651 }
2652 }
2653 Statement::Block(ref block_statements) => {
2654 let scope_id = self.gen_id();
2655 self.function.consume(block, Instruction::branch(scope_id));
2656
2657 let merge_id = self.gen_id();
2658 let merge_used = self.write_block(
2659 scope_id,
2660 block_statements,
2661 BlockExit::Branch { target: merge_id },
2662 loop_context,
2663 debug_info,
2664 )?;
2665
2666 match merge_used {
2667 BlockExitDisposition::Used => {
2668 block = Block::new(merge_id);
2669 }
2670 BlockExitDisposition::Discarded => {
2671 return Ok(BlockExitDisposition::Discarded);
2672 }
2673 }
2674 }
2675 Statement::If {
2676 condition,
2677 ref accept,
2678 ref reject,
2679 } => {
2680 let condition_id = self.cached[condition];
2681
2682 let merge_id = self.gen_id();
2683 block.body.push(Instruction::selection_merge(
2684 merge_id,
2685 spirv::SelectionControl::NONE,
2686 ));
2687
2688 let accept_id = if accept.is_empty() {
2689 None
2690 } else {
2691 Some(self.gen_id())
2692 };
2693 let reject_id = if reject.is_empty() {
2694 None
2695 } else {
2696 Some(self.gen_id())
2697 };
2698
2699 self.function.consume(
2700 block,
2701 Instruction::branch_conditional(
2702 condition_id,
2703 accept_id.unwrap_or(merge_id),
2704 reject_id.unwrap_or(merge_id),
2705 ),
2706 );
2707
2708 if let Some(block_id) = accept_id {
2709 let _ = self.write_block(
2714 block_id,
2715 accept,
2716 BlockExit::Branch { target: merge_id },
2717 loop_context,
2718 debug_info,
2719 )?;
2720 }
2721 if let Some(block_id) = reject_id {
2722 let _ = self.write_block(
2727 block_id,
2728 reject,
2729 BlockExit::Branch { target: merge_id },
2730 loop_context,
2731 debug_info,
2732 )?;
2733 }
2734
2735 block = Block::new(merge_id);
2736 }
2737 Statement::Switch {
2738 selector,
2739 ref cases,
2740 } => {
2741 let selector_id = self.cached[selector];
2742
2743 let merge_id = self.gen_id();
2744 block.body.push(Instruction::selection_merge(
2745 merge_id,
2746 spirv::SelectionControl::NONE,
2747 ));
2748
2749 let mut default_id = None;
2750 let mut last_id = None;
2752
2753 let mut raw_cases = Vec::with_capacity(cases.len());
2754 let mut case_ids = Vec::with_capacity(cases.len());
2755 for case in cases.iter() {
2756 let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
2758
2759 if case.fall_through && case.body.is_empty() {
2760 last_id = Some(label_id);
2761 }
2762
2763 case_ids.push(label_id);
2764
2765 match case.value {
2766 crate::SwitchValue::I32(value) => {
2767 raw_cases.push(super::instructions::Case {
2768 value: value as Word,
2769 label_id,
2770 });
2771 }
2772 crate::SwitchValue::U32(value) => {
2773 raw_cases.push(super::instructions::Case { value, label_id });
2774 }
2775 crate::SwitchValue::Default => {
2776 default_id = Some(label_id);
2777 }
2778 }
2779 }
2780
2781 let default_id = default_id.unwrap();
2782
2783 self.function.consume(
2784 block,
2785 Instruction::switch(selector_id, default_id, &raw_cases),
2786 );
2787
2788 let inner_context = LoopContext {
2789 break_id: Some(merge_id),
2790 ..loop_context
2791 };
2792
2793 for (i, (case, label_id)) in cases
2794 .iter()
2795 .zip(case_ids.iter())
2796 .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
2797 .enumerate()
2798 {
2799 let case_finish_id = if case.fall_through {
2800 case_ids[i + 1]
2801 } else {
2802 merge_id
2803 };
2804 let _ = self.write_block(
2813 *label_id,
2814 &case.body,
2815 BlockExit::Branch {
2816 target: case_finish_id,
2817 },
2818 inner_context,
2819 debug_info,
2820 )?;
2821 }
2822
2823 block = Block::new(merge_id);
2824 }
2825 Statement::Loop {
2826 ref body,
2827 ref continuing,
2828 break_if,
2829 } => {
2830 let preamble_id = self.gen_id();
2831 self.function
2832 .consume(block, Instruction::branch(preamble_id));
2833
2834 let merge_id = self.gen_id();
2835 let body_id = self.gen_id();
2836 let continuing_id = self.gen_id();
2837
2838 block = Block::new(preamble_id);
2841 if let Some(debug_info) = debug_info {
2844 let loc: crate::SourceLocation = span.location(debug_info.source_code);
2845 block.body.push(Instruction::line(
2846 debug_info.source_file_id,
2847 loc.line_number,
2848 loc.line_position,
2849 ))
2850 }
2851 block.body.push(Instruction::loop_merge(
2852 merge_id,
2853 continuing_id,
2854 spirv::SelectionControl::NONE,
2855 ));
2856
2857 if self.force_loop_bounding {
2858 block = self.write_force_bounded_loop_instructions(block, merge_id);
2859 }
2860 self.function.consume(block, Instruction::branch(body_id));
2861
2862 let _ = self.write_block(
2866 body_id,
2867 body,
2868 BlockExit::Branch {
2869 target: continuing_id,
2870 },
2871 LoopContext {
2872 continuing_id: Some(continuing_id),
2873 break_id: Some(merge_id),
2874 },
2875 debug_info,
2876 )?;
2877
2878 let exit = match break_if {
2879 Some(condition) => BlockExit::BreakIf {
2880 condition,
2881 preamble_id,
2882 },
2883 None => BlockExit::Branch {
2884 target: preamble_id,
2885 },
2886 };
2887
2888 let _ = self.write_block(
2892 continuing_id,
2893 continuing,
2894 exit,
2895 LoopContext {
2896 continuing_id: None,
2897 break_id: Some(merge_id),
2898 },
2899 debug_info,
2900 )?;
2901
2902 block = Block::new(merge_id);
2903 }
2904 Statement::Break => {
2905 self.function
2906 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
2907 return Ok(BlockExitDisposition::Discarded);
2908 }
2909 Statement::Continue => {
2910 self.function.consume(
2911 block,
2912 Instruction::branch(loop_context.continuing_id.unwrap()),
2913 );
2914 return Ok(BlockExitDisposition::Discarded);
2915 }
2916 Statement::Return { value: Some(value) } => {
2917 let value_id = self.cached[value];
2918 let instruction = match self.function.entry_point_context {
2919 Some(ref context) => {
2922 self.writer.write_entry_point_return(
2923 value_id,
2924 self.ir_function.result.as_ref().unwrap(),
2925 &context.results,
2926 &mut block.body,
2927 )?;
2928 Instruction::return_void()
2929 }
2930 None => Instruction::return_value(value_id),
2931 };
2932 self.function.consume(block, instruction);
2933 return Ok(BlockExitDisposition::Discarded);
2934 }
2935 Statement::Return { value: None } => {
2936 self.function.consume(block, Instruction::return_void());
2937 return Ok(BlockExitDisposition::Discarded);
2938 }
2939 Statement::Kill => {
2940 self.function.consume(block, Instruction::kill());
2941 return Ok(BlockExitDisposition::Discarded);
2942 }
2943 Statement::Barrier(flags) => {
2944 self.writer.write_barrier(flags, &mut block);
2945 }
2946 Statement::Store { pointer, value } => {
2947 let value_id = self.cached[value];
2948 match self.write_access_chain(
2949 pointer,
2950 &mut block,
2951 AccessTypeAdjustment::None,
2952 )? {
2953 ExpressionPointer::Ready { pointer_id } => {
2954 let atomic_space = match *self.fun_info[pointer]
2955 .ty
2956 .inner_with(&self.ir_module.types)
2957 {
2958 crate::TypeInner::Pointer { base, space } => {
2959 match self.ir_module.types[base].inner {
2960 crate::TypeInner::Atomic { .. } => Some(space),
2961 _ => None,
2962 }
2963 }
2964 _ => None,
2965 };
2966 let instruction = if let Some(space) = atomic_space {
2967 let (semantics, scope) = space.to_spirv_semantics_and_scope();
2968 let scope_constant_id = self.get_scope_constant(scope as u32);
2969 let semantics_id = self.get_index_constant(semantics.bits());
2970 Instruction::atomic_store(
2971 pointer_id,
2972 scope_constant_id,
2973 semantics_id,
2974 value_id,
2975 )
2976 } else {
2977 Instruction::store(pointer_id, value_id, None)
2978 };
2979 block.body.push(instruction);
2980 }
2981 ExpressionPointer::Conditional { condition, access } => {
2982 let mut selection = Selection::start(&mut block, ());
2983 selection.if_true(self, condition, ());
2984
2985 let pointer_id = access.result_id.unwrap();
2987 selection.block().body.push(access);
2988 selection
2989 .block()
2990 .body
2991 .push(Instruction::store(pointer_id, value_id, None));
2992
2993 selection.finish(self, ());
2996 }
2997 };
2998 }
2999 Statement::ImageStore {
3000 image,
3001 coordinate,
3002 array_index,
3003 value,
3004 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
3005 Statement::Call {
3006 function: local_function,
3007 ref arguments,
3008 result,
3009 } => {
3010 let id = self.gen_id();
3011 self.temp_list.clear();
3012 for &argument in arguments {
3013 self.temp_list.push(self.cached[argument]);
3014 }
3015
3016 let type_id = match result {
3017 Some(expr) => {
3018 self.cached[expr] = id;
3019 self.get_expression_type_id(&self.fun_info[expr].ty)
3020 }
3021 None => self.writer.void_type,
3022 };
3023
3024 block.body.push(Instruction::function_call(
3025 type_id,
3026 id,
3027 self.writer.lookup_function[&local_function],
3028 &self.temp_list,
3029 ));
3030 }
3031 Statement::Atomic {
3032 pointer,
3033 ref fun,
3034 value,
3035 result,
3036 } => {
3037 let id = self.gen_id();
3038 let result_type_id =
3042 self.get_expression_type_id(&self.fun_info[result.unwrap_or(value)].ty);
3043
3044 if let Some(result) = result {
3045 self.cached[result] = id;
3046 }
3047
3048 let pointer_id = match self.write_access_chain(
3049 pointer,
3050 &mut block,
3051 AccessTypeAdjustment::None,
3052 )? {
3053 ExpressionPointer::Ready { pointer_id } => pointer_id,
3054 ExpressionPointer::Conditional { .. } => {
3055 return Err(Error::FeatureNotImplemented(
3056 "Atomics out-of-bounds handling",
3057 ));
3058 }
3059 };
3060
3061 let space = self.fun_info[pointer]
3062 .ty
3063 .inner_with(&self.ir_module.types)
3064 .pointer_space()
3065 .unwrap();
3066 let (semantics, scope) = space.to_spirv_semantics_and_scope();
3067 let scope_constant_id = self.get_scope_constant(scope as u32);
3068 let semantics_id = self.get_index_constant(semantics.bits());
3069 let value_id = self.cached[value];
3070 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
3071
3072 let crate::TypeInner::Scalar(scalar) = *value_inner else {
3073 return Err(Error::FeatureNotImplemented(
3074 "Atomics with non-scalar values",
3075 ));
3076 };
3077
3078 let instruction = match *fun {
3079 crate::AtomicFunction::Add => {
3080 let spirv_op = match scalar.kind {
3081 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3082 spirv::Op::AtomicIAdd
3083 }
3084 crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
3085 _ => unimplemented!(),
3086 };
3087 Instruction::atomic_binary(
3088 spirv_op,
3089 result_type_id,
3090 id,
3091 pointer_id,
3092 scope_constant_id,
3093 semantics_id,
3094 value_id,
3095 )
3096 }
3097 crate::AtomicFunction::Subtract => {
3098 let (spirv_op, value_id) = match scalar.kind {
3099 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3100 (spirv::Op::AtomicISub, value_id)
3101 }
3102 crate::ScalarKind::Float => {
3103 let neg_result_id = self.gen_id();
3106 block.body.push(Instruction::unary(
3107 spirv::Op::FNegate,
3108 result_type_id,
3109 neg_result_id,
3110 value_id,
3111 ));
3112 (spirv::Op::AtomicFAddEXT, neg_result_id)
3113 }
3114 _ => unimplemented!(),
3115 };
3116 Instruction::atomic_binary(
3117 spirv_op,
3118 result_type_id,
3119 id,
3120 pointer_id,
3121 scope_constant_id,
3122 semantics_id,
3123 value_id,
3124 )
3125 }
3126 crate::AtomicFunction::And => {
3127 let spirv_op = match scalar.kind {
3128 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3129 spirv::Op::AtomicAnd
3130 }
3131 _ => unimplemented!(),
3132 };
3133 Instruction::atomic_binary(
3134 spirv_op,
3135 result_type_id,
3136 id,
3137 pointer_id,
3138 scope_constant_id,
3139 semantics_id,
3140 value_id,
3141 )
3142 }
3143 crate::AtomicFunction::InclusiveOr => {
3144 let spirv_op = match scalar.kind {
3145 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3146 spirv::Op::AtomicOr
3147 }
3148 _ => unimplemented!(),
3149 };
3150 Instruction::atomic_binary(
3151 spirv_op,
3152 result_type_id,
3153 id,
3154 pointer_id,
3155 scope_constant_id,
3156 semantics_id,
3157 value_id,
3158 )
3159 }
3160 crate::AtomicFunction::ExclusiveOr => {
3161 let spirv_op = match scalar.kind {
3162 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3163 spirv::Op::AtomicXor
3164 }
3165 _ => unimplemented!(),
3166 };
3167 Instruction::atomic_binary(
3168 spirv_op,
3169 result_type_id,
3170 id,
3171 pointer_id,
3172 scope_constant_id,
3173 semantics_id,
3174 value_id,
3175 )
3176 }
3177 crate::AtomicFunction::Min => {
3178 let spirv_op = match scalar.kind {
3179 crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
3180 crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
3181 _ => unimplemented!(),
3182 };
3183 Instruction::atomic_binary(
3184 spirv_op,
3185 result_type_id,
3186 id,
3187 pointer_id,
3188 scope_constant_id,
3189 semantics_id,
3190 value_id,
3191 )
3192 }
3193 crate::AtomicFunction::Max => {
3194 let spirv_op = match scalar.kind {
3195 crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
3196 crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
3197 _ => unimplemented!(),
3198 };
3199 Instruction::atomic_binary(
3200 spirv_op,
3201 result_type_id,
3202 id,
3203 pointer_id,
3204 scope_constant_id,
3205 semantics_id,
3206 value_id,
3207 )
3208 }
3209 crate::AtomicFunction::Exchange { compare: None } => {
3210 Instruction::atomic_binary(
3211 spirv::Op::AtomicExchange,
3212 result_type_id,
3213 id,
3214 pointer_id,
3215 scope_constant_id,
3216 semantics_id,
3217 value_id,
3218 )
3219 }
3220 crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
3221 let scalar_type_id =
3222 self.get_numeric_type_id(NumericType::Scalar(scalar));
3223 let bool_type_id =
3224 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL));
3225
3226 let cas_result_id = self.gen_id();
3227 let equality_result_id = self.gen_id();
3228 let equality_operator = match scalar.kind {
3229 crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3230 spirv::Op::IEqual
3231 }
3232 _ => unimplemented!(),
3233 };
3234 let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
3235 cas_instr.set_type(scalar_type_id);
3236 cas_instr.set_result(cas_result_id);
3237 cas_instr.add_operand(pointer_id);
3238 cas_instr.add_operand(scope_constant_id);
3239 cas_instr.add_operand(semantics_id); cas_instr.add_operand(semantics_id); cas_instr.add_operand(value_id);
3242 cas_instr.add_operand(self.cached[cmp]);
3243 block.body.push(cas_instr);
3244 block.body.push(Instruction::binary(
3245 equality_operator,
3246 bool_type_id,
3247 equality_result_id,
3248 cas_result_id,
3249 self.cached[cmp],
3250 ));
3251 Instruction::composite_construct(
3252 result_type_id,
3253 id,
3254 &[cas_result_id, equality_result_id],
3255 )
3256 }
3257 };
3258
3259 block.body.push(instruction);
3260 }
3261 Statement::ImageAtomic {
3262 image,
3263 coordinate,
3264 array_index,
3265 fun,
3266 value,
3267 } => {
3268 self.write_image_atomic(
3269 image,
3270 coordinate,
3271 array_index,
3272 fun,
3273 value,
3274 &mut block,
3275 )?;
3276 }
3277 Statement::WorkGroupUniformLoad { pointer, result } => {
3278 self.writer
3279 .write_barrier(crate::Barrier::WORK_GROUP, &mut block);
3280 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
3281 match self.write_access_chain(
3283 pointer,
3284 &mut block,
3285 AccessTypeAdjustment::None,
3286 )? {
3287 ExpressionPointer::Ready { pointer_id } => {
3288 let id = self.gen_id();
3289 block.body.push(Instruction::load(
3290 result_type_id,
3291 id,
3292 pointer_id,
3293 None,
3294 ));
3295 self.cached[result] = id;
3296 }
3297 ExpressionPointer::Conditional { condition, access } => {
3298 self.cached[result] = self.write_conditional_indexed_load(
3299 result_type_id,
3300 condition,
3301 &mut block,
3302 move |id_gen, block| {
3303 let pointer_id = access.result_id.unwrap();
3305 let value_id = id_gen.next();
3306 block.body.push(access);
3307 block.body.push(Instruction::load(
3308 result_type_id,
3309 value_id,
3310 pointer_id,
3311 None,
3312 ));
3313 value_id
3314 },
3315 )
3316 }
3317 }
3318 self.writer
3319 .write_barrier(crate::Barrier::WORK_GROUP, &mut block);
3320 }
3321 Statement::RayQuery { query, ref fun } => {
3322 self.write_ray_query_function(query, fun, &mut block);
3323 }
3324 Statement::SubgroupBallot {
3325 result,
3326 ref predicate,
3327 } => {
3328 self.write_subgroup_ballot(predicate, result, &mut block)?;
3329 }
3330 Statement::SubgroupCollectiveOperation {
3331 ref op,
3332 ref collective_op,
3333 argument,
3334 result,
3335 } => {
3336 self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
3337 }
3338 Statement::SubgroupGather {
3339 ref mode,
3340 argument,
3341 result,
3342 } => {
3343 self.write_subgroup_gather(mode, argument, result, &mut block)?;
3344 }
3345 }
3346 }
3347
3348 let termination = match exit {
3349 BlockExit::Return => match self.ir_function.result {
3352 Some(ref result) if self.function.entry_point_context.is_none() => {
3353 let type_id = self.get_handle_type_id(result.ty);
3354 let null_id = self.writer.get_constant_null(type_id);
3355 Instruction::return_value(null_id)
3356 }
3357 _ => Instruction::return_void(),
3358 },
3359 BlockExit::Branch { target } => Instruction::branch(target),
3360 BlockExit::BreakIf {
3361 condition,
3362 preamble_id,
3363 } => {
3364 let condition_id = self.cached[condition];
3365
3366 Instruction::branch_conditional(
3367 condition_id,
3368 loop_context.break_id.unwrap(),
3369 preamble_id,
3370 )
3371 }
3372 };
3373
3374 self.function.consume(block, termination);
3375 Ok(BlockExitDisposition::Used)
3376 }
3377
3378 pub(super) fn write_function_body(
3379 &mut self,
3380 entry_id: Word,
3381 debug_info: Option<&DebugInfoInner>,
3382 ) -> Result<(), Error> {
3383 let _ = self.write_block(
3386 entry_id,
3387 &self.ir_function.body,
3388 BlockExit::Return,
3389 LoopContext::default(),
3390 debug_info,
3391 )?;
3392
3393 Ok(())
3394 }
3395}