1use alloc::vec;
6
7use super::{
8 Block, BlockContext, Function, FunctionArgument, Instruction, LookupFunctionType, NumericType,
9 Writer,
10};
11use crate::arena::Handle;
12
13impl Writer {
14 pub(super) fn write_ray_query_get_intersection_function(
15 &mut self,
16 is_committed: bool,
17 ir_module: &crate::Module,
18 ) -> spirv::Word {
19 if let Some(func_id) = self.ray_get_intersection_function {
20 return func_id;
21 }
22 let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
23 let intersection_type_id = self.get_handle_type_id(ray_intersection);
24 let intersection_pointer_type_id =
25 self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);
26
27 let flag_type_id = self.get_u32_type_id();
28 let flag_pointer_type_id =
29 self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);
30
31 let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
32 columns: crate::VectorSize::Quad,
33 rows: crate::VectorSize::Tri,
34 scalar: crate::Scalar::F32,
35 });
36 let transform_pointer_type_id =
37 self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);
38
39 let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
40 size: crate::VectorSize::Bi,
41 scalar: crate::Scalar::F32,
42 });
43 let barycentrics_pointer_type_id =
44 self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);
45
46 let bool_type_id = self.get_bool_type_id();
47 let bool_pointer_type_id =
48 self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);
49
50 let scalar_type_id = self.get_f32_type_id();
51 let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
52
53 let argument_type_id = self.get_ray_query_pointer_id(ir_module);
54
55 let func_ty = self.get_function_type(LookupFunctionType {
56 parameter_type_ids: vec![argument_type_id],
57 return_type_id: intersection_type_id,
58 });
59
60 let mut function = Function::default();
61 let func_id = self.id_gen.next();
62 function.signature = Some(Instruction::function(
63 intersection_type_id,
64 func_id,
65 spirv::FunctionControl::empty(),
66 func_ty,
67 ));
68 let blank_intersection = self.get_constant_null(intersection_type_id);
69 let query_id = self.id_gen.next();
70 let instruction = Instruction::function_parameter(argument_type_id, query_id);
71 function.parameters.push(FunctionArgument {
72 instruction,
73 handle_id: 0,
74 });
75
76 let label_id = self.id_gen.next();
77 let mut block = Block::new(label_id);
78
79 let blank_intersection_id = self.id_gen.next();
80 block.body.push(Instruction::variable(
81 intersection_pointer_type_id,
82 blank_intersection_id,
83 spirv::StorageClass::Function,
84 Some(blank_intersection),
85 ));
86
87 let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
88 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
89 } else {
90 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
91 } as _));
92 let raw_kind_id = self.id_gen.next();
93 block.body.push(Instruction::ray_query_get_intersection(
94 spirv::Op::RayQueryGetIntersectionTypeKHR,
95 flag_type_id,
96 raw_kind_id,
97 query_id,
98 intersection_id,
99 ));
100 let kind_id = if is_committed {
101 raw_kind_id
103 } else {
104 let condition_id = self.id_gen.next();
106 let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
107 spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
108 as _,
109 ));
110 block.body.push(Instruction::binary(
111 spirv::Op::IEqual,
112 self.get_bool_type_id(),
113 condition_id,
114 raw_kind_id,
115 committed_triangle_kind_id,
116 ));
117 let kind_id = self.id_gen.next();
118 block.body.push(Instruction::select(
119 flag_type_id,
120 kind_id,
121 condition_id,
122 self.get_constant_scalar(crate::Literal::U32(
123 crate::RayQueryIntersection::Triangle as _,
124 )),
125 self.get_constant_scalar(crate::Literal::U32(
126 crate::RayQueryIntersection::Aabb as _,
127 )),
128 ));
129 kind_id
130 };
131 let idx_id = self.get_index_constant(0);
132 let access_idx = self.id_gen.next();
133 block.body.push(Instruction::access_chain(
134 flag_pointer_type_id,
135 access_idx,
136 blank_intersection_id,
137 &[idx_id],
138 ));
139 block
140 .body
141 .push(Instruction::store(access_idx, kind_id, None));
142
143 let not_none_comp_id = self.id_gen.next();
144 let none_id =
145 self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
146 block.body.push(Instruction::binary(
147 spirv::Op::INotEqual,
148 self.get_bool_type_id(),
149 not_none_comp_id,
150 kind_id,
151 none_id,
152 ));
153
154 let not_none_label_id = self.id_gen.next();
155 let mut not_none_block = Block::new(not_none_label_id);
156
157 let final_label_id = self.id_gen.next();
158 let mut final_block = Block::new(final_label_id);
159
160 block.body.push(Instruction::selection_merge(
161 final_label_id,
162 spirv::SelectionControl::NONE,
163 ));
164 function.consume(
165 block,
166 Instruction::branch_conditional(not_none_comp_id, not_none_label_id, final_label_id),
167 );
168
169 let instance_custom_index_id = self.id_gen.next();
170 not_none_block
171 .body
172 .push(Instruction::ray_query_get_intersection(
173 spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
174 flag_type_id,
175 instance_custom_index_id,
176 query_id,
177 intersection_id,
178 ));
179 let instance_id = self.id_gen.next();
180 not_none_block
181 .body
182 .push(Instruction::ray_query_get_intersection(
183 spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
184 flag_type_id,
185 instance_id,
186 query_id,
187 intersection_id,
188 ));
189 let sbt_record_offset_id = self.id_gen.next();
190 not_none_block
191 .body
192 .push(Instruction::ray_query_get_intersection(
193 spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
194 flag_type_id,
195 sbt_record_offset_id,
196 query_id,
197 intersection_id,
198 ));
199 let geometry_index_id = self.id_gen.next();
200 not_none_block
201 .body
202 .push(Instruction::ray_query_get_intersection(
203 spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
204 flag_type_id,
205 geometry_index_id,
206 query_id,
207 intersection_id,
208 ));
209 let primitive_index_id = self.id_gen.next();
210 not_none_block
211 .body
212 .push(Instruction::ray_query_get_intersection(
213 spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
214 flag_type_id,
215 primitive_index_id,
216 query_id,
217 intersection_id,
218 ));
219
220 let object_to_world_id = self.id_gen.next();
224 not_none_block
225 .body
226 .push(Instruction::ray_query_get_intersection(
227 spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
228 transform_type_id,
229 object_to_world_id,
230 query_id,
231 intersection_id,
232 ));
233 let world_to_object_id = self.id_gen.next();
234 not_none_block
235 .body
236 .push(Instruction::ray_query_get_intersection(
237 spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
238 transform_type_id,
239 world_to_object_id,
240 query_id,
241 intersection_id,
242 ));
243
244 let idx_id = self.get_index_constant(2);
246 let access_idx = self.id_gen.next();
247 not_none_block.body.push(Instruction::access_chain(
248 flag_pointer_type_id,
249 access_idx,
250 blank_intersection_id,
251 &[idx_id],
252 ));
253 not_none_block.body.push(Instruction::store(
254 access_idx,
255 instance_custom_index_id,
256 None,
257 ));
258
259 let idx_id = self.get_index_constant(3);
261 let access_idx = self.id_gen.next();
262 not_none_block.body.push(Instruction::access_chain(
263 flag_pointer_type_id,
264 access_idx,
265 blank_intersection_id,
266 &[idx_id],
267 ));
268 not_none_block
269 .body
270 .push(Instruction::store(access_idx, instance_id, None));
271
272 let idx_id = self.get_index_constant(4);
273 let access_idx = self.id_gen.next();
274 not_none_block.body.push(Instruction::access_chain(
275 flag_pointer_type_id,
276 access_idx,
277 blank_intersection_id,
278 &[idx_id],
279 ));
280 not_none_block
281 .body
282 .push(Instruction::store(access_idx, sbt_record_offset_id, None));
283
284 let idx_id = self.get_index_constant(5);
285 let access_idx = self.id_gen.next();
286 not_none_block.body.push(Instruction::access_chain(
287 flag_pointer_type_id,
288 access_idx,
289 blank_intersection_id,
290 &[idx_id],
291 ));
292 not_none_block
293 .body
294 .push(Instruction::store(access_idx, geometry_index_id, None));
295
296 let idx_id = self.get_index_constant(6);
297 let access_idx = self.id_gen.next();
298 not_none_block.body.push(Instruction::access_chain(
299 flag_pointer_type_id,
300 access_idx,
301 blank_intersection_id,
302 &[idx_id],
303 ));
304 not_none_block
305 .body
306 .push(Instruction::store(access_idx, primitive_index_id, None));
307
308 let idx_id = self.get_index_constant(9);
309 let access_idx = self.id_gen.next();
310 not_none_block.body.push(Instruction::access_chain(
311 transform_pointer_type_id,
312 access_idx,
313 blank_intersection_id,
314 &[idx_id],
315 ));
316 not_none_block
317 .body
318 .push(Instruction::store(access_idx, object_to_world_id, None));
319
320 let idx_id = self.get_index_constant(10);
321 let access_idx = self.id_gen.next();
322 not_none_block.body.push(Instruction::access_chain(
323 transform_pointer_type_id,
324 access_idx,
325 blank_intersection_id,
326 &[idx_id],
327 ));
328 not_none_block
329 .body
330 .push(Instruction::store(access_idx, world_to_object_id, None));
331
332 let tri_comp_id = self.id_gen.next();
333 let tri_id = self.get_constant_scalar(crate::Literal::U32(
334 crate::RayQueryIntersection::Triangle as _,
335 ));
336 not_none_block.body.push(Instruction::binary(
337 spirv::Op::IEqual,
338 self.get_bool_type_id(),
339 tri_comp_id,
340 kind_id,
341 tri_id,
342 ));
343
344 let tri_label_id = self.id_gen.next();
345 let mut tri_block = Block::new(tri_label_id);
346
347 let merge_label_id = self.id_gen.next();
348 let merge_block = Block::new(merge_label_id);
349 {
351 let block = if is_committed {
352 &mut not_none_block
353 } else {
354 &mut tri_block
355 };
356 let t_id = self.id_gen.next();
357 block.body.push(Instruction::ray_query_get_intersection(
358 spirv::Op::RayQueryGetIntersectionTKHR,
359 scalar_type_id,
360 t_id,
361 query_id,
362 intersection_id,
363 ));
364 let idx_id = self.get_index_constant(1);
365 let access_idx = self.id_gen.next();
366 block.body.push(Instruction::access_chain(
367 float_pointer_type_id,
368 access_idx,
369 blank_intersection_id,
370 &[idx_id],
371 ));
372 block.body.push(Instruction::store(access_idx, t_id, None));
373 }
374 not_none_block.body.push(Instruction::selection_merge(
375 merge_label_id,
376 spirv::SelectionControl::NONE,
377 ));
378 function.consume(
379 not_none_block,
380 Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
381 );
382
383 let barycentrics_id = self.id_gen.next();
384 tri_block.body.push(Instruction::ray_query_get_intersection(
385 spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
386 barycentrics_type_id,
387 barycentrics_id,
388 query_id,
389 intersection_id,
390 ));
391
392 let front_face_id = self.id_gen.next();
393 tri_block.body.push(Instruction::ray_query_get_intersection(
394 spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
395 bool_type_id,
396 front_face_id,
397 query_id,
398 intersection_id,
399 ));
400
401 let idx_id = self.get_index_constant(7);
402 let access_idx = self.id_gen.next();
403 tri_block.body.push(Instruction::access_chain(
404 barycentrics_pointer_type_id,
405 access_idx,
406 blank_intersection_id,
407 &[idx_id],
408 ));
409 tri_block
410 .body
411 .push(Instruction::store(access_idx, barycentrics_id, None));
412
413 let idx_id = self.get_index_constant(8);
414 let access_idx = self.id_gen.next();
415 tri_block.body.push(Instruction::access_chain(
416 bool_pointer_type_id,
417 access_idx,
418 blank_intersection_id,
419 &[idx_id],
420 ));
421 tri_block
422 .body
423 .push(Instruction::store(access_idx, front_face_id, None));
424 function.consume(tri_block, Instruction::branch(merge_label_id));
425 function.consume(merge_block, Instruction::branch(final_label_id));
426
427 let loaded_blank_intersection_id = self.id_gen.next();
428 final_block.body.push(Instruction::load(
429 intersection_type_id,
430 loaded_blank_intersection_id,
431 blank_intersection_id,
432 None,
433 ));
434 function.consume(
435 final_block,
436 Instruction::return_value(loaded_blank_intersection_id),
437 );
438
439 function.to_words(&mut self.logical_layout.function_definitions);
440 self.ray_get_intersection_function = Some(func_id);
441 func_id
442 }
443}
444
445impl BlockContext<'_> {
446 pub(super) fn write_ray_query_function(
447 &mut self,
448 query: Handle<crate::Expression>,
449 function: &crate::RayQueryFunction,
450 block: &mut Block,
451 ) {
452 let query_id = self.cached[query];
453 match *function {
454 crate::RayQueryFunction::Initialize {
455 acceleration_structure,
456 descriptor,
457 } => {
458 let desc_id = self.cached[descriptor];
460 let acc_struct_id = self.get_handle_id(acceleration_structure);
461
462 let flag_type_id =
463 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::U32));
464 let ray_flags_id = self.gen_id();
465 block.body.push(Instruction::composite_extract(
466 flag_type_id,
467 ray_flags_id,
468 desc_id,
469 &[0],
470 ));
471 let cull_mask_id = self.gen_id();
472 block.body.push(Instruction::composite_extract(
473 flag_type_id,
474 cull_mask_id,
475 desc_id,
476 &[1],
477 ));
478
479 let scalar_type_id =
480 self.get_numeric_type_id(NumericType::Scalar(crate::Scalar::F32));
481 let tmin_id = self.gen_id();
482 block.body.push(Instruction::composite_extract(
483 scalar_type_id,
484 tmin_id,
485 desc_id,
486 &[2],
487 ));
488 let tmax_id = self.gen_id();
489 block.body.push(Instruction::composite_extract(
490 scalar_type_id,
491 tmax_id,
492 desc_id,
493 &[3],
494 ));
495
496 let vector_type_id = self.get_numeric_type_id(NumericType::Vector {
497 size: crate::VectorSize::Tri,
498 scalar: crate::Scalar::F32,
499 });
500 let ray_origin_id = self.gen_id();
501 block.body.push(Instruction::composite_extract(
502 vector_type_id,
503 ray_origin_id,
504 desc_id,
505 &[4],
506 ));
507 let ray_dir_id = self.gen_id();
508 block.body.push(Instruction::composite_extract(
509 vector_type_id,
510 ray_dir_id,
511 desc_id,
512 &[5],
513 ));
514
515 block.body.push(Instruction::ray_query_initialize(
516 query_id,
517 acc_struct_id,
518 ray_flags_id,
519 cull_mask_id,
520 ray_origin_id,
521 tmin_id,
522 ray_dir_id,
523 tmax_id,
524 ));
525 }
526 crate::RayQueryFunction::Proceed { result } => {
527 let id = self.gen_id();
528 self.cached[result] = id;
529 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
530
531 block
532 .body
533 .push(Instruction::ray_query_proceed(result_type_id, id, query_id));
534 }
535 crate::RayQueryFunction::GenerateIntersection { hit_t } => {
536 let hit_id = self.cached[hit_t];
537 block
538 .body
539 .push(Instruction::ray_query_generate_intersection(
540 query_id, hit_id,
541 ));
542 }
543 crate::RayQueryFunction::ConfirmIntersection => {
544 block
545 .body
546 .push(Instruction::ray_query_confirm_intersection(query_id));
547 }
548 crate::RayQueryFunction::Terminate => {}
549 }
550 }
551
552 pub(super) fn write_ray_query_return_vertex_position(
553 &mut self,
554 query: Handle<crate::Expression>,
555 block: &mut Block,
556 is_committed: bool,
557 ) -> spirv::Word {
558 let query_id = self.cached[query];
559 let id = self.gen_id();
560 let ray_vertex_return_ty = self
561 .ir_module
562 .special_types
563 .ray_vertex_return
564 .expect("type should have been populated");
565 let ray_vertex_return_ty_id = self.writer.get_handle_type_id(ray_vertex_return_ty);
566 let intersection_id =
567 self.writer
568 .get_constant_scalar(crate::Literal::U32(if is_committed {
569 spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
570 } else {
571 spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
572 } as _));
573 block
574 .body
575 .push(Instruction::ray_query_return_vertex_position(
576 ray_vertex_return_ty_id,
577 id,
578 query_id,
579 intersection_id,
580 ));
581 id
582 }
583}