wgpu_core/device/
ray_tracing.rs

1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::ManuallyDrop;
3
4use crate::api_log;
5#[cfg(feature = "trace")]
6use crate::device::trace;
7use crate::lock::rank;
8use crate::resource::{Fallible, TrackingData};
9use crate::snatch::Snatchable;
10use crate::{
11    device::{Device, DeviceError},
12    global::Global,
13    id::{self, BlasId, TlasId},
14    lock::RwLock,
15    ray_tracing::{CreateBlasError, CreateTlasError},
16    resource, LabelHelpers,
17};
18use hal::AccelerationStructureTriangleIndices;
19use wgt::Features;
20
21impl Device {
22    fn create_blas(
23        self: &Arc<Self>,
24        blas_desc: &resource::BlasDescriptor,
25        sizes: wgt::BlasGeometrySizeDescriptors,
26    ) -> Result<Arc<resource::Blas>, CreateBlasError> {
27        self.check_is_valid()?;
28        self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
29
30        if blas_desc
31            .flags
32            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
33        {
34            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
35        }
36
37        let size_info = match &sizes {
38            wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
39                let mut entries =
40                    Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
41                        descriptors.len(),
42                    );
43                for desc in descriptors {
44                    if desc.index_count.is_some() != desc.index_format.is_some() {
45                        return Err(CreateBlasError::MissingIndexData);
46                    }
47                    let indices =
48                        desc.index_count
49                            .map(|count| AccelerationStructureTriangleIndices::<
50                                dyn hal::DynBuffer,
51                            > {
52                                format: desc.index_format.unwrap(),
53                                buffer: None,
54                                offset: 0,
55                                count,
56                            });
57                    if !self
58                        .features
59                        .allowed_vertex_formats_for_blas()
60                        .contains(&desc.vertex_format)
61                    {
62                        return Err(CreateBlasError::InvalidVertexFormat(
63                            desc.vertex_format,
64                            self.features.allowed_vertex_formats_for_blas(),
65                        ));
66                    }
67
68                    let mut transform = None;
69
70                    if blas_desc
71                        .flags
72                        .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
73                    {
74                        transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
75                            buffer: self.zero_buffer.as_ref(),
76                            offset: 0,
77                        })
78                    }
79
80                    entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
81                        vertex_buffer: None,
82                        vertex_format: desc.vertex_format,
83                        first_vertex: 0,
84                        vertex_count: desc.vertex_count,
85                        vertex_stride: 0,
86                        indices,
87                        transform,
88                        flags: desc.flags,
89                    });
90                }
91                unsafe {
92                    self.raw().get_acceleration_structure_build_sizes(
93                        &hal::GetAccelerationStructureBuildSizesDescriptor {
94                            entries: &hal::AccelerationStructureEntries::Triangles(entries),
95                            flags: blas_desc.flags,
96                        },
97                    )
98                }
99            }
100        };
101
102        let raw = unsafe {
103            self.raw()
104                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
105                    label: blas_desc.label.as_deref(),
106                    size: size_info.acceleration_structure_size,
107                    format: hal::AccelerationStructureFormat::BottomLevel,
108                    // change this once compaction is implemented in wgpu-core
109                    allow_compaction: false,
110                })
111        }
112        .map_err(DeviceError::from_hal)?;
113
114        let handle = unsafe {
115            self.raw()
116                .get_acceleration_structure_device_address(raw.as_ref())
117        };
118
119        Ok(Arc::new(resource::Blas {
120            raw: Snatchable::new(raw),
121            device: self.clone(),
122            size_info,
123            sizes,
124            flags: blas_desc.flags,
125            update_mode: blas_desc.update_mode,
126            handle,
127            label: blas_desc.label.to_string(),
128            built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
129            tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
130        }))
131    }
132
133    fn create_tlas(
134        self: &Arc<Self>,
135        desc: &resource::TlasDescriptor,
136    ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
137        self.check_is_valid()?;
138        self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
139
140        if desc
141            .flags
142            .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
143        {
144            return Err(CreateTlasError::DisallowedFlag(
145                wgt::AccelerationStructureFlags::USE_TRANSFORM,
146            ));
147        }
148
149        if desc
150            .flags
151            .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
152        {
153            self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
154        }
155
156        let size_info = unsafe {
157            self.raw().get_acceleration_structure_build_sizes(
158                &hal::GetAccelerationStructureBuildSizesDescriptor {
159                    entries: &hal::AccelerationStructureEntries::Instances(
160                        hal::AccelerationStructureInstances {
161                            buffer: None,
162                            offset: 0,
163                            count: desc.max_instances,
164                        },
165                    ),
166                    flags: desc.flags,
167                },
168            )
169        };
170
171        let raw = unsafe {
172            self.raw()
173                .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
174                    label: desc.label.as_deref(),
175                    size: size_info.acceleration_structure_size,
176                    format: hal::AccelerationStructureFormat::TopLevel,
177                    allow_compaction: false,
178                })
179        }
180        .map_err(DeviceError::from_hal)?;
181
182        let instance_buffer_size =
183            self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
184        let instance_buffer = unsafe {
185            self.raw().create_buffer(&hal::BufferDescriptor {
186                label: Some("(wgpu-core) instances_buffer"),
187                size: instance_buffer_size as u64,
188                usage: wgt::BufferUses::COPY_DST
189                    | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
190                memory_flags: hal::MemoryFlags::PREFER_COHERENT,
191            })
192        }
193        .map_err(DeviceError::from_hal)?;
194
195        Ok(Arc::new(resource::Tlas {
196            raw: Snatchable::new(raw),
197            device: self.clone(),
198            size_info,
199            flags: desc.flags,
200            update_mode: desc.update_mode,
201            built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
202            dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
203            instance_buffer: ManuallyDrop::new(instance_buffer),
204            label: desc.label.to_string(),
205            max_instance_count: desc.max_instances,
206            tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
207        }))
208    }
209}
210
211impl Global {
212    pub fn device_create_blas(
213        &self,
214        device_id: id::DeviceId,
215        desc: &resource::BlasDescriptor,
216        sizes: wgt::BlasGeometrySizeDescriptors,
217        id_in: Option<BlasId>,
218    ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
219        profiling::scope!("Device::create_blas");
220
221        let fid = self.hub.blas_s.prepare(id_in);
222
223        let error = 'error: {
224            let device = self.hub.devices.get(device_id);
225
226            #[cfg(feature = "trace")]
227            if let Some(trace) = device.trace.lock().as_mut() {
228                trace.add(trace::Action::CreateBlas {
229                    id: fid.id(),
230                    desc: desc.clone(),
231                    sizes: sizes.clone(),
232                });
233            }
234
235            let blas = match device.create_blas(desc, sizes) {
236                Ok(blas) => blas,
237                Err(e) => break 'error e,
238            };
239            let handle = blas.handle;
240
241            let id = fid.assign(Fallible::Valid(blas));
242            api_log!("Device::create_blas -> {id:?}");
243
244            return (id, Some(handle), None);
245        };
246
247        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
248        (id, None, Some(error))
249    }
250
251    pub fn device_create_tlas(
252        &self,
253        device_id: id::DeviceId,
254        desc: &resource::TlasDescriptor,
255        id_in: Option<TlasId>,
256    ) -> (TlasId, Option<CreateTlasError>) {
257        profiling::scope!("Device::create_tlas");
258
259        let fid = self.hub.tlas_s.prepare(id_in);
260
261        let error = 'error: {
262            let device = self.hub.devices.get(device_id);
263
264            #[cfg(feature = "trace")]
265            if let Some(trace) = device.trace.lock().as_mut() {
266                trace.add(trace::Action::CreateTlas {
267                    id: fid.id(),
268                    desc: desc.clone(),
269                });
270            }
271
272            let tlas = match device.create_tlas(desc) {
273                Ok(tlas) => tlas,
274                Err(e) => break 'error e,
275            };
276
277            let id = fid.assign(Fallible::Valid(tlas));
278            api_log!("Device::create_tlas -> {id:?}");
279
280            return (id, None);
281        };
282
283        let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
284        (id, Some(error))
285    }
286
287    pub fn blas_drop(&self, blas_id: BlasId) {
288        profiling::scope!("Blas::drop");
289        api_log!("Blas::drop {blas_id:?}");
290
291        let _blas = self.hub.blas_s.remove(blas_id);
292
293        #[cfg(feature = "trace")]
294        if let Ok(blas) = _blas.get() {
295            if let Some(t) = blas.device.trace.lock().as_mut() {
296                t.add(trace::Action::DestroyBlas(blas_id));
297            }
298        }
299    }
300
301    pub fn tlas_drop(&self, tlas_id: TlasId) {
302        profiling::scope!("Tlas::drop");
303        api_log!("Tlas::drop {tlas_id:?}");
304
305        let _tlas = self.hub.tlas_s.remove(tlas_id);
306
307        #[cfg(feature = "trace")]
308        if let Ok(tlas) = _tlas.get() {
309            if let Some(t) = tlas.device.trace.lock().as_mut() {
310                t.add(trace::Action::DestroyTlas(tlas_id));
311            }
312        }
313    }
314}