wgpu_core/indirect_validation/
dispatch.rs1use super::CreateIndirectValidationPipelineError;
2use crate::{
3 device::DeviceError,
4 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
5};
6use alloc::{boxed::Box, format, string::ToString as _};
7use core::num::NonZeroU64;
8
9#[derive(Debug)]
21pub(crate) struct Dispatch {
22 module: Box<dyn hal::DynShaderModule>,
23 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
24 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
25 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
26 pipeline: Box<dyn hal::DynComputePipeline>,
27 dst_buffer: Box<dyn hal::DynBuffer>,
28 dst_bind_group: Box<dyn hal::DynBindGroup>,
29}
30
31pub struct Params<'a> {
32 pub pipeline_layout: &'a dyn hal::DynPipelineLayout,
33 pub pipeline: &'a dyn hal::DynComputePipeline,
34 pub dst_buffer: &'a dyn hal::DynBuffer,
35 pub dst_bind_group: &'a dyn hal::DynBindGroup,
36 pub aligned_offset: u64,
37 pub offset_remainder: u64,
38}
39
40impl Dispatch {
41 pub(super) fn new(
42 device: &dyn hal::DynDevice,
43 limits: &wgt::Limits,
44 ) -> Result<Self, CreateIndirectValidationPipelineError> {
45 let max_compute_workgroups_per_dimension = limits.max_compute_workgroups_per_dimension;
46
47 let src = format!(
48 "
49 @group(0) @binding(0)
50 var<storage, read_write> dst: array<u32, 6>;
51 @group(1) @binding(0)
52 var<storage, read> src: array<u32>;
53 struct OffsetPc {{
54 inner: u32,
55 }}
56 var<push_constant> offset: OffsetPc;
57
58 @compute @workgroup_size(1)
59 fn main() {{
60 let src = vec3(src[offset.inner], src[offset.inner + 1], src[offset.inner + 2]);
61 let max_compute_workgroups_per_dimension = {max_compute_workgroups_per_dimension}u;
62 if (
63 src.x > max_compute_workgroups_per_dimension ||
64 src.y > max_compute_workgroups_per_dimension ||
65 src.z > max_compute_workgroups_per_dimension
66 ) {{
67 dst = array(0u, 0u, 0u, 0u, 0u, 0u);
68 }} else {{
69 dst = array(src.x, src.y, src.z, src.x, src.y, src.z);
70 }}
71 }}
72 "
73 );
74
75 const SRC_BUFFER_SIZE: NonZeroU64 =
77 unsafe { NonZeroU64::new_unchecked(size_of::<u32>() as u64 * 3) };
78
79 const DST_BUFFER_SIZE: NonZeroU64 = unsafe {
81 NonZeroU64::new_unchecked(
82 SRC_BUFFER_SIZE.get() * 2, )
84 };
85
86 #[cfg(feature = "wgsl")]
87 let module = naga::front::wgsl::parse_str(&src).map_err(|inner| {
88 CreateShaderModuleError::Parsing(naga::error::ShaderError {
89 source: src.clone(),
90 label: None,
91 inner: Box::new(inner),
92 })
93 })?;
94 #[cfg(not(feature = "wgsl"))]
95 #[allow(clippy::diverging_sub_expression)]
96 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
97
98 let info = crate::device::create_validator(
99 wgt::Features::PUSH_CONSTANTS,
100 wgt::DownlevelFlags::empty(),
101 naga::valid::ValidationFlags::all(),
102 )
103 .validate(&module)
104 .map_err(|inner| {
105 CreateShaderModuleError::Validation(naga::error::ShaderError {
106 source: src,
107 label: None,
108 inner: Box::new(inner),
109 })
110 })?;
111 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
112 module: alloc::borrow::Cow::Owned(module),
113 info,
114 debug_source: None,
115 });
116 let hal_desc = hal::ShaderModuleDescriptor {
117 label: None,
118 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
119 };
120 let module =
121 unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(|error| {
122 match error {
123 hal::ShaderError::Device(error) => {
124 CreateShaderModuleError::Device(DeviceError::from_hal(error))
125 }
126 hal::ShaderError::Compilation(ref msg) => {
127 log::error!("Shader error: {}", msg);
128 CreateShaderModuleError::Generation
129 }
130 }
131 })?;
132
133 let dst_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
134 label: None,
135 flags: hal::BindGroupLayoutFlags::empty(),
136 entries: &[wgt::BindGroupLayoutEntry {
137 binding: 0,
138 visibility: wgt::ShaderStages::COMPUTE,
139 ty: wgt::BindingType::Buffer {
140 ty: wgt::BufferBindingType::Storage { read_only: false },
141 has_dynamic_offset: false,
142 min_binding_size: Some(DST_BUFFER_SIZE),
143 },
144 count: None,
145 }],
146 };
147 let dst_bind_group_layout = unsafe {
148 device
149 .create_bind_group_layout(&dst_bind_group_layout_desc)
150 .map_err(DeviceError::from_hal)?
151 };
152
153 let src_bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
154 label: None,
155 flags: hal::BindGroupLayoutFlags::empty(),
156 entries: &[wgt::BindGroupLayoutEntry {
157 binding: 0,
158 visibility: wgt::ShaderStages::COMPUTE,
159 ty: wgt::BindingType::Buffer {
160 ty: wgt::BufferBindingType::Storage { read_only: true },
161 has_dynamic_offset: true,
162 min_binding_size: Some(SRC_BUFFER_SIZE),
163 },
164 count: None,
165 }],
166 };
167 let src_bind_group_layout = unsafe {
168 device
169 .create_bind_group_layout(&src_bind_group_layout_desc)
170 .map_err(DeviceError::from_hal)?
171 };
172
173 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
174 label: None,
175 flags: hal::PipelineLayoutFlags::empty(),
176 bind_group_layouts: &[
177 dst_bind_group_layout.as_ref(),
178 src_bind_group_layout.as_ref(),
179 ],
180 push_constant_ranges: &[wgt::PushConstantRange {
181 stages: wgt::ShaderStages::COMPUTE,
182 range: 0..4,
183 }],
184 };
185 let pipeline_layout = unsafe {
186 device
187 .create_pipeline_layout(&pipeline_layout_desc)
188 .map_err(DeviceError::from_hal)?
189 };
190
191 let pipeline_desc = hal::ComputePipelineDescriptor {
192 label: None,
193 layout: pipeline_layout.as_ref(),
194 stage: hal::ProgrammableStage {
195 module: module.as_ref(),
196 entry_point: "main",
197 constants: &Default::default(),
198 zero_initialize_workgroup_memory: false,
199 },
200 cache: None,
201 };
202 let pipeline =
203 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
204 hal::PipelineError::Device(error) => {
205 CreateComputePipelineError::Device(DeviceError::from_hal(error))
206 }
207 hal::PipelineError::Linkage(_stages, msg) => {
208 CreateComputePipelineError::Internal(msg)
209 }
210 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
211 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
212 ),
213 hal::PipelineError::PipelineConstants(_, error) => {
214 CreateComputePipelineError::PipelineConstants(error)
215 }
216 })?;
217
218 let dst_buffer_desc = hal::BufferDescriptor {
219 label: None,
220 size: DST_BUFFER_SIZE.get(),
221 usage: wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE,
222 memory_flags: hal::MemoryFlags::empty(),
223 };
224 let dst_buffer =
225 unsafe { device.create_buffer(&dst_buffer_desc) }.map_err(DeviceError::from_hal)?;
226
227 let dst_bind_group_desc = hal::BindGroupDescriptor {
228 label: None,
229 layout: dst_bind_group_layout.as_ref(),
230 entries: &[hal::BindGroupEntry {
231 binding: 0,
232 resource_index: 0,
233 count: 1,
234 }],
235 buffers: &[hal::BufferBinding {
236 buffer: dst_buffer.as_ref(),
237 offset: 0,
238 size: Some(DST_BUFFER_SIZE),
239 }],
240 samplers: &[],
241 textures: &[],
242 acceleration_structures: &[],
243 };
244 let dst_bind_group = unsafe {
245 device
246 .create_bind_group(&dst_bind_group_desc)
247 .map_err(DeviceError::from_hal)
248 }?;
249
250 Ok(Self {
251 module,
252 dst_bind_group_layout,
253 src_bind_group_layout,
254 pipeline_layout,
255 pipeline,
256 dst_buffer,
257 dst_bind_group,
258 })
259 }
260
261 pub(super) fn create_src_bind_group(
263 &self,
264 device: &dyn hal::DynDevice,
265 limits: &wgt::Limits,
266 buffer_size: u64,
267 buffer: &dyn hal::DynBuffer,
268 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
269 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
270 let Some(binding_size) = NonZeroU64::new(binding_size) else {
271 return Ok(None);
272 };
273 let hal_desc = hal::BindGroupDescriptor {
274 label: None,
275 layout: self.src_bind_group_layout.as_ref(),
276 entries: &[hal::BindGroupEntry {
277 binding: 0,
278 resource_index: 0,
279 count: 1,
280 }],
281 buffers: &[hal::BufferBinding {
282 buffer,
283 offset: 0,
284 size: Some(binding_size),
285 }],
286 samplers: &[],
287 textures: &[],
288 acceleration_structures: &[],
289 };
290 unsafe {
291 device
292 .create_bind_group(&hal_desc)
293 .map(Some)
294 .map_err(DeviceError::from_hal)
295 }
296 }
297
298 pub fn params<'a>(&'a self, limits: &wgt::Limits, offset: u64, buffer_size: u64) -> Params<'a> {
299 let alignment = limits.min_storage_buffer_offset_alignment as u64;
314 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
315 let aligned_offset = offset - offset % alignment;
316 let max_aligned_offset = buffer_size - binding_size;
318 let aligned_offset = aligned_offset.min(max_aligned_offset);
319 let offset_remainder = offset - aligned_offset;
320
321 Params {
322 pipeline_layout: self.pipeline_layout.as_ref(),
323 pipeline: self.pipeline.as_ref(),
324 dst_buffer: self.dst_buffer.as_ref(),
325 dst_bind_group: self.dst_bind_group.as_ref(),
326 aligned_offset,
327 offset_remainder,
328 }
329 }
330
331 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
332 let Dispatch {
333 module,
334 dst_bind_group_layout,
335 src_bind_group_layout,
336 pipeline_layout,
337 pipeline,
338 dst_buffer,
339 dst_bind_group,
340 } = self;
341
342 unsafe {
343 device.destroy_bind_group(dst_bind_group);
344 device.destroy_buffer(dst_buffer);
345 device.destroy_compute_pipeline(pipeline);
346 device.destroy_pipeline_layout(pipeline_layout);
347 device.destroy_bind_group_layout(src_bind_group_layout);
348 device.destroy_bind_group_layout(dst_bind_group_layout);
349 device.destroy_shader_module(module);
350 }
351 }
352}
353
354fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &wgt::Limits) -> u64 {
355 let alignment = limits.min_storage_buffer_offset_alignment as u64;
356
357 let binding_size = 2 * alignment + (buffer_size % alignment);
388 binding_size.min(buffer_size)
389}