1use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
4use crate::command::bind::Binder;
5use crate::command::memory_init::{CommandBufferTextureMemoryActions, SurfacesInDiscardState};
6use crate::command::{CommandBuffer, QueryResetMap, QueryUseError};
7use crate::device::{Device, DeviceError, MissingFeatures};
8use crate::init_tracker::BufferInitTrackerAction;
9use crate::pipeline::LateSizedBufferGroup;
10use crate::ray_tracing::AsAction;
11use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
12use crate::snatch::SnatchGuard;
13use crate::track::{ResourceUsageCompatibilityError, Tracker, UsageScope};
14use crate::{api_log, binding_model};
15use alloc::sync::Arc;
16use alloc::vec::Vec;
17use core::str;
18use thiserror::Error;
19use wgt::error::{ErrorType, WebGpuError};
20use wgt::DynamicOffset;
21
22#[derive(Clone, Debug, Error)]
23#[error(
24 "Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}"
25)]
26pub struct BindGroupIndexOutOfRange {
27 pub index: u32,
28 pub max: u32,
29}
30
31#[derive(Clone, Debug, Error)]
32#[error("Pipeline must be set")]
33pub struct MissingPipeline;
34
35#[derive(Clone, Debug, Error)]
36#[error("Setting `values_offset` to be `None` is only for internal use in render bundles")]
37pub struct InvalidValuesOffset;
38
39impl WebGpuError for InvalidValuesOffset {
40 fn webgpu_error_type(&self) -> ErrorType {
41 ErrorType::Validation
42 }
43}
44
45#[derive(Clone, Debug, Error)]
46#[error("Cannot pop debug group, because number of pushed debug groups is zero")]
47pub struct InvalidPopDebugGroup;
48
49impl WebGpuError for InvalidPopDebugGroup {
50 fn webgpu_error_type(&self) -> ErrorType {
51 ErrorType::Validation
52 }
53}
54
55pub(crate) struct BaseState<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
56 pub(crate) device: &'cmd_buf Arc<Device>,
57
58 pub(crate) raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
59
60 pub(crate) tracker: &'cmd_buf mut Tracker,
61 pub(crate) buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
62 pub(crate) texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
63 pub(crate) as_actions: &'cmd_buf mut Vec<AsAction>,
64
65 pub(crate) pending_discard_init_fixups: SurfacesInDiscardState,
68
69 pub(crate) scope: UsageScope<'scope>,
70
71 pub(crate) binder: Binder,
72
73 pub(crate) temp_offsets: Vec<u32>,
74
75 pub(crate) dynamic_offset_count: usize,
76
77 pub(crate) snatch_guard: &'snatch_guard SnatchGuard<'snatch_guard>,
78
79 pub(crate) debug_scope_depth: u32,
80 pub(crate) string_offset: usize,
81}
82
83pub(crate) fn set_bind_group<E>(
84 state: &mut BaseState,
85 cmd_buf: &CommandBuffer,
86 dynamic_offsets: &[DynamicOffset],
87 index: u32,
88 num_dynamic_offsets: usize,
89 bind_group: Option<Arc<BindGroup>>,
90 merge_bind_groups: bool,
91) -> Result<(), E>
92where
93 E: From<DeviceError>
94 + From<BindGroupIndexOutOfRange>
95 + From<ResourceUsageCompatibilityError>
96 + From<DestroyedResourceError>
97 + From<BindError>,
98{
99 if bind_group.is_none() {
100 api_log!("Pass::set_bind_group {index} None");
101 } else {
102 api_log!(
103 "Pass::set_bind_group {index} {}",
104 bind_group.as_ref().unwrap().error_ident()
105 );
106 }
107
108 let max_bind_groups = state.device.limits.max_bind_groups;
109 if index >= max_bind_groups {
110 return Err(BindGroupIndexOutOfRange {
111 index,
112 max: max_bind_groups,
113 }
114 .into());
115 }
116
117 state.temp_offsets.clear();
118 state.temp_offsets.extend_from_slice(
119 &dynamic_offsets
120 [state.dynamic_offset_count..state.dynamic_offset_count + num_dynamic_offsets],
121 );
122 state.dynamic_offset_count += num_dynamic_offsets;
123
124 if bind_group.is_none() {
125 return Ok(());
127 }
128
129 let bind_group = bind_group.unwrap();
130 let bind_group = state.tracker.bind_groups.insert_single(bind_group);
131
132 bind_group.same_device_as(cmd_buf)?;
133
134 bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
135
136 if merge_bind_groups {
137 unsafe {
139 state.scope.merge_bind_group(&bind_group.used)?;
140 }
141 }
142 state
146 .buffer_memory_init_actions
147 .extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
148 action
149 .buffer
150 .initialization_status
151 .read()
152 .check_action(action)
153 }));
154 for action in bind_group.used_texture_ranges.iter() {
155 state
156 .pending_discard_init_fixups
157 .extend(state.texture_memory_actions.register_init_action(action));
158 }
159
160 let used_resource = bind_group
161 .used
162 .acceleration_structures
163 .into_iter()
164 .map(|tlas| AsAction::UseTlas(tlas.clone()));
165
166 state.as_actions.extend(used_resource);
167
168 let pipeline_layout = state.binder.pipeline_layout.clone();
169 let entries = state
170 .binder
171 .assign_group(index as usize, bind_group, &state.temp_offsets);
172 if !entries.is_empty() && pipeline_layout.is_some() {
173 let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
174 for (i, e) in entries.iter().enumerate() {
175 if let Some(group) = e.group.as_ref() {
176 let raw_bg = group.try_raw(state.snatch_guard)?;
177 unsafe {
178 state.raw_encoder.set_bind_group(
179 pipeline_layout,
180 index + i as u32,
181 Some(raw_bg),
182 &e.dynamic_offsets,
183 );
184 }
185 }
186 }
187 }
188 Ok(())
189}
190
191pub(crate) fn rebind_resources<E, F: FnOnce()>(
193 state: &mut BaseState,
194 pipeline_layout: &Arc<binding_model::PipelineLayout>,
195 late_sized_buffer_groups: &[LateSizedBufferGroup],
196 f: F,
197) -> Result<(), E>
198where
199 E: From<DestroyedResourceError>,
200{
201 if state.binder.pipeline_layout.is_none()
202 || !state
203 .binder
204 .pipeline_layout
205 .as_ref()
206 .unwrap()
207 .is_equal(pipeline_layout)
208 {
209 let (start_index, entries) = state
210 .binder
211 .change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
212 if !entries.is_empty() {
213 for (i, e) in entries.iter().enumerate() {
214 if let Some(group) = e.group.as_ref() {
215 let raw_bg = group.try_raw(state.snatch_guard)?;
216 unsafe {
217 state.raw_encoder.set_bind_group(
218 pipeline_layout.raw(),
219 start_index as u32 + i as u32,
220 Some(raw_bg),
221 &e.dynamic_offsets,
222 );
223 }
224 }
225 }
226 }
227
228 f();
229
230 let non_overlapping =
231 super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
232
233 for range in non_overlapping {
235 let offset = range.range.start;
236 let size_bytes = range.range.end - offset;
237 super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
238 state.raw_encoder.set_push_constants(
239 pipeline_layout.raw(),
240 range.stages,
241 clear_offset,
242 clear_data,
243 );
244 });
245 }
246 }
247 Ok(())
248}
249
250pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
251 state: &mut BaseState,
252 push_constant_data: &[u32],
253 stages: wgt::ShaderStages,
254 offset: u32,
255 size_bytes: u32,
256 values_offset: Option<u32>,
257 f: F,
258) -> Result<(), E>
259where
260 E: From<PushConstantUploadError> + From<InvalidValuesOffset> + From<MissingPipeline>,
261{
262 api_log!("Pass::set_push_constants");
263
264 let values_offset = values_offset.ok_or(InvalidValuesOffset)?;
265
266 let end_offset_bytes = offset + size_bytes;
267 let values_end_offset = (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
268 let data_slice = &push_constant_data[(values_offset as usize)..values_end_offset];
269
270 let pipeline_layout = state
271 .binder
272 .pipeline_layout
273 .as_ref()
274 .ok_or(MissingPipeline)?;
275
276 pipeline_layout.validate_push_constant_ranges(stages, offset, end_offset_bytes)?;
277
278 f(data_slice);
279
280 unsafe {
281 state
282 .raw_encoder
283 .set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
284 }
285 Ok(())
286}
287
288pub(crate) fn write_timestamp<E>(
289 state: &mut BaseState,
290 cmd_buf: &CommandBuffer,
291 pending_query_resets: Option<&mut QueryResetMap>,
292 query_set: Arc<QuerySet>,
293 query_index: u32,
294) -> Result<(), E>
295where
296 E: From<MissingFeatures> + From<QueryUseError> + From<DeviceError>,
297{
298 api_log!(
299 "Pass::write_timestamps {query_index} {}",
300 query_set.error_ident()
301 );
302
303 query_set.same_device_as(cmd_buf)?;
304
305 state
306 .device
307 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
308
309 let query_set = state.tracker.query_sets.insert_single(query_set);
310
311 query_set.validate_and_write_timestamp(state.raw_encoder, query_index, pending_query_resets)?;
312 Ok(())
313}
314
315pub(crate) fn push_debug_group(state: &mut BaseState, string_data: &[u8], len: usize) {
316 state.debug_scope_depth += 1;
317 if !state
318 .device
319 .instance_flags
320 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
321 {
322 let label =
323 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
324
325 api_log!("Pass::push_debug_group {label:?}");
326 unsafe {
327 state.raw_encoder.begin_debug_marker(label);
328 }
329 }
330 state.string_offset += len;
331}
332
333pub(crate) fn pop_debug_group<E>(state: &mut BaseState) -> Result<(), E>
334where
335 E: From<InvalidPopDebugGroup>,
336{
337 api_log!("Pass::pop_debug_group");
338
339 if state.debug_scope_depth == 0 {
340 return Err(InvalidPopDebugGroup.into());
341 }
342 state.debug_scope_depth -= 1;
343 if !state
344 .device
345 .instance_flags
346 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
347 {
348 unsafe {
349 state.raw_encoder.end_debug_marker();
350 }
351 }
352 Ok(())
353}
354
355pub(crate) fn insert_debug_marker(state: &mut BaseState, string_data: &[u8], len: usize) {
356 if !state
357 .device
358 .instance_flags
359 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
360 {
361 let label =
362 str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
363 api_log!("Pass::insert_debug_marker {label:?}");
364 unsafe {
365 state.raw_encoder.insert_debug_marker(label);
366 }
367 }
368 state.string_offset += len;
369}