wgpu_core/indirect_validation/
draw.rs

1use super::{
2    utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3    CreateIndirectValidationPipelineError,
4};
5use crate::{
6    device::{queue::TempResource, Device, DeviceError},
7    lock::{rank, Mutex},
8    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
9    resource::{StagingBuffer, Trackable},
10    snatch::SnatchGuard,
11    track::TrackerIndex,
12    FastHashMap,
13};
14use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
15use core::{
16    mem::{size_of, size_of_val},
17    num::NonZeroU64,
18};
19use wgt::Limits;
20
21/// Note: This needs to be under:
22///
23/// default max_compute_workgroups_per_dimension * size_of::<wgt::DrawIndirectArgs>() * `workgroup_size` used by the shader
24///
25/// = (2^16 - 1) * 2^4 * 2^6
26///
27/// It is currently set to:
28///
29/// = (2^16 - 1) * 2^4
30///
31/// This is enough space for:
32///
33/// - 65535 [`wgt::DrawIndirectArgs`] / [`MetadataEntry`]
34/// - 52428 [`wgt::DrawIndexedIndirectArgs`]
35const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
36
37/// Holds all device-level resources that are needed to validate indirect draws.
38///
39/// This machinery requires the following limits:
40///
41/// - max_bind_groups: 3,
42/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
43/// - max_storage_buffers_per_shader_stage: 3,
44/// - max_push_constant_size: 8,
45///
46/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
47/// required for this module's functionality to work.
48#[derive(Debug)]
49pub(crate) struct Draw {
50    module: Box<dyn hal::DynShaderModule>,
51    metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
55    pipeline: Box<dyn hal::DynComputePipeline>,
56
57    free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
58    free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
59}
60
61impl Draw {
62    pub(super) fn new(
63        device: &dyn hal::DynDevice,
64        required_features: &wgt::Features,
65        backend: wgt::Backend,
66    ) -> Result<Self, CreateIndirectValidationPipelineError> {
67        let module = create_validation_module(device)?;
68
69        let metadata_bind_group_layout =
70            create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
71        let src_bind_group_layout =
72            create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
73        let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
74
75        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
76            label: None,
77            flags: hal::PipelineLayoutFlags::empty(),
78            bind_group_layouts: &[
79                metadata_bind_group_layout.as_ref(),
80                src_bind_group_layout.as_ref(),
81                dst_bind_group_layout.as_ref(),
82            ],
83            push_constant_ranges: &[wgt::PushConstantRange {
84                stages: wgt::ShaderStages::COMPUTE,
85                range: 0..8,
86            }],
87        };
88        let pipeline_layout = unsafe {
89            device
90                .create_pipeline_layout(&pipeline_layout_desc)
91                .map_err(DeviceError::from_hal)?
92        };
93
94        let supports_indirect_first_instance =
95            required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
96        let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
97        let pipeline = create_validation_pipeline(
98            device,
99            module.as_ref(),
100            pipeline_layout.as_ref(),
101            supports_indirect_first_instance,
102            write_d3d12_special_constants,
103        )?;
104
105        Ok(Self {
106            module,
107            metadata_bind_group_layout,
108            src_bind_group_layout,
109            dst_bind_group_layout,
110            pipeline_layout,
111            pipeline,
112
113            free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
114            free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
115        })
116    }
117
118    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
119    pub(super) fn create_src_bind_group(
120        &self,
121        device: &dyn hal::DynDevice,
122        limits: &Limits,
123        buffer_size: u64,
124        buffer: &dyn hal::DynBuffer,
125    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
126        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
127        let Some(binding_size) = NonZeroU64::new(binding_size) else {
128            return Ok(None);
129        };
130        let hal_desc = hal::BindGroupDescriptor {
131            label: None,
132            layout: self.src_bind_group_layout.as_ref(),
133            entries: &[hal::BindGroupEntry {
134                binding: 0,
135                resource_index: 0,
136                count: 1,
137            }],
138            buffers: &[hal::BufferBinding {
139                buffer,
140                offset: 0,
141                size: Some(binding_size),
142            }],
143            samplers: &[],
144            textures: &[],
145            acceleration_structures: &[],
146        };
147        unsafe {
148            device
149                .create_bind_group(&hal_desc)
150                .map(Some)
151                .map_err(DeviceError::from_hal)
152        }
153    }
154
155    fn acquire_dst_entry(
156        &self,
157        device: &dyn hal::DynDevice,
158    ) -> Result<BufferPoolEntry, hal::DeviceError> {
159        let mut free_buffers = self.free_indirect_entries.lock();
160        match free_buffers.pop() {
161            Some(buffer) => Ok(buffer),
162            None => {
163                let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
164                create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
165            }
166        }
167    }
168
169    fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
170        self.free_indirect_entries.lock().extend(entries);
171    }
172
173    fn acquire_metadata_entry(
174        &self,
175        device: &dyn hal::DynDevice,
176    ) -> Result<BufferPoolEntry, hal::DeviceError> {
177        let mut free_buffers = self.free_metadata_entries.lock();
178        match free_buffers.pop() {
179            Some(buffer) => Ok(buffer),
180            None => {
181                let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
182                create_buffer_and_bind_group(
183                    device,
184                    usage,
185                    self.metadata_bind_group_layout.as_ref(),
186                )
187            }
188        }
189    }
190
191    fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
192        self.free_metadata_entries.lock().extend(entries);
193    }
194
195    /// Injects a compute pass that will validate all indirect draws in the current render pass.
196    pub(crate) fn inject_validation_pass(
197        &self,
198        device: &Arc<Device>,
199        snatch_guard: &SnatchGuard,
200        resources: &mut DrawResources,
201        temp_resources: &mut Vec<TempResource>,
202        encoder: &mut dyn hal::DynCommandEncoder,
203        batcher: DrawBatcher,
204    ) -> Result<(), DeviceError> {
205        let mut batches = batcher.batches;
206
207        if batches.is_empty() {
208            return Ok(());
209        }
210
211        let max_staging_buffer_size = 1 << 26; // ~67MiB
212
213        let mut staging_buffers = Vec::new();
214
215        let mut current_size = 0;
216        for batch in batches.values_mut() {
217            let data = batch.metadata();
218            let offset = if current_size + data.len() > max_staging_buffer_size {
219                let staging_buffer =
220                    StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
221                staging_buffers.push(staging_buffer);
222                current_size = data.len();
223                0
224            } else {
225                let offset = current_size;
226                current_size += data.len();
227                offset as u64
228            };
229            batch.staging_buffer_index = staging_buffers.len();
230            batch.staging_buffer_offset = offset;
231        }
232        if current_size != 0 {
233            let staging_buffer =
234                StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
235            staging_buffers.push(staging_buffer);
236        }
237
238        for batch in batches.values() {
239            let data = batch.metadata();
240            let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
241            unsafe {
242                staging_buffer.write_with_offset(
243                    data,
244                    0,
245                    batch.staging_buffer_offset as isize,
246                    data.len(),
247                )
248            };
249        }
250
251        let staging_buffers: Vec<_> = staging_buffers
252            .into_iter()
253            .map(|buffer| buffer.flush())
254            .collect();
255
256        let mut current_metadata_entry = None;
257        for batch in batches.values_mut() {
258            let data = batch.metadata();
259            let (metadata_resource_index, metadata_buffer_offset) =
260                resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
261            batch.metadata_resource_index = metadata_resource_index;
262            batch.metadata_buffer_offset = metadata_buffer_offset;
263        }
264
265        let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
266        let unique_index_scratch = &mut UniqueIndexScratch::new();
267
268        BufferBarriers::new(buffer_barrier_scratch)
269            .extend(
270                batches
271                    .values()
272                    .map(|batch| batch.staging_buffer_index)
273                    .unique(unique_index_scratch)
274                    .map(|index| hal::BufferBarrier {
275                        buffer: staging_buffers[index].raw(),
276                        usage: hal::StateTransition {
277                            from: wgt::BufferUses::MAP_WRITE,
278                            to: wgt::BufferUses::COPY_SRC,
279                        },
280                    }),
281            )
282            .extend(
283                batches
284                    .values()
285                    .map(|batch| batch.metadata_resource_index)
286                    .unique(unique_index_scratch)
287                    .map(|index| hal::BufferBarrier {
288                        buffer: resources.get_metadata_buffer(index),
289                        usage: hal::StateTransition {
290                            from: wgt::BufferUses::STORAGE_READ_ONLY,
291                            to: wgt::BufferUses::COPY_DST,
292                        },
293                    }),
294            )
295            .encode(encoder);
296
297        for batch in batches.values() {
298            let data = batch.metadata();
299            let data_size = NonZeroU64::new(data.len() as u64).unwrap();
300
301            let staging_buffer = &staging_buffers[batch.staging_buffer_index];
302
303            let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
304
305            unsafe {
306                encoder.copy_buffer_to_buffer(
307                    staging_buffer.raw(),
308                    metadata_buffer,
309                    &[hal::BufferCopy {
310                        src_offset: batch.staging_buffer_offset,
311                        dst_offset: batch.metadata_buffer_offset,
312                        size: data_size,
313                    }],
314                );
315            }
316        }
317
318        for staging_buffer in staging_buffers {
319            temp_resources.push(TempResource::StagingBuffer(staging_buffer));
320        }
321
322        BufferBarriers::new(buffer_barrier_scratch)
323            .extend(
324                batches
325                    .values()
326                    .map(|batch| batch.metadata_resource_index)
327                    .unique(unique_index_scratch)
328                    .map(|index| hal::BufferBarrier {
329                        buffer: resources.get_metadata_buffer(index),
330                        usage: hal::StateTransition {
331                            from: wgt::BufferUses::COPY_DST,
332                            to: wgt::BufferUses::STORAGE_READ_ONLY,
333                        },
334                    }),
335            )
336            .extend(
337                batches
338                    .values()
339                    .map(|batch| batch.dst_resource_index)
340                    .unique(unique_index_scratch)
341                    .map(|index| hal::BufferBarrier {
342                        buffer: resources.get_dst_buffer(index),
343                        usage: hal::StateTransition {
344                            from: wgt::BufferUses::INDIRECT,
345                            to: wgt::BufferUses::STORAGE_READ_WRITE,
346                        },
347                    }),
348            )
349            .encode(encoder);
350
351        let desc = hal::ComputePassDescriptor {
352            label: None,
353            timestamp_writes: None,
354        };
355        unsafe {
356            encoder.begin_compute_pass(&desc);
357        }
358        unsafe {
359            encoder.set_compute_pipeline(self.pipeline.as_ref());
360        }
361
362        for batch in batches.values() {
363            let pipeline_layout = self.pipeline_layout.as_ref();
364
365            let metadata_start =
366                (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
367            let metadata_count = batch.entries.len() as u32;
368            unsafe {
369                encoder.set_push_constants(
370                    pipeline_layout,
371                    wgt::ShaderStages::COMPUTE,
372                    0,
373                    &[metadata_start, metadata_count],
374                );
375            }
376
377            let metadata_bind_group =
378                resources.get_metadata_bind_group(batch.metadata_resource_index);
379            unsafe {
380                encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
381            }
382
383            let src_bind_group = batch
384                .src_buffer
385                .indirect_validation_bind_groups
386                .get(snatch_guard)
387                .unwrap()
388                .draw
389                .as_ref();
390            unsafe {
391                encoder.set_bind_group(
392                    pipeline_layout,
393                    1,
394                    Some(src_bind_group),
395                    &[batch.src_dynamic_offset as u32],
396                );
397            }
398
399            let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
400            unsafe {
401                encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
402            }
403
404            unsafe {
405                encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
406            }
407        }
408
409        unsafe {
410            encoder.end_compute_pass();
411        }
412
413        BufferBarriers::new(buffer_barrier_scratch)
414            .extend(
415                batches
416                    .values()
417                    .map(|batch| batch.dst_resource_index)
418                    .unique(unique_index_scratch)
419                    .map(|index| hal::BufferBarrier {
420                        buffer: resources.get_dst_buffer(index),
421                        usage: hal::StateTransition {
422                            from: wgt::BufferUses::STORAGE_READ_WRITE,
423                            to: wgt::BufferUses::INDIRECT,
424                        },
425                    }),
426            )
427            .encode(encoder);
428
429        Ok(())
430    }
431
432    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
433        let Draw {
434            module,
435            metadata_bind_group_layout,
436            src_bind_group_layout,
437            dst_bind_group_layout,
438            pipeline_layout,
439            pipeline,
440
441            free_indirect_entries,
442            free_metadata_entries,
443        } = self;
444
445        for entry in free_indirect_entries.into_inner().drain(..) {
446            unsafe {
447                device.destroy_bind_group(entry.bind_group);
448                device.destroy_buffer(entry.buffer);
449            }
450        }
451
452        for entry in free_metadata_entries.into_inner().drain(..) {
453            unsafe {
454                device.destroy_bind_group(entry.bind_group);
455                device.destroy_buffer(entry.buffer);
456            }
457        }
458
459        unsafe {
460            device.destroy_compute_pipeline(pipeline);
461            device.destroy_pipeline_layout(pipeline_layout);
462            device.destroy_bind_group_layout(metadata_bind_group_layout);
463            device.destroy_bind_group_layout(src_bind_group_layout);
464            device.destroy_bind_group_layout(dst_bind_group_layout);
465            device.destroy_shader_module(module);
466        }
467    }
468}
469
470fn create_validation_module(
471    device: &dyn hal::DynDevice,
472) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
473    let src = include_str!("./validate_draw.wgsl");
474
475    #[cfg(feature = "wgsl")]
476    let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
477        CreateShaderModuleError::Parsing(naga::error::ShaderError {
478            source: src.to_string(),
479            label: None,
480            inner: Box::new(inner),
481        })
482    })?;
483    #[cfg(not(feature = "wgsl"))]
484    #[allow(clippy::diverging_sub_expression)]
485    let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
486
487    let info = crate::device::create_validator(
488        wgt::Features::PUSH_CONSTANTS,
489        wgt::DownlevelFlags::empty(),
490        naga::valid::ValidationFlags::all(),
491    )
492    .validate(&module)
493    .map_err(|inner| {
494        CreateShaderModuleError::Validation(naga::error::ShaderError {
495            source: src.to_string(),
496            label: None,
497            inner: Box::new(inner),
498        })
499    })?;
500    let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
501        module: alloc::borrow::Cow::Owned(module),
502        info,
503        debug_source: None,
504    });
505    let hal_desc = hal::ShaderModuleDescriptor {
506        label: None,
507        runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
508    };
509    let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
510        |error| match error {
511            hal::ShaderError::Device(error) => {
512                CreateShaderModuleError::Device(DeviceError::from_hal(error))
513            }
514            hal::ShaderError::Compilation(ref msg) => {
515                log::error!("Shader error: {}", msg);
516                CreateShaderModuleError::Generation
517            }
518        },
519    )?;
520
521    Ok(module)
522}
523
524fn create_validation_pipeline(
525    device: &dyn hal::DynDevice,
526    module: &dyn hal::DynShaderModule,
527    pipeline_layout: &dyn hal::DynPipelineLayout,
528    supports_indirect_first_instance: bool,
529    write_d3d12_special_constants: bool,
530) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
531    let pipeline_desc = hal::ComputePipelineDescriptor {
532        label: None,
533        layout: pipeline_layout,
534        stage: hal::ProgrammableStage {
535            module,
536            entry_point: "main",
537            constants: &hashbrown::HashMap::from([
538                (
539                    "supports_indirect_first_instance".to_string(),
540                    f64::from(supports_indirect_first_instance),
541                ),
542                (
543                    "write_d3d12_special_constants".to_string(),
544                    f64::from(write_d3d12_special_constants),
545                ),
546            ]),
547            zero_initialize_workgroup_memory: false,
548        },
549        cache: None,
550    };
551    let pipeline =
552        unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
553            hal::PipelineError::Device(error) => {
554                CreateComputePipelineError::Device(DeviceError::from_hal(error))
555            }
556            hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
557            hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
558                crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
559            ),
560            hal::PipelineError::PipelineConstants(_, error) => {
561                CreateComputePipelineError::PipelineConstants(error)
562            }
563        })?;
564
565    Ok(pipeline)
566}
567
568fn create_bind_group_layout(
569    device: &dyn hal::DynDevice,
570    read_only: bool,
571    has_dynamic_offset: bool,
572    min_binding_size: wgt::BufferSize,
573) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
574    let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
575        label: None,
576        flags: hal::BindGroupLayoutFlags::empty(),
577        entries: &[wgt::BindGroupLayoutEntry {
578            binding: 0,
579            visibility: wgt::ShaderStages::COMPUTE,
580            ty: wgt::BindingType::Buffer {
581                ty: wgt::BufferBindingType::Storage { read_only },
582                has_dynamic_offset,
583                min_binding_size: Some(min_binding_size),
584            },
585            count: None,
586        }],
587    };
588    let bind_group_layout = unsafe {
589        device
590            .create_bind_group_layout(&bind_group_layout_desc)
591            .map_err(DeviceError::from_hal)?
592    };
593
594    Ok(bind_group_layout)
595}
596
597/// Returns the largest binding size that when combined with dynamic offsets can address the whole buffer.
598fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
599    let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
600    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
601
602    if buffer_size <= max_storage_buffer_binding_size {
603        buffer_size
604    } else {
605        let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
606        let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
607
608        // Can the buffer remainder fit in the binding remainder?
609        // If so, align max binding size and add buffer remainder
610        if buffer_rem <= binding_rem {
611            max_storage_buffer_binding_size - binding_rem + buffer_rem
612        }
613        // If not, align max binding size, shorten it by a chunk and add buffer remainder
614        else {
615            max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
616                + buffer_rem
617        }
618    }
619}
620
621/// Splits the given `offset` into a dynamic offset & offset.
622fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
623    let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
624
625    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
626
627    let chunk_adjustment = match min_storage_buffer_offset_alignment {
628        // No need to adjust since the src_offset is 4 byte aligned.
629        4 => 0,
630        // With 16/20 bytes of data we can straddle up to 2 8 byte boundaries:
631        //  - 16 bytes of data: (4|8|4)
632        //  - 20 bytes of data: (4|8|8, 8|8|4)
633        8 => 2,
634        // With 16/20 bytes of data we can straddle up to 1 16+ byte boundary:
635        //  - 16 bytes of data: (4|12, 8|8, 12|4)
636        //  - 20 bytes of data: (4|16, 8|12, 12|8, 16|4)
637        16.. => 1,
638        _ => unreachable!(),
639    };
640
641    let chunks = binding_size / min_storage_buffer_offset_alignment;
642    let dynamic_offset_stride =
643        chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
644
645    if dynamic_offset_stride == 0 {
646        return (0, offset);
647    }
648
649    let max_dynamic_offset = buffer_size - binding_size;
650    let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
651
652    let src_dynamic_offset_index = offset / dynamic_offset_stride;
653
654    let src_dynamic_offset =
655        src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
656    let src_offset = offset - src_dynamic_offset;
657
658    (src_dynamic_offset, src_offset)
659}
660
661#[derive(Debug)]
662struct BufferPoolEntry {
663    buffer: Box<dyn hal::DynBuffer>,
664    bind_group: Box<dyn hal::DynBindGroup>,
665}
666
667fn create_buffer_and_bind_group(
668    device: &dyn hal::DynDevice,
669    usage: wgt::BufferUses,
670    bind_group_layout: &dyn hal::DynBindGroupLayout,
671) -> Result<BufferPoolEntry, hal::DeviceError> {
672    let buffer_desc = hal::BufferDescriptor {
673        label: None,
674        size: BUFFER_SIZE.get(),
675        usage,
676        memory_flags: hal::MemoryFlags::empty(),
677    };
678    let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
679    let bind_group_desc = hal::BindGroupDescriptor {
680        label: None,
681        layout: bind_group_layout,
682        entries: &[hal::BindGroupEntry {
683            binding: 0,
684            resource_index: 0,
685            count: 1,
686        }],
687        buffers: &[hal::BufferBinding {
688            buffer: buffer.as_ref(),
689            offset: 0,
690            size: Some(BUFFER_SIZE),
691        }],
692        samplers: &[],
693        textures: &[],
694        acceleration_structures: &[],
695    };
696    let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
697    Ok(BufferPoolEntry { buffer, bind_group })
698}
699
700#[derive(Clone)]
701struct CurrentEntry {
702    index: usize,
703    offset: u64,
704}
705
706/// Holds all command buffer-level resources that are needed to validate indirect draws.
707pub(crate) struct DrawResources {
708    device: Arc<Device>,
709    dst_entries: Vec<BufferPoolEntry>,
710    metadata_entries: Vec<BufferPoolEntry>,
711}
712
713impl Drop for DrawResources {
714    fn drop(&mut self) {
715        if let Some(ref indirect_validation) = self.device.indirect_validation {
716            let indirect_draw_validation = &indirect_validation.draw;
717            indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
718            indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
719        }
720    }
721}
722
723impl DrawResources {
724    pub(crate) fn new(device: Arc<Device>) -> Self {
725        DrawResources {
726            device,
727            dst_entries: Vec::new(),
728            metadata_entries: Vec::new(),
729        }
730    }
731
732    pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
733        self.dst_entries.get(index).unwrap().buffer.as_ref()
734    }
735
736    fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
737        self.dst_entries.get(index).unwrap().bind_group.as_ref()
738    }
739
740    fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
741        self.metadata_entries.get(index).unwrap().buffer.as_ref()
742    }
743
744    fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
745        self.metadata_entries
746            .get(index)
747            .unwrap()
748            .bind_group
749            .as_ref()
750    }
751
752    fn get_dst_subrange(
753        &mut self,
754        size: u64,
755        current_entry: &mut Option<CurrentEntry>,
756    ) -> Result<(usize, u64), DeviceError> {
757        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
758        let ensure_entry = |index: usize| {
759            if self.dst_entries.len() <= index {
760                let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
761                self.dst_entries.push(entry);
762            }
763            Ok(())
764        };
765        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
766        Ok((entry_data.index, entry_data.offset))
767    }
768
769    fn get_metadata_subrange(
770        &mut self,
771        size: u64,
772        current_entry: &mut Option<CurrentEntry>,
773    ) -> Result<(usize, u64), DeviceError> {
774        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
775        let ensure_entry = |index: usize| {
776            if self.metadata_entries.len() <= index {
777                let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
778                self.metadata_entries.push(entry);
779            }
780            Ok(())
781        };
782        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
783        Ok((entry_data.index, entry_data.offset))
784    }
785
786    fn get_subrange_impl(
787        ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
788        current_entry: &mut Option<CurrentEntry>,
789        size: u64,
790    ) -> Result<CurrentEntry, DeviceError> {
791        let index = if let Some(current_entry) = current_entry.as_mut() {
792            if current_entry.offset + size <= BUFFER_SIZE.get() {
793                let entry_data = current_entry.clone();
794                current_entry.offset += size;
795                return Ok(entry_data);
796            } else {
797                current_entry.index + 1
798            }
799        } else {
800            0
801        };
802
803        ensure_entry(index).map_err(DeviceError::from_hal)?;
804
805        let entry_data = CurrentEntry { index, offset: 0 };
806
807        *current_entry = Some(CurrentEntry {
808            index,
809            offset: size,
810        });
811
812        Ok(entry_data)
813    }
814}
815
816/// This must match the `MetadataEntry` struct used by the shader.
817#[repr(C)]
818struct MetadataEntry {
819    src_offset: u32,
820    dst_offset: u32,
821    vertex_or_index_limit: u32,
822    instance_limit: u32,
823}
824
825impl MetadataEntry {
826    fn new(
827        indexed: bool,
828        src_offset: u64,
829        dst_offset: u64,
830        vertex_or_index_limit: u64,
831        instance_limit: u64,
832    ) -> Self {
833        debug_assert_eq!(
834            4,
835            size_of_val(&Limits::default().max_storage_buffer_binding_size)
836        );
837
838        let src_offset = src_offset as u32; // max_storage_buffer_binding_size is a u32
839        let src_offset = src_offset / 4; // translate byte offset to offset in u32's
840
841        // `src_offset` needs at most 30 bits,
842        // pack `indexed` in bit 31 of `src_offset`
843        let src_offset = src_offset | ((indexed as u32) << 31);
844
845        // max value for limits since first_X and X_count indirect draw arguments are u32
846        let max_limit = u32::MAX as u64 + u32::MAX as u64; // 1 11111111 11111111 11111111 11111110
847
848        let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
849        let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; // extract bit 32
850        let vertex_or_index_limit = vertex_or_index_limit as u32; // truncate the limit to a u32
851
852        let instance_limit = instance_limit.min(max_limit);
853        let instance_limit_bit_32 = (instance_limit >> 32) as u32; // extract bit 32
854        let instance_limit = instance_limit as u32; // truncate the limit to a u32
855
856        let dst_offset = dst_offset as u32; // max_storage_buffer_binding_size is a u32
857        let dst_offset = dst_offset / 4; // translate byte offset to offset in u32's
858
859        // `dst_offset` needs at most 30 bits,
860        // pack `vertex_or_index_limit_bit_32` in bit 30 of `dst_offset` and
861        // pack `instance_limit_bit_32` in bit 31 of `dst_offset`
862        let dst_offset =
863            dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
864
865        Self {
866            src_offset,
867            dst_offset,
868            vertex_or_index_limit,
869            instance_limit,
870        }
871    }
872}
873
874struct DrawIndirectValidationBatch {
875    src_buffer: Arc<crate::resource::Buffer>,
876    src_dynamic_offset: u64,
877    dst_resource_index: usize,
878    entries: Vec<MetadataEntry>,
879
880    staging_buffer_index: usize,
881    staging_buffer_offset: u64,
882    metadata_resource_index: usize,
883    metadata_buffer_offset: u64,
884}
885
886impl DrawIndirectValidationBatch {
887    /// Data to be written to the metadata buffer.
888    fn metadata(&self) -> &[u8] {
889        unsafe {
890            core::slice::from_raw_parts(
891                self.entries.as_ptr().cast::<u8>(),
892                self.entries.len() * size_of::<MetadataEntry>(),
893            )
894        }
895    }
896}
897
898/// Accumulates all needed data needed to validate indirect draws.
899pub(crate) struct DrawBatcher {
900    batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
901    current_dst_entry: Option<CurrentEntry>,
902}
903
904impl DrawBatcher {
905    pub(crate) fn new() -> Self {
906        Self {
907            batches: FastHashMap::default(),
908            current_dst_entry: None,
909        }
910    }
911
912    /// Add an indirect draw to be validated.
913    ///
914    /// Returns the index of the indirect buffer in `indirect_draw_validation_resources`
915    /// and the offset to be used for the draw.
916    pub(crate) fn add<'a>(
917        &mut self,
918        indirect_draw_validation_resources: &'a mut DrawResources,
919        device: &Device,
920        src_buffer: &Arc<crate::resource::Buffer>,
921        offset: u64,
922        indexed: bool,
923        vertex_or_index_limit: u64,
924        instance_limit: u64,
925    ) -> Result<(usize, u64), DeviceError> {
926        // space for D3D12 special constants
927        let extra = if device.backend() == wgt::Backend::Dx12 {
928            3 * size_of::<u32>() as u64
929        } else {
930            0
931        };
932        let stride = extra + crate::command::get_stride_of_indirect_args(indexed);
933
934        let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
935            .get_dst_subrange(stride, &mut self.current_dst_entry)?;
936
937        let buffer_size = src_buffer.size;
938        let limits = device.adapter.limits();
939        let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
940
941        let src_buffer_tracker_index = src_buffer.tracker_index();
942
943        let entry = MetadataEntry::new(
944            indexed,
945            src_offset,
946            dst_offset,
947            vertex_or_index_limit,
948            instance_limit,
949        );
950
951        match self.batches.entry((
952            src_buffer_tracker_index,
953            src_dynamic_offset,
954            dst_resource_index,
955        )) {
956            hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
957                occupied_entry.get_mut().entries.push(entry)
958            }
959            hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
960                vacant_entry.insert(DrawIndirectValidationBatch {
961                    src_buffer: src_buffer.clone(),
962                    src_dynamic_offset,
963                    dst_resource_index,
964                    entries: vec![entry],
965
966                    // these will be initialized once we accumulated all entries for the batch
967                    staging_buffer_index: 0,
968                    staging_buffer_offset: 0,
969                    metadata_resource_index: 0,
970                    metadata_buffer_offset: 0,
971                });
972            }
973        }
974
975        Ok((dst_resource_index, dst_offset))
976    }
977}