1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::{
3 cmp::max,
4 num::NonZeroU64,
5 ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::device::resource::CommandIndices;
11use crate::lock::RwLockWriteGuard;
12use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
13use crate::{
14 command::CommandBufferMutable,
15 device::queue::TempResource,
16 global::Global,
17 hub::Hub,
18 id::CommandEncoderId,
19 init_tracker::MemoryInitKind,
20 ray_tracing::{
21 BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
22 TlasBuildEntry, TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
23 TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
24 },
25 resource::{AccelerationStructure, Blas, Buffer, Labeled, StagingBuffer, Tlas},
26 scratch::ScratchBuffer,
27 snatch::SnatchGuard,
28 track::PendingTransition,
29};
30
31use crate::id::{BlasId, TlasId};
32
33struct TriangleBufferStore<'a> {
34 vertex_buffer: Arc<Buffer>,
35 vertex_transition: Option<PendingTransition<BufferUses>>,
36 index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
37 transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
38 geometry: BlasTriangleGeometry<'a>,
39 ending_blas: Option<Arc<Blas>>,
40}
41
42struct BlasStore<'a> {
43 blas: Arc<Blas>,
44 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
45 scratch_buffer_offset: u64,
46}
47
48struct UnsafeTlasStore<'a> {
49 tlas: Arc<Tlas>,
50 entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
51 scratch_buffer_offset: u64,
52}
53
54struct TlasStore<'a> {
55 internal: UnsafeTlasStore<'a>,
56 range: Range<usize>,
57}
58
59struct TlasBufferStore {
60 buffer: Arc<Buffer>,
61 transition: Option<PendingTransition<BufferUses>>,
62 entry: TlasBuildEntry,
63}
64
65impl Global {
66 pub fn command_encoder_mark_acceleration_structures_built(
67 &self,
68 command_encoder_id: CommandEncoderId,
69 blas_ids: &[BlasId],
70 tlas_ids: &[TlasId],
71 ) -> Result<(), BuildAccelerationStructureError> {
72 profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
73
74 let hub = &self.hub;
75
76 let cmd_buf = hub
77 .command_buffers
78 .get(command_encoder_id.into_command_buffer_id());
79
80 let device = &cmd_buf.device;
81
82 device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
83
84 let mut build_command = AsBuild::default();
85
86 for blas in blas_ids {
87 let blas = hub.blas_s.get(*blas).get()?;
88 build_command.blas_s_built.push(blas);
89 }
90
91 for tlas in tlas_ids {
92 let tlas = hub.tlas_s.get(*tlas).get()?;
93 build_command.tlas_s_built.push(TlasBuild {
94 tlas,
95 dependencies: Vec::new(),
96 });
97 }
98
99 let mut cmd_buf_data = cmd_buf.data.lock();
100 let mut cmd_buf_data_guard = cmd_buf_data.record()?;
101 let cmd_buf_data = &mut *cmd_buf_data_guard;
102
103 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
104
105 cmd_buf_data_guard.mark_successful();
106
107 Ok(())
108 }
109 pub fn command_encoder_build_acceleration_structures_unsafe_tlas<'a>(
114 &self,
115 command_encoder_id: CommandEncoderId,
116 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
117 tlas_iter: impl Iterator<Item = TlasBuildEntry>,
118 ) -> Result<(), BuildAccelerationStructureError> {
119 profiling::scope!("CommandEncoder::build_acceleration_structures_unsafe_tlas");
120
121 let hub = &self.hub;
122
123 let cmd_buf = hub
124 .command_buffers
125 .get(command_encoder_id.into_command_buffer_id());
126
127 let device = &cmd_buf.device;
128
129 device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
130
131 let mut build_command = AsBuild::default();
132
133 #[cfg(feature = "trace")]
134 let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
135 .map(|blas_entry| {
136 let geometries = match blas_entry.geometries {
137 BlasGeometries::TriangleGeometries(triangle_geometries) => {
138 TraceBlasGeometries::TriangleGeometries(
139 triangle_geometries
140 .map(|tg| TraceBlasTriangleGeometry {
141 size: tg.size.clone(),
142 vertex_buffer: tg.vertex_buffer,
143 index_buffer: tg.index_buffer,
144 transform_buffer: tg.transform_buffer,
145 first_vertex: tg.first_vertex,
146 vertex_stride: tg.vertex_stride,
147 first_index: tg.first_index,
148 transform_buffer_offset: tg.transform_buffer_offset,
149 })
150 .collect(),
151 )
152 }
153 };
154 TraceBlasBuildEntry {
155 blas_id: blas_entry.blas_id,
156 geometries,
157 }
158 })
159 .collect();
160
161 #[cfg(feature = "trace")]
162 let trace_tlas: Vec<TlasBuildEntry> = tlas_iter.collect();
163 #[cfg(feature = "trace")]
164 if let Some(ref mut list) = cmd_buf.data.lock().get_inner()?.commands {
165 list.push(
166 crate::device::trace::Command::BuildAccelerationStructuresUnsafeTlas {
167 blas: trace_blas.clone(),
168 tlas: trace_tlas.clone(),
169 },
170 );
171 if !trace_tlas.is_empty() {
172 log::warn!("a trace of command_encoder_build_acceleration_structures_unsafe_tlas containing a tlas build is not replayable!");
173 }
174 }
175
176 #[cfg(feature = "trace")]
177 let blas_iter = trace_blas.iter().map(|blas_entry| {
178 let geometries = match &blas_entry.geometries {
179 TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
180 let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
181 size: &tg.size,
182 vertex_buffer: tg.vertex_buffer,
183 index_buffer: tg.index_buffer,
184 transform_buffer: tg.transform_buffer,
185 first_vertex: tg.first_vertex,
186 vertex_stride: tg.vertex_stride,
187 first_index: tg.first_index,
188 transform_buffer_offset: tg.transform_buffer_offset,
189 });
190 BlasGeometries::TriangleGeometries(Box::new(iter))
191 }
192 };
193 BlasBuildEntry {
194 blas_id: blas_entry.blas_id,
195 geometries,
196 }
197 });
198
199 #[cfg(feature = "trace")]
200 let tlas_iter = trace_tlas.iter();
201
202 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
203 let mut buf_storage = Vec::new();
204
205 let mut scratch_buffer_blas_size = 0;
206 let mut blas_storage = Vec::new();
207 let mut cmd_buf_data = cmd_buf.data.lock();
208 let mut cmd_buf_data_guard = cmd_buf_data.record()?;
209 let cmd_buf_data = &mut *cmd_buf_data_guard;
210
211 iter_blas(
212 blas_iter,
213 cmd_buf_data,
214 &mut build_command,
215 &mut buf_storage,
216 hub,
217 )?;
218
219 let snatch_guard = device.snatchable_lock.read();
220 iter_buffers(
221 &mut buf_storage,
222 &snatch_guard,
223 &mut input_barriers,
224 cmd_buf_data,
225 &mut scratch_buffer_blas_size,
226 &mut blas_storage,
227 hub,
228 device.alignments.ray_tracing_scratch_buffer_alignment,
229 )?;
230
231 let mut scratch_buffer_tlas_size = 0;
232 let mut tlas_storage = Vec::<UnsafeTlasStore>::new();
233 let mut tlas_buf_storage = Vec::new();
234
235 for entry in tlas_iter {
236 let instance_buffer = hub.buffers.get(entry.instance_buffer_id).get()?;
237 let data = cmd_buf_data.trackers.buffers.set_single(
238 &instance_buffer,
239 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
240 );
241 tlas_buf_storage.push(TlasBufferStore {
242 buffer: instance_buffer,
243 transition: data,
244 entry: entry.clone(),
245 });
246 }
247
248 for tlas_buf in &mut tlas_buf_storage {
249 let entry = &tlas_buf.entry;
250 let instance_buffer = {
251 let (instance_buffer, instance_pending) =
252 (&mut tlas_buf.buffer, &mut tlas_buf.transition);
253 let instance_raw = instance_buffer.try_raw(&snatch_guard)?;
254 instance_buffer.check_usage(BufferUsages::TLAS_INPUT)?;
255
256 if let Some(barrier) = instance_pending
257 .take()
258 .map(|pending| pending.into_hal(instance_buffer, &snatch_guard))
259 {
260 input_barriers.push(barrier);
261 }
262 instance_raw
263 };
264
265 let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
266 cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
267
268 build_command.tlas_s_built.push(TlasBuild {
269 tlas: tlas.clone(),
270 dependencies: Vec::new(),
271 });
272
273 let scratch_buffer_offset = scratch_buffer_tlas_size;
274 scratch_buffer_tlas_size += align_to(
275 tlas.size_info.build_scratch_size as u32,
276 device.alignments.ray_tracing_scratch_buffer_alignment,
277 ) as u64;
278
279 tlas_storage.push(UnsafeTlasStore {
280 tlas,
281 entries: hal::AccelerationStructureEntries::Instances(
282 hal::AccelerationStructureInstances {
283 buffer: Some(instance_buffer),
284 offset: 0,
285 count: entry.instance_count,
286 },
287 ),
288 scratch_buffer_offset,
289 });
290 }
291
292 let scratch_size =
293 match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
294 None => {
295 cmd_buf_data_guard.mark_successful();
296 return Ok(());
297 }
298 Some(size) => size,
299 };
300
301 let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
302
303 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
304 buffer: scratch_buffer.raw(),
305 usage: hal::StateTransition {
306 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
307 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
308 },
309 };
310
311 let mut tlas_descriptors = Vec::new();
312
313 for UnsafeTlasStore {
314 tlas,
315 entries,
316 scratch_buffer_offset,
317 } in &tlas_storage
318 {
319 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
320 log::info!("only rebuild implemented")
321 }
322 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
323 entries,
324 mode: hal::AccelerationStructureBuildMode::Build,
325 flags: tlas.flags,
326 source_acceleration_structure: None,
327 destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
328 scratch_buffer: scratch_buffer.raw(),
329 scratch_buffer_offset: *scratch_buffer_offset,
330 })
331 }
332
333 let blas_present = !blas_storage.is_empty();
334 let tlas_present = !tlas_storage.is_empty();
335
336 let cmd_buf_raw = cmd_buf_data.encoder.open()?;
337
338 let mut descriptors = Vec::new();
339
340 for storage in &blas_storage {
341 descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
342 }
343
344 build_blas(
345 cmd_buf_raw,
346 blas_present,
347 tlas_present,
348 input_barriers,
349 &descriptors,
350 scratch_buffer_barrier,
351 );
352
353 if tlas_present {
354 unsafe {
355 cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
356
357 cmd_buf_raw.place_acceleration_structure_barrier(
358 hal::AccelerationStructureBarrier {
359 usage: hal::StateTransition {
360 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
361 to: hal::AccelerationStructureUses::SHADER_INPUT,
362 },
363 },
364 );
365 }
366 }
367
368 cmd_buf_data
369 .temp_resources
370 .push(TempResource::ScratchBuffer(scratch_buffer));
371
372 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
373
374 cmd_buf_data_guard.mark_successful();
375 Ok(())
376 }
377
378 pub fn command_encoder_build_acceleration_structures<'a>(
379 &self,
380 command_encoder_id: CommandEncoderId,
381 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
382 tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
383 ) -> Result<(), BuildAccelerationStructureError> {
384 profiling::scope!("CommandEncoder::build_acceleration_structures");
385
386 let hub = &self.hub;
387
388 let cmd_buf = hub
389 .command_buffers
390 .get(command_encoder_id.into_command_buffer_id());
391
392 let device = &cmd_buf.device;
393
394 device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
395
396 let mut build_command = AsBuild::default();
397
398 let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
399 .map(|blas_entry| {
400 let geometries = match blas_entry.geometries {
401 BlasGeometries::TriangleGeometries(triangle_geometries) => {
402 TraceBlasGeometries::TriangleGeometries(
403 triangle_geometries
404 .map(|tg| TraceBlasTriangleGeometry {
405 size: tg.size.clone(),
406 vertex_buffer: tg.vertex_buffer,
407 index_buffer: tg.index_buffer,
408 transform_buffer: tg.transform_buffer,
409 first_vertex: tg.first_vertex,
410 vertex_stride: tg.vertex_stride,
411 first_index: tg.first_index,
412 transform_buffer_offset: tg.transform_buffer_offset,
413 })
414 .collect(),
415 )
416 }
417 };
418 TraceBlasBuildEntry {
419 blas_id: blas_entry.blas_id,
420 geometries,
421 }
422 })
423 .collect();
424
425 let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
426 .map(|package: TlasPackage| {
427 let instances = package
428 .instances
429 .map(|instance| {
430 instance.map(|instance| TraceTlasInstance {
431 blas_id: instance.blas_id,
432 transform: *instance.transform,
433 custom_data: instance.custom_data,
434 mask: instance.mask,
435 })
436 })
437 .collect();
438 TraceTlasPackage {
439 tlas_id: package.tlas_id,
440 instances,
441 lowest_unmodified: package.lowest_unmodified,
442 }
443 })
444 .collect();
445
446 #[cfg(feature = "trace")]
447 if let Some(ref mut list) = cmd_buf.data.lock().get_inner()?.commands {
448 list.push(crate::device::trace::Command::BuildAccelerationStructures {
449 blas: trace_blas.clone(),
450 tlas: trace_tlas.clone(),
451 });
452 }
453
454 let blas_iter = trace_blas.iter().map(|blas_entry| {
455 let geometries = match &blas_entry.geometries {
456 TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
457 let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
458 size: &tg.size,
459 vertex_buffer: tg.vertex_buffer,
460 index_buffer: tg.index_buffer,
461 transform_buffer: tg.transform_buffer,
462 first_vertex: tg.first_vertex,
463 vertex_stride: tg.vertex_stride,
464 first_index: tg.first_index,
465 transform_buffer_offset: tg.transform_buffer_offset,
466 });
467 BlasGeometries::TriangleGeometries(Box::new(iter))
468 }
469 };
470 BlasBuildEntry {
471 blas_id: blas_entry.blas_id,
472 geometries,
473 }
474 });
475
476 let tlas_iter = trace_tlas.iter().map(|tlas_package| {
477 let instances = tlas_package.instances.iter().map(|instance| {
478 instance.as_ref().map(|instance| TlasInstance {
479 blas_id: instance.blas_id,
480 transform: &instance.transform,
481 custom_data: instance.custom_data,
482 mask: instance.mask,
483 })
484 });
485 TlasPackage {
486 tlas_id: tlas_package.tlas_id,
487 instances: Box::new(instances),
488 lowest_unmodified: tlas_package.lowest_unmodified,
489 }
490 });
491
492 let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
493 let mut buf_storage = Vec::new();
494
495 let mut scratch_buffer_blas_size = 0;
496 let mut blas_storage = Vec::new();
497 let mut cmd_buf_data = cmd_buf.data.lock();
498 let mut cmd_buf_data_guard = cmd_buf_data.record()?;
499 let cmd_buf_data = &mut *cmd_buf_data_guard;
500
501 iter_blas(
502 blas_iter,
503 cmd_buf_data,
504 &mut build_command,
505 &mut buf_storage,
506 hub,
507 )?;
508
509 let snatch_guard = device.snatchable_lock.read();
510 iter_buffers(
511 &mut buf_storage,
512 &snatch_guard,
513 &mut input_barriers,
514 cmd_buf_data,
515 &mut scratch_buffer_blas_size,
516 &mut blas_storage,
517 hub,
518 device.alignments.ray_tracing_scratch_buffer_alignment,
519 )?;
520 let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
521
522 for package in tlas_iter {
523 let tlas = hub.tlas_s.get(package.tlas_id).get()?;
524
525 cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
526
527 tlas_lock_store.push((Some(package), tlas))
528 }
529
530 let mut scratch_buffer_tlas_size = 0;
531 let mut tlas_storage = Vec::<TlasStore>::new();
532 let mut instance_buffer_staging_source = Vec::<u8>::new();
533
534 for (package, tlas) in &mut tlas_lock_store {
535 let package = package.take().unwrap();
536
537 let scratch_buffer_offset = scratch_buffer_tlas_size;
538 scratch_buffer_tlas_size += align_to(
539 tlas.size_info.build_scratch_size as u32,
540 device.alignments.ray_tracing_scratch_buffer_alignment,
541 ) as u64;
542
543 let first_byte_index = instance_buffer_staging_source.len();
544
545 let mut dependencies = Vec::new();
546
547 let mut instance_count = 0;
548 for instance in package.instances.flatten() {
549 if instance.custom_data >= (1u32 << 24u32) {
550 return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
551 tlas.error_ident(),
552 ));
553 }
554 let blas = hub.blas_s.get(instance.blas_id).get()?;
555
556 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
557
558 instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
559 hal::TlasInstance {
560 transform: *instance.transform,
561 custom_data: instance.custom_data,
562 mask: instance.mask,
563 blas_address: blas.handle,
564 },
565 ));
566
567 if tlas
568 .flags
569 .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
570 && !blas.flags.contains(
571 wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
572 )
573 {
574 return Err(
575 BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
576 tlas.error_ident(),
577 blas.error_ident(),
578 ),
579 );
580 }
581
582 instance_count += 1;
583
584 dependencies.push(blas.clone());
585 }
586
587 build_command.tlas_s_built.push(TlasBuild {
588 tlas: tlas.clone(),
589 dependencies,
590 });
591
592 if instance_count > tlas.max_instance_count {
593 return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
594 tlas.error_ident(),
595 instance_count,
596 tlas.max_instance_count,
597 ));
598 }
599
600 tlas_storage.push(TlasStore {
601 internal: UnsafeTlasStore {
602 tlas: tlas.clone(),
603 entries: hal::AccelerationStructureEntries::Instances(
604 hal::AccelerationStructureInstances {
605 buffer: Some(tlas.instance_buffer.as_ref()),
606 offset: 0,
607 count: instance_count,
608 },
609 ),
610 scratch_buffer_offset,
611 },
612 range: first_byte_index..instance_buffer_staging_source.len(),
613 });
614 }
615
616 let scratch_size =
617 match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
618 None => {
620 cmd_buf_data_guard.mark_successful();
621 return Ok(());
622 }
623 Some(size) => size,
624 };
625
626 let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
627
628 let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
629 buffer: scratch_buffer.raw(),
630 usage: hal::StateTransition {
631 from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
632 to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
633 },
634 };
635
636 let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
637
638 for &TlasStore {
639 internal:
640 UnsafeTlasStore {
641 ref tlas,
642 ref entries,
643 ref scratch_buffer_offset,
644 },
645 ..
646 } in &tlas_storage
647 {
648 if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
649 log::info!("only rebuild implemented")
650 }
651 tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
652 entries,
653 mode: hal::AccelerationStructureBuildMode::Build,
654 flags: tlas.flags,
655 source_acceleration_structure: None,
656 destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
657 scratch_buffer: scratch_buffer.raw(),
658 scratch_buffer_offset: *scratch_buffer_offset,
659 })
660 }
661
662 let blas_present = !blas_storage.is_empty();
663 let tlas_present = !tlas_storage.is_empty();
664
665 let cmd_buf_raw = cmd_buf_data.encoder.open()?;
666
667 let mut descriptors = Vec::new();
668
669 for storage in &blas_storage {
670 descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
671 }
672
673 build_blas(
674 cmd_buf_raw,
675 blas_present,
676 tlas_present,
677 input_barriers,
678 &descriptors,
679 scratch_buffer_barrier,
680 );
681
682 if tlas_present {
683 let staging_buffer = if !instance_buffer_staging_source.is_empty() {
684 let mut staging_buffer = StagingBuffer::new(
685 device,
686 wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
687 )?;
688 staging_buffer.write(&instance_buffer_staging_source);
689 let flushed = staging_buffer.flush();
690 Some(flushed)
691 } else {
692 None
693 };
694
695 unsafe {
696 if let Some(ref staging_buffer) = staging_buffer {
697 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
698 buffer: staging_buffer.raw(),
699 usage: hal::StateTransition {
700 from: BufferUses::MAP_WRITE,
701 to: BufferUses::COPY_SRC,
702 },
703 }]);
704 }
705 }
706
707 let mut instance_buffer_barriers = Vec::new();
708 for &TlasStore {
709 internal: UnsafeTlasStore { ref tlas, .. },
710 ref range,
711 } in &tlas_storage
712 {
713 let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
714 None => continue,
715 Some(size) => size,
716 };
717 instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
718 buffer: tlas.instance_buffer.as_ref(),
719 usage: hal::StateTransition {
720 from: BufferUses::COPY_DST,
721 to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
722 },
723 });
724 unsafe {
725 cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
726 buffer: tlas.instance_buffer.as_ref(),
727 usage: hal::StateTransition {
728 from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
729 to: BufferUses::COPY_DST,
730 },
731 }]);
732 let temp = hal::BufferCopy {
733 src_offset: range.start as u64,
734 dst_offset: 0,
735 size,
736 };
737 cmd_buf_raw.copy_buffer_to_buffer(
738 staging_buffer.as_ref().unwrap().raw(),
741 tlas.instance_buffer.as_ref(),
742 &[temp],
743 );
744 }
745 }
746
747 unsafe {
748 cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
749
750 cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
751
752 cmd_buf_raw.place_acceleration_structure_barrier(
753 hal::AccelerationStructureBarrier {
754 usage: hal::StateTransition {
755 from: hal::AccelerationStructureUses::BUILD_OUTPUT,
756 to: hal::AccelerationStructureUses::SHADER_INPUT,
757 },
758 },
759 );
760 }
761
762 if let Some(staging_buffer) = staging_buffer {
763 cmd_buf_data
764 .temp_resources
765 .push(TempResource::StagingBuffer(staging_buffer));
766 }
767 }
768
769 cmd_buf_data
770 .temp_resources
771 .push(TempResource::ScratchBuffer(scratch_buffer));
772
773 cmd_buf_data.as_actions.push(AsAction::Build(build_command));
774
775 cmd_buf_data_guard.mark_successful();
776 Ok(())
777 }
778}
779
780impl CommandBufferMutable {
781 pub(crate) fn validate_acceleration_structure_actions(
782 &self,
783 snatch_guard: &SnatchGuard,
784 command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
785 ) -> Result<(), ValidateAsActionsError> {
786 profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
787 for action in &self.as_actions {
788 match action {
789 AsAction::Build(build) => {
790 let build_command_index = NonZeroU64::new(
791 command_index_guard.next_acceleration_structure_build_command_index,
792 )
793 .unwrap();
794
795 command_index_guard.next_acceleration_structure_build_command_index += 1;
796 for blas in build.blas_s_built.iter() {
797 *blas.built_index.write() = Some(build_command_index);
798 }
799
800 for tlas_build in build.tlas_s_built.iter() {
801 for blas in &tlas_build.dependencies {
802 if blas.built_index.read().is_none() {
803 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
804 blas.error_ident(),
805 tlas_build.tlas.error_ident(),
806 ));
807 }
808 }
809 *tlas_build.tlas.built_index.write() = Some(build_command_index);
810 tlas_build
811 .tlas
812 .dependencies
813 .write()
814 .clone_from(&tlas_build.dependencies)
815 }
816 }
817 AsAction::UseTlas(tlas) => {
818 let tlas_build_index = tlas.built_index.read();
819 let dependencies = tlas.dependencies.read();
820
821 if (*tlas_build_index).is_none() {
822 return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
823 }
824 for blas in dependencies.deref() {
825 let blas_build_index = *blas.built_index.read();
826 if blas_build_index.is_none() {
827 return Err(ValidateAsActionsError::UsedUnbuiltBlas(
828 tlas.error_ident(),
829 blas.error_ident(),
830 ));
831 }
832 if blas_build_index.unwrap() > tlas_build_index.unwrap() {
833 return Err(ValidateAsActionsError::BlasNewerThenTlas(
834 blas.error_ident(),
835 tlas.error_ident(),
836 ));
837 }
838 blas.try_raw(snatch_guard)?;
839 }
840 }
841 }
842 }
843 Ok(())
844 }
845}
846
847fn iter_blas<'a>(
849 blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
850 cmd_buf_data: &mut CommandBufferMutable,
851 build_command: &mut AsBuild,
852 buf_storage: &mut Vec<TriangleBufferStore<'a>>,
853 hub: &Hub,
854) -> Result<(), BuildAccelerationStructureError> {
855 let mut temp_buffer = Vec::new();
856 for entry in blas_iter {
857 let blas = hub.blas_s.get(entry.blas_id).get()?;
858 cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
859
860 build_command.blas_s_built.push(blas.clone());
861
862 match entry.geometries {
863 BlasGeometries::TriangleGeometries(triangle_geometries) => {
864 for (i, mesh) in triangle_geometries.enumerate() {
865 let size_desc = match &blas.sizes {
866 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
867 };
868 if i >= size_desc.len() {
869 return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
870 blas.error_ident(),
871 ));
872 }
873 let size_desc = &size_desc[i];
874
875 if size_desc.flags != mesh.size.flags {
876 return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
877 blas.error_ident(),
878 size_desc.flags,
879 mesh.size.flags,
880 ));
881 }
882
883 if size_desc.vertex_count < mesh.size.vertex_count {
884 return Err(
885 BuildAccelerationStructureError::IncompatibleBlasVertexCount(
886 blas.error_ident(),
887 size_desc.vertex_count,
888 mesh.size.vertex_count,
889 ),
890 );
891 }
892
893 if size_desc.vertex_format != mesh.size.vertex_format {
894 return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
895 blas.error_ident(),
896 size_desc.vertex_format,
897 mesh.size.vertex_format,
898 ));
899 }
900
901 match (size_desc.index_count, mesh.size.index_count) {
902 (Some(_), None) | (None, Some(_)) => {
903 return Err(
904 BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
905 blas.error_ident(),
906 ),
907 )
908 }
909 (Some(create), Some(build)) if create < build => {
910 return Err(
911 BuildAccelerationStructureError::IncompatibleBlasIndexCount(
912 blas.error_ident(),
913 create,
914 build,
915 ),
916 )
917 }
918 _ => {}
919 }
920
921 if size_desc.index_format != mesh.size.index_format {
922 return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
923 blas.error_ident(),
924 size_desc.index_format,
925 mesh.size.index_format,
926 ));
927 }
928
929 if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
930 return Err(BuildAccelerationStructureError::MissingIndexBuffer(
931 blas.error_ident(),
932 ));
933 }
934 let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
935 let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
936 &vertex_buffer,
937 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
938 );
939 let index_data = if let Some(index_id) = mesh.index_buffer {
940 let index_buffer = hub.buffers.get(index_id).get()?;
941 if mesh.first_index.is_none()
942 || mesh.size.index_count.is_none()
943 || mesh.size.index_count.is_none()
944 {
945 return Err(BuildAccelerationStructureError::MissingAssociatedData(
946 index_buffer.error_ident(),
947 ));
948 }
949 let data = cmd_buf_data.trackers.buffers.set_single(
950 &index_buffer,
951 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
952 );
953 Some((index_buffer, data))
954 } else {
955 None
956 };
957 let transform_data = if let Some(transform_id) = mesh.transform_buffer {
958 if !blas
959 .flags
960 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
961 {
962 return Err(BuildAccelerationStructureError::UseTransformMissing(
963 blas.error_ident(),
964 ));
965 }
966 let transform_buffer = hub.buffers.get(transform_id).get()?;
967 if mesh.transform_buffer_offset.is_none() {
968 return Err(BuildAccelerationStructureError::MissingAssociatedData(
969 transform_buffer.error_ident(),
970 ));
971 }
972 let data = cmd_buf_data.trackers.buffers.set_single(
973 &transform_buffer,
974 BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
975 );
976 Some((transform_buffer, data))
977 } else {
978 if blas
979 .flags
980 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
981 {
982 return Err(BuildAccelerationStructureError::TransformMissing(
983 blas.error_ident(),
984 ));
985 }
986 None
987 };
988 temp_buffer.push(TriangleBufferStore {
989 vertex_buffer,
990 vertex_transition: vertex_pending,
991 index_buffer_transition: index_data,
992 transform_buffer_transition: transform_data,
993 geometry: mesh,
994 ending_blas: None,
995 });
996 }
997
998 if let Some(last) = temp_buffer.last_mut() {
999 last.ending_blas = Some(blas);
1000 buf_storage.append(&mut temp_buffer);
1001 }
1002 }
1003 }
1004 }
1005 Ok(())
1006}
1007
1008fn iter_buffers<'a, 'b>(
1010 buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
1011 snatch_guard: &'a SnatchGuard,
1012 input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
1013 cmd_buf_data: &mut CommandBufferMutable,
1014 scratch_buffer_blas_size: &mut u64,
1015 blas_storage: &mut Vec<BlasStore<'a>>,
1016 hub: &Hub,
1017 ray_tracing_scratch_buffer_alignment: u32,
1018) -> Result<(), BuildAccelerationStructureError> {
1019 let mut triangle_entries =
1020 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
1021 for buf in buf_storage {
1022 let mesh = &buf.geometry;
1023 let vertex_buffer = {
1024 let vertex_buffer = buf.vertex_buffer.as_ref();
1025 let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
1026 vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1027
1028 if let Some(barrier) = buf
1029 .vertex_transition
1030 .take()
1031 .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
1032 {
1033 input_barriers.push(barrier);
1034 }
1035 if vertex_buffer.size
1036 < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
1037 {
1038 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1039 vertex_buffer.error_ident(),
1040 vertex_buffer.size,
1041 (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
1042 ));
1043 }
1044 let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
1045 cmd_buf_data.buffer_memory_init_actions.extend(
1046 vertex_buffer.initialization_status.read().create_action(
1047 &hub.buffers.get(mesh.vertex_buffer).get()?,
1048 vertex_buffer_offset
1049 ..(vertex_buffer_offset
1050 + mesh.size.vertex_count as u64 * mesh.vertex_stride),
1051 MemoryInitKind::NeedsInitializedMemory,
1052 ),
1053 );
1054 vertex_raw
1055 };
1056 let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
1057 buf.index_buffer_transition
1058 {
1059 let index_raw = index_buffer.try_raw(snatch_guard)?;
1060 index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1061
1062 if let Some(barrier) = index_pending
1063 .take()
1064 .map(|pending| pending.into_hal(index_buffer, snatch_guard))
1065 {
1066 input_barriers.push(barrier);
1067 }
1068 let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
1069 let offset = mesh.first_index.unwrap() as u64 * index_stride;
1070 let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
1071
1072 if mesh.size.index_count.unwrap() % 3 != 0 {
1073 return Err(BuildAccelerationStructureError::InvalidIndexCount(
1074 index_buffer.error_ident(),
1075 mesh.size.index_count.unwrap(),
1076 ));
1077 }
1078 if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
1079 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1080 index_buffer.error_ident(),
1081 index_buffer.size,
1082 mesh.size.index_count.unwrap() as u64 * index_stride + offset,
1083 ));
1084 }
1085
1086 cmd_buf_data.buffer_memory_init_actions.extend(
1087 index_buffer.initialization_status.read().create_action(
1088 index_buffer,
1089 offset..(offset + index_buffer_size),
1090 MemoryInitKind::NeedsInitializedMemory,
1091 ),
1092 );
1093 Some(index_raw)
1094 } else {
1095 None
1096 };
1097 let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
1098 buf.transform_buffer_transition
1099 {
1100 if mesh.transform_buffer_offset.is_none() {
1101 return Err(BuildAccelerationStructureError::MissingAssociatedData(
1102 transform_buffer.error_ident(),
1103 ));
1104 }
1105 let transform_raw = transform_buffer.try_raw(snatch_guard)?;
1106 transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1107
1108 if let Some(barrier) = transform_pending
1109 .take()
1110 .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
1111 {
1112 input_barriers.push(barrier);
1113 }
1114
1115 let offset = mesh.transform_buffer_offset.unwrap();
1116
1117 if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
1118 return Err(
1119 BuildAccelerationStructureError::UnalignedTransformBufferOffset(
1120 transform_buffer.error_ident(),
1121 ),
1122 );
1123 }
1124 if transform_buffer.size < 48 + offset {
1125 return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1126 transform_buffer.error_ident(),
1127 transform_buffer.size,
1128 48 + offset,
1129 ));
1130 }
1131 cmd_buf_data.buffer_memory_init_actions.extend(
1132 transform_buffer.initialization_status.read().create_action(
1133 transform_buffer,
1134 offset..(offset + 48),
1135 MemoryInitKind::NeedsInitializedMemory,
1136 ),
1137 );
1138 Some(transform_raw)
1139 } else {
1140 None
1141 };
1142
1143 let triangles = hal::AccelerationStructureTriangles {
1144 vertex_buffer: Some(vertex_buffer),
1145 vertex_format: mesh.size.vertex_format,
1146 first_vertex: mesh.first_vertex,
1147 vertex_count: mesh.size.vertex_count,
1148 vertex_stride: mesh.vertex_stride,
1149 indices: index_buffer.map(|index_buffer| {
1150 let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
1151 hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
1152 format: mesh.size.index_format.unwrap(),
1153 buffer: Some(index_buffer),
1154 offset: mesh.first_index.unwrap() * index_stride,
1155 count: mesh.size.index_count.unwrap(),
1156 }
1157 }),
1158 transform: transform_buffer.map(|transform_buffer| {
1159 hal::AccelerationStructureTriangleTransform {
1160 buffer: transform_buffer,
1161 offset: mesh.transform_buffer_offset.unwrap() as u32,
1162 }
1163 }),
1164 flags: mesh.size.flags,
1165 };
1166 triangle_entries.push(triangles);
1167 if let Some(blas) = buf.ending_blas.take() {
1168 let scratch_buffer_offset = *scratch_buffer_blas_size;
1169 *scratch_buffer_blas_size += align_to(
1170 blas.size_info.build_scratch_size as u32,
1171 ray_tracing_scratch_buffer_alignment,
1172 ) as u64;
1173
1174 blas_storage.push(BlasStore {
1175 blas,
1176 entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
1177 scratch_buffer_offset,
1178 });
1179 triangle_entries = Vec::new();
1180 }
1181 }
1182 Ok(())
1183}
1184
1185fn map_blas<'a>(
1186 storage: &'a BlasStore<'_>,
1187 scratch_buffer: &'a dyn hal::DynBuffer,
1188 snatch_guard: &'a SnatchGuard,
1189) -> Result<
1190 hal::BuildAccelerationStructureDescriptor<
1191 'a,
1192 dyn hal::DynBuffer,
1193 dyn hal::DynAccelerationStructure,
1194 >,
1195 BuildAccelerationStructureError,
1196> {
1197 let BlasStore {
1198 blas,
1199 entries,
1200 scratch_buffer_offset,
1201 } = storage;
1202 if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1203 log::info!("only rebuild implemented")
1204 }
1205 Ok(hal::BuildAccelerationStructureDescriptor {
1206 entries,
1207 mode: hal::AccelerationStructureBuildMode::Build,
1208 flags: blas.flags,
1209 source_acceleration_structure: None,
1210 destination_acceleration_structure: blas.try_raw(snatch_guard)?,
1211 scratch_buffer,
1212 scratch_buffer_offset: *scratch_buffer_offset,
1213 })
1214}
1215
1216fn build_blas<'a>(
1217 cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1218 blas_present: bool,
1219 tlas_present: bool,
1220 input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1221 blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1222 'a,
1223 dyn hal::DynBuffer,
1224 dyn hal::DynAccelerationStructure,
1225 >],
1226 scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1227) {
1228 unsafe {
1229 cmd_buf_raw.transition_buffers(&input_barriers);
1230 }
1231
1232 if blas_present {
1233 unsafe {
1234 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1235 usage: hal::StateTransition {
1236 from: hal::AccelerationStructureUses::BUILD_INPUT,
1237 to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1238 },
1239 });
1240
1241 cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1242 }
1243 }
1244
1245 if blas_present && tlas_present {
1246 unsafe {
1247 cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1248 }
1249 }
1250
1251 let mut source_usage = hal::AccelerationStructureUses::empty();
1252 let mut destination_usage = hal::AccelerationStructureUses::empty();
1253 if blas_present {
1254 source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1255 destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1256 }
1257 if tlas_present {
1258 source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1259 destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1260 }
1261 unsafe {
1262 cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1263 usage: hal::StateTransition {
1264 from: source_usage,
1265 to: destination_usage,
1266 },
1267 });
1268 }
1269}