wgpu_core/command/
pass.rs

1//! Generic pass functions that both compute and render passes need.
2
3use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
4use crate::command::bind::Binder;
5use crate::command::memory_init::{CommandBufferTextureMemoryActions, SurfacesInDiscardState};
6use crate::command::{CommandBuffer, QueryResetMap, QueryUseError};
7use crate::device::{Device, DeviceError, MissingFeatures};
8use crate::init_tracker::BufferInitTrackerAction;
9use crate::pipeline::LateSizedBufferGroup;
10use crate::ray_tracing::AsAction;
11use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
12use crate::snatch::SnatchGuard;
13use crate::track::{ResourceUsageCompatibilityError, Tracker, UsageScope};
14use crate::{api_log, binding_model};
15use alloc::sync::Arc;
16use alloc::vec::Vec;
17use core::str;
18use thiserror::Error;
19use wgt::error::{ErrorType, WebGpuError};
20use wgt::DynamicOffset;
21
22#[derive(Clone, Debug, Error)]
23#[error(
24    "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
25)]
26pub struct BindGroupIndexOutOfRange {
27    pub index: u32,
28    pub max: u32,
29}
30
31#[derive(Clone, Debug, Error)]
32#[error("Pipeline must be set")]
33pub struct MissingPipeline;
34
35#[derive(Clone, Debug, Error)]
36#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
37pub struct InvalidValuesOffset;
38
39impl WebGpuError for InvalidValuesOffset {
40    fn webgpu_error_type(&self) -> ErrorType {
41        ErrorType::Validation
42    }
43}
44
45#[derive(Clone, Debug, Error)]
46#[error("Cannot pop debug group, because number of pushed debug groups is zero")]
47pub struct InvalidPopDebugGroup;
48
49impl WebGpuError for InvalidPopDebugGroup {
50    fn webgpu_error_type(&self) -> ErrorType {
51        ErrorType::Validation
52    }
53}
54
55pub(crate) struct BaseState<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
56    pub(crate) device: &'cmd_buf Arc<Device>,
57
58    pub(crate) raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
59
60    pub(crate) tracker: &'cmd_buf mut Tracker,
61    pub(crate) buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
62    pub(crate) texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
63    pub(crate) as_actions: &'cmd_buf mut Vec<AsAction>,
64
65    /// Immediate texture inits required because of prior discards. Need to
66    /// be inserted before texture reads.
67    pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
68
69    pub(crate) scope: UsageScope<'scope>,
70
71    pub(crate) binder: Binder,
72
73    pub(crate) temp_offsets: Vec<u32>,
74
75    pub(crate) dynamic_offset_count: usize,
76
77    pub(crate) snatch_guard: &'snatch_guard SnatchGuard<'snatch_guard>,
78
79    pub(crate) debug_scope_depth: u32,
80    pub(crate) string_offset: usize,
81}
82
83pub(crate) fn set_bind_group<E>(
84    state: &mut BaseState,
85    cmd_buf: &CommandBuffer,
86    dynamic_offsets: &[DynamicOffset],
87    index: u32,
88    num_dynamic_offsets: usize,
89    bind_group: Option<Arc<BindGroup>>,
90    merge_bind_groups: bool,
91) -> Result<(), E>
92where
93    E: From<DeviceError>
94        + From<BindGroupIndexOutOfRange>
95        + From<ResourceUsageCompatibilityError>
96        + From<DestroyedResourceError>
97        + From<BindError>,
98{
99    if bind_group.is_none() {
100        api_log!("Pass::set_bind_group {index} None");
101    } else {
102        api_log!(
103            "Pass::set_bind_group {index} {}",
104            bind_group.as_ref().unwrap().error_ident()
105        );
106    }
107
108    let max_bind_groups = state.device.limits.max_bind_groups;
109    if index >= max_bind_groups {
110        return Err(BindGroupIndexOutOfRange {
111            index,
112            max: max_bind_groups,
113        }
114        .into());
115    }
116
117    state.temp_offsets.clear();
118    state.temp_offsets.extend_from_slice(
119        &dynamic_offsets
120            [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
121    );
122    state.dynamic_offset_count += num_dynamic_offsets;
123
124    if bind_group.is_none() {
125        // TODO: Handle bind_group None.
126        return Ok(());
127    }
128
129    let bind_group = bind_group.unwrap();
130    let bind_group = state.tracker.bind_groups.insert_single(bind_group);
131
132    bind_group.same_device_as(cmd_buf)?;
133
134    bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
135
136    if merge_bind_groups {
137        // merge the resource tracker in
138        unsafe {
139            state.scope.merge_bind_group(&bind_group.used)?;
140        }
141    }
142    //Note: stateless trackers are not merged: the lifetime reference
143    // is held to the bind group itself.
144
145    state
146        .buffer_memory_init_actions
147        .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
148            action
149                .buffer
150                .initialization_status
151                .read()
152                .check_action(action)
153        }));
154    for action in bind_group.used_texture_ranges.iter() {
155        state
156            .pending_discard_init_fixups
157            .extend(state.texture_memory_actions.register_init_action(action));
158    }
159
160    let used_resource = bind_group
161        .used
162        .acceleration_structures
163        .into_iter()
164        .map(|tlas| AsAction::UseTlas(tlas.clone()));
165
166    state.as_actions.extend(used_resource);
167
168    let pipeline_layout = state.binder.pipeline_layout.clone();
169    let entries = state
170        .binder
171        .assign_group(index as usize, bind_group, &state.temp_offsets);
172    if !entries.is_empty() && pipeline_layout.is_some() {
173        let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
174        for (i, e) in entries.iter().enumerate() {
175            if let Some(group) = e.group.as_ref() {
176                let raw_bg = group.try_raw(state.snatch_guard)?;
177                unsafe {
178                    state.raw_encoder.set_bind_group(
179                        pipeline_layout,
180                        index + i as u32,
181                        Some(raw_bg),
182                        &e.dynamic_offsets,
183                    );
184                }
185            }
186        }
187    }
188    Ok(())
189}
190
191/// After a pipeline has been changed, resources must be rebound
192pub(crate) fn rebind_resources<E, F: FnOnce()>(
193    state: &mut BaseState,
194    pipeline_layout: &Arc<binding_model::PipelineLayout>,
195    late_sized_buffer_groups: &[LateSizedBufferGroup],
196    f: F,
197) -> Result<(), E>
198where
199    E: From<DestroyedResourceError>,
200{
201    if state.binder.pipeline_layout.is_none()
202        || !state
203            .binder
204            .pipeline_layout
205            .as_ref()
206            .unwrap()
207            .is_equal(pipeline_layout)
208    {
209        let (start_index, entries) = state
210            .binder
211            .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
212        if !entries.is_empty() {
213            for (i, e) in entries.iter().enumerate() {
214                if let Some(group) = e.group.as_ref() {
215                    let raw_bg = group.try_raw(state.snatch_guard)?;
216                    unsafe {
217                        state.raw_encoder.set_bind_group(
218                            pipeline_layout.raw(),
219                            start_index as u32 + i as u32,
220                            Some(raw_bg),
221                            &e.dynamic_offsets,
222                        );
223                    }
224                }
225            }
226        }
227
228        f();
229
230        let non_overlapping =
231            super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
232
233        // Clear push constant ranges
234        for range in non_overlapping {
235            let offset = range.range.start;
236            let size_bytes = range.range.end - offset;
237            super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
238                state.raw_encoder.set_push_constants(
239                    pipeline_layout.raw(),
240                    range.stages,
241                    clear_offset,
242                    clear_data,
243                );
244            });
245        }
246    }
247    Ok(())
248}
249
250pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
251    state: &mut BaseState,
252    push_constant_data: &[u32],
253    stages: wgt::ShaderStages,
254    offset: u32,
255    size_bytes: u32,
256    values_offset: Option<u32>,
257    f: F,
258) -> Result<(), E>
259where
260    E: From<PushConstantUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
261{
262    api_log!("Pass::set_push_constants");
263
264    let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
265
266    let end_offset_bytes = offset + size_bytes;
267    let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
268    let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
269
270    let pipeline_layout = state
271        .binder
272        .pipeline_layout
273        .as_ref()
274        .ok_or(MissingPipeline)?;
275
276    pipeline_layout.validate_push_constant_ranges(stages, offset, end_offset_bytes)?;
277
278    f(data_slice);
279
280    unsafe {
281        state
282            .raw_encoder
283            .set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
284    }
285    Ok(())
286}
287
288pub(crate) fn write_timestamp<E>(
289    state: &mut BaseState,
290    cmd_buf: &CommandBuffer,
291    pending_query_resets: Option<&mut QueryResetMap>,
292    query_set: Arc<QuerySet>,
293    query_index: u32,
294) -> Result<(), E>
295where
296    E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
297{
298    api_log!(
299        "Pass::write_timestamps {query_index} {}",
300        query_set.error_ident()
301    );
302
303    query_set.same_device_as(cmd_buf)?;
304
305    state
306        .device
307        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
308
309    let query_set = state.tracker.query_sets.insert_single(query_set);
310
311    query_set.validate_and_write_timestamp(state.raw_encoder, query_index, pending_query_resets)?;
312    Ok(())
313}
314
315pub(crate) fn push_debug_group(state: &mut BaseState, string_data: &[u8], len: usize) {
316    state.debug_scope_depth += 1;
317    if !state
318        .device
319        .instance_flags
320        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
321    {
322        let label =
323            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
324
325        api_log!("Pass::push_debug_group {label:?}");
326        unsafe {
327            state.raw_encoder.begin_debug_marker(label);
328        }
329    }
330    state.string_offset += len;
331}
332
333pub(crate) fn pop_debug_group<E>(state: &mut BaseState) -> Result<(), E>
334where
335    E: From<InvalidPopDebugGroup>,
336{
337    api_log!("Pass::pop_debug_group");
338
339    if state.debug_scope_depth == 0 {
340        return Err(InvalidPopDebugGroup.into());
341    }
342    state.debug_scope_depth -= 1;
343    if !state
344        .device
345        .instance_flags
346        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
347    {
348        unsafe {
349            state.raw_encoder.end_debug_marker();
350        }
351    }
352    Ok(())
353}
354
355pub(crate) fn insert_debug_marker(state: &mut BaseState, string_data: &[u8], len: usize) {
356    if !state
357        .device
358        .instance_flags
359        .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
360    {
361        let label =
362            str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
363        api_log!("Pass::insert_debug_marker {label:?}");
364        unsafe {
365            state.raw_encoder.insert_debug_marker(label);
366        }
367    }
368    state.string_offset += len;
369}