wgpu_core/device/
mod.rs

1use alloc::{boxed::Box, string::String, vec::Vec};
2use core::{fmt, num::NonZeroU32};
3
4use crate::{
5    binding_model,
6    hub::Hub,
7    id::{BindGroupLayoutId, PipelineLayoutId},
8    ray_tracing::BlasCompactReadyPendingClosure,
9    resource::{
10        Buffer, BufferAccessError, BufferAccessResult, BufferMapOperation, Labeled,
11        RawResourceAccess, ResourceErrorIdent,
12    },
13    snatch::SnatchGuard,
14    Label, DOWNLEVEL_ERROR_MESSAGE,
15};
16
17use arrayvec::ArrayVec;
18use smallvec::SmallVec;
19use thiserror::Error;
20use wgt::{
21    error::{ErrorType, WebGpuError},
22    BufferAddress, DeviceLostReason, TextureFormat,
23};
24
25pub(crate) mod bgl;
26pub mod global;
27mod life;
28pub mod queue;
29pub mod ray_tracing;
30pub mod resource;
31#[cfg(any(feature = "trace", feature = "replay"))]
32pub mod trace;
33pub use {life::WaitIdleError, resource::Device};
34
35pub const SHADER_STAGE_COUNT: usize = hal::MAX_CONCURRENT_SHADER_STAGES;
36// Should be large enough for the largest possible texture row. This
37// value is enough for a 16k texture with float4 format.
38pub(crate) const ZERO_BUFFER_SIZE: BufferAddress = 512 << 10;
39
40// If a submission is not completed within this time, we go off into UB land.
41// See https://github.com/gfx-rs/wgpu/issues/4589. 60s to reduce the chances of this.
42const CLEANUP_WAIT_MS: u32 = 60000;
43
44pub(crate) const ENTRYPOINT_FAILURE_ERROR: &str = "The given EntryPoint is Invalid";
45
46pub type DeviceDescriptor<'a> = wgt::DeviceDescriptor<Label<'a>>;
47
48#[repr(C)]
49#[derive(Clone, Copy, Debug, Eq, PartialEq)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51pub enum HostMap {
52    Read,
53    Write,
54}
55
56#[derive(Clone, Debug, Hash, PartialEq)]
57#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
58pub(crate) struct AttachmentData<T> {
59    pub colors: ArrayVec<Option<T>, { hal::MAX_COLOR_ATTACHMENTS }>,
60    pub resolves: ArrayVec<T, { hal::MAX_COLOR_ATTACHMENTS }>,
61    pub depth_stencil: Option<T>,
62}
63impl<T: PartialEq> Eq for AttachmentData<T> {}
64
65#[derive(Clone, Debug, Hash, PartialEq)]
66#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
67pub(crate) struct RenderPassContext {
68    pub attachments: AttachmentData<TextureFormat>,
69    pub sample_count: u32,
70    pub multiview: Option<NonZeroU32>,
71}
72#[derive(Clone, Debug, Error)]
73#[non_exhaustive]
74pub enum RenderPassCompatibilityError {
75    #[error(
76        "Incompatible color attachments at indices {indices:?}: the RenderPass uses textures with formats {expected:?} but the {res} uses attachments with formats {actual:?}",
77    )]
78    IncompatibleColorAttachment {
79        indices: Vec<usize>,
80        expected: Vec<Option<TextureFormat>>,
81        actual: Vec<Option<TextureFormat>>,
82        res: ResourceErrorIdent,
83    },
84    #[error(
85        "Incompatible depth-stencil attachment format: the RenderPass uses a texture with format {expected:?} but the {res} uses an attachment with format {actual:?}",
86    )]
87    IncompatibleDepthStencilAttachment {
88        expected: Option<TextureFormat>,
89        actual: Option<TextureFormat>,
90        res: ResourceErrorIdent,
91    },
92    #[error(
93        "Incompatible sample count: the RenderPass uses textures with sample count {expected:?} but the {res} uses attachments with format {actual:?}",
94    )]
95    IncompatibleSampleCount {
96        expected: u32,
97        actual: u32,
98        res: ResourceErrorIdent,
99    },
100    #[error("Incompatible multiview setting: the RenderPass uses setting {expected:?} but the {res} uses setting {actual:?}")]
101    IncompatibleMultiview {
102        expected: Option<NonZeroU32>,
103        actual: Option<NonZeroU32>,
104        res: ResourceErrorIdent,
105    },
106}
107
108impl WebGpuError for RenderPassCompatibilityError {
109    fn webgpu_error_type(&self) -> ErrorType {
110        ErrorType::Validation
111    }
112}
113
114impl RenderPassContext {
115    // Assumes the renderpass only contains one subpass
116    pub(crate) fn check_compatible<T: Labeled>(
117        &self,
118        other: &Self,
119        res: &T,
120    ) -> Result<(), RenderPassCompatibilityError> {
121        if self.attachments.colors != other.attachments.colors {
122            let indices = self
123                .attachments
124                .colors
125                .iter()
126                .zip(&other.attachments.colors)
127                .enumerate()
128                .filter_map(|(idx, (left, right))| (left != right).then_some(idx))
129                .collect();
130            return Err(RenderPassCompatibilityError::IncompatibleColorAttachment {
131                indices,
132                expected: self.attachments.colors.iter().cloned().collect(),
133                actual: other.attachments.colors.iter().cloned().collect(),
134                res: res.error_ident(),
135            });
136        }
137        if self.attachments.depth_stencil != other.attachments.depth_stencil {
138            return Err(
139                RenderPassCompatibilityError::IncompatibleDepthStencilAttachment {
140                    expected: self.attachments.depth_stencil,
141                    actual: other.attachments.depth_stencil,
142                    res: res.error_ident(),
143                },
144            );
145        }
146        if self.sample_count != other.sample_count {
147            return Err(RenderPassCompatibilityError::IncompatibleSampleCount {
148                expected: self.sample_count,
149                actual: other.sample_count,
150                res: res.error_ident(),
151            });
152        }
153        if self.multiview != other.multiview {
154            return Err(RenderPassCompatibilityError::IncompatibleMultiview {
155                expected: self.multiview,
156                actual: other.multiview,
157                res: res.error_ident(),
158            });
159        }
160        Ok(())
161    }
162}
163
164pub type BufferMapPendingClosure = (BufferMapOperation, BufferAccessResult);
165
166#[derive(Default)]
167pub struct UserClosures {
168    pub mappings: Vec<BufferMapPendingClosure>,
169    pub blas_compact_ready: Vec<BlasCompactReadyPendingClosure>,
170    pub submissions: SmallVec<[queue::SubmittedWorkDoneClosure; 1]>,
171    pub device_lost_invocations: SmallVec<[DeviceLostInvocation; 1]>,
172}
173
174impl UserClosures {
175    fn extend(&mut self, other: Self) {
176        self.mappings.extend(other.mappings);
177        self.blas_compact_ready.extend(other.blas_compact_ready);
178        self.submissions.extend(other.submissions);
179        self.device_lost_invocations
180            .extend(other.device_lost_invocations);
181    }
182
183    fn fire(self) {
184        // Note: this logic is specifically moved out of `handle_mapping()` in order to
185        // have nothing locked by the time we execute users callback code.
186
187        // Mappings _must_ be fired before submissions, as the spec requires all mapping callbacks that are registered before
188        // a on_submitted_work_done callback to be fired before the on_submitted_work_done callback.
189        for (mut operation, status) in self.mappings {
190            if let Some(callback) = operation.callback.take() {
191                callback(status);
192            }
193        }
194        for (mut operation, status) in self.blas_compact_ready {
195            if let Some(callback) = operation.take() {
196                callback(status);
197            }
198        }
199        for closure in self.submissions {
200            closure();
201        }
202        for invocation in self.device_lost_invocations {
203            (invocation.closure)(invocation.reason, invocation.message);
204        }
205    }
206}
207
208#[cfg(send_sync)]
209pub type DeviceLostClosure = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
210#[cfg(not(send_sync))]
211pub type DeviceLostClosure = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;
212
213pub struct DeviceLostInvocation {
214    closure: DeviceLostClosure,
215    reason: DeviceLostReason,
216    message: String,
217}
218
219pub(crate) fn map_buffer(
220    buffer: &Buffer,
221    offset: BufferAddress,
222    size: BufferAddress,
223    kind: HostMap,
224    snatch_guard: &SnatchGuard,
225) -> Result<hal::BufferMapping, BufferAccessError> {
226    let raw_device = buffer.device.raw();
227    let raw_buffer = buffer.try_raw(snatch_guard)?;
228    let mapping = unsafe {
229        raw_device
230            .map_buffer(raw_buffer, offset..offset + size)
231            .map_err(|e| buffer.device.handle_hal_error(e))?
232    };
233
234    if !mapping.is_coherent && kind == HostMap::Read {
235        #[allow(clippy::single_range_in_vec_init)]
236        unsafe {
237            raw_device.invalidate_mapped_ranges(raw_buffer, &[offset..offset + size]);
238        }
239    }
240
241    assert_eq!(offset % wgt::COPY_BUFFER_ALIGNMENT, 0);
242    assert_eq!(size % wgt::COPY_BUFFER_ALIGNMENT, 0);
243    // Zero out uninitialized parts of the mapping. (Spec dictates all resources
244    // behave as if they were initialized with zero)
245    //
246    // If this is a read mapping, ideally we would use a `clear_buffer` command
247    // before reading the data from GPU (i.e. `invalidate_range`). However, this
248    // would require us to kick off and wait for a command buffer or piggy back
249    // on an existing one (the later is likely the only worthwhile option). As
250    // reading uninitialized memory isn't a particular important path to
251    // support, we instead just initialize the memory here and make sure it is
252    // GPU visible, so this happens at max only once for every buffer region.
253    //
254    // If this is a write mapping zeroing out the memory here is the only
255    // reasonable way as all data is pushed to GPU anyways.
256
257    let mapped = unsafe { core::slice::from_raw_parts_mut(mapping.ptr.as_ptr(), size as usize) };
258
259    // We can't call flush_mapped_ranges in this case, so we can't drain the uninitialized ranges either
260    if !mapping.is_coherent
261        && kind == HostMap::Read
262        && !buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
263    {
264        for uninitialized in buffer
265            .initialization_status
266            .write()
267            .uninitialized(offset..(size + offset))
268        {
269            // The mapping's pointer is already offset, however we track the
270            // uninitialized range relative to the buffer's start.
271            let fill_range =
272                (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
273            mapped[fill_range].fill(0);
274        }
275    } else {
276        for uninitialized in buffer
277            .initialization_status
278            .write()
279            .drain(offset..(size + offset))
280        {
281            // The mapping's pointer is already offset, however we track the
282            // uninitialized range relative to the buffer's start.
283            let fill_range =
284                (uninitialized.start - offset) as usize..(uninitialized.end - offset) as usize;
285            mapped[fill_range].fill(0);
286
287            // NOTE: This is only possible when MAPPABLE_PRIMARY_BUFFERS is enabled.
288            if !mapping.is_coherent
289                && kind == HostMap::Read
290                && buffer.usage.contains(wgt::BufferUsages::MAP_WRITE)
291            {
292                unsafe { raw_device.flush_mapped_ranges(raw_buffer, &[uninitialized]) };
293            }
294        }
295    }
296
297    Ok(mapping)
298}
299
300#[derive(Clone, Debug)]
301#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
302pub struct DeviceMismatch {
303    pub(super) res: ResourceErrorIdent,
304    pub(super) res_device: ResourceErrorIdent,
305    pub(super) target: Option<ResourceErrorIdent>,
306    pub(super) target_device: ResourceErrorIdent,
307}
308
309impl fmt::Display for DeviceMismatch {
310    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
311        write!(
312            f,
313            "{} of {} doesn't match {}",
314            self.res_device, self.res, self.target_device
315        )?;
316        if let Some(target) = self.target.as_ref() {
317            write!(f, " of {target}")?;
318        }
319        Ok(())
320    }
321}
322
323impl core::error::Error for DeviceMismatch {}
324
325impl WebGpuError for DeviceMismatch {
326    fn webgpu_error_type(&self) -> ErrorType {
327        ErrorType::Validation
328    }
329}
330
331#[derive(Clone, Debug, Error)]
332#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
333#[non_exhaustive]
334pub enum DeviceError {
335    #[error("Parent device is lost")]
336    Lost,
337    #[error("Not enough memory left.")]
338    OutOfMemory,
339    #[error(transparent)]
340    DeviceMismatch(#[from] Box<DeviceMismatch>),
341}
342
343impl WebGpuError for DeviceError {
344    fn webgpu_error_type(&self) -> ErrorType {
345        match self {
346            Self::DeviceMismatch(e) => e.webgpu_error_type(),
347            Self::Lost => ErrorType::DeviceLost,
348            Self::OutOfMemory => ErrorType::OutOfMemory,
349        }
350    }
351}
352
353impl DeviceError {
354    /// Only use this function in contexts where there is no `Device`.
355    ///
356    /// Use [`Device::handle_hal_error`] otherwise.
357    pub fn from_hal(error: hal::DeviceError) -> Self {
358        match error {
359            hal::DeviceError::Lost => Self::Lost,
360            hal::DeviceError::OutOfMemory => Self::OutOfMemory,
361            hal::DeviceError::Unexpected => Self::Lost,
362        }
363    }
364}
365
366#[derive(Clone, Debug, Error)]
367#[error("Features {0:?} are required but not enabled on the device")]
368pub struct MissingFeatures(pub wgt::Features);
369
370impl WebGpuError for MissingFeatures {
371    fn webgpu_error_type(&self) -> ErrorType {
372        ErrorType::Validation
373    }
374}
375
376#[derive(Clone, Debug, Error)]
377#[error(
378    "Downlevel flags {0:?} are required but not supported on the device.\n{DOWNLEVEL_ERROR_MESSAGE}",
379)]
380pub struct MissingDownlevelFlags(pub wgt::DownlevelFlags);
381
382impl WebGpuError for MissingDownlevelFlags {
383    fn webgpu_error_type(&self) -> ErrorType {
384        ErrorType::Validation
385    }
386}
387
388#[derive(Clone, Debug)]
389#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
390pub struct ImplicitPipelineContext {
391    pub root_id: PipelineLayoutId,
392    pub group_ids: ArrayVec<BindGroupLayoutId, { hal::MAX_BIND_GROUPS }>,
393}
394
395pub struct ImplicitPipelineIds<'a> {
396    pub root_id: PipelineLayoutId,
397    pub group_ids: &'a [BindGroupLayoutId],
398}
399
400impl ImplicitPipelineIds<'_> {
401    fn prepare(self, hub: &Hub) -> ImplicitPipelineContext {
402        ImplicitPipelineContext {
403            root_id: hub.pipeline_layouts.prepare(Some(self.root_id)).id(),
404            group_ids: self
405                .group_ids
406                .iter()
407                .map(|id_in| hub.bind_group_layouts.prepare(Some(*id_in)).id())
408                .collect(),
409        }
410    }
411}
412
413/// Create a validator with the given validation flags.
414pub fn create_validator(
415    features: wgt::Features,
416    downlevel: wgt::DownlevelFlags,
417    flags: naga::valid::ValidationFlags,
418) -> naga::valid::Validator {
419    use naga::valid::Capabilities as Caps;
420    let mut caps = Caps::empty();
421    caps.set(
422        Caps::PUSH_CONSTANT,
423        features.contains(wgt::Features::PUSH_CONSTANTS),
424    );
425    caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64));
426    caps.set(
427        Caps::SHADER_FLOAT16,
428        features.contains(wgt::Features::SHADER_F16),
429    );
430    caps.set(
431        Caps::PRIMITIVE_INDEX,
432        features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX),
433    );
434    caps.set(
435        Caps::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
436        features
437            .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
438    );
439    caps.set(
440        Caps::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING,
441        features.contains(wgt::Features::STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING),
442    );
443    caps.set(
444        Caps::UNIFORM_BUFFER_ARRAY_NON_UNIFORM_INDEXING,
445        features.contains(wgt::Features::UNIFORM_BUFFER_BINDING_ARRAYS),
446    );
447    // TODO: This needs a proper wgpu feature
448    caps.set(
449        Caps::SAMPLER_NON_UNIFORM_INDEXING,
450        features
451            .contains(wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING),
452    );
453    caps.set(
454        Caps::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
455        features.contains(wgt::Features::TEXTURE_FORMAT_16BIT_NORM),
456    );
457    caps.set(Caps::MULTIVIEW, features.contains(wgt::Features::MULTIVIEW));
458    caps.set(
459        Caps::EARLY_DEPTH_TEST,
460        features.contains(wgt::Features::SHADER_EARLY_DEPTH_TEST),
461    );
462    caps.set(
463        Caps::SHADER_INT64,
464        features.contains(wgt::Features::SHADER_INT64),
465    );
466    caps.set(
467        Caps::SHADER_INT64_ATOMIC_MIN_MAX,
468        features.intersects(
469            wgt::Features::SHADER_INT64_ATOMIC_MIN_MAX | wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS,
470        ),
471    );
472    caps.set(
473        Caps::SHADER_INT64_ATOMIC_ALL_OPS,
474        features.contains(wgt::Features::SHADER_INT64_ATOMIC_ALL_OPS),
475    );
476    caps.set(
477        Caps::TEXTURE_ATOMIC,
478        features.contains(wgt::Features::TEXTURE_ATOMIC),
479    );
480    caps.set(
481        Caps::TEXTURE_INT64_ATOMIC,
482        features.contains(wgt::Features::TEXTURE_INT64_ATOMIC),
483    );
484    caps.set(
485        Caps::SHADER_FLOAT32_ATOMIC,
486        features.contains(wgt::Features::SHADER_FLOAT32_ATOMIC),
487    );
488    caps.set(
489        Caps::MULTISAMPLED_SHADING,
490        downlevel.contains(wgt::DownlevelFlags::MULTISAMPLED_SHADING),
491    );
492    caps.set(
493        Caps::DUAL_SOURCE_BLENDING,
494        features.contains(wgt::Features::DUAL_SOURCE_BLENDING),
495    );
496    caps.set(
497        Caps::CLIP_DISTANCE,
498        features.contains(wgt::Features::CLIP_DISTANCES),
499    );
500    caps.set(
501        Caps::CUBE_ARRAY_TEXTURES,
502        downlevel.contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES),
503    );
504    caps.set(
505        Caps::SUBGROUP,
506        features.intersects(wgt::Features::SUBGROUP | wgt::Features::SUBGROUP_VERTEX),
507    );
508    caps.set(
509        Caps::SUBGROUP_BARRIER,
510        features.intersects(wgt::Features::SUBGROUP_BARRIER),
511    );
512    caps.set(
513        Caps::RAY_QUERY,
514        features.intersects(wgt::Features::EXPERIMENTAL_RAY_QUERY),
515    );
516    caps.set(
517        Caps::SUBGROUP_VERTEX_STAGE,
518        features.contains(wgt::Features::SUBGROUP_VERTEX),
519    );
520    caps.set(
521        Caps::RAY_HIT_VERTEX_POSITION,
522        features.intersects(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN),
523    );
524
525    naga::valid::Validator::new(flags, caps)
526}