1use alloc::{string::ToString as _, sync::Arc, vec::Vec};
2use core::mem::ManuallyDrop;
3
4use crate::api_log;
5#[cfg(feature = "trace")]
6use crate::device::trace;
7use crate::lock::rank;
8use crate::resource::{Fallible, TrackingData};
9use crate::snatch::Snatchable;
10use crate::{
11 device::{Device, DeviceError},
12 global::Global,
13 id::{self, BlasId, TlasId},
14 lock::RwLock,
15 ray_tracing::{CreateBlasError, CreateTlasError},
16 resource, LabelHelpers,
17};
18use hal::AccelerationStructureTriangleIndices;
19use wgt::Features;
20
21impl Device {
22 fn create_blas(
23 self: &Arc<Self>,
24 blas_desc: &resource::BlasDescriptor,
25 sizes: wgt::BlasGeometrySizeDescriptors,
26 ) -> Result<Arc<resource::Blas>, CreateBlasError> {
27 self.check_is_valid()?;
28 self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
29
30 if blas_desc
31 .flags
32 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
33 {
34 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
35 }
36
37 let size_info = match &sizes {
38 wgt::BlasGeometrySizeDescriptors::Triangles { descriptors } => {
39 let mut entries =
40 Vec::<hal::AccelerationStructureTriangles<dyn hal::DynBuffer>>::with_capacity(
41 descriptors.len(),
42 );
43 for desc in descriptors {
44 if desc.index_count.is_some() != desc.index_format.is_some() {
45 return Err(CreateBlasError::MissingIndexData);
46 }
47 let indices =
48 desc.index_count
49 .map(|count| AccelerationStructureTriangleIndices::<
50 dyn hal::DynBuffer,
51 > {
52 format: desc.index_format.unwrap(),
53 buffer: None,
54 offset: 0,
55 count,
56 });
57 if !self
58 .features
59 .allowed_vertex_formats_for_blas()
60 .contains(&desc.vertex_format)
61 {
62 return Err(CreateBlasError::InvalidVertexFormat(
63 desc.vertex_format,
64 self.features.allowed_vertex_formats_for_blas(),
65 ));
66 }
67
68 let mut transform = None;
69
70 if blas_desc
71 .flags
72 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
73 {
74 transform = Some(wgpu_hal::AccelerationStructureTriangleTransform {
75 buffer: self.zero_buffer.as_ref(),
76 offset: 0,
77 })
78 }
79
80 entries.push(hal::AccelerationStructureTriangles::<dyn hal::DynBuffer> {
81 vertex_buffer: None,
82 vertex_format: desc.vertex_format,
83 first_vertex: 0,
84 vertex_count: desc.vertex_count,
85 vertex_stride: 0,
86 indices,
87 transform,
88 flags: desc.flags,
89 });
90 }
91 unsafe {
92 self.raw().get_acceleration_structure_build_sizes(
93 &hal::GetAccelerationStructureBuildSizesDescriptor {
94 entries: &hal::AccelerationStructureEntries::Triangles(entries),
95 flags: blas_desc.flags,
96 },
97 )
98 }
99 }
100 };
101
102 let raw = unsafe {
103 self.raw()
104 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
105 label: blas_desc.label.as_deref(),
106 size: size_info.acceleration_structure_size,
107 format: hal::AccelerationStructureFormat::BottomLevel,
108 allow_compaction: false,
110 })
111 }
112 .map_err(DeviceError::from_hal)?;
113
114 let handle = unsafe {
115 self.raw()
116 .get_acceleration_structure_device_address(raw.as_ref())
117 };
118
119 Ok(Arc::new(resource::Blas {
120 raw: Snatchable::new(raw),
121 device: self.clone(),
122 size_info,
123 sizes,
124 flags: blas_desc.flags,
125 update_mode: blas_desc.update_mode,
126 handle,
127 label: blas_desc.label.to_string(),
128 built_index: RwLock::new(rank::BLAS_BUILT_INDEX, None),
129 tracking_data: TrackingData::new(self.tracker_indices.blas_s.clone()),
130 }))
131 }
132
133 fn create_tlas(
134 self: &Arc<Self>,
135 desc: &resource::TlasDescriptor,
136 ) -> Result<Arc<resource::Tlas>, CreateTlasError> {
137 self.check_is_valid()?;
138 self.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
139
140 if desc
141 .flags
142 .contains(wgt::AccelerationStructureFlags::USE_TRANSFORM)
143 {
144 return Err(CreateTlasError::DisallowedFlag(
145 wgt::AccelerationStructureFlags::USE_TRANSFORM,
146 ));
147 }
148
149 if desc
150 .flags
151 .contains(wgt::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
152 {
153 self.require_features(Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN)?;
154 }
155
156 let size_info = unsafe {
157 self.raw().get_acceleration_structure_build_sizes(
158 &hal::GetAccelerationStructureBuildSizesDescriptor {
159 entries: &hal::AccelerationStructureEntries::Instances(
160 hal::AccelerationStructureInstances {
161 buffer: None,
162 offset: 0,
163 count: desc.max_instances,
164 },
165 ),
166 flags: desc.flags,
167 },
168 )
169 };
170
171 let raw = unsafe {
172 self.raw()
173 .create_acceleration_structure(&hal::AccelerationStructureDescriptor {
174 label: desc.label.as_deref(),
175 size: size_info.acceleration_structure_size,
176 format: hal::AccelerationStructureFormat::TopLevel,
177 allow_compaction: false,
178 })
179 }
180 .map_err(DeviceError::from_hal)?;
181
182 let instance_buffer_size =
183 self.alignments.raw_tlas_instance_size * desc.max_instances.max(1) as usize;
184 let instance_buffer = unsafe {
185 self.raw().create_buffer(&hal::BufferDescriptor {
186 label: Some("(wgpu-core) instances_buffer"),
187 size: instance_buffer_size as u64,
188 usage: wgt::BufferUses::COPY_DST
189 | wgt::BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
190 memory_flags: hal::MemoryFlags::PREFER_COHERENT,
191 })
192 }
193 .map_err(DeviceError::from_hal)?;
194
195 Ok(Arc::new(resource::Tlas {
196 raw: Snatchable::new(raw),
197 device: self.clone(),
198 size_info,
199 flags: desc.flags,
200 update_mode: desc.update_mode,
201 built_index: RwLock::new(rank::TLAS_BUILT_INDEX, None),
202 dependencies: RwLock::new(rank::TLAS_DEPENDENCIES, Vec::new()),
203 instance_buffer: ManuallyDrop::new(instance_buffer),
204 label: desc.label.to_string(),
205 max_instance_count: desc.max_instances,
206 tracking_data: TrackingData::new(self.tracker_indices.tlas_s.clone()),
207 }))
208 }
209}
210
211impl Global {
212 pub fn device_create_blas(
213 &self,
214 device_id: id::DeviceId,
215 desc: &resource::BlasDescriptor,
216 sizes: wgt::BlasGeometrySizeDescriptors,
217 id_in: Option<BlasId>,
218 ) -> (BlasId, Option<u64>, Option<CreateBlasError>) {
219 profiling::scope!("Device::create_blas");
220
221 let fid = self.hub.blas_s.prepare(id_in);
222
223 let error = 'error: {
224 let device = self.hub.devices.get(device_id);
225
226 #[cfg(feature = "trace")]
227 if let Some(trace) = device.trace.lock().as_mut() {
228 trace.add(trace::Action::CreateBlas {
229 id: fid.id(),
230 desc: desc.clone(),
231 sizes: sizes.clone(),
232 });
233 }
234
235 let blas = match device.create_blas(desc, sizes) {
236 Ok(blas) => blas,
237 Err(e) => break 'error e,
238 };
239 let handle = blas.handle;
240
241 let id = fid.assign(Fallible::Valid(blas));
242 api_log!("Device::create_blas -> {id:?}");
243
244 return (id, Some(handle), None);
245 };
246
247 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
248 (id, None, Some(error))
249 }
250
251 pub fn device_create_tlas(
252 &self,
253 device_id: id::DeviceId,
254 desc: &resource::TlasDescriptor,
255 id_in: Option<TlasId>,
256 ) -> (TlasId, Option<CreateTlasError>) {
257 profiling::scope!("Device::create_tlas");
258
259 let fid = self.hub.tlas_s.prepare(id_in);
260
261 let error = 'error: {
262 let device = self.hub.devices.get(device_id);
263
264 #[cfg(feature = "trace")]
265 if let Some(trace) = device.trace.lock().as_mut() {
266 trace.add(trace::Action::CreateTlas {
267 id: fid.id(),
268 desc: desc.clone(),
269 });
270 }
271
272 let tlas = match device.create_tlas(desc) {
273 Ok(tlas) => tlas,
274 Err(e) => break 'error e,
275 };
276
277 let id = fid.assign(Fallible::Valid(tlas));
278 api_log!("Device::create_tlas -> {id:?}");
279
280 return (id, None);
281 };
282
283 let id = fid.assign(Fallible::Invalid(Arc::new(error.to_string())));
284 (id, Some(error))
285 }
286
287 pub fn blas_drop(&self, blas_id: BlasId) {
288 profiling::scope!("Blas::drop");
289 api_log!("Blas::drop {blas_id:?}");
290
291 let _blas = self.hub.blas_s.remove(blas_id);
292
293 #[cfg(feature = "trace")]
294 if let Ok(blas) = _blas.get() {
295 if let Some(t) = blas.device.trace.lock().as_mut() {
296 t.add(trace::Action::DestroyBlas(blas_id));
297 }
298 }
299 }
300
301 pub fn tlas_drop(&self, tlas_id: TlasId) {
302 profiling::scope!("Tlas::drop");
303 api_log!("Tlas::drop {tlas_id:?}");
304
305 let _tlas = self.hub.tlas_s.remove(tlas_id);
306
307 #[cfg(feature = "trace")]
308 if let Ok(tlas) = _tlas.get() {
309 if let Some(t) = tlas.device.trace.lock().as_mut() {
310 t.add(trace::Action::DestroyTlas(tlas_id));
311 }
312 }
313 }
314}