1use alloc::{boxed::Box, string::String, vec::Vec};
2use core::{fmt, num::NonZeroU32};
3
4use crate::{
5 binding_model,
6 hub::Hub,
7 id::{BindGroupLayoutId, PipelineLayoutId},
8 ray_tracing::BlasCompactReadyPendingClosure,
9 resource::{
10 Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation, Labeled,
11 RawResourceAccess, ResourceErrorIdent,
12 },
13 snatch::SnatchGuard,
14 Label, DOWNLEVEL_ERROR_MESSAGE,
15};
16
17use arrayvec::ArrayVec;
18use smallvec::SmallVec;
19use thiserror::Error;
20use wgt::{
21 error::{ErrorType, WebGpuError},
22 BufferAddress, DeviceLostReason, TextureFormat,
23};
24
25pub(crate) mod bgl;
26pub mod global;
27mod life;
28pub mod queue;
29pub mod ray_tracing;
30pub mod resource;
31#[cfg(any(feature = "trace", feature = "replay"))]
32pub mod trace;
33pub use {life::WaitIdleError, resource::Device};
34
35pub const SHADER_STAGE_COUNT: usize = hal::MAX_CONCURRENT_SHADER_STAGES;
36pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
39
40const CLEANUP_WAIT_MS: u32 = 60000;
43
44pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
45
46pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
47
48#[repr(C)]
49#[derive(Clone, Copy, Debug, Eq, PartialEq)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51pub enum HostMap {
52 Read,
53 Write,
54}
55
56#[derive(Clone, Debug, Hash, PartialEq)]
57#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
58pub(crate) struct AttachmentData<T> {
59 pub colors: ArrayVec<Option<T>, { hal::MAX_COLOR_ATTACHMENTS }>,
60 pub resolves: ArrayVec<T, { hal::MAX_COLOR_ATTACHMENTS }>,
61 pub depth_stencil: Option<T>,
62}
63impl<T: PartialEq> Eq for AttachmentData<T> {}
64
65#[derive(Clone, Debug, Hash, PartialEq)]
66#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
67pub(crate) struct RenderPassContext {
68 pub attachments: AttachmentData<TextureFormat>,
69 pub sample_count: u32,
70 pub multiview: Option<NonZeroU32>,
71}
72#[derive(Clone, Debug, Error)]
73#[non_exhaustive]
74pub enum RenderPassCompatibilityError {
75 #[error(
76 "Incompatible color attachments at indices {indices:?}: the RenderPass uses textures with formats {expected:?} but the {res} uses attachments with formats {actual:?}",
77 )]
78 IncompatibleColorAttachment {
79 indices: Vec<usize>,
80 expected: Vec<Option<TextureFormat>>,
81 actual: Vec<Option<TextureFormat>>,
82 res: ResourceErrorIdent,
83 },
84 #[error(
85 "Incompatible depth-stencil attachment format: the RenderPass uses a texture with format {expected:?} but the {res} uses an attachment with format {actual:?}",
86 )]
87 IncompatibleDepthStencilAttachment {
88 expected: Option<TextureFormat>,
89 actual: Option<TextureFormat>,
90 res: ResourceErrorIdent,
91 },
92 #[error(
93 "Incompatible sample count: the RenderPass uses textures with sample count {expected:?} but the {res} uses attachments with format {actual:?}",
94 )]
95 IncompatibleSampleCount {
96 expected: u32,
97 actual: u32,
98 res: ResourceErrorIdent,
99 },
100 #[error("Incompatible multiview setting: the RenderPass uses setting {expected:?} but the {res} uses setting {actual:?}")]
101 IncompatibleMultiview {
102 expected: Option<NonZeroU32>,
103 actual: Option<NonZeroU32>,
104 res: ResourceErrorIdent,
105 },
106}
107
108impl WebGpuError for RenderPassCompatibilityError {
109 fn webgpu_error_type(&self) -> ErrorType {
110 ErrorType::Validation
111 }
112}
113
114impl RenderPassContext {
115 pub(crate) fn check_compatible<T: Labeled>(
117 &self,
118 other: &Self,
119 res: &T,
120 ) -> Result<(), RenderPassCompatibilityError> {
121 if self.attachments.colors != other.attachments.colors {
122 let indices = self
123 .attachments
124 .colors
125 .iter()
126 .zip(&other.attachments.colors)
127 .enumerate()
128 .filter_map(|(idx, (left, right))| (left != right).then_some(idx))
129 .collect();
130 return Err(RenderPassCompatibilityError::IncompatibleColorAttachment {
131 indices,
132 expected: self.attachments.colors.iter().cloned().collect(),
133 actual: other.attachments.colors.iter().cloned().collect(),
134 res: res.error_ident(),
135 });
136 }
137 if self.attachments.depth_stencil != other.attachments.depth_stencil {
138 return Err(
139 RenderPassCompatibilityError::IncompatibleDepthStencilAttachment {
140 expected: self.attachments.depth_stencil,
141 actual: other.attachments.depth_stencil,
142 res: res.error_ident(),
143 },
144 );
145 }
146 if self.sample_count != other.sample_count {
147 return Err(RenderPassCompatibilityError::IncompatibleSampleCount {
148 expected: self.sample_count,
149 actual: other.sample_count,
150 res: res.error_ident(),
151 });
152 }
153 if self.multiview != other.multiview {
154 return Err(RenderPassCompatibilityError::IncompatibleMultiview {
155 expected: self.multiview,
156 actual: other.multiview,
157 res: res.error_ident(),
158 });
159 }
160 Ok(())
161 }
162}
163
164pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
165
166#[derive(Default)]
167pub struct UserClosures {
168 pub mappings: Vec<BufferMapPendingClosure>,
169 pub blas_compact_ready: Vec<BlasCompactReadyPendingClosure>,
170 pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
171 pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
172}
173
174impl UserClosures {
175 fn extend(&mut self, other: Self) {
176 self.mappings.extend(other.mappings);
177 self.blas_compact_ready.extend(other.blas_compact_ready);
178 self.submissions.extend(other.submissions);
179 self.device_lost_invocations
180 .extend(other.device_lost_invocations);
181 }
182
183 fn fire(self) {
184 for (mut operation, status) in self.mappings {
190 if let Some(callback) = operation.callback.take() {
191 callback(status);
192 }
193 }
194 for (mut operation, status) in self.blas_compact_ready {
195 if let Some(callback) = operation.take() {
196 callback(status);
197 }
198 }
199 for closure in self.submissions {
200 closure();
201 }
202 for invocation in self.device_lost_invocations {
203 (invocation.closure)(invocation.reason, invocation.message);
204 }
205 }
206}
207
208#[cfg(send_sync)]
209pub type DeviceLostClosure = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
210#[cfg(not(send_sync))]
211pub type DeviceLostClosure = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;
212
213pub struct DeviceLostInvocation {
214 closure: DeviceLostClosure,
215 reason: DeviceLostReason,
216 message: String,
217}
218
219pub(crate) fn map_buffer(
220 buffer: &Buffer,
221 offset: BufferAddress,
222 size: BufferAddress,
223 kind: HostMap,
224 snatch_guard: &SnatchGuard,
225) -> Result<hal::BufferMapping, BufferAccessError> {
226 let raw_device = buffer.device.raw();
227 let raw_buffer = buffer.try_raw(snatch_guard)?;
228 let mapping = unsafe {
229 raw_device
230 .map_buffer(raw_buffer, offset..offset + size)
231 .map_err(|e| buffer.device.handle_hal_error(e))?
232 };
233
234 if !mapping.is_coherent && kind == HostMap::Read {
235 #[allow(clippy::single_range_in_vec_init)]
236 unsafe {
237 raw_device.invalidate_mapped_ranges(raw_buffer, &[offset..offset + size]);
238 }
239 }
240
241 assert_eq!(offset % wgt::COPY_BUFFER_ALIGNMENT, 0);
242 assert_eq!(size % wgt::COPY_BUFFER_ALIGNMENT, 0);
243 let mapped = unsafe { core::slice::from_raw_parts_mut(mapping.ptr.as_ptr(), size as usize) };
258
259 if !mapping.is_coherent
261 && kind == HostMap::Read
262 && !buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
263 {
264 for uninitialized in buffer
265 .initialization_status
266 .write()
267 .uninitialized(offset..(size + offset))
268 {
269 let fill_range =
272 (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
273 mapped[fill_range].fill(0);
274 }
275 } else {
276 for uninitialized in buffer
277 .initialization_status
278 .write()
279 .drain(offset..(size + offset))
280 {
281 let fill_range =
284 (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
285 mapped[fill_range].fill(0);
286
287 if !mapping.is_coherent
289 && kind == HostMap::Read
290 && buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
291 {
292 unsafe { raw_device.flush_mapped_ranges(raw_buffer, &[uninitialized]) };
293 }
294 }
295 }
296
297 Ok(mapping)
298}
299
300#[derive(Clone, Debug)]
301#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
302pub struct DeviceMismatch {
303 pub(super) res: ResourceErrorIdent,
304 pub(super) res_device: ResourceErrorIdent,
305 pub(super) target: Option<ResourceErrorIdent>,
306 pub(super) target_device: ResourceErrorIdent,
307}
308
309impl fmt::Display for DeviceMismatch {
310 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
311 write!(
312 f,
313 "{} of {} doesn't match {}",
314 self.res_device, self.res, self.target_device
315 )?;
316 if let Some(target) = self.target.as_ref() {
317 write!(f, " of {target}")?;
318 }
319 Ok(())
320 }
321}
322
323impl core::error::Error for DeviceMismatch {}
324
325impl WebGpuError for DeviceMismatch {
326 fn webgpu_error_type(&self) -> ErrorType {
327 ErrorType::Validation
328 }
329}
330
331#[derive(Clone, Debug, Error)]
332#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
333#[non_exhaustive]
334pub enum DeviceError {
335 #[error("Parent device is lost")]
336 Lost,
337 #[error("Not enough memory left.")]
338 OutOfMemory,
339 #[error(transparent)]
340 DeviceMismatch(#[from] Box<DeviceMismatch>),
341}
342
343impl WebGpuError for DeviceError {
344 fn webgpu_error_type(&self) -> ErrorType {
345 match self {
346 Self::DeviceMismatch(e) => e.webgpu_error_type(),
347 Self::Lost => ErrorType::DeviceLost,
348 Self::OutOfMemory => ErrorType::OutOfMemory,
349 }
350 }
351}
352
353impl DeviceError {
354 pub fn from_hal(error: hal::DeviceError) -> Self {
358 match error {
359 hal::DeviceError::Lost => Self::Lost,
360 hal::DeviceError::OutOfMemory => Self::OutOfMemory,
361 hal::DeviceError::Unexpected => Self::Lost,
362 }
363 }
364}
365
366#[derive(Clone, Debug, Error)]
367#[error("Features {0:?} are required but not enabled on the device")]
368pub struct MissingFeatures(pub wgt::Features);
369
370impl WebGpuError for MissingFeatures {
371 fn webgpu_error_type(&self) -> ErrorType {
372 ErrorType::Validation
373 }
374}
375
376#[derive(Clone, Debug, Error)]
377#[error(
378 "Downlevel flags {0:?} are required but not supported on the device.\n{DOWNLEVEL_ERROR_MESSAGE}",
379)]
380pub struct MissingDownlevelFlags(pub wgt::DownlevelFlags);
381
382impl WebGpuError for MissingDownlevelFlags {
383 fn webgpu_error_type(&self) -> ErrorType {
384 ErrorType::Validation
385 }
386}
387
388#[derive(Clone, Debug)]
389#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
390pub struct ImplicitPipelineContext {
391 pub root_id: PipelineLayoutId,
392 pub group_ids: ArrayVec<BindGroupLayoutId, { hal::MAX_BIND_GROUPS }>,
393}
394
395pub struct ImplicitPipelineIds<'a> {
396 pub root_id: PipelineLayoutId,
397 pub group_ids: &'a [BindGroupLayoutId],
398}
399
400impl ImplicitPipelineIds<'_> {
401 fn prepare(self, hub: &Hub) -> ImplicitPipelineContext {
402 ImplicitPipelineContext {
403 root_id: hub.pipeline_layouts.prepare(Some(self.root_id)).id(),
404 group_ids: self
405 .group_ids
406 .iter()
407 .map(|id_in| hub.bind_group_layouts.prepare(Some(*id_in)).id())
408 .collect(),
409 }
410 }
411}
412
413pub fn create_validator(
415 features: wgt::Features,
416 downlevel: wgt::DownlevelFlags,
417 flags: naga::valid::ValidationFlags,
418) -> naga::valid::Validator {
419 use naga::valid::Capabilities as Caps;
420 let mut caps = Caps::empty();
421 caps.set(
422 Caps::PUSH_CONSTANT,
423 features.contains(wgt::Features::PUSH_CONSTANTS),
424 );
425 caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
426 caps.set(
427 Caps::SHADER_FLOAT16,
428 features.contains(wgt::Features::SHADER_F16),
429 );
430 caps.set(
431 Caps::PRIMITIVE_INDEX,
432 features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),
433 );
434 caps.set(
435 Caps::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
436 features
437 .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
438 );
439 caps.set(
440 Caps::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
441 features.contains(wgt::Features::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
442 );
443 caps.set(
444 Caps::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
445 features.contains(wgt::Features::UNIFORM_BUFFER_BINDING_ARRAYS),
446 );
447 caps.set(
449 Caps::SAMPLER_NON_UNIFORM_INDEXING,
450 features
451 .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
452 );
453 caps.set(
454 Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
455 features.contains(wgt::Features::TEXTURE_FORMAT_16BIT_NORM),
456 );
457 caps.set(Caps::MULTIVIEW, features.contains(wgt::Features::MULTIVIEW));
458 caps.set(
459 Caps::EARLY_DEPTH_TEST,
460 features.contains(wgt::Features::SHADER_EARLY_DEPTH_TEST),
461 );
462 caps.set(
463 Caps::SHADER_INT64,
464 features.contains(wgt::Features::SHADER_INT64),
465 );
466 caps.set(
467 Caps::SHADER_INT64_ATOMIC_MIN_MAX,
468 features.intersects(
469 wgt::Features::SHADER_INT64_ATOMIC_MIN_MAX | wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS,
470 ),
471 );
472 caps.set(
473 Caps::SHADER_INT64_ATOMIC_ALL_OPS,
474 features.contains(wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS),
475 );
476 caps.set(
477 Caps::TEXTURE_ATOMIC,
478 features.contains(wgt::Features::TEXTURE_ATOMIC),
479 );
480 caps.set(
481 Caps::TEXTURE_INT64_ATOMIC,
482 features.contains(wgt::Features::TEXTURE_INT64_ATOMIC),
483 );
484 caps.set(
485 Caps::SHADER_FLOAT32_ATOMIC,
486 features.contains(wgt::Features::SHADER_FLOAT32_ATOMIC),
487 );
488 caps.set(
489 Caps::MULTISAMPLED_SHADING,
490 downlevel.contains(wgt::DownlevelFlags::MULTISAMPLED_SHADING),
491 );
492 caps.set(
493 Caps::DUAL_SOURCE_BLENDING,
494 features.contains(wgt::Features::DUAL_SOURCE_BLENDING),
495 );
496 caps.set(
497 Caps::CLIP_DISTANCE,
498 features.contains(wgt::Features::CLIP_DISTANCES),
499 );
500 caps.set(
501 Caps::CUBE_ARRAY_TEXTURES,
502 downlevel.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
503 );
504 caps.set(
505 Caps::SUBGROUP,
506 features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
507 );
508 caps.set(
509 Caps::SUBGROUP_BARRIER,
510 features.intersects(wgt::Features::SUBGROUP_BARRIER),
511 );
512 caps.set(
513 Caps::RAY_QUERY,
514 features.intersects(wgt::Features::EXPERIMENTAL_RAY_QUERY),
515 );
516 caps.set(
517 Caps::SUBGROUP_VERTEX_STAGE,
518 features.contains(wgt::Features::SUBGROUP_VERTEX),
519 );
520 caps.set(
521 Caps::RAY_HIT_VERTEX_POSITION,
522 features.intersects(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN),
523 );
524
525 naga::valid::Validator::new(flags, caps)
526}