naga/back/spv/
ray.rs

1/*!
2Generating SPIR-V for ray query operations.
3*/
4
5use alloc::vec;
6
7use super::{
8    Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType,
9    Writer,
10};
11use crate::arena::Handle;
12
13impl Writer {
14    pub(super) fn write_ray_query_get_intersection_function(
15        &mut self,
16        is_committed: bool,
17        ir_module: &crate::Module,
18    ) -> spirv::Word {
19        if let Some(func_id) = self.ray_get_intersection_function {
20            return func_id;
21        }
22        let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
23        let intersection_type_id = self.get_handle_type_id(ray_intersection);
24        let intersection_pointer_type_id =
25            self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
26
27        let flag_type_id = self.get_u32_type_id();
28        let flag_pointer_type_id =
29            self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
30
31        let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
32            columns: crate::VectorSize::Quad,
33            rows: crate::VectorSize::Tri,
34            scalar: crate::Scalar::F32,
35        });
36        let transform_pointer_type_id =
37            self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
38
39        let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
40            size: crate::VectorSize::Bi,
41            scalar: crate::Scalar::F32,
42        });
43        let barycentrics_pointer_type_id =
44            self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
45
46        let bool_type_id = self.get_bool_type_id();
47        let bool_pointer_type_id =
48            self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
49
50        let scalar_type_id = self.get_f32_type_id();
51        let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
52
53        let argument_type_id = self.get_ray_query_pointer_id(ir_module);
54
55        let func_ty = self.get_function_type(LookupFunctionType {
56            parameter_type_ids: vec![argument_type_id],
57            return_type_id: intersection_type_id,
58        });
59
60        let mut function = Function::default();
61        let func_id = self.id_gen.next();
62        function.signature = Some(Instruction::function(
63            intersection_type_id,
64            func_id,
65            spirv::FunctionControl::empty(),
66            func_ty,
67        ));
68        let blank_intersection = self.get_constant_null(intersection_type_id);
69        let query_id = self.id_gen.next();
70        let instruction = Instruction::function_parameter(argument_type_id, query_id);
71        function.parameters.push(FunctionArgument {
72            instruction,
73            handle_id: 0,
74        });
75
76        let label_id = self.id_gen.next();
77        let mut block = Block::new(label_id);
78
79        let blank_intersection_id = self.id_gen.next();
80        block.body.push(Instruction::variable(
81            intersection_pointer_type_id,
82            blank_intersection_id,
83            spirv::StorageClass::Function,
84            Some(blank_intersection),
85        ));
86
87        let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
88            spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
89        } else {
90            spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
91        } as _));
92        let raw_kind_id = self.id_gen.next();
93        block.body.push(Instruction::ray_query_get_intersection(
94            spirv::Op::RayQueryGetIntersectionTypeKHR,
95            flag_type_id,
96            raw_kind_id,
97            query_id,
98            intersection_id,
99        ));
100        let kind_id = if is_committed {
101            // Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
102            raw_kind_id
103        } else {
104            // Remap from the candidate kind to IR
105            let condition_id = self.id_gen.next();
106            let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
107                spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
108                    as _,
109            ));
110            block.body.push(Instruction::binary(
111                spirv::Op::IEqual,
112                self.get_bool_type_id(),
113                condition_id,
114                raw_kind_id,
115                committed_triangle_kind_id,
116            ));
117            let kind_id = self.id_gen.next();
118            block.body.push(Instruction::select(
119                flag_type_id,
120                kind_id,
121                condition_id,
122                self.get_constant_scalar(crate::Literal::U32(
123                    crate::RayQueryIntersection::Triangle as _,
124                )),
125                self.get_constant_scalar(crate::Literal::U32(
126                    crate::RayQueryIntersection::Aabb as _,
127                )),
128            ));
129            kind_id
130        };
131        let idx_id = self.get_index_constant(0);
132        let access_idx = self.id_gen.next();
133        block.body.push(Instruction::access_chain(
134            flag_pointer_type_id,
135            access_idx,
136            blank_intersection_id,
137            &[idx_id],
138        ));
139        block
140            .body
141            .push(Instruction::store(access_idx, kind_id, None));
142
143        let not_none_comp_id = self.id_gen.next();
144        let none_id =
145            self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
146        block.body.push(Instruction::binary(
147            spirv::Op::INotEqual,
148            self.get_bool_type_id(),
149            not_none_comp_id,
150            kind_id,
151            none_id,
152        ));
153
154        let not_none_label_id = self.id_gen.next();
155        let mut not_none_block = Block::new(not_none_label_id);
156
157        let final_label_id = self.id_gen.next();
158        let mut final_block = Block::new(final_label_id);
159
160        block.body.push(Instruction::selection_merge(
161            final_label_id,
162            spirv::SelectionControl::NONE,
163        ));
164        function.consume(
165            block,
166            Instruction::branch_conditional(not_none_comp_id, not_none_label_id, final_label_id),
167        );
168
169        let instance_custom_index_id = self.id_gen.next();
170        not_none_block
171            .body
172            .push(Instruction::ray_query_get_intersection(
173                spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
174                flag_type_id,
175                instance_custom_index_id,
176                query_id,
177                intersection_id,
178            ));
179        let instance_id = self.id_gen.next();
180        not_none_block
181            .body
182            .push(Instruction::ray_query_get_intersection(
183                spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
184                flag_type_id,
185                instance_id,
186                query_id,
187                intersection_id,
188            ));
189        let sbt_record_offset_id = self.id_gen.next();
190        not_none_block
191            .body
192            .push(Instruction::ray_query_get_intersection(
193                spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
194                flag_type_id,
195                sbt_record_offset_id,
196                query_id,
197                intersection_id,
198            ));
199        let geometry_index_id = self.id_gen.next();
200        not_none_block
201            .body
202            .push(Instruction::ray_query_get_intersection(
203                spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
204                flag_type_id,
205                geometry_index_id,
206                query_id,
207                intersection_id,
208            ));
209        let primitive_index_id = self.id_gen.next();
210        not_none_block
211            .body
212            .push(Instruction::ray_query_get_intersection(
213                spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
214                flag_type_id,
215                primitive_index_id,
216                query_id,
217                intersection_id,
218            ));
219
220        //Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
221        // but it's not a property of an intersection.
222
223        let object_to_world_id = self.id_gen.next();
224        not_none_block
225            .body
226            .push(Instruction::ray_query_get_intersection(
227                spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
228                transform_type_id,
229                object_to_world_id,
230                query_id,
231                intersection_id,
232            ));
233        let world_to_object_id = self.id_gen.next();
234        not_none_block
235            .body
236            .push(Instruction::ray_query_get_intersection(
237                spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
238                transform_type_id,
239                world_to_object_id,
240                query_id,
241                intersection_id,
242            ));
243
244        // instance custom index
245        let idx_id = self.get_index_constant(2);
246        let access_idx = self.id_gen.next();
247        not_none_block.body.push(Instruction::access_chain(
248            flag_pointer_type_id,
249            access_idx,
250            blank_intersection_id,
251            &[idx_id],
252        ));
253        not_none_block.body.push(Instruction::store(
254            access_idx,
255            instance_custom_index_id,
256            None,
257        ));
258
259        // instance
260        let idx_id = self.get_index_constant(3);
261        let access_idx = self.id_gen.next();
262        not_none_block.body.push(Instruction::access_chain(
263            flag_pointer_type_id,
264            access_idx,
265            blank_intersection_id,
266            &[idx_id],
267        ));
268        not_none_block
269            .body
270            .push(Instruction::store(access_idx, instance_id, None));
271
272        let idx_id = self.get_index_constant(4);
273        let access_idx = self.id_gen.next();
274        not_none_block.body.push(Instruction::access_chain(
275            flag_pointer_type_id,
276            access_idx,
277            blank_intersection_id,
278            &[idx_id],
279        ));
280        not_none_block
281            .body
282            .push(Instruction::store(access_idx, sbt_record_offset_id, None));
283
284        let idx_id = self.get_index_constant(5);
285        let access_idx = self.id_gen.next();
286        not_none_block.body.push(Instruction::access_chain(
287            flag_pointer_type_id,
288            access_idx,
289            blank_intersection_id,
290            &[idx_id],
291        ));
292        not_none_block
293            .body
294            .push(Instruction::store(access_idx, geometry_index_id, None));
295
296        let idx_id = self.get_index_constant(6);
297        let access_idx = self.id_gen.next();
298        not_none_block.body.push(Instruction::access_chain(
299            flag_pointer_type_id,
300            access_idx,
301            blank_intersection_id,
302            &[idx_id],
303        ));
304        not_none_block
305            .body
306            .push(Instruction::store(access_idx, primitive_index_id, None));
307
308        let idx_id = self.get_index_constant(9);
309        let access_idx = self.id_gen.next();
310        not_none_block.body.push(Instruction::access_chain(
311            transform_pointer_type_id,
312            access_idx,
313            blank_intersection_id,
314            &[idx_id],
315        ));
316        not_none_block
317            .body
318            .push(Instruction::store(access_idx, object_to_world_id, None));
319
320        let idx_id = self.get_index_constant(10);
321        let access_idx = self.id_gen.next();
322        not_none_block.body.push(Instruction::access_chain(
323            transform_pointer_type_id,
324            access_idx,
325            blank_intersection_id,
326            &[idx_id],
327        ));
328        not_none_block
329            .body
330            .push(Instruction::store(access_idx, world_to_object_id, None));
331
332        let tri_comp_id = self.id_gen.next();
333        let tri_id = self.get_constant_scalar(crate::Literal::U32(
334            crate::RayQueryIntersection::Triangle as _,
335        ));
336        not_none_block.body.push(Instruction::binary(
337            spirv::Op::IEqual,
338            self.get_bool_type_id(),
339            tri_comp_id,
340            kind_id,
341            tri_id,
342        ));
343
344        let tri_label_id = self.id_gen.next();
345        let mut tri_block = Block::new(tri_label_id);
346
347        let merge_label_id = self.id_gen.next();
348        let merge_block = Block::new(merge_label_id);
349        // t
350        {
351            let block = if is_committed {
352                &mut not_none_block
353            } else {
354                &mut tri_block
355            };
356            let t_id = self.id_gen.next();
357            block.body.push(Instruction::ray_query_get_intersection(
358                spirv::Op::RayQueryGetIntersectionTKHR,
359                scalar_type_id,
360                t_id,
361                query_id,
362                intersection_id,
363            ));
364            let idx_id = self.get_index_constant(1);
365            let access_idx = self.id_gen.next();
366            block.body.push(Instruction::access_chain(
367                float_pointer_type_id,
368                access_idx,
369                blank_intersection_id,
370                &[idx_id],
371            ));
372            block.body.push(Instruction::store(access_idx, t_id, None));
373        }
374        not_none_block.body.push(Instruction::selection_merge(
375            merge_label_id,
376            spirv::SelectionControl::NONE,
377        ));
378        function.consume(
379            not_none_block,
380            Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
381        );
382
383        let barycentrics_id = self.id_gen.next();
384        tri_block.body.push(Instruction::ray_query_get_intersection(
385            spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
386            barycentrics_type_id,
387            barycentrics_id,
388            query_id,
389            intersection_id,
390        ));
391
392        let front_face_id = self.id_gen.next();
393        tri_block.body.push(Instruction::ray_query_get_intersection(
394            spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
395            bool_type_id,
396            front_face_id,
397            query_id,
398            intersection_id,
399        ));
400
401        let idx_id = self.get_index_constant(7);
402        let access_idx = self.id_gen.next();
403        tri_block.body.push(Instruction::access_chain(
404            barycentrics_pointer_type_id,
405            access_idx,
406            blank_intersection_id,
407            &[idx_id],
408        ));
409        tri_block
410            .body
411            .push(Instruction::store(access_idx, barycentrics_id, None));
412
413        let idx_id = self.get_index_constant(8);
414        let access_idx = self.id_gen.next();
415        tri_block.body.push(Instruction::access_chain(
416            bool_pointer_type_id,
417            access_idx,
418            blank_intersection_id,
419            &[idx_id],
420        ));
421        tri_block
422            .body
423            .push(Instruction::store(access_idx, front_face_id, None));
424        function.consume(tri_block, Instruction::branch(merge_label_id));
425        function.consume(merge_block, Instruction::branch(final_label_id));
426
427        let loaded_blank_intersection_id = self.id_gen.next();
428        final_block.body.push(Instruction::load(
429            intersection_type_id,
430            loaded_blank_intersection_id,
431            blank_intersection_id,
432            None,
433        ));
434        function.consume(
435            final_block,
436            Instruction::return_value(loaded_blank_intersection_id),
437        );
438
439        function.to_words(&mut self.logical_layout.function_definitions);
440        self.ray_get_intersection_function = Some(func_id);
441        func_id
442    }
443}
444
445impl BlockContext<'_> {
446    pub(super) fn write_ray_query_function(
447        &mut self,
448        query: Handle<crate::Expression>,
449        function: &crate::RayQueryFunction,
450        block: &mut Block,
451    ) {
452        let query_id = self.cached[query];
453        match *function {
454            crate::RayQueryFunction::Initialize {
455                acceleration_structure,
456                descriptor,
457            } => {
458                //Note: composite extract indices and types must match `generate_ray_desc_type`
459                let desc_id = self.cached[descriptor];
460                let acc_struct_id = self.get_handle_id(acceleration_structure);
461
462                let flag_type_id =
463                    self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
464                let ray_flags_id = self.gen_id();
465                block.body.push(Instruction::composite_extract(
466                    flag_type_id,
467                    ray_flags_id,
468                    desc_id,
469                    &[0],
470                ));
471                let cull_mask_id = self.gen_id();
472                block.body.push(Instruction::composite_extract(
473                    flag_type_id,
474                    cull_mask_id,
475                    desc_id,
476                    &[1],
477                ));
478
479                let scalar_type_id =
480                    self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32));
481                let tmin_id = self.gen_id();
482                block.body.push(Instruction::composite_extract(
483                    scalar_type_id,
484                    tmin_id,
485                    desc_id,
486                    &[2],
487                ));
488                let tmax_id = self.gen_id();
489                block.body.push(Instruction::composite_extract(
490                    scalar_type_id,
491                    tmax_id,
492                    desc_id,
493                    &[3],
494                ));
495
496                let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
497                    size: crate::VectorSize::Tri,
498                    scalar: crate::Scalar::F32,
499                });
500                let ray_origin_id = self.gen_id();
501                block.body.push(Instruction::composite_extract(
502                    vector_type_id,
503                    ray_origin_id,
504                    desc_id,
505                    &[4],
506                ));
507                let ray_dir_id = self.gen_id();
508                block.body.push(Instruction::composite_extract(
509                    vector_type_id,
510                    ray_dir_id,
511                    desc_id,
512                    &[5],
513                ));
514
515                block.body.push(Instruction::ray_query_initialize(
516                    query_id,
517                    acc_struct_id,
518                    ray_flags_id,
519                    cull_mask_id,
520                    ray_origin_id,
521                    tmin_id,
522                    ray_dir_id,
523                    tmax_id,
524                ));
525            }
526            crate::RayQueryFunction::Proceed { result } => {
527                let id = self.gen_id();
528                self.cached[result] = id;
529                let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
530
531                block
532                    .body
533                    .push(Instruction::ray_query_proceed(result_type_id, id, query_id));
534            }
535            crate::RayQueryFunction::GenerateIntersection { hit_t } => {
536                let hit_id = self.cached[hit_t];
537                block
538                    .body
539                    .push(Instruction::ray_query_generate_intersection(
540                        query_id, hit_id,
541                    ));
542            }
543            crate::RayQueryFunction::ConfirmIntersection => {
544                block
545                    .body
546                    .push(Instruction::ray_query_confirm_intersection(query_id));
547            }
548            crate::RayQueryFunction::Terminate => {}
549        }
550    }
551
552    pub(super) fn write_ray_query_return_vertex_position(
553        &mut self,
554        query: Handle<crate::Expression>,
555        block: &mut Block,
556        is_committed: bool,
557    ) -> spirv::Word {
558        let query_id = self.cached[query];
559        let id = self.gen_id();
560        let ray_vertex_return_ty = self
561            .ir_module
562            .special_types
563            .ray_vertex_return
564            .expect("type should have been populated");
565        let ray_vertex_return_ty_id = self.writer.get_handle_type_id(ray_vertex_return_ty);
566        let intersection_id =
567            self.writer
568                .get_constant_scalar(crate::Literal::U32(if is_committed {
569                    spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
570                } else {
571                    spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
572                } as _));
573        block
574            .body
575            .push(Instruction::ray_query_return_vertex_position(
576                ray_vertex_return_ty_id,
577                id,
578                query_id,
579                intersection_id,
580            ));
581        id
582    }
583}