naga/back/spv/
writer.rs

1use alloc::{
2    string::{String, ToString},
3    vec,
4    vec::Vec,
5};
6
7use hashbrown::hash_map::Entry;
8use spirv::Word;
9
10use super::{
11    block::DebugInfoInner,
12    helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
13    Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error,
14    Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalImageType,
15    LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, NumericType, Options,
16    PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
17};
18use crate::{
19    arena::{Handle, HandleVec, UniqueArena},
20    back::spv::{BindingInfo, WrappedFunction},
21    proc::{Alignment, TypeResolution},
22    valid::{FunctionInfo, ModuleInfo},
23};
24
25struct FunctionInterface<'a> {
26    varying_ids: &'a mut Vec<Word>,
27    stage: crate::ShaderStage,
28}
29
30impl Function {
31    pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
32        self.signature.as_ref().unwrap().to_words(sink);
33        for argument in self.parameters.iter() {
34            argument.instruction.to_words(sink);
35        }
36        for (index, block) in self.blocks.iter().enumerate() {
37            Instruction::label(block.label_id).to_words(sink);
38            if index == 0 {
39                for local_var in self.variables.values() {
40                    local_var.instruction.to_words(sink);
41                }
42                for local_var in self.force_loop_bounding_vars.iter() {
43                    local_var.instruction.to_words(sink);
44                }
45                for internal_var in self.spilled_composites.values() {
46                    internal_var.instruction.to_words(sink);
47                }
48            }
49            for instruction in block.body.iter() {
50                instruction.to_words(sink);
51            }
52        }
53        Instruction::function_end().to_words(sink);
54    }
55}
56
57impl Writer {
58    pub fn new(options: &Options) -> Result<Self, Error> {
59        let (major, minor) = options.lang_version;
60        if major != 1 {
61            return Err(Error::UnsupportedVersion(major, minor));
62        }
63        let raw_version = ((major as u32) << 16) | ((minor as u32) << 8);
64
65        let mut capabilities_used = crate::FastIndexSet::default();
66        capabilities_used.insert(spirv::Capability::Shader);
67
68        let mut id_gen = IdGenerator::default();
69        let gl450_ext_inst_id = id_gen.next();
70        let void_type = id_gen.next();
71
72        Ok(Writer {
73            physical_layout: PhysicalLayout::new(raw_version),
74            logical_layout: LogicalLayout::default(),
75            id_gen,
76            capabilities_available: options.capabilities.clone(),
77            capabilities_used,
78            extensions_used: crate::FastIndexSet::default(),
79            debugs: vec![],
80            annotations: vec![],
81            flags: options.flags,
82            bounds_check_policies: options.bounds_check_policies,
83            zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
84            force_loop_bounding: options.force_loop_bounding,
85            void_type,
86            lookup_type: crate::FastHashMap::default(),
87            lookup_function: crate::FastHashMap::default(),
88            lookup_function_type: crate::FastHashMap::default(),
89            wrapped_functions: crate::FastHashMap::default(),
90            constant_ids: HandleVec::new(),
91            cached_constants: crate::FastHashMap::default(),
92            global_variables: HandleVec::new(),
93            binding_map: options.binding_map.clone(),
94            saved_cached: CachedExpressions::default(),
95            gl450_ext_inst_id,
96            temp_list: Vec::new(),
97            ray_get_intersection_function: None,
98        })
99    }
100
101    /// Reset `Writer` to its initial state, retaining any allocations.
102    ///
103    /// Why not just implement `Recyclable` for `Writer`? By design,
104    /// `Recyclable::recycle` requires ownership of the value, not just
105    /// `&mut`; see the trait documentation. But we need to use this method
106    /// from functions like `Writer::write`, which only have `&mut Writer`.
107    /// Workarounds include unsafe code (`core::ptr::read`, then `write`, ugh)
108    /// or something like a `Default` impl that returns an oddly-initialized
109    /// `Writer`, which is worse.
110    fn reset(&mut self) {
111        use super::recyclable::Recyclable;
112        use core::mem::take;
113
114        let mut id_gen = IdGenerator::default();
115        let gl450_ext_inst_id = id_gen.next();
116        let void_type = id_gen.next();
117
118        // Every field of the old writer that is not determined by the `Options`
119        // passed to `Writer::new` should be reset somehow.
120        let fresh = Writer {
121            // Copied from the old Writer:
122            flags: self.flags,
123            bounds_check_policies: self.bounds_check_policies,
124            zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
125            force_loop_bounding: self.force_loop_bounding,
126            capabilities_available: take(&mut self.capabilities_available),
127            binding_map: take(&mut self.binding_map),
128
129            // Initialized afresh:
130            id_gen,
131            void_type,
132            gl450_ext_inst_id,
133
134            // Recycled:
135            capabilities_used: take(&mut self.capabilities_used).recycle(),
136            extensions_used: take(&mut self.extensions_used).recycle(),
137            physical_layout: self.physical_layout.clone().recycle(),
138            logical_layout: take(&mut self.logical_layout).recycle(),
139            debugs: take(&mut self.debugs).recycle(),
140            annotations: take(&mut self.annotations).recycle(),
141            lookup_type: take(&mut self.lookup_type).recycle(),
142            lookup_function: take(&mut self.lookup_function).recycle(),
143            lookup_function_type: take(&mut self.lookup_function_type).recycle(),
144            wrapped_functions: take(&mut self.wrapped_functions).recycle(),
145            constant_ids: take(&mut self.constant_ids).recycle(),
146            cached_constants: take(&mut self.cached_constants).recycle(),
147            global_variables: take(&mut self.global_variables).recycle(),
148            saved_cached: take(&mut self.saved_cached).recycle(),
149            temp_list: take(&mut self.temp_list).recycle(),
150            ray_get_intersection_function: None,
151        };
152
153        *self = fresh;
154
155        self.capabilities_used.insert(spirv::Capability::Shader);
156    }
157
158    /// Indicate that the code requires any one of the listed capabilities.
159    ///
160    /// If nothing in `capabilities` appears in the available capabilities
161    /// specified in the [`Options`] from which this `Writer` was created,
162    /// return an error. The `what` string is used in the error message to
163    /// explain what provoked the requirement. (If no available capabilities were
164    /// given, assume everything is available.)
165    ///
166    /// The first acceptable capability will be added to this `Writer`'s
167    /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
168    /// result. For this reason, more specific capabilities should be listed
169    /// before more general.
170    ///
171    /// [`capabilities_used`]: Writer::capabilities_used
172    pub(super) fn require_any(
173        &mut self,
174        what: &'static str,
175        capabilities: &[spirv::Capability],
176    ) -> Result<(), Error> {
177        match *capabilities {
178            [] => Ok(()),
179            [first, ..] => {
180                // Find the first acceptable capability, or return an error if
181                // there is none.
182                let selected = match self.capabilities_available {
183                    None => first,
184                    Some(ref available) => {
185                        match capabilities
186                            .iter()
187                            // need explicit type for hashbrown::HashSet::contains fn call to keep rustc happy
188                            .find(|cap| available.contains::<spirv::Capability>(cap))
189                        {
190                            Some(&cap) => cap,
191                            None => {
192                                return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
193                            }
194                        }
195                    }
196                };
197                self.capabilities_used.insert(selected);
198                Ok(())
199            }
200        }
201    }
202
203    /// Indicate that the code uses the given extension.
204    pub(super) fn use_extension(&mut self, extension: &'static str) {
205        self.extensions_used.insert(extension);
206    }
207
208    pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
209        match self.lookup_type.entry(lookup_ty) {
210            Entry::Occupied(e) => *e.get(),
211            Entry::Vacant(e) => {
212                let local = match lookup_ty {
213                    LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
214                    LookupType::Local(local) => local,
215                };
216
217                let id = self.id_gen.next();
218                e.insert(id);
219                self.write_type_declaration_local(id, local);
220                id
221            }
222        }
223    }
224
225    pub(super) fn get_handle_type_id(&mut self, handle: Handle<crate::Type>) -> Word {
226        self.get_type_id(LookupType::Handle(handle))
227    }
228
229    pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
230        match *tr {
231            TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
232            TypeResolution::Value(ref inner) => {
233                let inner_local_type = self.localtype_from_inner(inner).unwrap();
234                LookupType::Local(inner_local_type)
235            }
236        }
237    }
238
239    pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
240        let lookup_ty = self.get_expression_lookup_type(tr);
241        self.get_type_id(lookup_ty)
242    }
243
244    pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
245        self.get_type_id(LookupType::Local(local))
246    }
247
248    pub(super) fn get_pointer_type_id(&mut self, base: Word, class: spirv::StorageClass) -> Word {
249        self.get_type_id(LookupType::Local(LocalType::Pointer { base, class }))
250    }
251
252    pub(super) fn get_handle_pointer_type_id(
253        &mut self,
254        base: Handle<crate::Type>,
255        class: spirv::StorageClass,
256    ) -> Word {
257        let base_id = self.get_handle_type_id(base);
258        self.get_pointer_type_id(base_id, class)
259    }
260
261    pub(super) fn get_ray_query_pointer_id(&mut self, module: &crate::Module) -> Word {
262        let rq_ty = module
263            .types
264            .get(&crate::Type {
265                name: None,
266                inner: crate::TypeInner::RayQuery {
267                    vertex_return: false,
268                },
269            })
270            .or_else(|| {
271                module.types.get(&crate::Type {
272                    name: None,
273                    inner: crate::TypeInner::RayQuery {
274                        vertex_return: true,
275                    },
276                })
277            })
278            .expect("ray_query type should have been populated by the variable passed into this!");
279        self.get_handle_pointer_type_id(rq_ty, spirv::StorageClass::Function)
280    }
281
282    /// Return a SPIR-V type for a pointer to `resolution`.
283    ///
284    /// The given `resolution` must be one that we can represent
285    /// either as a `LocalType::Pointer` or `LocalType::LocalPointer`.
286    pub(super) fn get_resolution_pointer_id(
287        &mut self,
288        resolution: &TypeResolution,
289        class: spirv::StorageClass,
290    ) -> Word {
291        let resolution_type_id = self.get_expression_type_id(resolution);
292        self.get_pointer_type_id(resolution_type_id, class)
293    }
294
295    pub(super) fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word {
296        self.get_type_id(LocalType::Numeric(numeric).into())
297    }
298
299    pub(super) fn get_u32_type_id(&mut self) -> Word {
300        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32))
301    }
302
303    pub(super) fn get_f32_type_id(&mut self) -> Word {
304        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32))
305    }
306
307    pub(super) fn get_vec2u_type_id(&mut self) -> Word {
308        self.get_numeric_type_id(NumericType::Vector {
309            size: crate::VectorSize::Bi,
310            scalar: crate::Scalar::U32,
311        })
312    }
313
314    pub(super) fn get_vec3u_type_id(&mut self) -> Word {
315        self.get_numeric_type_id(NumericType::Vector {
316            size: crate::VectorSize::Tri,
317            scalar: crate::Scalar::U32,
318        })
319    }
320
321    pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
322        let f32_id = self.get_f32_type_id();
323        self.get_pointer_type_id(f32_id, class)
324    }
325
326    pub(super) fn get_vec2u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
327        let vec2u_id = self.get_numeric_type_id(NumericType::Vector {
328            size: crate::VectorSize::Bi,
329            scalar: crate::Scalar::U32,
330        });
331        self.get_pointer_type_id(vec2u_id, class)
332    }
333
334    pub(super) fn get_vec3u_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
335        let vec3u_id = self.get_numeric_type_id(NumericType::Vector {
336            size: crate::VectorSize::Tri,
337            scalar: crate::Scalar::U32,
338        });
339        self.get_pointer_type_id(vec3u_id, class)
340    }
341
342    pub(super) fn get_bool_type_id(&mut self) -> Word {
343        self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::BOOL))
344    }
345
346    pub(super) fn get_vec2_bool_type_id(&mut self) -> Word {
347        self.get_numeric_type_id(NumericType::Vector {
348            size: crate::VectorSize::Bi,
349            scalar: crate::Scalar::BOOL,
350        })
351    }
352
353    pub(super) fn get_vec3_bool_type_id(&mut self) -> Word {
354        self.get_numeric_type_id(NumericType::Vector {
355            size: crate::VectorSize::Tri,
356            scalar: crate::Scalar::BOOL,
357        })
358    }
359
360    pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
361        self.annotations
362            .push(Instruction::decorate(id, decoration, operands));
363    }
364
365    /// Return `inner` as a `LocalType`, if that's possible.
366    ///
367    /// If `inner` can be represented as a `LocalType`, return
368    /// `Some(local_type)`.
369    ///
370    /// Otherwise, return `None`. In this case, the type must always be looked
371    /// up using a `LookupType::Handle`.
372    fn localtype_from_inner(&mut self, inner: &crate::TypeInner) -> Option<LocalType> {
373        Some(match *inner {
374            crate::TypeInner::Scalar(_)
375            | crate::TypeInner::Atomic(_)
376            | crate::TypeInner::Vector { .. }
377            | crate::TypeInner::Matrix { .. } => {
378                // We expect `NumericType::from_inner` to handle all
379                // these cases, so unwrap.
380                LocalType::Numeric(NumericType::from_inner(inner).unwrap())
381            }
382            crate::TypeInner::Pointer { base, space } => {
383                let base_type_id = self.get_handle_type_id(base);
384                LocalType::Pointer {
385                    base: base_type_id,
386                    class: map_storage_class(space),
387                }
388            }
389            crate::TypeInner::ValuePointer {
390                size,
391                scalar,
392                space,
393            } => {
394                let base_numeric_type = match size {
395                    Some(size) => NumericType::Vector { size, scalar },
396                    None => NumericType::Scalar(scalar),
397                };
398                LocalType::Pointer {
399                    base: self.get_numeric_type_id(base_numeric_type),
400                    class: map_storage_class(space),
401                }
402            }
403            crate::TypeInner::Image {
404                dim,
405                arrayed,
406                class,
407            } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
408            crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
409            crate::TypeInner::AccelerationStructure { .. } => LocalType::AccelerationStructure,
410            crate::TypeInner::RayQuery { .. } => LocalType::RayQuery,
411            crate::TypeInner::Array { .. }
412            | crate::TypeInner::Struct { .. }
413            | crate::TypeInner::BindingArray { .. } => return None,
414        })
415    }
416
417    /// Emits code for any wrapper functions required by the expressions in ir_function.
418    /// The IDs of any emitted functions will be stored in [`Self::wrapped_functions`].
419    fn write_wrapped_functions(
420        &mut self,
421        ir_function: &crate::Function,
422        info: &FunctionInfo,
423        ir_module: &crate::Module,
424    ) -> Result<(), Error> {
425        log::trace!("Generating wrapped functions for {:?}", ir_function.name);
426
427        for (expr_handle, expr) in ir_function.expressions.iter() {
428            match *expr {
429                crate::Expression::Binary { op, left, right } => {
430                    let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
431                    if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
432                        match (op, expr_ty.scalar().kind) {
433                            // Division and modulo are undefined behaviour when the
434                            // dividend is the minimum representable value and the divisor
435                            // is negative one, or when the divisor is zero. These wrapped
436                            // functions override the divisor to one in these cases,
437                            // matching the WGSL spec.
438                            (
439                                crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
440                                crate::ScalarKind::Sint | crate::ScalarKind::Uint,
441                            ) => {
442                                self.write_wrapped_binary_op(
443                                    op,
444                                    expr_ty,
445                                    &info[left].ty,
446                                    &info[right].ty,
447                                )?;
448                            }
449                            _ => {}
450                        }
451                    }
452                }
453                _ => {}
454            }
455        }
456
457        Ok(())
458    }
459
460    /// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
461    ///
462    /// Define a function that performs an integer division or modulo operation,
463    /// except that using a divisor of zero or causing signed overflow with a
464    /// divisor of -1 returns the numerator unchanged, rather than exhibiting
465    /// undefined behavior.
466    ///
467    /// Store the generated function's id in the [`wrapped_functions`] table.
468    ///
469    /// The operator `op` must be either [`Divide`] or [`Modulo`].
470    ///
471    /// # Panics
472    ///
473    /// The `return_type`, `left_type` or `right_type` arguments must all be
474    /// integer scalars or vectors. If not, this function panics.
475    ///
476    /// [`wrapped_functions`]: Writer::wrapped_functions
477    /// [`Divide`]: crate::BinaryOperator::Divide
478    /// [`Modulo`]: crate::BinaryOperator::Modulo
479    fn write_wrapped_binary_op(
480        &mut self,
481        op: crate::BinaryOperator,
482        return_type: NumericType,
483        left_type: &TypeResolution,
484        right_type: &TypeResolution,
485    ) -> Result<(), Error> {
486        let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
487        let left_type_id = self.get_expression_type_id(left_type);
488        let right_type_id = self.get_expression_type_id(right_type);
489
490        // Check if we've already emitted this function.
491        let wrapped = WrappedFunction::BinaryOp {
492            op,
493            left_type_id,
494            right_type_id,
495        };
496        let function_id = match self.wrapped_functions.entry(wrapped) {
497            Entry::Occupied(_) => return Ok(()),
498            Entry::Vacant(e) => *e.insert(self.id_gen.next()),
499        };
500
501        let scalar = return_type.scalar();
502
503        if self.flags.contains(WriterFlags::DEBUG) {
504            let function_name = match op {
505                crate::BinaryOperator::Divide => "naga_div",
506                crate::BinaryOperator::Modulo => "naga_mod",
507                _ => unreachable!(),
508            };
509            self.debugs
510                .push(Instruction::name(function_id, function_name));
511        }
512        let mut function = Function::default();
513
514        let function_type_id = self.get_function_type(LookupFunctionType {
515            parameter_type_ids: vec![left_type_id, right_type_id],
516            return_type_id,
517        });
518        function.signature = Some(Instruction::function(
519            return_type_id,
520            function_id,
521            spirv::FunctionControl::empty(),
522            function_type_id,
523        ));
524
525        let lhs_id = self.id_gen.next();
526        let rhs_id = self.id_gen.next();
527        if self.flags.contains(WriterFlags::DEBUG) {
528            self.debugs.push(Instruction::name(lhs_id, "lhs"));
529            self.debugs.push(Instruction::name(rhs_id, "rhs"));
530        }
531        let left_par = Instruction::function_parameter(left_type_id, lhs_id);
532        let right_par = Instruction::function_parameter(right_type_id, rhs_id);
533        for instruction in [left_par, right_par] {
534            function.parameters.push(FunctionArgument {
535                instruction,
536                handle_id: 0,
537            });
538        }
539
540        let label_id = self.id_gen.next();
541        let mut block = Block::new(label_id);
542
543        let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
544        let bool_type_id = self.get_numeric_type_id(bool_type);
545
546        let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
547            NumericType::Scalar(_) => const_id,
548            NumericType::Vector { size, .. } => {
549                let constituent_ids = [const_id; crate::VectorSize::MAX];
550                writer.get_constant_composite(
551                    LookupType::Local(LocalType::Numeric(return_type)),
552                    &constituent_ids[..size as usize],
553                )
554            }
555            NumericType::Matrix { .. } => unreachable!(),
556        };
557
558        let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
559        let composite_zero_id = maybe_splat_const(self, const_zero_id);
560        let rhs_eq_zero_id = self.id_gen.next();
561        block.body.push(Instruction::binary(
562            spirv::Op::IEqual,
563            bool_type_id,
564            rhs_eq_zero_id,
565            rhs_id,
566            composite_zero_id,
567        ));
568        let divisor_selector_id = match scalar.kind {
569            crate::ScalarKind::Sint => {
570                let (const_min_id, const_neg_one_id) = match scalar.width {
571                    4 => Ok((
572                        self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
573                        self.get_constant_scalar(crate::Literal::I32(-1i32)),
574                    )),
575                    8 => Ok((
576                        self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
577                        self.get_constant_scalar(crate::Literal::I64(-1i64)),
578                    )),
579                    _ => Err(Error::Validation("Unexpected scalar width")),
580                }?;
581                let composite_min_id = maybe_splat_const(self, const_min_id);
582                let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
583
584                let lhs_eq_int_min_id = self.id_gen.next();
585                block.body.push(Instruction::binary(
586                    spirv::Op::IEqual,
587                    bool_type_id,
588                    lhs_eq_int_min_id,
589                    lhs_id,
590                    composite_min_id,
591                ));
592                let rhs_eq_neg_one_id = self.id_gen.next();
593                block.body.push(Instruction::binary(
594                    spirv::Op::IEqual,
595                    bool_type_id,
596                    rhs_eq_neg_one_id,
597                    rhs_id,
598                    composite_neg_one_id,
599                ));
600                let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
601                block.body.push(Instruction::binary(
602                    spirv::Op::LogicalAnd,
603                    bool_type_id,
604                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
605                    lhs_eq_int_min_id,
606                    rhs_eq_neg_one_id,
607                ));
608                let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
609                block.body.push(Instruction::binary(
610                    spirv::Op::LogicalOr,
611                    bool_type_id,
612                    rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
613                    rhs_eq_zero_id,
614                    lhs_eq_int_min_and_rhs_eq_neg_one_id,
615                ));
616                rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
617            }
618            crate::ScalarKind::Uint => rhs_eq_zero_id,
619            _ => unreachable!(),
620        };
621
622        let const_one_id = self.get_constant_scalar_with(1, scalar)?;
623        let composite_one_id = maybe_splat_const(self, const_one_id);
624        let divisor_id = self.id_gen.next();
625        block.body.push(Instruction::select(
626            right_type_id,
627            divisor_id,
628            divisor_selector_id,
629            composite_one_id,
630            rhs_id,
631        ));
632        let op = match (op, scalar.kind) {
633            (crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
634            (crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
635            (crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
636            (crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
637            _ => unreachable!(),
638        };
639        let return_id = self.id_gen.next();
640        block.body.push(Instruction::binary(
641            op,
642            return_type_id,
643            return_id,
644            lhs_id,
645            divisor_id,
646        ));
647
648        function.consume(block, Instruction::return_value(return_id));
649        function.to_words(&mut self.logical_layout.function_definitions);
650        Ok(())
651    }
652
653    fn write_function(
654        &mut self,
655        ir_function: &crate::Function,
656        info: &FunctionInfo,
657        ir_module: &crate::Module,
658        mut interface: Option<FunctionInterface>,
659        debug_info: &Option<DebugInfoInner>,
660    ) -> Result<Word, Error> {
661        self.write_wrapped_functions(ir_function, info, ir_module)?;
662
663        log::trace!("Generating code for {:?}", ir_function.name);
664        let mut function = Function::default();
665
666        let prelude_id = self.id_gen.next();
667        let mut prelude = Block::new(prelude_id);
668        let mut ep_context = EntryPointContext {
669            argument_ids: Vec::new(),
670            results: Vec::new(),
671        };
672
673        let mut local_invocation_id = None;
674
675        let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
676        for argument in ir_function.arguments.iter() {
677            let class = spirv::StorageClass::Input;
678            let handle_ty = ir_module.types[argument.ty].inner.is_handle();
679            let argument_type_id = if handle_ty {
680                self.get_handle_pointer_type_id(argument.ty, spirv::StorageClass::UniformConstant)
681            } else {
682                self.get_handle_type_id(argument.ty)
683            };
684
685            if let Some(ref mut iface) = interface {
686                let id = if let Some(ref binding) = argument.binding {
687                    let name = argument.name.as_deref();
688
689                    let varying_id = self.write_varying(
690                        ir_module,
691                        iface.stage,
692                        class,
693                        name,
694                        argument.ty,
695                        binding,
696                    )?;
697                    iface.varying_ids.push(varying_id);
698                    let id = self.id_gen.next();
699                    prelude
700                        .body
701                        .push(Instruction::load(argument_type_id, id, varying_id, None));
702
703                    if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
704                        local_invocation_id = Some(id);
705                    }
706
707                    id
708                } else if let crate::TypeInner::Struct { ref members, .. } =
709                    ir_module.types[argument.ty].inner
710                {
711                    let struct_id = self.id_gen.next();
712                    let mut constituent_ids = Vec::with_capacity(members.len());
713                    for member in members {
714                        let type_id = self.get_handle_type_id(member.ty);
715                        let name = member.name.as_deref();
716                        let binding = member.binding.as_ref().unwrap();
717                        let varying_id = self.write_varying(
718                            ir_module,
719                            iface.stage,
720                            class,
721                            name,
722                            member.ty,
723                            binding,
724                        )?;
725                        iface.varying_ids.push(varying_id);
726                        let id = self.id_gen.next();
727                        prelude
728                            .body
729                            .push(Instruction::load(type_id, id, varying_id, None));
730                        constituent_ids.push(id);
731
732                        if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) {
733                            local_invocation_id = Some(id);
734                        }
735                    }
736                    prelude.body.push(Instruction::composite_construct(
737                        argument_type_id,
738                        struct_id,
739                        &constituent_ids,
740                    ));
741                    struct_id
742                } else {
743                    unreachable!("Missing argument binding on an entry point");
744                };
745                ep_context.argument_ids.push(id);
746            } else {
747                let argument_id = self.id_gen.next();
748                let instruction = Instruction::function_parameter(argument_type_id, argument_id);
749                if self.flags.contains(WriterFlags::DEBUG) {
750                    if let Some(ref name) = argument.name {
751                        self.debugs.push(Instruction::name(argument_id, name));
752                    }
753                }
754                function.parameters.push(FunctionArgument {
755                    instruction,
756                    handle_id: if handle_ty {
757                        let id = self.id_gen.next();
758                        prelude.body.push(Instruction::load(
759                            self.get_handle_type_id(argument.ty),
760                            id,
761                            argument_id,
762                            None,
763                        ));
764                        id
765                    } else {
766                        0
767                    },
768                });
769                parameter_type_ids.push(argument_type_id);
770            };
771        }
772
773        let return_type_id = match ir_function.result {
774            Some(ref result) => {
775                if let Some(ref mut iface) = interface {
776                    let mut has_point_size = false;
777                    let class = spirv::StorageClass::Output;
778                    if let Some(ref binding) = result.binding {
779                        has_point_size |=
780                            *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
781                        let type_id = self.get_handle_type_id(result.ty);
782                        let varying_id = self.write_varying(
783                            ir_module,
784                            iface.stage,
785                            class,
786                            None,
787                            result.ty,
788                            binding,
789                        )?;
790                        iface.varying_ids.push(varying_id);
791                        ep_context.results.push(ResultMember {
792                            id: varying_id,
793                            type_id,
794                            built_in: binding.to_built_in(),
795                        });
796                    } else if let crate::TypeInner::Struct { ref members, .. } =
797                        ir_module.types[result.ty].inner
798                    {
799                        for member in members {
800                            let type_id = self.get_handle_type_id(member.ty);
801                            let name = member.name.as_deref();
802                            let binding = member.binding.as_ref().unwrap();
803                            has_point_size |=
804                                *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
805                            let varying_id = self.write_varying(
806                                ir_module,
807                                iface.stage,
808                                class,
809                                name,
810                                member.ty,
811                                binding,
812                            )?;
813                            iface.varying_ids.push(varying_id);
814                            ep_context.results.push(ResultMember {
815                                id: varying_id,
816                                type_id,
817                                built_in: binding.to_built_in(),
818                            });
819                        }
820                    } else {
821                        unreachable!("Missing result binding on an entry point");
822                    }
823
824                    if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
825                        && iface.stage == crate::ShaderStage::Vertex
826                        && !has_point_size
827                    {
828                        // add point size artificially
829                        let varying_id = self.id_gen.next();
830                        let pointer_type_id = self.get_f32_pointer_type_id(class);
831                        Instruction::variable(pointer_type_id, varying_id, class, None)
832                            .to_words(&mut self.logical_layout.declarations);
833                        self.decorate(
834                            varying_id,
835                            spirv::Decoration::BuiltIn,
836                            &[spirv::BuiltIn::PointSize as u32],
837                        );
838                        iface.varying_ids.push(varying_id);
839
840                        let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
841                        prelude
842                            .body
843                            .push(Instruction::store(varying_id, default_value_id, None));
844                    }
845                    self.void_type
846                } else {
847                    self.get_handle_type_id(result.ty)
848                }
849            }
850            None => self.void_type,
851        };
852
853        let lookup_function_type = LookupFunctionType {
854            parameter_type_ids,
855            return_type_id,
856        };
857
858        let function_id = self.id_gen.next();
859        if self.flags.contains(WriterFlags::DEBUG) {
860            if let Some(ref name) = ir_function.name {
861                self.debugs.push(Instruction::name(function_id, name));
862            }
863        }
864
865        let function_type = self.get_function_type(lookup_function_type);
866        function.signature = Some(Instruction::function(
867            return_type_id,
868            function_id,
869            spirv::FunctionControl::empty(),
870            function_type,
871        ));
872
873        if interface.is_some() {
874            function.entry_point_context = Some(ep_context);
875        }
876
877        // fill up the `GlobalVariable::access_id`
878        for gv in self.global_variables.iter_mut() {
879            gv.reset_for_function();
880        }
881        for (handle, var) in ir_module.global_variables.iter() {
882            if info[handle].is_empty() {
883                continue;
884            }
885
886            let mut gv = self.global_variables[handle].clone();
887            if let Some(ref mut iface) = interface {
888                // Have to include global variables in the interface
889                if self.physical_layout.version >= 0x10400 {
890                    iface.varying_ids.push(gv.var_id);
891                }
892            }
893
894            // Handle globals are pre-emitted and should be loaded automatically.
895            //
896            // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
897            match ir_module.types[var.ty].inner {
898                crate::TypeInner::BindingArray { .. } => {
899                    gv.access_id = gv.var_id;
900                }
901                _ => {
902                    if var.space == crate::AddressSpace::Handle {
903                        let var_type_id = self.get_handle_type_id(var.ty);
904                        let id = self.id_gen.next();
905                        prelude
906                            .body
907                            .push(Instruction::load(var_type_id, id, gv.var_id, None));
908                        gv.access_id = gv.var_id;
909                        gv.handle_id = id;
910                    } else if global_needs_wrapper(ir_module, var) {
911                        let class = map_storage_class(var.space);
912                        let pointer_type_id = self.get_handle_pointer_type_id(var.ty, class);
913                        let index_id = self.get_index_constant(0);
914                        let id = self.id_gen.next();
915                        prelude.body.push(Instruction::access_chain(
916                            pointer_type_id,
917                            id,
918                            gv.var_id,
919                            &[index_id],
920                        ));
921                        gv.access_id = id;
922                    } else {
923                        // by default, the variable ID is accessed as is
924                        gv.access_id = gv.var_id;
925                    };
926                }
927            }
928
929            // work around borrow checking in the presence of `self.xxx()` calls
930            self.global_variables[handle] = gv;
931        }
932
933        // Create a `BlockContext` for generating SPIR-V for the function's
934        // body.
935        let mut context = BlockContext {
936            ir_module,
937            ir_function,
938            fun_info: info,
939            function: &mut function,
940            // Re-use the cached expression table from prior functions.
941            cached: core::mem::take(&mut self.saved_cached),
942
943            // Steal the Writer's temp list for a bit.
944            temp_list: core::mem::take(&mut self.temp_list),
945            force_loop_bounding: self.force_loop_bounding,
946            writer: self,
947            expression_constness: super::ExpressionConstnessTracker::from_arena(
948                &ir_function.expressions,
949            ),
950        };
951
952        // fill up the pre-emitted and const expressions
953        context.cached.reset(ir_function.expressions.len());
954        for (handle, expr) in ir_function.expressions.iter() {
955            if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
956                || context.expression_constness.is_const(handle)
957            {
958                context.cache_expression_value(handle, &mut prelude)?;
959            }
960        }
961
962        for (handle, variable) in ir_function.local_variables.iter() {
963            let id = context.gen_id();
964
965            if context.writer.flags.contains(WriterFlags::DEBUG) {
966                if let Some(ref name) = variable.name {
967                    context.writer.debugs.push(Instruction::name(id, name));
968                }
969            }
970
971            let init_word = variable.init.map(|constant| context.cached[constant]);
972            let pointer_type_id = context
973                .writer
974                .get_handle_pointer_type_id(variable.ty, spirv::StorageClass::Function);
975            let instruction = Instruction::variable(
976                pointer_type_id,
977                id,
978                spirv::StorageClass::Function,
979                init_word.or_else(|| match ir_module.types[variable.ty].inner {
980                    crate::TypeInner::RayQuery { .. } => None,
981                    _ => {
982                        let type_id = context.get_handle_type_id(variable.ty);
983                        Some(context.writer.write_constant_null(type_id))
984                    }
985                }),
986            );
987            context
988                .function
989                .variables
990                .insert(handle, LocalVariable { id, instruction });
991        }
992
993        for (handle, expr) in ir_function.expressions.iter() {
994            match *expr {
995                crate::Expression::LocalVariable(_) => {
996                    // Cache the `OpVariable` instruction we generated above as
997                    // the value of this expression.
998                    context.cache_expression_value(handle, &mut prelude)?;
999                }
1000                crate::Expression::Access { base, .. }
1001                | crate::Expression::AccessIndex { base, .. } => {
1002                    // Count references to `base` by `Access` and `AccessIndex`
1003                    // instructions. See `access_uses` for details.
1004                    *context.function.access_uses.entry(base).or_insert(0) += 1;
1005                }
1006                _ => {}
1007            }
1008        }
1009
1010        let next_id = context.gen_id();
1011
1012        context
1013            .function
1014            .consume(prelude, Instruction::branch(next_id));
1015
1016        let workgroup_vars_init_exit_block_id =
1017            match (context.writer.zero_initialize_workgroup_memory, interface) {
1018                (
1019                    super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
1020                    Some(
1021                        ref mut interface @ FunctionInterface {
1022                            stage: crate::ShaderStage::Compute,
1023                            ..
1024                        },
1025                    ),
1026                ) => context.writer.generate_workgroup_vars_init_block(
1027                    next_id,
1028                    ir_module,
1029                    info,
1030                    local_invocation_id,
1031                    interface,
1032                    context.function,
1033                ),
1034                _ => None,
1035            };
1036
1037        let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
1038            exit_id
1039        } else {
1040            next_id
1041        };
1042
1043        context.write_function_body(main_id, debug_info.as_ref())?;
1044
1045        // Consume the `BlockContext`, ending its borrows and letting the
1046        // `Writer` steal back its cached expression table and temp_list.
1047        let BlockContext {
1048            cached, temp_list, ..
1049        } = context;
1050        self.saved_cached = cached;
1051        self.temp_list = temp_list;
1052
1053        function.to_words(&mut self.logical_layout.function_definitions);
1054
1055        Ok(function_id)
1056    }
1057
1058    fn write_execution_mode(
1059        &mut self,
1060        function_id: Word,
1061        mode: spirv::ExecutionMode,
1062    ) -> Result<(), Error> {
1063        //self.check(mode.required_capabilities())?;
1064        Instruction::execution_mode(function_id, mode, &[])
1065            .to_words(&mut self.logical_layout.execution_modes);
1066        Ok(())
1067    }
1068
1069    // TODO Move to instructions module
1070    fn write_entry_point(
1071        &mut self,
1072        entry_point: &crate::EntryPoint,
1073        info: &FunctionInfo,
1074        ir_module: &crate::Module,
1075        debug_info: &Option<DebugInfoInner>,
1076    ) -> Result<Instruction, Error> {
1077        let mut interface_ids = Vec::new();
1078        let function_id = self.write_function(
1079            &entry_point.function,
1080            info,
1081            ir_module,
1082            Some(FunctionInterface {
1083                varying_ids: &mut interface_ids,
1084                stage: entry_point.stage,
1085            }),
1086            debug_info,
1087        )?;
1088
1089        let exec_model = match entry_point.stage {
1090            crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
1091            crate::ShaderStage::Fragment => {
1092                self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
1093                if let Some(ref result) = entry_point.function.result {
1094                    if contains_builtin(
1095                        result.binding.as_ref(),
1096                        result.ty,
1097                        &ir_module.types,
1098                        crate::BuiltIn::FragDepth,
1099                    ) {
1100                        self.write_execution_mode(
1101                            function_id,
1102                            spirv::ExecutionMode::DepthReplacing,
1103                        )?;
1104                    }
1105                }
1106                spirv::ExecutionModel::Fragment
1107            }
1108            crate::ShaderStage::Compute => {
1109                let execution_mode = spirv::ExecutionMode::LocalSize;
1110                //self.check(execution_mode.required_capabilities())?;
1111                Instruction::execution_mode(
1112                    function_id,
1113                    execution_mode,
1114                    &entry_point.workgroup_size,
1115                )
1116                .to_words(&mut self.logical_layout.execution_modes);
1117                spirv::ExecutionModel::GLCompute
1118            }
1119            crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(),
1120        };
1121        //self.check(exec_model.required_capabilities())?;
1122
1123        Ok(Instruction::entry_point(
1124            exec_model,
1125            function_id,
1126            &entry_point.name,
1127            interface_ids.as_slice(),
1128        ))
1129    }
1130
1131    fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
1132        use crate::ScalarKind as Sk;
1133
1134        let bits = (scalar.width * BITS_PER_BYTE) as u32;
1135        match scalar.kind {
1136            Sk::Sint | Sk::Uint => {
1137                let signedness = if scalar.kind == Sk::Sint {
1138                    super::instructions::Signedness::Signed
1139                } else {
1140                    super::instructions::Signedness::Unsigned
1141                };
1142                let cap = match bits {
1143                    8 => Some(spirv::Capability::Int8),
1144                    16 => Some(spirv::Capability::Int16),
1145                    64 => Some(spirv::Capability::Int64),
1146                    _ => None,
1147                };
1148                if let Some(cap) = cap {
1149                    self.capabilities_used.insert(cap);
1150                }
1151                Instruction::type_int(id, bits, signedness)
1152            }
1153            Sk::Float => {
1154                if bits == 64 {
1155                    self.capabilities_used.insert(spirv::Capability::Float64);
1156                }
1157                if bits == 16 {
1158                    self.capabilities_used.insert(spirv::Capability::Float16);
1159                    self.capabilities_used
1160                        .insert(spirv::Capability::StorageBuffer16BitAccess);
1161                    self.capabilities_used
1162                        .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess);
1163                    self.capabilities_used
1164                        .insert(spirv::Capability::StorageInputOutput16);
1165                }
1166                Instruction::type_float(id, bits)
1167            }
1168            Sk::Bool => Instruction::type_bool(id),
1169            Sk::AbstractInt | Sk::AbstractFloat => {
1170                unreachable!("abstract types should never reach the backend");
1171            }
1172        }
1173    }
1174
1175    fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
1176        match *inner {
1177            crate::TypeInner::Image {
1178                dim,
1179                arrayed,
1180                class,
1181            } => {
1182                let sampled = match class {
1183                    crate::ImageClass::Sampled { .. } => true,
1184                    crate::ImageClass::Depth { .. } => true,
1185                    crate::ImageClass::Storage { format, .. } => {
1186                        self.request_image_format_capabilities(format.into())?;
1187                        false
1188                    }
1189                };
1190
1191                match dim {
1192                    crate::ImageDimension::D1 => {
1193                        if sampled {
1194                            self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
1195                        } else {
1196                            self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
1197                        }
1198                    }
1199                    crate::ImageDimension::Cube if arrayed => {
1200                        if sampled {
1201                            self.require_any(
1202                                "sampled cube array images",
1203                                &[spirv::Capability::SampledCubeArray],
1204                            )?;
1205                        } else {
1206                            self.require_any(
1207                                "cube array storage images",
1208                                &[spirv::Capability::ImageCubeArray],
1209                            )?;
1210                        }
1211                    }
1212                    _ => {}
1213                }
1214            }
1215            crate::TypeInner::AccelerationStructure { .. } => {
1216                self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?;
1217            }
1218            crate::TypeInner::RayQuery { .. } => {
1219                self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
1220            }
1221            crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
1222                self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
1223            }
1224            crate::TypeInner::Atomic(crate::Scalar {
1225                width: 4,
1226                kind: crate::ScalarKind::Float,
1227            }) => {
1228                self.require_any(
1229                    "32 bit floating-point atomics",
1230                    &[spirv::Capability::AtomicFloat32AddEXT],
1231                )?;
1232                self.use_extension("SPV_EXT_shader_atomic_float_add");
1233            }
1234            // 16 bit floating-point support requires Float16 capability
1235            crate::TypeInner::Matrix {
1236                scalar: crate::Scalar::F16,
1237                ..
1238            }
1239            | crate::TypeInner::Vector {
1240                scalar: crate::Scalar::F16,
1241                ..
1242            }
1243            | crate::TypeInner::Scalar(crate::Scalar::F16) => {
1244                self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?;
1245                self.use_extension("SPV_KHR_16bit_storage");
1246            }
1247            _ => {}
1248        }
1249        Ok(())
1250    }
1251
1252    fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) {
1253        let instruction = match numeric {
1254            NumericType::Scalar(scalar) => self.make_scalar(id, scalar),
1255            NumericType::Vector { size, scalar } => {
1256                let scalar_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
1257                Instruction::type_vector(id, scalar_id, size)
1258            }
1259            NumericType::Matrix {
1260                columns,
1261                rows,
1262                scalar,
1263            } => {
1264                let column_id =
1265                    self.get_numeric_type_id(NumericType::Vector { size: rows, scalar });
1266                Instruction::type_matrix(id, column_id, columns)
1267            }
1268        };
1269
1270        instruction.to_words(&mut self.logical_layout.declarations);
1271    }
1272
1273    fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
1274        let instruction = match local_ty {
1275            LocalType::Numeric(numeric) => {
1276                self.write_numeric_type_declaration_local(id, numeric);
1277                return;
1278            }
1279            LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base),
1280            LocalType::Image(image) => {
1281                let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
1282                let type_id = self.get_localtype_id(local_type);
1283                Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
1284            }
1285            LocalType::Sampler => Instruction::type_sampler(id),
1286            LocalType::SampledImage { image_type_id } => {
1287                Instruction::type_sampled_image(id, image_type_id)
1288            }
1289            LocalType::BindingArray { base, size } => {
1290                let inner_ty = self.get_handle_type_id(base);
1291                let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
1292                Instruction::type_array(id, inner_ty, scalar_id)
1293            }
1294            LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
1295            LocalType::RayQuery => Instruction::type_ray_query(id),
1296        };
1297
1298        instruction.to_words(&mut self.logical_layout.declarations);
1299    }
1300
1301    fn write_type_declaration_arena(
1302        &mut self,
1303        module: &crate::Module,
1304        handle: Handle<crate::Type>,
1305    ) -> Result<Word, Error> {
1306        let ty = &module.types[handle];
1307        // If it's a type that needs SPIR-V capabilities, request them now.
1308        // This needs to happen regardless of the LocalType lookup succeeding,
1309        // because some types which map to the same LocalType have different
1310        // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
1311        self.request_type_capabilities(&ty.inner)?;
1312        let id = if let Some(local) = self.localtype_from_inner(&ty.inner) {
1313            // This type can be represented as a `LocalType`, so check if we've
1314            // already written an instruction for it. If not, do so now, with
1315            // `write_type_declaration_local`.
1316            match self.lookup_type.entry(LookupType::Local(local)) {
1317                // We already have an id for this `LocalType`.
1318                Entry::Occupied(e) => *e.get(),
1319
1320                // It's a type we haven't seen before.
1321                Entry::Vacant(e) => {
1322                    let id = self.id_gen.next();
1323                    e.insert(id);
1324
1325                    self.write_type_declaration_local(id, local);
1326
1327                    id
1328                }
1329            }
1330        } else {
1331            use spirv::Decoration;
1332
1333            let id = self.id_gen.next();
1334            let instruction = match ty.inner {
1335                crate::TypeInner::Array { base, size, stride } => {
1336                    self.decorate(id, Decoration::ArrayStride, &[stride]);
1337
1338                    let type_id = self.get_handle_type_id(base);
1339                    match size.resolve(module.to_ctx())? {
1340                        crate::proc::IndexableLength::Known(length) => {
1341                            let length_id = self.get_index_constant(length);
1342                            Instruction::type_array(id, type_id, length_id)
1343                        }
1344                        crate::proc::IndexableLength::Dynamic => {
1345                            Instruction::type_runtime_array(id, type_id)
1346                        }
1347                    }
1348                }
1349                crate::TypeInner::BindingArray { base, size } => {
1350                    let type_id = self.get_handle_type_id(base);
1351                    match size.resolve(module.to_ctx())? {
1352                        crate::proc::IndexableLength::Known(length) => {
1353                            let length_id = self.get_index_constant(length);
1354                            Instruction::type_array(id, type_id, length_id)
1355                        }
1356                        crate::proc::IndexableLength::Dynamic => {
1357                            Instruction::type_runtime_array(id, type_id)
1358                        }
1359                    }
1360                }
1361                crate::TypeInner::Struct {
1362                    ref members,
1363                    span: _,
1364                } => {
1365                    let mut has_runtime_array = false;
1366                    let mut member_ids = Vec::with_capacity(members.len());
1367                    for (index, member) in members.iter().enumerate() {
1368                        let member_ty = &module.types[member.ty];
1369                        match member_ty.inner {
1370                            crate::TypeInner::Array {
1371                                base: _,
1372                                size: crate::ArraySize::Dynamic,
1373                                stride: _,
1374                            } => {
1375                                has_runtime_array = true;
1376                            }
1377                            _ => (),
1378                        }
1379                        self.decorate_struct_member(id, index, member, &module.types)?;
1380                        let member_id = self.get_handle_type_id(member.ty);
1381                        member_ids.push(member_id);
1382                    }
1383                    if has_runtime_array {
1384                        self.decorate(id, Decoration::Block, &[]);
1385                    }
1386                    Instruction::type_struct(id, member_ids.as_slice())
1387                }
1388
1389                // These all have TypeLocal representations, so they should have been
1390                // handled by `write_type_declaration_local` above.
1391                crate::TypeInner::Scalar(_)
1392                | crate::TypeInner::Atomic(_)
1393                | crate::TypeInner::Vector { .. }
1394                | crate::TypeInner::Matrix { .. }
1395                | crate::TypeInner::Pointer { .. }
1396                | crate::TypeInner::ValuePointer { .. }
1397                | crate::TypeInner::Image { .. }
1398                | crate::TypeInner::Sampler { .. }
1399                | crate::TypeInner::AccelerationStructure { .. }
1400                | crate::TypeInner::RayQuery { .. } => unreachable!(),
1401            };
1402
1403            instruction.to_words(&mut self.logical_layout.declarations);
1404            id
1405        };
1406
1407        // Add this handle as a new alias for that type.
1408        self.lookup_type.insert(LookupType::Handle(handle), id);
1409
1410        if self.flags.contains(WriterFlags::DEBUG) {
1411            if let Some(ref name) = ty.name {
1412                self.debugs.push(Instruction::name(id, name));
1413            }
1414        }
1415
1416        Ok(id)
1417    }
1418
1419    fn request_image_format_capabilities(
1420        &mut self,
1421        format: spirv::ImageFormat,
1422    ) -> Result<(), Error> {
1423        use spirv::ImageFormat as If;
1424        match format {
1425            If::Rg32f
1426            | If::Rg16f
1427            | If::R11fG11fB10f
1428            | If::R16f
1429            | If::Rgba16
1430            | If::Rgb10A2
1431            | If::Rg16
1432            | If::Rg8
1433            | If::R16
1434            | If::R8
1435            | If::Rgba16Snorm
1436            | If::Rg16Snorm
1437            | If::Rg8Snorm
1438            | If::R16Snorm
1439            | If::R8Snorm
1440            | If::Rg32i
1441            | If::Rg16i
1442            | If::Rg8i
1443            | If::R16i
1444            | If::R8i
1445            | If::Rgb10a2ui
1446            | If::Rg32ui
1447            | If::Rg16ui
1448            | If::Rg8ui
1449            | If::R16ui
1450            | If::R8ui => self.require_any(
1451                "storage image format",
1452                &[spirv::Capability::StorageImageExtendedFormats],
1453            ),
1454            If::R64ui | If::R64i => {
1455                self.use_extension("SPV_EXT_shader_image_int64");
1456                self.require_any(
1457                    "64-bit integer storage image format",
1458                    &[spirv::Capability::Int64ImageEXT],
1459                )
1460            }
1461            If::Unknown
1462            | If::Rgba32f
1463            | If::Rgba16f
1464            | If::R32f
1465            | If::Rgba8
1466            | If::Rgba8Snorm
1467            | If::Rgba32i
1468            | If::Rgba16i
1469            | If::Rgba8i
1470            | If::R32i
1471            | If::Rgba32ui
1472            | If::Rgba16ui
1473            | If::Rgba8ui
1474            | If::R32ui => Ok(()),
1475        }
1476    }
1477
1478    pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
1479        self.get_constant_scalar(crate::Literal::U32(index))
1480    }
1481
1482    pub(super) fn get_constant_scalar_with(
1483        &mut self,
1484        value: u8,
1485        scalar: crate::Scalar,
1486    ) -> Result<Word, Error> {
1487        Ok(
1488            self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
1489                Error::Validation("Unexpected kind and/or width for Literal"),
1490            )?),
1491        )
1492    }
1493
1494    pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
1495        let scalar = CachedConstant::Literal(value.into());
1496        if let Some(&id) = self.cached_constants.get(&scalar) {
1497            return id;
1498        }
1499        let id = self.id_gen.next();
1500        self.write_constant_scalar(id, &value, None);
1501        self.cached_constants.insert(scalar, id);
1502        id
1503    }
1504
1505    fn write_constant_scalar(
1506        &mut self,
1507        id: Word,
1508        value: &crate::Literal,
1509        debug_name: Option<&String>,
1510    ) {
1511        if self.flags.contains(WriterFlags::DEBUG) {
1512            if let Some(name) = debug_name {
1513                self.debugs.push(Instruction::name(id, name));
1514            }
1515        }
1516        let type_id = self.get_numeric_type_id(NumericType::Scalar(value.scalar()));
1517        let instruction = match *value {
1518            crate::Literal::F64(value) => {
1519                let bits = value.to_bits();
1520                Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
1521            }
1522            crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
1523            crate::Literal::F16(value) => {
1524                let low = value.to_bits();
1525                Instruction::constant_16bit(type_id, id, low as u32)
1526            }
1527            crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
1528            crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
1529            crate::Literal::U64(value) => {
1530                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
1531            }
1532            crate::Literal::I64(value) => {
1533                Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
1534            }
1535            crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
1536            crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
1537            crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
1538                unreachable!("Abstract types should not appear in IR presented to backends");
1539            }
1540        };
1541
1542        instruction.to_words(&mut self.logical_layout.declarations);
1543    }
1544
1545    pub(super) fn get_constant_composite(
1546        &mut self,
1547        ty: LookupType,
1548        constituent_ids: &[Word],
1549    ) -> Word {
1550        let composite = CachedConstant::Composite {
1551            ty,
1552            constituent_ids: constituent_ids.to_vec(),
1553        };
1554        if let Some(&id) = self.cached_constants.get(&composite) {
1555            return id;
1556        }
1557        let id = self.id_gen.next();
1558        self.write_constant_composite(id, ty, constituent_ids, None);
1559        self.cached_constants.insert(composite, id);
1560        id
1561    }
1562
1563    fn write_constant_composite(
1564        &mut self,
1565        id: Word,
1566        ty: LookupType,
1567        constituent_ids: &[Word],
1568        debug_name: Option<&String>,
1569    ) {
1570        if self.flags.contains(WriterFlags::DEBUG) {
1571            if let Some(name) = debug_name {
1572                self.debugs.push(Instruction::name(id, name));
1573            }
1574        }
1575        let type_id = self.get_type_id(ty);
1576        Instruction::constant_composite(type_id, id, constituent_ids)
1577            .to_words(&mut self.logical_layout.declarations);
1578    }
1579
1580    pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
1581        let null = CachedConstant::ZeroValue(type_id);
1582        if let Some(&id) = self.cached_constants.get(&null) {
1583            return id;
1584        }
1585        let id = self.write_constant_null(type_id);
1586        self.cached_constants.insert(null, id);
1587        id
1588    }
1589
1590    pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
1591        let null_id = self.id_gen.next();
1592        Instruction::constant_null(type_id, null_id)
1593            .to_words(&mut self.logical_layout.declarations);
1594        null_id
1595    }
1596
1597    fn write_constant_expr(
1598        &mut self,
1599        handle: Handle<crate::Expression>,
1600        ir_module: &crate::Module,
1601        mod_info: &ModuleInfo,
1602    ) -> Result<Word, Error> {
1603        let id = match ir_module.global_expressions[handle] {
1604            crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
1605            crate::Expression::Constant(constant) => {
1606                let constant = &ir_module.constants[constant];
1607                self.constant_ids[constant.init]
1608            }
1609            crate::Expression::ZeroValue(ty) => {
1610                let type_id = self.get_handle_type_id(ty);
1611                self.get_constant_null(type_id)
1612            }
1613            crate::Expression::Compose { ty, ref components } => {
1614                let component_ids: Vec<_> = crate::proc::flatten_compose(
1615                    ty,
1616                    components,
1617                    &ir_module.global_expressions,
1618                    &ir_module.types,
1619                )
1620                .map(|component| self.constant_ids[component])
1621                .collect();
1622                self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
1623            }
1624            crate::Expression::Splat { size, value } => {
1625                let value_id = self.constant_ids[value];
1626                let component_ids = &[value_id; 4][..size as usize];
1627
1628                let ty = self.get_expression_lookup_type(&mod_info[handle]);
1629
1630                self.get_constant_composite(ty, component_ids)
1631            }
1632            _ => {
1633                return Err(Error::Override);
1634            }
1635        };
1636
1637        self.constant_ids[handle] = id;
1638
1639        Ok(id)
1640    }
1641
1642    pub(super) fn write_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
1643        let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
1644            spirv::Scope::Device
1645        } else {
1646            spirv::Scope::Workgroup
1647        };
1648        let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
1649        semantics.set(
1650            spirv::MemorySemantics::UNIFORM_MEMORY,
1651            flags.contains(crate::Barrier::STORAGE),
1652        );
1653        semantics.set(
1654            spirv::MemorySemantics::WORKGROUP_MEMORY,
1655            flags.contains(crate::Barrier::WORK_GROUP),
1656        );
1657        semantics.set(
1658            spirv::MemorySemantics::IMAGE_MEMORY,
1659            flags.contains(crate::Barrier::TEXTURE),
1660        );
1661        let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
1662            self.get_index_constant(spirv::Scope::Subgroup as u32)
1663        } else {
1664            self.get_index_constant(spirv::Scope::Workgroup as u32)
1665        };
1666        let mem_scope_id = self.get_index_constant(memory_scope as u32);
1667        let semantics_id = self.get_index_constant(semantics.bits());
1668        block.body.push(Instruction::control_barrier(
1669            exec_scope_id,
1670            mem_scope_id,
1671            semantics_id,
1672        ));
1673    }
1674
1675    fn generate_workgroup_vars_init_block(
1676        &mut self,
1677        entry_id: Word,
1678        ir_module: &crate::Module,
1679        info: &FunctionInfo,
1680        local_invocation_id: Option<Word>,
1681        interface: &mut FunctionInterface,
1682        function: &mut Function,
1683    ) -> Option<Word> {
1684        let body = ir_module
1685            .global_variables
1686            .iter()
1687            .filter(|&(handle, var)| {
1688                !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
1689            })
1690            .map(|(handle, var)| {
1691                // It's safe to use `var_id` here, not `access_id`, because only
1692                // variables in the `Uniform` and `StorageBuffer` address spaces
1693                // get wrapped, and we're initializing `WorkGroup` variables.
1694                let var_id = self.global_variables[handle].var_id;
1695                let var_type_id = self.get_handle_type_id(var.ty);
1696                let init_word = self.get_constant_null(var_type_id);
1697                Instruction::store(var_id, init_word, None)
1698            })
1699            .collect::<Vec<_>>();
1700
1701        if body.is_empty() {
1702            return None;
1703        }
1704
1705        let uint3_type_id = self.get_vec3u_type_id();
1706
1707        let mut pre_if_block = Block::new(entry_id);
1708
1709        let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id {
1710            local_invocation_id
1711        } else {
1712            let varying_id = self.id_gen.next();
1713            let class = spirv::StorageClass::Input;
1714            let pointer_type_id = self.get_vec3u_pointer_type_id(class);
1715
1716            Instruction::variable(pointer_type_id, varying_id, class, None)
1717                .to_words(&mut self.logical_layout.declarations);
1718
1719            self.decorate(
1720                varying_id,
1721                spirv::Decoration::BuiltIn,
1722                &[spirv::BuiltIn::LocalInvocationId as u32],
1723            );
1724
1725            interface.varying_ids.push(varying_id);
1726            let id = self.id_gen.next();
1727            pre_if_block
1728                .body
1729                .push(Instruction::load(uint3_type_id, id, varying_id, None));
1730
1731            id
1732        };
1733
1734        let zero_id = self.get_constant_null(uint3_type_id);
1735        let bool3_type_id = self.get_vec3_bool_type_id();
1736
1737        let eq_id = self.id_gen.next();
1738        pre_if_block.body.push(Instruction::binary(
1739            spirv::Op::IEqual,
1740            bool3_type_id,
1741            eq_id,
1742            local_invocation_id,
1743            zero_id,
1744        ));
1745
1746        let condition_id = self.id_gen.next();
1747        let bool_type_id = self.get_bool_type_id();
1748        pre_if_block.body.push(Instruction::relational(
1749            spirv::Op::All,
1750            bool_type_id,
1751            condition_id,
1752            eq_id,
1753        ));
1754
1755        let merge_id = self.id_gen.next();
1756        pre_if_block.body.push(Instruction::selection_merge(
1757            merge_id,
1758            spirv::SelectionControl::NONE,
1759        ));
1760
1761        let accept_id = self.id_gen.next();
1762        function.consume(
1763            pre_if_block,
1764            Instruction::branch_conditional(condition_id, accept_id, merge_id),
1765        );
1766
1767        let accept_block = Block {
1768            label_id: accept_id,
1769            body,
1770        };
1771        function.consume(accept_block, Instruction::branch(merge_id));
1772
1773        let mut post_if_block = Block::new(merge_id);
1774
1775        self.write_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block);
1776
1777        let next_id = self.id_gen.next();
1778        function.consume(post_if_block, Instruction::branch(next_id));
1779        Some(next_id)
1780    }
1781
1782    /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
1783    ///
1784    /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
1785    /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
1786    /// interface is represented by global variables in the `Input` and `Output`
1787    /// storage classes, with decorations indicating which builtin or location
1788    /// each variable corresponds to.
1789    ///
1790    /// This function emits a single global `OpVariable` for a single value from
1791    /// the interface, and adds appropriate decorations to indicate which
1792    /// builtin or location it represents, how it should be interpolated, and so
1793    /// on. The `class` argument gives the variable's SPIR-V storage class,
1794    /// which should be either [`Input`] or [`Output`].
1795    ///
1796    /// [`Binding`]: crate::Binding
1797    /// [`Function`]: crate::Function
1798    /// [`EntryPoint`]: crate::EntryPoint
1799    /// [`Input`]: spirv::StorageClass::Input
1800    /// [`Output`]: spirv::StorageClass::Output
1801    fn write_varying(
1802        &mut self,
1803        ir_module: &crate::Module,
1804        stage: crate::ShaderStage,
1805        class: spirv::StorageClass,
1806        debug_name: Option<&str>,
1807        ty: Handle<crate::Type>,
1808        binding: &crate::Binding,
1809    ) -> Result<Word, Error> {
1810        let id = self.id_gen.next();
1811        let pointer_type_id = self.get_handle_pointer_type_id(ty, class);
1812        Instruction::variable(pointer_type_id, id, class, None)
1813            .to_words(&mut self.logical_layout.declarations);
1814
1815        if self
1816            .flags
1817            .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
1818        {
1819            if let Some(name) = debug_name {
1820                self.debugs.push(Instruction::name(id, name));
1821            }
1822        }
1823
1824        use spirv::{BuiltIn, Decoration};
1825
1826        match *binding {
1827            crate::Binding::Location {
1828                location,
1829                interpolation,
1830                sampling,
1831                blend_src,
1832            } => {
1833                self.decorate(id, Decoration::Location, &[location]);
1834
1835                let no_decorations =
1836                    // VUID-StandaloneSpirv-Flat-06202
1837                    // > The Flat, NoPerspective, Sample, and Centroid decorations
1838                    // > must not be used on variables with the Input storage class in a vertex shader
1839                    (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
1840                    // VUID-StandaloneSpirv-Flat-06201
1841                    // > The Flat, NoPerspective, Sample, and Centroid decorations
1842                    // > must not be used on variables with the Output storage class in a fragment shader
1843                    (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
1844
1845                if !no_decorations {
1846                    match interpolation {
1847                        // Perspective-correct interpolation is the default in SPIR-V.
1848                        None | Some(crate::Interpolation::Perspective) => (),
1849                        Some(crate::Interpolation::Flat) => {
1850                            self.decorate(id, Decoration::Flat, &[]);
1851                        }
1852                        Some(crate::Interpolation::Linear) => {
1853                            self.decorate(id, Decoration::NoPerspective, &[]);
1854                        }
1855                    }
1856                    match sampling {
1857                        // Center sampling is the default in SPIR-V.
1858                        None
1859                        | Some(
1860                            crate::Sampling::Center
1861                            | crate::Sampling::First
1862                            | crate::Sampling::Either,
1863                        ) => (),
1864                        Some(crate::Sampling::Centroid) => {
1865                            self.decorate(id, Decoration::Centroid, &[]);
1866                        }
1867                        Some(crate::Sampling::Sample) => {
1868                            self.require_any(
1869                                "per-sample interpolation",
1870                                &[spirv::Capability::SampleRateShading],
1871                            )?;
1872                            self.decorate(id, Decoration::Sample, &[]);
1873                        }
1874                    }
1875                }
1876                if let Some(blend_src) = blend_src {
1877                    self.decorate(id, Decoration::Index, &[blend_src]);
1878                }
1879            }
1880            crate::Binding::BuiltIn(built_in) => {
1881                use crate::BuiltIn as Bi;
1882                let built_in = match built_in {
1883                    Bi::Position { invariant } => {
1884                        if invariant {
1885                            self.decorate(id, Decoration::Invariant, &[]);
1886                        }
1887
1888                        if class == spirv::StorageClass::Output {
1889                            BuiltIn::Position
1890                        } else {
1891                            BuiltIn::FragCoord
1892                        }
1893                    }
1894                    Bi::ViewIndex => {
1895                        self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
1896                        BuiltIn::ViewIndex
1897                    }
1898                    // vertex
1899                    Bi::BaseInstance => BuiltIn::BaseInstance,
1900                    Bi::BaseVertex => BuiltIn::BaseVertex,
1901                    Bi::ClipDistance => {
1902                        self.require_any(
1903                            "`clip_distance` built-in",
1904                            &[spirv::Capability::ClipDistance],
1905                        )?;
1906                        BuiltIn::ClipDistance
1907                    }
1908                    Bi::CullDistance => {
1909                        self.require_any(
1910                            "`cull_distance` built-in",
1911                            &[spirv::Capability::CullDistance],
1912                        )?;
1913                        BuiltIn::CullDistance
1914                    }
1915                    Bi::InstanceIndex => BuiltIn::InstanceIndex,
1916                    Bi::PointSize => BuiltIn::PointSize,
1917                    Bi::VertexIndex => BuiltIn::VertexIndex,
1918                    Bi::DrawID => BuiltIn::DrawIndex,
1919                    // fragment
1920                    Bi::FragDepth => BuiltIn::FragDepth,
1921                    Bi::PointCoord => BuiltIn::PointCoord,
1922                    Bi::FrontFacing => BuiltIn::FrontFacing,
1923                    Bi::PrimitiveIndex => {
1924                        self.require_any(
1925                            "`primitive_index` built-in",
1926                            &[spirv::Capability::Geometry],
1927                        )?;
1928                        BuiltIn::PrimitiveId
1929                    }
1930                    Bi::SampleIndex => {
1931                        self.require_any(
1932                            "`sample_index` built-in",
1933                            &[spirv::Capability::SampleRateShading],
1934                        )?;
1935
1936                        BuiltIn::SampleId
1937                    }
1938                    Bi::SampleMask => BuiltIn::SampleMask,
1939                    // compute
1940                    Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
1941                    Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
1942                    Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
1943                    Bi::WorkGroupId => BuiltIn::WorkgroupId,
1944                    Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
1945                    Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
1946                    // Subgroup
1947                    Bi::NumSubgroups => {
1948                        self.require_any(
1949                            "`num_subgroups` built-in",
1950                            &[spirv::Capability::GroupNonUniform],
1951                        )?;
1952                        BuiltIn::NumSubgroups
1953                    }
1954                    Bi::SubgroupId => {
1955                        self.require_any(
1956                            "`subgroup_id` built-in",
1957                            &[spirv::Capability::GroupNonUniform],
1958                        )?;
1959                        BuiltIn::SubgroupId
1960                    }
1961                    Bi::SubgroupSize => {
1962                        self.require_any(
1963                            "`subgroup_size` built-in",
1964                            &[
1965                                spirv::Capability::GroupNonUniform,
1966                                spirv::Capability::SubgroupBallotKHR,
1967                            ],
1968                        )?;
1969                        BuiltIn::SubgroupSize
1970                    }
1971                    Bi::SubgroupInvocationId => {
1972                        self.require_any(
1973                            "`subgroup_invocation_id` built-in",
1974                            &[
1975                                spirv::Capability::GroupNonUniform,
1976                                spirv::Capability::SubgroupBallotKHR,
1977                            ],
1978                        )?;
1979                        BuiltIn::SubgroupLocalInvocationId
1980                    }
1981                };
1982
1983                self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
1984
1985                use crate::ScalarKind as Sk;
1986
1987                // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
1988                //
1989                // > Any variable with integer or double-precision floating-
1990                // > point type and with Input storage class in a fragment
1991                // > shader, must be decorated Flat
1992                if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
1993                    let is_flat = match ir_module.types[ty].inner {
1994                        crate::TypeInner::Scalar(scalar)
1995                        | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
1996                            Sk::Uint | Sk::Sint | Sk::Bool => true,
1997                            Sk::Float => false,
1998                            Sk::AbstractInt | Sk::AbstractFloat => {
1999                                return Err(Error::Validation(
2000                                    "Abstract types should not appear in IR presented to backends",
2001                                ))
2002                            }
2003                        },
2004                        _ => false,
2005                    };
2006
2007                    if is_flat {
2008                        self.decorate(id, Decoration::Flat, &[]);
2009                    }
2010                }
2011            }
2012        }
2013
2014        Ok(id)
2015    }
2016
2017    fn write_global_variable(
2018        &mut self,
2019        ir_module: &crate::Module,
2020        global_variable: &crate::GlobalVariable,
2021    ) -> Result<Word, Error> {
2022        use spirv::Decoration;
2023
2024        let id = self.id_gen.next();
2025        let class = map_storage_class(global_variable.space);
2026
2027        //self.check(class.required_capabilities())?;
2028
2029        if self.flags.contains(WriterFlags::DEBUG) {
2030            if let Some(ref name) = global_variable.name {
2031                self.debugs.push(Instruction::name(id, name));
2032            }
2033        }
2034
2035        let storage_access = match global_variable.space {
2036            crate::AddressSpace::Storage { access } => Some(access),
2037            _ => match ir_module.types[global_variable.ty].inner {
2038                crate::TypeInner::Image {
2039                    class: crate::ImageClass::Storage { access, .. },
2040                    ..
2041                } => Some(access),
2042                _ => None,
2043            },
2044        };
2045        if let Some(storage_access) = storage_access {
2046            if !storage_access.contains(crate::StorageAccess::LOAD) {
2047                self.decorate(id, Decoration::NonReadable, &[]);
2048            }
2049            if !storage_access.contains(crate::StorageAccess::STORE) {
2050                self.decorate(id, Decoration::NonWritable, &[]);
2051            }
2052        }
2053
2054        // Note: we should be able to substitute `binding_array<Foo, 0>`,
2055        // but there is still code that tries to register the pre-substituted type,
2056        // and it is failing on 0.
2057        let mut substitute_inner_type_lookup = None;
2058        if let Some(ref res_binding) = global_variable.binding {
2059            self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]);
2060            self.decorate(id, Decoration::Binding, &[res_binding.binding]);
2061
2062            if let Some(&BindingInfo {
2063                binding_array_size: Some(remapped_binding_array_size),
2064            }) = self.binding_map.get(res_binding)
2065            {
2066                if let crate::TypeInner::BindingArray { base, .. } =
2067                    ir_module.types[global_variable.ty].inner
2068                {
2069                    let binding_array_type_id =
2070                        self.get_type_id(LookupType::Local(LocalType::BindingArray {
2071                            base,
2072                            size: remapped_binding_array_size,
2073                        }));
2074                    substitute_inner_type_lookup = Some(LookupType::Local(LocalType::Pointer {
2075                        base: binding_array_type_id,
2076                        class,
2077                    }));
2078                }
2079            }
2080        };
2081
2082        let init_word = global_variable
2083            .init
2084            .map(|constant| self.constant_ids[constant]);
2085        let inner_type_id = self.get_type_id(
2086            substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
2087        );
2088
2089        // generate the wrapping structure if needed
2090        let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
2091            let wrapper_type_id = self.id_gen.next();
2092
2093            self.decorate(wrapper_type_id, Decoration::Block, &[]);
2094            let member = crate::StructMember {
2095                name: None,
2096                ty: global_variable.ty,
2097                binding: None,
2098                offset: 0,
2099            };
2100            self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
2101
2102            Instruction::type_struct(wrapper_type_id, &[inner_type_id])
2103                .to_words(&mut self.logical_layout.declarations);
2104
2105            let pointer_type_id = self.id_gen.next();
2106            Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
2107                .to_words(&mut self.logical_layout.declarations);
2108
2109            pointer_type_id
2110        } else {
2111            // This is a global variable in the Storage address space. The only
2112            // way it could have `global_needs_wrapper() == false` is if it has
2113            // a runtime-sized or binding array.
2114            // Runtime-sized arrays were decorated when iterating through struct content.
2115            // Now binding arrays require Block decorating.
2116            if let crate::AddressSpace::Storage { .. } = global_variable.space {
2117                match ir_module.types[global_variable.ty].inner {
2118                    crate::TypeInner::BindingArray { base, .. } => {
2119                        let ty = &ir_module.types[base];
2120                        let mut should_decorate = true;
2121                        // Check if the type has a runtime array.
2122                        // A normal runtime array gets validated out,
2123                        // so only structs can be with runtime arrays
2124                        if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
2125                            // only the last member in a struct can be dynamically sized
2126                            if let Some(last_member) = members.last() {
2127                                if let &crate::TypeInner::Array {
2128                                    size: crate::ArraySize::Dynamic,
2129                                    ..
2130                                } = &ir_module.types[last_member.ty].inner
2131                                {
2132                                    should_decorate = false;
2133                                }
2134                            }
2135                        }
2136                        if should_decorate {
2137                            let decorated_id = self.get_handle_type_id(base);
2138                            self.decorate(decorated_id, Decoration::Block, &[]);
2139                        }
2140                    }
2141                    _ => (),
2142                };
2143            }
2144            if substitute_inner_type_lookup.is_some() {
2145                inner_type_id
2146            } else {
2147                self.get_handle_pointer_type_id(global_variable.ty, class)
2148            }
2149        };
2150
2151        let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
2152            (crate::AddressSpace::Private, _)
2153            | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
2154                init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
2155            }
2156            _ => init_word,
2157        };
2158
2159        Instruction::variable(pointer_type_id, id, class, init_word)
2160            .to_words(&mut self.logical_layout.declarations);
2161        Ok(id)
2162    }
2163
2164    /// Write the necessary decorations for a struct member.
2165    ///
2166    /// Emit decorations for the `index`'th member of the struct type
2167    /// designated by `struct_id`, described by `member`.
2168    fn decorate_struct_member(
2169        &mut self,
2170        struct_id: Word,
2171        index: usize,
2172        member: &crate::StructMember,
2173        arena: &UniqueArena<crate::Type>,
2174    ) -> Result<(), Error> {
2175        use spirv::Decoration;
2176
2177        self.annotations.push(Instruction::member_decorate(
2178            struct_id,
2179            index as u32,
2180            Decoration::Offset,
2181            &[member.offset],
2182        ));
2183
2184        if self.flags.contains(WriterFlags::DEBUG) {
2185            if let Some(ref name) = member.name {
2186                self.debugs
2187                    .push(Instruction::member_name(struct_id, index as u32, name));
2188            }
2189        }
2190
2191        // Matrices and (potentially nested) arrays of matrices both require decorations,
2192        // so "see through" any arrays to determine if they're needed.
2193        let mut member_array_subty_inner = &arena[member.ty].inner;
2194        while let crate::TypeInner::Array { base, .. } = *member_array_subty_inner {
2195            member_array_subty_inner = &arena[base].inner;
2196        }
2197
2198        if let crate::TypeInner::Matrix {
2199            columns: _,
2200            rows,
2201            scalar,
2202        } = *member_array_subty_inner
2203        {
2204            let byte_stride = Alignment::from(rows) * scalar.width as u32;
2205            self.annotations.push(Instruction::member_decorate(
2206                struct_id,
2207                index as u32,
2208                Decoration::ColMajor,
2209                &[],
2210            ));
2211            self.annotations.push(Instruction::member_decorate(
2212                struct_id,
2213                index as u32,
2214                Decoration::MatrixStride,
2215                &[byte_stride],
2216            ));
2217        }
2218
2219        Ok(())
2220    }
2221
2222    pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
2223        match self
2224            .lookup_function_type
2225            .entry(lookup_function_type.clone())
2226        {
2227            Entry::Occupied(e) => *e.get(),
2228            Entry::Vacant(_) => {
2229                let id = self.id_gen.next();
2230                let instruction = Instruction::type_function(
2231                    id,
2232                    lookup_function_type.return_type_id,
2233                    &lookup_function_type.parameter_type_ids,
2234                );
2235                instruction.to_words(&mut self.logical_layout.declarations);
2236                self.lookup_function_type.insert(lookup_function_type, id);
2237                id
2238            }
2239        }
2240    }
2241
2242    fn write_physical_layout(&mut self) {
2243        self.physical_layout.bound = self.id_gen.0 + 1;
2244    }
2245
2246    fn write_logical_layout(
2247        &mut self,
2248        ir_module: &crate::Module,
2249        mod_info: &ModuleInfo,
2250        ep_index: Option<usize>,
2251        debug_info: &Option<DebugInfo>,
2252    ) -> Result<(), Error> {
2253        fn has_view_index_check(
2254            ir_module: &crate::Module,
2255            binding: Option<&crate::Binding>,
2256            ty: Handle<crate::Type>,
2257        ) -> bool {
2258            match ir_module.types[ty].inner {
2259                crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
2260                    has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
2261                }),
2262                _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
2263            }
2264        }
2265
2266        let has_storage_buffers =
2267            ir_module
2268                .global_variables
2269                .iter()
2270                .any(|(_, var)| match var.space {
2271                    crate::AddressSpace::Storage { .. } => true,
2272                    _ => false,
2273                });
2274        let has_view_index = ir_module
2275            .entry_points
2276            .iter()
2277            .flat_map(|entry| entry.function.arguments.iter())
2278            .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
2279        let mut has_ray_query = ir_module.special_types.ray_desc.is_some()
2280            | ir_module.special_types.ray_intersection.is_some();
2281        let has_vertex_return = ir_module.special_types.ray_vertex_return.is_some();
2282
2283        for (_, &crate::Type { ref inner, .. }) in ir_module.types.iter() {
2284            // spirv does not know whether these have vertex return - that is done by us
2285            if let &crate::TypeInner::AccelerationStructure { .. }
2286            | &crate::TypeInner::RayQuery { .. } = inner
2287            {
2288                has_ray_query = true
2289            }
2290        }
2291
2292        if self.physical_layout.version < 0x10300 && has_storage_buffers {
2293            // enable the storage buffer class on < SPV-1.3
2294            Instruction::extension("SPV_KHR_storage_buffer_storage_class")
2295                .to_words(&mut self.logical_layout.extensions);
2296        }
2297        if has_view_index {
2298            Instruction::extension("SPV_KHR_multiview")
2299                .to_words(&mut self.logical_layout.extensions)
2300        }
2301        if has_ray_query {
2302            Instruction::extension("SPV_KHR_ray_query")
2303                .to_words(&mut self.logical_layout.extensions)
2304        }
2305        if has_vertex_return {
2306            Instruction::extension("SPV_KHR_ray_tracing_position_fetch")
2307                .to_words(&mut self.logical_layout.extensions);
2308        }
2309        Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
2310        Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
2311            .to_words(&mut self.logical_layout.ext_inst_imports);
2312
2313        let mut debug_info_inner = None;
2314        if self.flags.contains(WriterFlags::DEBUG) {
2315            if let Some(debug_info) = debug_info.as_ref() {
2316                let source_file_id = self.id_gen.next();
2317                self.debugs.push(Instruction::string(
2318                    &debug_info.file_name.display().to_string(),
2319                    source_file_id,
2320                ));
2321
2322                debug_info_inner = Some(DebugInfoInner {
2323                    source_code: debug_info.source_code,
2324                    source_file_id,
2325                });
2326                self.debugs.append(&mut Instruction::source_auto_continued(
2327                    debug_info.language,
2328                    0,
2329                    &debug_info_inner,
2330                ));
2331            }
2332        }
2333
2334        // write all types
2335        for (handle, _) in ir_module.types.iter() {
2336            self.write_type_declaration_arena(ir_module, handle)?;
2337        }
2338
2339        // write all const-expressions as constants
2340        self.constant_ids
2341            .resize(ir_module.global_expressions.len(), 0);
2342        for (handle, _) in ir_module.global_expressions.iter() {
2343            self.write_constant_expr(handle, ir_module, mod_info)?;
2344        }
2345        debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
2346
2347        // write the name of constants on their respective const-expression initializer
2348        if self.flags.contains(WriterFlags::DEBUG) {
2349            for (_, constant) in ir_module.constants.iter() {
2350                if let Some(ref name) = constant.name {
2351                    let id = self.constant_ids[constant.init];
2352                    self.debugs.push(Instruction::name(id, name));
2353                }
2354            }
2355        }
2356
2357        // write all global variables
2358        for (handle, var) in ir_module.global_variables.iter() {
2359            // If a single entry point was specified, only write `OpVariable` instructions
2360            // for the globals it actually uses. Emit dummies for the others,
2361            // to preserve the indices in `global_variables`.
2362            let gvar = match ep_index {
2363                Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
2364                    GlobalVariable::dummy()
2365                }
2366                _ => {
2367                    let id = self.write_global_variable(ir_module, var)?;
2368                    GlobalVariable::new(id)
2369                }
2370            };
2371            self.global_variables.insert(handle, gvar);
2372        }
2373
2374        // write all functions
2375        for (handle, ir_function) in ir_module.functions.iter() {
2376            let info = &mod_info[handle];
2377            if let Some(index) = ep_index {
2378                let ep_info = mod_info.get_entry_point(index);
2379                // If this function uses globals that we omitted from the SPIR-V
2380                // because the entry point and its callees didn't use them,
2381                // then we must skip it.
2382                if !ep_info.dominates_global_use(info) {
2383                    log::info!("Skip function {:?}", ir_function.name);
2384                    continue;
2385                }
2386
2387                // Skip functions that that are not compatible with this entry point's stage.
2388                //
2389                // When validation is enabled, it rejects modules whose entry points try to call
2390                // incompatible functions, so if we got this far, then any functions incompatible
2391                // with our selected entry point must not be used.
2392                //
2393                // When validation is disabled, `fun_info.available_stages` is always just
2394                // `ShaderStages::all()`, so this will write all functions in the module, and
2395                // the downstream GLSL compiler will catch any problems.
2396                if !info.available_stages.contains(ep_info.available_stages) {
2397                    continue;
2398                }
2399            }
2400            let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
2401            self.lookup_function.insert(handle, id);
2402        }
2403
2404        // write all or one entry points
2405        for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
2406            if ep_index.is_some() && ep_index != Some(index) {
2407                continue;
2408            }
2409            let info = mod_info.get_entry_point(index);
2410            let ep_instruction =
2411                self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
2412            ep_instruction.to_words(&mut self.logical_layout.entry_points);
2413        }
2414
2415        for capability in self.capabilities_used.iter() {
2416            Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
2417        }
2418        for extension in self.extensions_used.iter() {
2419            Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
2420        }
2421        if ir_module.entry_points.is_empty() {
2422            // SPIR-V doesn't like modules without entry points
2423            Instruction::capability(spirv::Capability::Linkage)
2424                .to_words(&mut self.logical_layout.capabilities);
2425        }
2426
2427        let addressing_model = spirv::AddressingModel::Logical;
2428        let memory_model = spirv::MemoryModel::GLSL450;
2429        //self.check(addressing_model.required_capabilities())?;
2430        //self.check(memory_model.required_capabilities())?;
2431
2432        Instruction::memory_model(addressing_model, memory_model)
2433            .to_words(&mut self.logical_layout.memory_model);
2434
2435        if self.flags.contains(WriterFlags::DEBUG) {
2436            for debug in self.debugs.iter() {
2437                debug.to_words(&mut self.logical_layout.debugs);
2438            }
2439        }
2440
2441        for annotation in self.annotations.iter() {
2442            annotation.to_words(&mut self.logical_layout.annotations);
2443        }
2444
2445        Ok(())
2446    }
2447
2448    pub fn write(
2449        &mut self,
2450        ir_module: &crate::Module,
2451        info: &ModuleInfo,
2452        pipeline_options: Option<&PipelineOptions>,
2453        debug_info: &Option<DebugInfo>,
2454        words: &mut Vec<Word>,
2455    ) -> Result<(), Error> {
2456        self.reset();
2457
2458        // Try to find the entry point and corresponding index
2459        let ep_index = match pipeline_options {
2460            Some(po) => {
2461                let index = ir_module
2462                    .entry_points
2463                    .iter()
2464                    .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
2465                    .ok_or(Error::EntryPointNotFound)?;
2466                Some(index)
2467            }
2468            None => None,
2469        };
2470
2471        self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
2472        self.write_physical_layout();
2473
2474        self.physical_layout.in_words(words);
2475        self.logical_layout.in_words(words);
2476        Ok(())
2477    }
2478
2479    /// Return the set of capabilities the last module written used.
2480    pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
2481        &self.capabilities_used
2482    }
2483
2484    pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
2485        self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
2486        self.use_extension("SPV_EXT_descriptor_indexing");
2487        self.decorate(id, spirv::Decoration::NonUniform, &[]);
2488        Ok(())
2489    }
2490}
2491
2492#[test]
2493fn test_write_physical_layout() {
2494    let mut writer = Writer::new(&Options::default()).unwrap();
2495    assert_eq!(writer.physical_layout.bound, 0);
2496    writer.write_physical_layout();
2497    assert_eq!(writer.physical_layout.bound, 3);
2498}