wgpu_core/command/
ray_tracing.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::{
3    cmp::max,
4    num::NonZeroU64,
5    ops::{Deref, Range},
6};
7
8use wgt::{math::align_to, BufferUsages, BufferUses, Features};
9
10use crate::device::resource::CommandIndices;
11use crate::lock::RwLockWriteGuard;
12use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
13use crate::{
14    command::CommandBufferMutable,
15    device::queue::TempResource,
16    global::Global,
17    hub::Hub,
18    id::CommandEncoderId,
19    init_tracker::MemoryInitKind,
20    ray_tracing::{
21        BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
22        TlasBuildEntry, TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
23        TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
24    },
25    resource::{AccelerationStructure, Blas, Buffer, Labeled, StagingBuffer, Tlas},
26    scratch::ScratchBuffer,
27    snatch::SnatchGuard,
28    track::PendingTransition,
29};
30
31use crate::id::{BlasId, TlasId};
32
33struct TriangleBufferStore<'a> {
34    vertex_buffer: Arc<Buffer>,
35    vertex_transition: Option<PendingTransition<BufferUses>>,
36    index_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
37    transform_buffer_transition: Option<(Arc<Buffer>, Option<PendingTransition<BufferUses>>)>,
38    geometry: BlasTriangleGeometry<'a>,
39    ending_blas: Option<Arc<Blas>>,
40}
41
42struct BlasStore<'a> {
43    blas: Arc<Blas>,
44    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
45    scratch_buffer_offset: u64,
46}
47
48struct UnsafeTlasStore<'a> {
49    tlas: Arc<Tlas>,
50    entries: hal::AccelerationStructureEntries<'a, dyn hal::DynBuffer>,
51    scratch_buffer_offset: u64,
52}
53
54struct TlasStore<'a> {
55    internal: UnsafeTlasStore<'a>,
56    range: Range<usize>,
57}
58
59struct TlasBufferStore {
60    buffer: Arc<Buffer>,
61    transition: Option<PendingTransition<BufferUses>>,
62    entry: TlasBuildEntry,
63}
64
65impl Global {
66    pub fn command_encoder_mark_acceleration_structures_built(
67        &self,
68        command_encoder_id: CommandEncoderId,
69        blas_ids: &[BlasId],
70        tlas_ids: &[TlasId],
71    ) -> Result<(), BuildAccelerationStructureError> {
72        profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
73
74        let hub = &self.hub;
75
76        let cmd_buf = hub
77            .command_buffers
78            .get(command_encoder_id.into_command_buffer_id());
79
80        let device = &cmd_buf.device;
81
82        device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
83
84        let mut build_command = AsBuild::default();
85
86        for blas in blas_ids {
87            let blas = hub.blas_s.get(*blas).get()?;
88            build_command.blas_s_built.push(blas);
89        }
90
91        for tlas in tlas_ids {
92            let tlas = hub.tlas_s.get(*tlas).get()?;
93            build_command.tlas_s_built.push(TlasBuild {
94                tlas,
95                dependencies: Vec::new(),
96            });
97        }
98
99        let mut cmd_buf_data = cmd_buf.data.lock();
100        let mut cmd_buf_data_guard = cmd_buf_data.record()?;
101        let cmd_buf_data = &mut *cmd_buf_data_guard;
102
103        cmd_buf_data.as_actions.push(AsAction::Build(build_command));
104
105        cmd_buf_data_guard.mark_successful();
106
107        Ok(())
108    }
109    // Currently this function is very similar to its safe counterpart, however certain parts of it are very different,
110    // making for the two to be implemented differently, the main difference is this function has separate buffers for each
111    // of the TLAS instances while the other has one large buffer
112    // TODO: reconsider this function's usefulness once blas and tlas `as_hal` are added and some time has passed.
113    pub fn command_encoder_build_acceleration_structures_unsafe_tlas<'a>(
114        &self,
115        command_encoder_id: CommandEncoderId,
116        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
117        tlas_iter: impl Iterator<Item = TlasBuildEntry>,
118    ) -> Result<(), BuildAccelerationStructureError> {
119        profiling::scope!("CommandEncoder::build_acceleration_structures_unsafe_tlas");
120
121        let hub = &self.hub;
122
123        let cmd_buf = hub
124            .command_buffers
125            .get(command_encoder_id.into_command_buffer_id());
126
127        let device = &cmd_buf.device;
128
129        device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
130
131        let mut build_command = AsBuild::default();
132
133        #[cfg(feature = "trace")]
134        let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
135            .map(|blas_entry| {
136                let geometries = match blas_entry.geometries {
137                    BlasGeometries::TriangleGeometries(triangle_geometries) => {
138                        TraceBlasGeometries::TriangleGeometries(
139                            triangle_geometries
140                                .map(|tg| TraceBlasTriangleGeometry {
141                                    size: tg.size.clone(),
142                                    vertex_buffer: tg.vertex_buffer,
143                                    index_buffer: tg.index_buffer,
144                                    transform_buffer: tg.transform_buffer,
145                                    first_vertex: tg.first_vertex,
146                                    vertex_stride: tg.vertex_stride,
147                                    first_index: tg.first_index,
148                                    transform_buffer_offset: tg.transform_buffer_offset,
149                                })
150                                .collect(),
151                        )
152                    }
153                };
154                TraceBlasBuildEntry {
155                    blas_id: blas_entry.blas_id,
156                    geometries,
157                }
158            })
159            .collect();
160
161        #[cfg(feature = "trace")]
162        let trace_tlas: Vec<TlasBuildEntry> = tlas_iter.collect();
163        #[cfg(feature = "trace")]
164        if let Some(ref mut list) = cmd_buf.data.lock().get_inner()?.commands {
165            list.push(
166                crate::device::trace::Command::BuildAccelerationStructuresUnsafeTlas {
167                    blas: trace_blas.clone(),
168                    tlas: trace_tlas.clone(),
169                },
170            );
171            if !trace_tlas.is_empty() {
172                log::warn!("a trace of command_encoder_build_acceleration_structures_unsafe_tlas containing a tlas build is not replayable!");
173            }
174        }
175
176        #[cfg(feature = "trace")]
177        let blas_iter = trace_blas.iter().map(|blas_entry| {
178            let geometries = match &blas_entry.geometries {
179                TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
180                    let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
181                        size: &tg.size,
182                        vertex_buffer: tg.vertex_buffer,
183                        index_buffer: tg.index_buffer,
184                        transform_buffer: tg.transform_buffer,
185                        first_vertex: tg.first_vertex,
186                        vertex_stride: tg.vertex_stride,
187                        first_index: tg.first_index,
188                        transform_buffer_offset: tg.transform_buffer_offset,
189                    });
190                    BlasGeometries::TriangleGeometries(Box::new(iter))
191                }
192            };
193            BlasBuildEntry {
194                blas_id: blas_entry.blas_id,
195                geometries,
196            }
197        });
198
199        #[cfg(feature = "trace")]
200        let tlas_iter = trace_tlas.iter();
201
202        let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
203        let mut buf_storage = Vec::new();
204
205        let mut scratch_buffer_blas_size = 0;
206        let mut blas_storage = Vec::new();
207        let mut cmd_buf_data = cmd_buf.data.lock();
208        let mut cmd_buf_data_guard = cmd_buf_data.record()?;
209        let cmd_buf_data = &mut *cmd_buf_data_guard;
210
211        iter_blas(
212            blas_iter,
213            cmd_buf_data,
214            &mut build_command,
215            &mut buf_storage,
216            hub,
217        )?;
218
219        let snatch_guard = device.snatchable_lock.read();
220        iter_buffers(
221            &mut buf_storage,
222            &snatch_guard,
223            &mut input_barriers,
224            cmd_buf_data,
225            &mut scratch_buffer_blas_size,
226            &mut blas_storage,
227            hub,
228            device.alignments.ray_tracing_scratch_buffer_alignment,
229        )?;
230
231        let mut scratch_buffer_tlas_size = 0;
232        let mut tlas_storage = Vec::<UnsafeTlasStore>::new();
233        let mut tlas_buf_storage = Vec::new();
234
235        for entry in tlas_iter {
236            let instance_buffer = hub.buffers.get(entry.instance_buffer_id).get()?;
237            let data = cmd_buf_data.trackers.buffers.set_single(
238                &instance_buffer,
239                BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
240            );
241            tlas_buf_storage.push(TlasBufferStore {
242                buffer: instance_buffer,
243                transition: data,
244                entry: entry.clone(),
245            });
246        }
247
248        for tlas_buf in &mut tlas_buf_storage {
249            let entry = &tlas_buf.entry;
250            let instance_buffer = {
251                let (instance_buffer, instance_pending) =
252                    (&mut tlas_buf.buffer, &mut tlas_buf.transition);
253                let instance_raw = instance_buffer.try_raw(&snatch_guard)?;
254                instance_buffer.check_usage(BufferUsages::TLAS_INPUT)?;
255
256                if let Some(barrier) = instance_pending
257                    .take()
258                    .map(|pending| pending.into_hal(instance_buffer, &snatch_guard))
259                {
260                    input_barriers.push(barrier);
261                }
262                instance_raw
263            };
264
265            let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
266            cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
267
268            build_command.tlas_s_built.push(TlasBuild {
269                tlas: tlas.clone(),
270                dependencies: Vec::new(),
271            });
272
273            let scratch_buffer_offset = scratch_buffer_tlas_size;
274            scratch_buffer_tlas_size += align_to(
275                tlas.size_info.build_scratch_size as u32,
276                device.alignments.ray_tracing_scratch_buffer_alignment,
277            ) as u64;
278
279            tlas_storage.push(UnsafeTlasStore {
280                tlas,
281                entries: hal::AccelerationStructureEntries::Instances(
282                    hal::AccelerationStructureInstances {
283                        buffer: Some(instance_buffer),
284                        offset: 0,
285                        count: entry.instance_count,
286                    },
287                ),
288                scratch_buffer_offset,
289            });
290        }
291
292        let scratch_size =
293            match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
294                None => {
295                    cmd_buf_data_guard.mark_successful();
296                    return Ok(());
297                }
298                Some(size) => size,
299            };
300
301        let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
302
303        let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
304            buffer: scratch_buffer.raw(),
305            usage: hal::StateTransition {
306                from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
307                to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
308            },
309        };
310
311        let mut tlas_descriptors = Vec::new();
312
313        for UnsafeTlasStore {
314            tlas,
315            entries,
316            scratch_buffer_offset,
317        } in &tlas_storage
318        {
319            if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
320                log::info!("only rebuild implemented")
321            }
322            tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
323                entries,
324                mode: hal::AccelerationStructureBuildMode::Build,
325                flags: tlas.flags,
326                source_acceleration_structure: None,
327                destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
328                scratch_buffer: scratch_buffer.raw(),
329                scratch_buffer_offset: *scratch_buffer_offset,
330            })
331        }
332
333        let blas_present = !blas_storage.is_empty();
334        let tlas_present = !tlas_storage.is_empty();
335
336        let cmd_buf_raw = cmd_buf_data.encoder.open()?;
337
338        let mut descriptors = Vec::new();
339
340        for storage in &blas_storage {
341            descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
342        }
343
344        build_blas(
345            cmd_buf_raw,
346            blas_present,
347            tlas_present,
348            input_barriers,
349            &descriptors,
350            scratch_buffer_barrier,
351        );
352
353        if tlas_present {
354            unsafe {
355                cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
356
357                cmd_buf_raw.place_acceleration_structure_barrier(
358                    hal::AccelerationStructureBarrier {
359                        usage: hal::StateTransition {
360                            from: hal::AccelerationStructureUses::BUILD_OUTPUT,
361                            to: hal::AccelerationStructureUses::SHADER_INPUT,
362                        },
363                    },
364                );
365            }
366        }
367
368        cmd_buf_data
369            .temp_resources
370            .push(TempResource::ScratchBuffer(scratch_buffer));
371
372        cmd_buf_data.as_actions.push(AsAction::Build(build_command));
373
374        cmd_buf_data_guard.mark_successful();
375        Ok(())
376    }
377
378    pub fn command_encoder_build_acceleration_structures<'a>(
379        &self,
380        command_encoder_id: CommandEncoderId,
381        blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
382        tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
383    ) -> Result<(), BuildAccelerationStructureError> {
384        profiling::scope!("CommandEncoder::build_acceleration_structures");
385
386        let hub = &self.hub;
387
388        let cmd_buf = hub
389            .command_buffers
390            .get(command_encoder_id.into_command_buffer_id());
391
392        let device = &cmd_buf.device;
393
394        device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
395
396        let mut build_command = AsBuild::default();
397
398        let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
399            .map(|blas_entry| {
400                let geometries = match blas_entry.geometries {
401                    BlasGeometries::TriangleGeometries(triangle_geometries) => {
402                        TraceBlasGeometries::TriangleGeometries(
403                            triangle_geometries
404                                .map(|tg| TraceBlasTriangleGeometry {
405                                    size: tg.size.clone(),
406                                    vertex_buffer: tg.vertex_buffer,
407                                    index_buffer: tg.index_buffer,
408                                    transform_buffer: tg.transform_buffer,
409                                    first_vertex: tg.first_vertex,
410                                    vertex_stride: tg.vertex_stride,
411                                    first_index: tg.first_index,
412                                    transform_buffer_offset: tg.transform_buffer_offset,
413                                })
414                                .collect(),
415                        )
416                    }
417                };
418                TraceBlasBuildEntry {
419                    blas_id: blas_entry.blas_id,
420                    geometries,
421                }
422            })
423            .collect();
424
425        let trace_tlas: Vec<TraceTlasPackage> = tlas_iter
426            .map(|package: TlasPackage| {
427                let instances = package
428                    .instances
429                    .map(|instance| {
430                        instance.map(|instance| TraceTlasInstance {
431                            blas_id: instance.blas_id,
432                            transform: *instance.transform,
433                            custom_data: instance.custom_data,
434                            mask: instance.mask,
435                        })
436                    })
437                    .collect();
438                TraceTlasPackage {
439                    tlas_id: package.tlas_id,
440                    instances,
441                    lowest_unmodified: package.lowest_unmodified,
442                }
443            })
444            .collect();
445
446        #[cfg(feature = "trace")]
447        if let Some(ref mut list) = cmd_buf.data.lock().get_inner()?.commands {
448            list.push(crate::device::trace::Command::BuildAccelerationStructures {
449                blas: trace_blas.clone(),
450                tlas: trace_tlas.clone(),
451            });
452        }
453
454        let blas_iter = trace_blas.iter().map(|blas_entry| {
455            let geometries = match &blas_entry.geometries {
456                TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
457                    let iter = triangle_geometries.iter().map(|tg| BlasTriangleGeometry {
458                        size: &tg.size,
459                        vertex_buffer: tg.vertex_buffer,
460                        index_buffer: tg.index_buffer,
461                        transform_buffer: tg.transform_buffer,
462                        first_vertex: tg.first_vertex,
463                        vertex_stride: tg.vertex_stride,
464                        first_index: tg.first_index,
465                        transform_buffer_offset: tg.transform_buffer_offset,
466                    });
467                    BlasGeometries::TriangleGeometries(Box::new(iter))
468                }
469            };
470            BlasBuildEntry {
471                blas_id: blas_entry.blas_id,
472                geometries,
473            }
474        });
475
476        let tlas_iter = trace_tlas.iter().map(|tlas_package| {
477            let instances = tlas_package.instances.iter().map(|instance| {
478                instance.as_ref().map(|instance| TlasInstance {
479                    blas_id: instance.blas_id,
480                    transform: &instance.transform,
481                    custom_data: instance.custom_data,
482                    mask: instance.mask,
483                })
484            });
485            TlasPackage {
486                tlas_id: tlas_package.tlas_id,
487                instances: Box::new(instances),
488                lowest_unmodified: tlas_package.lowest_unmodified,
489            }
490        });
491
492        let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
493        let mut buf_storage = Vec::new();
494
495        let mut scratch_buffer_blas_size = 0;
496        let mut blas_storage = Vec::new();
497        let mut cmd_buf_data = cmd_buf.data.lock();
498        let mut cmd_buf_data_guard = cmd_buf_data.record()?;
499        let cmd_buf_data = &mut *cmd_buf_data_guard;
500
501        iter_blas(
502            blas_iter,
503            cmd_buf_data,
504            &mut build_command,
505            &mut buf_storage,
506            hub,
507        )?;
508
509        let snatch_guard = device.snatchable_lock.read();
510        iter_buffers(
511            &mut buf_storage,
512            &snatch_guard,
513            &mut input_barriers,
514            cmd_buf_data,
515            &mut scratch_buffer_blas_size,
516            &mut blas_storage,
517            hub,
518            device.alignments.ray_tracing_scratch_buffer_alignment,
519        )?;
520        let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
521
522        for package in tlas_iter {
523            let tlas = hub.tlas_s.get(package.tlas_id).get()?;
524
525            cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
526
527            tlas_lock_store.push((Some(package), tlas))
528        }
529
530        let mut scratch_buffer_tlas_size = 0;
531        let mut tlas_storage = Vec::<TlasStore>::new();
532        let mut instance_buffer_staging_source = Vec::<u8>::new();
533
534        for (package, tlas) in &mut tlas_lock_store {
535            let package = package.take().unwrap();
536
537            let scratch_buffer_offset = scratch_buffer_tlas_size;
538            scratch_buffer_tlas_size += align_to(
539                tlas.size_info.build_scratch_size as u32,
540                device.alignments.ray_tracing_scratch_buffer_alignment,
541            ) as u64;
542
543            let first_byte_index = instance_buffer_staging_source.len();
544
545            let mut dependencies = Vec::new();
546
547            let mut instance_count = 0;
548            for instance in package.instances.flatten() {
549                if instance.custom_data >= (1u32 << 24u32) {
550                    return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
551                        tlas.error_ident(),
552                    ));
553                }
554                let blas = hub.blas_s.get(instance.blas_id).get()?;
555
556                cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
557
558                instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
559                    hal::TlasInstance {
560                        transform: *instance.transform,
561                        custom_data: instance.custom_data,
562                        mask: instance.mask,
563                        blas_address: blas.handle,
564                    },
565                ));
566
567                if tlas
568                    .flags
569                    .contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
570                    && !blas.flags.contains(
571                        wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
572                    )
573                {
574                    return Err(
575                        BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
576                            tlas.error_ident(),
577                            blas.error_ident(),
578                        ),
579                    );
580                }
581
582                instance_count += 1;
583
584                dependencies.push(blas.clone());
585            }
586
587            build_command.tlas_s_built.push(TlasBuild {
588                tlas: tlas.clone(),
589                dependencies,
590            });
591
592            if instance_count > tlas.max_instance_count {
593                return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
594                    tlas.error_ident(),
595                    instance_count,
596                    tlas.max_instance_count,
597                ));
598            }
599
600            tlas_storage.push(TlasStore {
601                internal: UnsafeTlasStore {
602                    tlas: tlas.clone(),
603                    entries: hal::AccelerationStructureEntries::Instances(
604                        hal::AccelerationStructureInstances {
605                            buffer: Some(tlas.instance_buffer.as_ref()),
606                            offset: 0,
607                            count: instance_count,
608                        },
609                    ),
610                    scratch_buffer_offset,
611                },
612                range: first_byte_index..instance_buffer_staging_source.len(),
613            });
614        }
615
616        let scratch_size =
617            match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
618                // if the size is zero there is nothing to build
619                None => {
620                    cmd_buf_data_guard.mark_successful();
621                    return Ok(());
622                }
623                Some(size) => size,
624            };
625
626        let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
627
628        let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
629            buffer: scratch_buffer.raw(),
630            usage: hal::StateTransition {
631                from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
632                to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
633            },
634        };
635
636        let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
637
638        for &TlasStore {
639            internal:
640                UnsafeTlasStore {
641                    ref tlas,
642                    ref entries,
643                    ref scratch_buffer_offset,
644                },
645            ..
646        } in &tlas_storage
647        {
648            if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
649                log::info!("only rebuild implemented")
650            }
651            tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
652                entries,
653                mode: hal::AccelerationStructureBuildMode::Build,
654                flags: tlas.flags,
655                source_acceleration_structure: None,
656                destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
657                scratch_buffer: scratch_buffer.raw(),
658                scratch_buffer_offset: *scratch_buffer_offset,
659            })
660        }
661
662        let blas_present = !blas_storage.is_empty();
663        let tlas_present = !tlas_storage.is_empty();
664
665        let cmd_buf_raw = cmd_buf_data.encoder.open()?;
666
667        let mut descriptors = Vec::new();
668
669        for storage in &blas_storage {
670            descriptors.push(map_blas(storage, scratch_buffer.raw(), &snatch_guard)?);
671        }
672
673        build_blas(
674            cmd_buf_raw,
675            blas_present,
676            tlas_present,
677            input_barriers,
678            &descriptors,
679            scratch_buffer_barrier,
680        );
681
682        if tlas_present {
683            let staging_buffer = if !instance_buffer_staging_source.is_empty() {
684                let mut staging_buffer = StagingBuffer::new(
685                    device,
686                    wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
687                )?;
688                staging_buffer.write(&instance_buffer_staging_source);
689                let flushed = staging_buffer.flush();
690                Some(flushed)
691            } else {
692                None
693            };
694
695            unsafe {
696                if let Some(ref staging_buffer) = staging_buffer {
697                    cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
698                        buffer: staging_buffer.raw(),
699                        usage: hal::StateTransition {
700                            from: BufferUses::MAP_WRITE,
701                            to: BufferUses::COPY_SRC,
702                        },
703                    }]);
704                }
705            }
706
707            let mut instance_buffer_barriers = Vec::new();
708            for &TlasStore {
709                internal: UnsafeTlasStore { ref tlas, .. },
710                ref range,
711            } in &tlas_storage
712            {
713                let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
714                    None => continue,
715                    Some(size) => size,
716                };
717                instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
718                    buffer: tlas.instance_buffer.as_ref(),
719                    usage: hal::StateTransition {
720                        from: BufferUses::COPY_DST,
721                        to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
722                    },
723                });
724                unsafe {
725                    cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
726                        buffer: tlas.instance_buffer.as_ref(),
727                        usage: hal::StateTransition {
728                            from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
729                            to: BufferUses::COPY_DST,
730                        },
731                    }]);
732                    let temp = hal::BufferCopy {
733                        src_offset: range.start as u64,
734                        dst_offset: 0,
735                        size,
736                    };
737                    cmd_buf_raw.copy_buffer_to_buffer(
738                        // the range whose size we just checked end is at (at that point in time) instance_buffer_staging_source.len()
739                        // and since instance_buffer_staging_source doesn't shrink we can un wrap this without a panic
740                        staging_buffer.as_ref().unwrap().raw(),
741                        tlas.instance_buffer.as_ref(),
742                        &[temp],
743                    );
744                }
745            }
746
747            unsafe {
748                cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
749
750                cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
751
752                cmd_buf_raw.place_acceleration_structure_barrier(
753                    hal::AccelerationStructureBarrier {
754                        usage: hal::StateTransition {
755                            from: hal::AccelerationStructureUses::BUILD_OUTPUT,
756                            to: hal::AccelerationStructureUses::SHADER_INPUT,
757                        },
758                    },
759                );
760            }
761
762            if let Some(staging_buffer) = staging_buffer {
763                cmd_buf_data
764                    .temp_resources
765                    .push(TempResource::StagingBuffer(staging_buffer));
766            }
767        }
768
769        cmd_buf_data
770            .temp_resources
771            .push(TempResource::ScratchBuffer(scratch_buffer));
772
773        cmd_buf_data.as_actions.push(AsAction::Build(build_command));
774
775        cmd_buf_data_guard.mark_successful();
776        Ok(())
777    }
778}
779
780impl CommandBufferMutable {
781    pub(crate) fn validate_acceleration_structure_actions(
782        &self,
783        snatch_guard: &SnatchGuard,
784        command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
785    ) -> Result<(), ValidateAsActionsError> {
786        profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
787        for action in &self.as_actions {
788            match action {
789                AsAction::Build(build) => {
790                    let build_command_index = NonZeroU64::new(
791                        command_index_guard.next_acceleration_structure_build_command_index,
792                    )
793                    .unwrap();
794
795                    command_index_guard.next_acceleration_structure_build_command_index += 1;
796                    for blas in build.blas_s_built.iter() {
797                        *blas.built_index.write() = Some(build_command_index);
798                    }
799
800                    for tlas_build in build.tlas_s_built.iter() {
801                        for blas in &tlas_build.dependencies {
802                            if blas.built_index.read().is_none() {
803                                return Err(ValidateAsActionsError::UsedUnbuiltBlas(
804                                    blas.error_ident(),
805                                    tlas_build.tlas.error_ident(),
806                                ));
807                            }
808                        }
809                        *tlas_build.tlas.built_index.write() = Some(build_command_index);
810                        tlas_build
811                            .tlas
812                            .dependencies
813                            .write()
814                            .clone_from(&tlas_build.dependencies)
815                    }
816                }
817                AsAction::UseTlas(tlas) => {
818                    let tlas_build_index = tlas.built_index.read();
819                    let dependencies = tlas.dependencies.read();
820
821                    if (*tlas_build_index).is_none() {
822                        return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
823                    }
824                    for blas in dependencies.deref() {
825                        let blas_build_index = *blas.built_index.read();
826                        if blas_build_index.is_none() {
827                            return Err(ValidateAsActionsError::UsedUnbuiltBlas(
828                                tlas.error_ident(),
829                                blas.error_ident(),
830                            ));
831                        }
832                        if blas_build_index.unwrap() > tlas_build_index.unwrap() {
833                            return Err(ValidateAsActionsError::BlasNewerThenTlas(
834                                blas.error_ident(),
835                                tlas.error_ident(),
836                            ));
837                        }
838                        blas.try_raw(snatch_guard)?;
839                    }
840                }
841            }
842        }
843        Ok(())
844    }
845}
846
847///iterates over the blas iterator, and it's geometry, pushing the buffers into a storage vector (and also some validation).
848fn iter_blas<'a>(
849    blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
850    cmd_buf_data: &mut CommandBufferMutable,
851    build_command: &mut AsBuild,
852    buf_storage: &mut Vec<TriangleBufferStore<'a>>,
853    hub: &Hub,
854) -> Result<(), BuildAccelerationStructureError> {
855    let mut temp_buffer = Vec::new();
856    for entry in blas_iter {
857        let blas = hub.blas_s.get(entry.blas_id).get()?;
858        cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
859
860        build_command.blas_s_built.push(blas.clone());
861
862        match entry.geometries {
863            BlasGeometries::TriangleGeometries(triangle_geometries) => {
864                for (i, mesh) in triangle_geometries.enumerate() {
865                    let size_desc = match &blas.sizes {
866                        wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => descriptors,
867                    };
868                    if i >= size_desc.len() {
869                        return Err(BuildAccelerationStructureError::IncompatibleBlasBuildSizes(
870                            blas.error_ident(),
871                        ));
872                    }
873                    let size_desc = &size_desc[i];
874
875                    if size_desc.flags != mesh.size.flags {
876                        return Err(BuildAccelerationStructureError::IncompatibleBlasFlags(
877                            blas.error_ident(),
878                            size_desc.flags,
879                            mesh.size.flags,
880                        ));
881                    }
882
883                    if size_desc.vertex_count < mesh.size.vertex_count {
884                        return Err(
885                            BuildAccelerationStructureError::IncompatibleBlasVertexCount(
886                                blas.error_ident(),
887                                size_desc.vertex_count,
888                                mesh.size.vertex_count,
889                            ),
890                        );
891                    }
892
893                    if size_desc.vertex_format != mesh.size.vertex_format {
894                        return Err(BuildAccelerationStructureError::DifferentBlasVertexFormats(
895                            blas.error_ident(),
896                            size_desc.vertex_format,
897                            mesh.size.vertex_format,
898                        ));
899                    }
900
901                    match (size_desc.index_count, mesh.size.index_count) {
902                        (Some(_), None) | (None, Some(_)) => {
903                            return Err(
904                                BuildAccelerationStructureError::BlasIndexCountProvidedMismatch(
905                                    blas.error_ident(),
906                                ),
907                            )
908                        }
909                        (Some(create), Some(build)) if create < build => {
910                            return Err(
911                                BuildAccelerationStructureError::IncompatibleBlasIndexCount(
912                                    blas.error_ident(),
913                                    create,
914                                    build,
915                                ),
916                            )
917                        }
918                        _ => {}
919                    }
920
921                    if size_desc.index_format != mesh.size.index_format {
922                        return Err(BuildAccelerationStructureError::DifferentBlasIndexFormats(
923                            blas.error_ident(),
924                            size_desc.index_format,
925                            mesh.size.index_format,
926                        ));
927                    }
928
929                    if size_desc.index_count.is_some() && mesh.index_buffer.is_none() {
930                        return Err(BuildAccelerationStructureError::MissingIndexBuffer(
931                            blas.error_ident(),
932                        ));
933                    }
934                    let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
935                    let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
936                        &vertex_buffer,
937                        BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
938                    );
939                    let index_data = if let Some(index_id) = mesh.index_buffer {
940                        let index_buffer = hub.buffers.get(index_id).get()?;
941                        if mesh.first_index.is_none()
942                            || mesh.size.index_count.is_none()
943                            || mesh.size.index_count.is_none()
944                        {
945                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
946                                index_buffer.error_ident(),
947                            ));
948                        }
949                        let data = cmd_buf_data.trackers.buffers.set_single(
950                            &index_buffer,
951                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
952                        );
953                        Some((index_buffer, data))
954                    } else {
955                        None
956                    };
957                    let transform_data = if let Some(transform_id) = mesh.transform_buffer {
958                        if !blas
959                            .flags
960                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
961                        {
962                            return Err(BuildAccelerationStructureError::UseTransformMissing(
963                                blas.error_ident(),
964                            ));
965                        }
966                        let transform_buffer = hub.buffers.get(transform_id).get()?;
967                        if mesh.transform_buffer_offset.is_none() {
968                            return Err(BuildAccelerationStructureError::MissingAssociatedData(
969                                transform_buffer.error_ident(),
970                            ));
971                        }
972                        let data = cmd_buf_data.trackers.buffers.set_single(
973                            &transform_buffer,
974                            BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
975                        );
976                        Some((transform_buffer, data))
977                    } else {
978                        if blas
979                            .flags
980                            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
981                        {
982                            return Err(BuildAccelerationStructureError::TransformMissing(
983                                blas.error_ident(),
984                            ));
985                        }
986                        None
987                    };
988                    temp_buffer.push(TriangleBufferStore {
989                        vertex_buffer,
990                        vertex_transition: vertex_pending,
991                        index_buffer_transition: index_data,
992                        transform_buffer_transition: transform_data,
993                        geometry: mesh,
994                        ending_blas: None,
995                    });
996                }
997
998                if let Some(last) = temp_buffer.last_mut() {
999                    last.ending_blas = Some(blas);
1000                    buf_storage.append(&mut temp_buffer);
1001                }
1002            }
1003        }
1004    }
1005    Ok(())
1006}
1007
1008/// Iterates over the buffers generated in [iter_blas], convert the barriers into hal barriers, and the triangles into [hal::AccelerationStructureEntries] (and also some validation).
1009fn iter_buffers<'a, 'b>(
1010    buf_storage: &'a mut Vec<TriangleBufferStore<'b>>,
1011    snatch_guard: &'a SnatchGuard,
1012    input_barriers: &mut Vec<hal::BufferBarrier<'a, dyn hal::DynBuffer>>,
1013    cmd_buf_data: &mut CommandBufferMutable,
1014    scratch_buffer_blas_size: &mut u64,
1015    blas_storage: &mut Vec<BlasStore<'a>>,
1016    hub: &Hub,
1017    ray_tracing_scratch_buffer_alignment: u32,
1018) -> Result<(), BuildAccelerationStructureError> {
1019    let mut triangle_entries =
1020        Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::new();
1021    for buf in buf_storage {
1022        let mesh = &buf.geometry;
1023        let vertex_buffer = {
1024            let vertex_buffer = buf.vertex_buffer.as_ref();
1025            let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
1026            vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1027
1028            if let Some(barrier) = buf
1029                .vertex_transition
1030                .take()
1031                .map(|pending| pending.into_hal(vertex_buffer, snatch_guard))
1032            {
1033                input_barriers.push(barrier);
1034            }
1035            if vertex_buffer.size
1036                < (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride
1037            {
1038                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1039                    vertex_buffer.error_ident(),
1040                    vertex_buffer.size,
1041                    (mesh.size.vertex_count + mesh.first_vertex) as u64 * mesh.vertex_stride,
1042                ));
1043            }
1044            let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
1045            cmd_buf_data.buffer_memory_init_actions.extend(
1046                vertex_buffer.initialization_status.read().create_action(
1047                    &hub.buffers.get(mesh.vertex_buffer).get()?,
1048                    vertex_buffer_offset
1049                        ..(vertex_buffer_offset
1050                            + mesh.size.vertex_count as u64 * mesh.vertex_stride),
1051                    MemoryInitKind::NeedsInitializedMemory,
1052                ),
1053            );
1054            vertex_raw
1055        };
1056        let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
1057            buf.index_buffer_transition
1058        {
1059            let index_raw = index_buffer.try_raw(snatch_guard)?;
1060            index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1061
1062            if let Some(barrier) = index_pending
1063                .take()
1064                .map(|pending| pending.into_hal(index_buffer, snatch_guard))
1065            {
1066                input_barriers.push(barrier);
1067            }
1068            let index_stride = mesh.size.index_format.unwrap().byte_size() as u64;
1069            let offset = mesh.first_index.unwrap() as u64 * index_stride;
1070            let index_buffer_size = mesh.size.index_count.unwrap() as u64 * index_stride;
1071
1072            if mesh.size.index_count.unwrap() % 3 != 0 {
1073                return Err(BuildAccelerationStructureError::InvalidIndexCount(
1074                    index_buffer.error_ident(),
1075                    mesh.size.index_count.unwrap(),
1076                ));
1077            }
1078            if index_buffer.size < mesh.size.index_count.unwrap() as u64 * index_stride + offset {
1079                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1080                    index_buffer.error_ident(),
1081                    index_buffer.size,
1082                    mesh.size.index_count.unwrap() as u64 * index_stride + offset,
1083                ));
1084            }
1085
1086            cmd_buf_data.buffer_memory_init_actions.extend(
1087                index_buffer.initialization_status.read().create_action(
1088                    index_buffer,
1089                    offset..(offset + index_buffer_size),
1090                    MemoryInitKind::NeedsInitializedMemory,
1091                ),
1092            );
1093            Some(index_raw)
1094        } else {
1095            None
1096        };
1097        let transform_buffer = if let Some((ref mut transform_buffer, ref mut transform_pending)) =
1098            buf.transform_buffer_transition
1099        {
1100            if mesh.transform_buffer_offset.is_none() {
1101                return Err(BuildAccelerationStructureError::MissingAssociatedData(
1102                    transform_buffer.error_ident(),
1103                ));
1104            }
1105            let transform_raw = transform_buffer.try_raw(snatch_guard)?;
1106            transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1107
1108            if let Some(barrier) = transform_pending
1109                .take()
1110                .map(|pending| pending.into_hal(transform_buffer, snatch_guard))
1111            {
1112                input_barriers.push(barrier);
1113            }
1114
1115            let offset = mesh.transform_buffer_offset.unwrap();
1116
1117            if offset % wgt::TRANSFORM_BUFFER_ALIGNMENT != 0 {
1118                return Err(
1119                    BuildAccelerationStructureError::UnalignedTransformBufferOffset(
1120                        transform_buffer.error_ident(),
1121                    ),
1122                );
1123            }
1124            if transform_buffer.size < 48 + offset {
1125                return Err(BuildAccelerationStructureError::InsufficientBufferSize(
1126                    transform_buffer.error_ident(),
1127                    transform_buffer.size,
1128                    48 + offset,
1129                ));
1130            }
1131            cmd_buf_data.buffer_memory_init_actions.extend(
1132                transform_buffer.initialization_status.read().create_action(
1133                    transform_buffer,
1134                    offset..(offset + 48),
1135                    MemoryInitKind::NeedsInitializedMemory,
1136                ),
1137            );
1138            Some(transform_raw)
1139        } else {
1140            None
1141        };
1142
1143        let triangles = hal::AccelerationStructureTriangles {
1144            vertex_buffer: Some(vertex_buffer),
1145            vertex_format: mesh.size.vertex_format,
1146            first_vertex: mesh.first_vertex,
1147            vertex_count: mesh.size.vertex_count,
1148            vertex_stride: mesh.vertex_stride,
1149            indices: index_buffer.map(|index_buffer| {
1150                let index_stride = mesh.size.index_format.unwrap().byte_size() as u32;
1151                hal::AccelerationStructureTriangleIndices::<dyn hal::DynBuffer> {
1152                    format: mesh.size.index_format.unwrap(),
1153                    buffer: Some(index_buffer),
1154                    offset: mesh.first_index.unwrap() * index_stride,
1155                    count: mesh.size.index_count.unwrap(),
1156                }
1157            }),
1158            transform: transform_buffer.map(|transform_buffer| {
1159                hal::AccelerationStructureTriangleTransform {
1160                    buffer: transform_buffer,
1161                    offset: mesh.transform_buffer_offset.unwrap() as u32,
1162                }
1163            }),
1164            flags: mesh.size.flags,
1165        };
1166        triangle_entries.push(triangles);
1167        if let Some(blas) = buf.ending_blas.take() {
1168            let scratch_buffer_offset = *scratch_buffer_blas_size;
1169            *scratch_buffer_blas_size += align_to(
1170                blas.size_info.build_scratch_size as u32,
1171                ray_tracing_scratch_buffer_alignment,
1172            ) as u64;
1173
1174            blas_storage.push(BlasStore {
1175                blas,
1176                entries: hal::AccelerationStructureEntries::Triangles(triangle_entries),
1177                scratch_buffer_offset,
1178            });
1179            triangle_entries = Vec::new();
1180        }
1181    }
1182    Ok(())
1183}
1184
1185fn map_blas<'a>(
1186    storage: &'a BlasStore<'_>,
1187    scratch_buffer: &'a dyn hal::DynBuffer,
1188    snatch_guard: &'a SnatchGuard,
1189) -> Result<
1190    hal::BuildAccelerationStructureDescriptor<
1191        'a,
1192        dyn hal::DynBuffer,
1193        dyn hal::DynAccelerationStructure,
1194    >,
1195    BuildAccelerationStructureError,
1196> {
1197    let BlasStore {
1198        blas,
1199        entries,
1200        scratch_buffer_offset,
1201    } = storage;
1202    if blas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
1203        log::info!("only rebuild implemented")
1204    }
1205    Ok(hal::BuildAccelerationStructureDescriptor {
1206        entries,
1207        mode: hal::AccelerationStructureBuildMode::Build,
1208        flags: blas.flags,
1209        source_acceleration_structure: None,
1210        destination_acceleration_structure: blas.try_raw(snatch_guard)?,
1211        scratch_buffer,
1212        scratch_buffer_offset: *scratch_buffer_offset,
1213    })
1214}
1215
1216fn build_blas<'a>(
1217    cmd_buf_raw: &mut dyn hal::DynCommandEncoder,
1218    blas_present: bool,
1219    tlas_present: bool,
1220    input_barriers: Vec<hal::BufferBarrier<dyn hal::DynBuffer>>,
1221    blas_descriptors: &[hal::BuildAccelerationStructureDescriptor<
1222        'a,
1223        dyn hal::DynBuffer,
1224        dyn hal::DynAccelerationStructure,
1225    >],
1226    scratch_buffer_barrier: hal::BufferBarrier<dyn hal::DynBuffer>,
1227) {
1228    unsafe {
1229        cmd_buf_raw.transition_buffers(&input_barriers);
1230    }
1231
1232    if blas_present {
1233        unsafe {
1234            cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1235                usage: hal::StateTransition {
1236                    from: hal::AccelerationStructureUses::BUILD_INPUT,
1237                    to: hal::AccelerationStructureUses::BUILD_OUTPUT,
1238                },
1239            });
1240
1241            cmd_buf_raw.build_acceleration_structures(blas_descriptors);
1242        }
1243    }
1244
1245    if blas_present && tlas_present {
1246        unsafe {
1247            cmd_buf_raw.transition_buffers(&[scratch_buffer_barrier]);
1248        }
1249    }
1250
1251    let mut source_usage = hal::AccelerationStructureUses::empty();
1252    let mut destination_usage = hal::AccelerationStructureUses::empty();
1253    if blas_present {
1254        source_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1255        destination_usage |= hal::AccelerationStructureUses::BUILD_INPUT
1256    }
1257    if tlas_present {
1258        source_usage |= hal::AccelerationStructureUses::SHADER_INPUT;
1259        destination_usage |= hal::AccelerationStructureUses::BUILD_OUTPUT;
1260    }
1261    unsafe {
1262        cmd_buf_raw.place_acceleration_structure_barrier(hal::AccelerationStructureBarrier {
1263            usage: hal::StateTransition {
1264                from: source_usage,
1265                to: destination_usage,
1266            },
1267        });
1268    }
1269}