wgpu_core/indirect_validation/
draw.rs

1use super::{
2    utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3    CreateIndirectValidationPipelineError,
4};
5use crate::{
6    device::{queue::TempResource, Device, DeviceError},
7    lock::{rank, Mutex},
8    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
9    resource::{StagingBuffer, Trackable},
10    snatch::SnatchGuard,
11    track::TrackerIndex,
12    FastHashMap,
13};
14use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
15use core::{
16    mem::{size_of, size_of_val},
17    num::NonZeroU64,
18};
19use wgt::Limits;
20
21/// Note: This needs to be under:
22///
23/// default max_compute_workgroups_per_dimension * size_of::<wgt::DrawIndirectArgs>() * `workgroup_size` used by the shader
24///
25/// = (2^16 - 1) * 2^4 * 2^6
26///
27/// It is currently set to:
28///
29/// = (2^16 - 1) * 2^4
30///
31/// This is enough space for:
32///
33/// - 65535 [`wgt::DrawIndirectArgs`] / [`MetadataEntry`]
34/// - 52428 [`wgt::DrawIndexedIndirectArgs`]
35const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
36
37/// Holds all device-level resources that are needed to validate indirect draws.
38///
39/// This machinery requires the following limits:
40///
41/// - max_bind_groups: 3,
42/// - max_dynamic_storage_buffers_per_pipeline_layout: 1,
43/// - max_storage_buffers_per_shader_stage: 3,
44/// - max_push_constant_size: 8,
45///
46/// These are all indirectly satisfied by `DownlevelFlags::INDIRECT_EXECUTION`, which is also
47/// required for this module's functionality to work.
48#[derive(Debug)]
49pub(crate) struct Draw {
50    module: Box<dyn hal::DynShaderModule>,
51    metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52    src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53    dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
55    pipeline: Box<dyn hal::DynComputePipeline>,
56
57    free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
58    free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
59}
60
61impl Draw {
62    pub(super) fn new(
63        device: &dyn hal::DynDevice,
64        required_features: &wgt::Features,
65    ) -> Result<Self, CreateIndirectValidationPipelineError> {
66        let module = create_validation_module(device)?;
67
68        let metadata_bind_group_layout =
69            create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
70        let src_bind_group_layout =
71            create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
72        let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
73
74        let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
75            label: None,
76            flags: hal::PipelineLayoutFlags::empty(),
77            bind_group_layouts: &[
78                metadata_bind_group_layout.as_ref(),
79                src_bind_group_layout.as_ref(),
80                dst_bind_group_layout.as_ref(),
81            ],
82            push_constant_ranges: &[wgt::PushConstantRange {
83                stages: wgt::ShaderStages::COMPUTE,
84                range: 0..8,
85            }],
86        };
87        let pipeline_layout = unsafe {
88            device
89                .create_pipeline_layout(&pipeline_layout_desc)
90                .map_err(DeviceError::from_hal)?
91        };
92
93        let supports_indirect_first_instance =
94            required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
95        let pipeline = create_validation_pipeline(
96            device,
97            module.as_ref(),
98            pipeline_layout.as_ref(),
99            supports_indirect_first_instance,
100        )?;
101
102        Ok(Self {
103            module,
104            metadata_bind_group_layout,
105            src_bind_group_layout,
106            dst_bind_group_layout,
107            pipeline_layout,
108            pipeline,
109
110            free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
111            free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
112        })
113    }
114
115    /// `Ok(None)` will only be returned if `buffer_size` is `0`.
116    pub(super) fn create_src_bind_group(
117        &self,
118        device: &dyn hal::DynDevice,
119        limits: &Limits,
120        buffer_size: u64,
121        buffer: &dyn hal::DynBuffer,
122    ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
123        let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
124        let Some(binding_size) = NonZeroU64::new(binding_size) else {
125            return Ok(None);
126        };
127        let hal_desc = hal::BindGroupDescriptor {
128            label: None,
129            layout: self.src_bind_group_layout.as_ref(),
130            entries: &[hal::BindGroupEntry {
131                binding: 0,
132                resource_index: 0,
133                count: 1,
134            }],
135            buffers: &[hal::BufferBinding {
136                buffer,
137                offset: 0,
138                size: Some(binding_size),
139            }],
140            samplers: &[],
141            textures: &[],
142            acceleration_structures: &[],
143        };
144        unsafe {
145            device
146                .create_bind_group(&hal_desc)
147                .map(Some)
148                .map_err(DeviceError::from_hal)
149        }
150    }
151
152    fn acquire_dst_entry(
153        &self,
154        device: &dyn hal::DynDevice,
155    ) -> Result<BufferPoolEntry, hal::DeviceError> {
156        let mut free_buffers = self.free_indirect_entries.lock();
157        match free_buffers.pop() {
158            Some(buffer) => Ok(buffer),
159            None => {
160                let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
161                create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
162            }
163        }
164    }
165
166    fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
167        self.free_indirect_entries.lock().extend(entries);
168    }
169
170    fn acquire_metadata_entry(
171        &self,
172        device: &dyn hal::DynDevice,
173    ) -> Result<BufferPoolEntry, hal::DeviceError> {
174        let mut free_buffers = self.free_metadata_entries.lock();
175        match free_buffers.pop() {
176            Some(buffer) => Ok(buffer),
177            None => {
178                let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
179                create_buffer_and_bind_group(
180                    device,
181                    usage,
182                    self.metadata_bind_group_layout.as_ref(),
183                )
184            }
185        }
186    }
187
188    fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
189        self.free_metadata_entries.lock().extend(entries);
190    }
191
192    /// Injects a compute pass that will validate all indirect draws in the current render pass.
193    pub(crate) fn inject_validation_pass(
194        &self,
195        device: &Arc<Device>,
196        snatch_guard: &SnatchGuard,
197        resources: &mut DrawResources,
198        temp_resources: &mut Vec<TempResource>,
199        encoder: &mut dyn hal::DynCommandEncoder,
200        batcher: DrawBatcher,
201    ) -> Result<(), DeviceError> {
202        let mut batches = batcher.batches;
203
204        if batches.is_empty() {
205            return Ok(());
206        }
207
208        let max_staging_buffer_size = 1 << 26; // ~67MiB
209
210        let mut staging_buffers = Vec::new();
211
212        let mut current_size = 0;
213        for batch in batches.values_mut() {
214            let data = batch.metadata();
215            let offset = if current_size + data.len() > max_staging_buffer_size {
216                let staging_buffer =
217                    StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
218                staging_buffers.push(staging_buffer);
219                current_size = data.len();
220                0
221            } else {
222                let offset = current_size;
223                current_size += data.len();
224                offset as u64
225            };
226            batch.staging_buffer_index = staging_buffers.len();
227            batch.staging_buffer_offset = offset;
228        }
229        if current_size != 0 {
230            let staging_buffer =
231                StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
232            staging_buffers.push(staging_buffer);
233        }
234
235        for batch in batches.values() {
236            let data = batch.metadata();
237            let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
238            unsafe {
239                staging_buffer.write_with_offset(
240                    data,
241                    0,
242                    batch.staging_buffer_offset as isize,
243                    data.len(),
244                )
245            };
246        }
247
248        let staging_buffers: Vec<_> = staging_buffers
249            .into_iter()
250            .map(|buffer| buffer.flush())
251            .collect();
252
253        let mut current_metadata_entry = None;
254        for batch in batches.values_mut() {
255            let data = batch.metadata();
256            let (metadata_resource_index, metadata_buffer_offset) =
257                resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
258            batch.metadata_resource_index = metadata_resource_index;
259            batch.metadata_buffer_offset = metadata_buffer_offset;
260        }
261
262        let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
263        let unique_index_scratch = &mut UniqueIndexScratch::new();
264
265        BufferBarriers::new(buffer_barrier_scratch)
266            .extend(
267                batches
268                    .values()
269                    .map(|batch| batch.staging_buffer_index)
270                    .unique(unique_index_scratch)
271                    .map(|index| hal::BufferBarrier {
272                        buffer: staging_buffers[index].raw(),
273                        usage: hal::StateTransition {
274                            from: wgt::BufferUses::MAP_WRITE,
275                            to: wgt::BufferUses::COPY_SRC,
276                        },
277                    }),
278            )
279            .extend(
280                batches
281                    .values()
282                    .map(|batch| batch.metadata_resource_index)
283                    .unique(unique_index_scratch)
284                    .map(|index| hal::BufferBarrier {
285                        buffer: resources.get_metadata_buffer(index),
286                        usage: hal::StateTransition {
287                            from: wgt::BufferUses::STORAGE_READ_ONLY,
288                            to: wgt::BufferUses::COPY_DST,
289                        },
290                    }),
291            )
292            .encode(encoder);
293
294        for batch in batches.values() {
295            let data = batch.metadata();
296            let data_size = NonZeroU64::new(data.len() as u64).unwrap();
297
298            let staging_buffer = &staging_buffers[batch.staging_buffer_index];
299
300            let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
301
302            unsafe {
303                encoder.copy_buffer_to_buffer(
304                    staging_buffer.raw(),
305                    metadata_buffer,
306                    &[hal::BufferCopy {
307                        src_offset: batch.staging_buffer_offset,
308                        dst_offset: batch.metadata_buffer_offset,
309                        size: data_size,
310                    }],
311                );
312            }
313        }
314
315        for staging_buffer in staging_buffers {
316            temp_resources.push(TempResource::StagingBuffer(staging_buffer));
317        }
318
319        BufferBarriers::new(buffer_barrier_scratch)
320            .extend(
321                batches
322                    .values()
323                    .map(|batch| batch.metadata_resource_index)
324                    .unique(unique_index_scratch)
325                    .map(|index| hal::BufferBarrier {
326                        buffer: resources.get_metadata_buffer(index),
327                        usage: hal::StateTransition {
328                            from: wgt::BufferUses::COPY_DST,
329                            to: wgt::BufferUses::STORAGE_READ_ONLY,
330                        },
331                    }),
332            )
333            .extend(
334                batches
335                    .values()
336                    .map(|batch| batch.dst_resource_index)
337                    .unique(unique_index_scratch)
338                    .map(|index| hal::BufferBarrier {
339                        buffer: resources.get_dst_buffer(index),
340                        usage: hal::StateTransition {
341                            from: wgt::BufferUses::INDIRECT,
342                            to: wgt::BufferUses::STORAGE_READ_WRITE,
343                        },
344                    }),
345            )
346            .encode(encoder);
347
348        let desc = hal::ComputePassDescriptor {
349            label: None,
350            timestamp_writes: None,
351        };
352        unsafe {
353            encoder.begin_compute_pass(&desc);
354        }
355        unsafe {
356            encoder.set_compute_pipeline(self.pipeline.as_ref());
357        }
358
359        for batch in batches.values() {
360            let pipeline_layout = self.pipeline_layout.as_ref();
361
362            let metadata_start =
363                (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
364            let metadata_count = batch.entries.len() as u32;
365            unsafe {
366                encoder.set_push_constants(
367                    pipeline_layout,
368                    wgt::ShaderStages::COMPUTE,
369                    0,
370                    &[metadata_start, metadata_count],
371                );
372            }
373
374            let metadata_bind_group =
375                resources.get_metadata_bind_group(batch.metadata_resource_index);
376            unsafe {
377                encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
378            }
379
380            let src_bind_group = batch
381                .src_buffer
382                .indirect_validation_bind_groups
383                .get(snatch_guard)
384                .unwrap()
385                .draw
386                .as_ref();
387            unsafe {
388                encoder.set_bind_group(
389                    pipeline_layout,
390                    1,
391                    Some(src_bind_group),
392                    &[batch.src_dynamic_offset as u32],
393                );
394            }
395
396            let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
397            unsafe {
398                encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
399            }
400
401            unsafe {
402                encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
403            }
404        }
405
406        unsafe {
407            encoder.end_compute_pass();
408        }
409
410        BufferBarriers::new(buffer_barrier_scratch)
411            .extend(
412                batches
413                    .values()
414                    .map(|batch| batch.dst_resource_index)
415                    .unique(unique_index_scratch)
416                    .map(|index| hal::BufferBarrier {
417                        buffer: resources.get_dst_buffer(index),
418                        usage: hal::StateTransition {
419                            from: wgt::BufferUses::STORAGE_READ_WRITE,
420                            to: wgt::BufferUses::INDIRECT,
421                        },
422                    }),
423            )
424            .encode(encoder);
425
426        Ok(())
427    }
428
429    pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
430        let Draw {
431            module,
432            metadata_bind_group_layout,
433            src_bind_group_layout,
434            dst_bind_group_layout,
435            pipeline_layout,
436            pipeline,
437
438            free_indirect_entries,
439            free_metadata_entries,
440        } = self;
441
442        for entry in free_indirect_entries.into_inner().drain(..) {
443            unsafe {
444                device.destroy_bind_group(entry.bind_group);
445                device.destroy_buffer(entry.buffer);
446            }
447        }
448
449        for entry in free_metadata_entries.into_inner().drain(..) {
450            unsafe {
451                device.destroy_bind_group(entry.bind_group);
452                device.destroy_buffer(entry.buffer);
453            }
454        }
455
456        unsafe {
457            device.destroy_compute_pipeline(pipeline);
458            device.destroy_pipeline_layout(pipeline_layout);
459            device.destroy_bind_group_layout(metadata_bind_group_layout);
460            device.destroy_bind_group_layout(src_bind_group_layout);
461            device.destroy_bind_group_layout(dst_bind_group_layout);
462            device.destroy_shader_module(module);
463        }
464    }
465}
466
467fn create_validation_module(
468    device: &dyn hal::DynDevice,
469) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
470    let src = include_str!("./validate_draw.wgsl");
471
472    #[cfg(feature = "wgsl")]
473    let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
474        CreateShaderModuleError::Parsing(naga::error::ShaderError {
475            source: src.to_string(),
476            label: None,
477            inner: Box::new(inner),
478        })
479    })?;
480    #[cfg(not(feature = "wgsl"))]
481    #[allow(clippy::diverging_sub_expression)]
482    let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
483
484    let info = crate::device::create_validator(
485        wgt::Features::PUSH_CONSTANTS,
486        wgt::DownlevelFlags::empty(),
487        naga::valid::ValidationFlags::all(),
488    )
489    .validate(&module)
490    .map_err(|inner| {
491        CreateShaderModuleError::Validation(naga::error::ShaderError {
492            source: src.to_string(),
493            label: None,
494            inner: Box::new(inner),
495        })
496    })?;
497    let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
498        module: alloc::borrow::Cow::Owned(module),
499        info,
500        debug_source: None,
501    });
502    let hal_desc = hal::ShaderModuleDescriptor {
503        label: None,
504        runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
505    };
506    let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
507        |error| match error {
508            hal::ShaderError::Device(error) => {
509                CreateShaderModuleError::Device(DeviceError::from_hal(error))
510            }
511            hal::ShaderError::Compilation(ref msg) => {
512                log::error!("Shader error: {}", msg);
513                CreateShaderModuleError::Generation
514            }
515        },
516    )?;
517
518    Ok(module)
519}
520
521fn create_validation_pipeline(
522    device: &dyn hal::DynDevice,
523    module: &dyn hal::DynShaderModule,
524    pipeline_layout: &dyn hal::DynPipelineLayout,
525    supports_indirect_first_instance: bool,
526) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
527    let pipeline_desc = hal::ComputePipelineDescriptor {
528        label: None,
529        layout: pipeline_layout,
530        stage: hal::ProgrammableStage {
531            module,
532            entry_point: "main",
533            constants: &hashbrown::HashMap::from([(
534                "supports_indirect_first_instance".to_string(),
535                f64::from(supports_indirect_first_instance),
536            )]),
537            zero_initialize_workgroup_memory: false,
538        },
539        cache: None,
540    };
541    let pipeline =
542        unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
543            hal::PipelineError::Device(error) => {
544                CreateComputePipelineError::Device(DeviceError::from_hal(error))
545            }
546            hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
547            hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
548                crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
549            ),
550            hal::PipelineError::PipelineConstants(_, error) => {
551                CreateComputePipelineError::PipelineConstants(error)
552            }
553        })?;
554
555    Ok(pipeline)
556}
557
558fn create_bind_group_layout(
559    device: &dyn hal::DynDevice,
560    read_only: bool,
561    has_dynamic_offset: bool,
562    min_binding_size: wgt::BufferSize,
563) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
564    let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
565        label: None,
566        flags: hal::BindGroupLayoutFlags::empty(),
567        entries: &[wgt::BindGroupLayoutEntry {
568            binding: 0,
569            visibility: wgt::ShaderStages::COMPUTE,
570            ty: wgt::BindingType::Buffer {
571                ty: wgt::BufferBindingType::Storage { read_only },
572                has_dynamic_offset,
573                min_binding_size: Some(min_binding_size),
574            },
575            count: None,
576        }],
577    };
578    let bind_group_layout = unsafe {
579        device
580            .create_bind_group_layout(&bind_group_layout_desc)
581            .map_err(DeviceError::from_hal)?
582    };
583
584    Ok(bind_group_layout)
585}
586
587/// Returns the largest binding size that when combined with dynamic offsets can address the whole buffer.
588fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
589    let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
590    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
591
592    if buffer_size <= max_storage_buffer_binding_size {
593        buffer_size
594    } else {
595        let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
596        let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
597
598        // Can the buffer remainder fit in the binding remainder?
599        // If so, align max binding size and add buffer remainder
600        if buffer_rem <= binding_rem {
601            max_storage_buffer_binding_size - binding_rem + buffer_rem
602        }
603        // If not, align max binding size, shorten it by a chunk and add buffer remainder
604        else {
605            max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
606                + buffer_rem
607        }
608    }
609}
610
611/// Splits the given `offset` into a dynamic offset & offset.
612fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
613    let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
614
615    let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
616
617    let chunk_adjustment = match min_storage_buffer_offset_alignment {
618        // No need to adjust since the src_offset is 4 byte aligned.
619        4 => 0,
620        // With 16/20 bytes of data we can straddle up to 2 8 byte boundaries:
621        //  - 16 bytes of data: (4|8|4)
622        //  - 20 bytes of data: (4|8|8, 8|8|4)
623        8 => 2,
624        // With 16/20 bytes of data we can straddle up to 1 16+ byte boundary:
625        //  - 16 bytes of data: (4|12, 8|8, 12|4)
626        //  - 20 bytes of data: (4|16, 8|12, 12|8, 16|4)
627        16.. => 1,
628        _ => unreachable!(),
629    };
630
631    let chunks = binding_size / min_storage_buffer_offset_alignment;
632    let dynamic_offset_stride =
633        chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
634
635    if dynamic_offset_stride == 0 {
636        return (0, offset);
637    }
638
639    let max_dynamic_offset = buffer_size - binding_size;
640    let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
641
642    let src_dynamic_offset_index = offset / dynamic_offset_stride;
643
644    let src_dynamic_offset =
645        src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
646    let src_offset = offset - src_dynamic_offset;
647
648    (src_dynamic_offset, src_offset)
649}
650
651#[derive(Debug)]
652struct BufferPoolEntry {
653    buffer: Box<dyn hal::DynBuffer>,
654    bind_group: Box<dyn hal::DynBindGroup>,
655}
656
657fn create_buffer_and_bind_group(
658    device: &dyn hal::DynDevice,
659    usage: wgt::BufferUses,
660    bind_group_layout: &dyn hal::DynBindGroupLayout,
661) -> Result<BufferPoolEntry, hal::DeviceError> {
662    let buffer_desc = hal::BufferDescriptor {
663        label: None,
664        size: BUFFER_SIZE.get(),
665        usage,
666        memory_flags: hal::MemoryFlags::empty(),
667    };
668    let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
669    let bind_group_desc = hal::BindGroupDescriptor {
670        label: None,
671        layout: bind_group_layout,
672        entries: &[hal::BindGroupEntry {
673            binding: 0,
674            resource_index: 0,
675            count: 1,
676        }],
677        buffers: &[hal::BufferBinding {
678            buffer: buffer.as_ref(),
679            offset: 0,
680            size: Some(BUFFER_SIZE),
681        }],
682        samplers: &[],
683        textures: &[],
684        acceleration_structures: &[],
685    };
686    let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
687    Ok(BufferPoolEntry { buffer, bind_group })
688}
689
690#[derive(Clone)]
691struct CurrentEntry {
692    index: usize,
693    offset: u64,
694}
695
696/// Holds all command buffer-level resources that are needed to validate indirect draws.
697pub(crate) struct DrawResources {
698    device: Arc<Device>,
699    dst_entries: Vec<BufferPoolEntry>,
700    metadata_entries: Vec<BufferPoolEntry>,
701}
702
703impl Drop for DrawResources {
704    fn drop(&mut self) {
705        if let Some(ref indirect_validation) = self.device.indirect_validation {
706            let indirect_draw_validation = &indirect_validation.draw;
707            indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
708            indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
709        }
710    }
711}
712
713impl DrawResources {
714    pub(crate) fn new(device: Arc<Device>) -> Self {
715        DrawResources {
716            device,
717            dst_entries: Vec::new(),
718            metadata_entries: Vec::new(),
719        }
720    }
721
722    pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
723        self.dst_entries.get(index).unwrap().buffer.as_ref()
724    }
725
726    fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
727        self.dst_entries.get(index).unwrap().bind_group.as_ref()
728    }
729
730    fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
731        self.metadata_entries.get(index).unwrap().buffer.as_ref()
732    }
733
734    fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
735        self.metadata_entries
736            .get(index)
737            .unwrap()
738            .bind_group
739            .as_ref()
740    }
741
742    fn get_dst_subrange(
743        &mut self,
744        size: u64,
745        current_entry: &mut Option<CurrentEntry>,
746    ) -> Result<(usize, u64), DeviceError> {
747        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
748        let ensure_entry = |index: usize| {
749            if self.dst_entries.len() <= index {
750                let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
751                self.dst_entries.push(entry);
752            }
753            Ok(())
754        };
755        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
756        Ok((entry_data.index, entry_data.offset))
757    }
758
759    fn get_metadata_subrange(
760        &mut self,
761        size: u64,
762        current_entry: &mut Option<CurrentEntry>,
763    ) -> Result<(usize, u64), DeviceError> {
764        let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
765        let ensure_entry = |index: usize| {
766            if self.metadata_entries.len() <= index {
767                let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
768                self.metadata_entries.push(entry);
769            }
770            Ok(())
771        };
772        let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
773        Ok((entry_data.index, entry_data.offset))
774    }
775
776    fn get_subrange_impl(
777        ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
778        current_entry: &mut Option<CurrentEntry>,
779        size: u64,
780    ) -> Result<CurrentEntry, DeviceError> {
781        let index = if let Some(current_entry) = current_entry.as_mut() {
782            if current_entry.offset + size <= BUFFER_SIZE.get() {
783                let entry_data = current_entry.clone();
784                current_entry.offset += size;
785                return Ok(entry_data);
786            } else {
787                current_entry.index + 1
788            }
789        } else {
790            0
791        };
792
793        ensure_entry(index).map_err(DeviceError::from_hal)?;
794
795        let entry_data = CurrentEntry { index, offset: 0 };
796
797        *current_entry = Some(CurrentEntry {
798            index,
799            offset: size,
800        });
801
802        Ok(entry_data)
803    }
804}
805
806/// This must match the `MetadataEntry` struct used by the shader.
807#[repr(C)]
808struct MetadataEntry {
809    src_offset: u32,
810    dst_offset: u32,
811    vertex_or_index_limit: u32,
812    instance_limit: u32,
813}
814
815impl MetadataEntry {
816    fn new(
817        indexed: bool,
818        src_offset: u64,
819        dst_offset: u64,
820        vertex_or_index_limit: u64,
821        instance_limit: u64,
822    ) -> Self {
823        debug_assert_eq!(
824            4,
825            size_of_val(&Limits::default().max_storage_buffer_binding_size)
826        );
827
828        let src_offset = src_offset as u32; // max_storage_buffer_binding_size is a u32
829        let src_offset = src_offset / 4; // translate byte offset to offset in u32's
830
831        // `src_offset` needs at most 30 bits,
832        // pack `indexed` in bit 31 of `src_offset`
833        let src_offset = src_offset | ((indexed as u32) << 31);
834
835        // max value for limits since first_X and X_count indirect draw arguments are u32
836        let max_limit = u32::MAX as u64 + u32::MAX as u64; // 1 11111111 11111111 11111111 11111110
837
838        let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
839        let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; // extract bit 32
840        let vertex_or_index_limit = vertex_or_index_limit as u32; // truncate the limit to a u32
841
842        let instance_limit = instance_limit.min(max_limit);
843        let instance_limit_bit_32 = (instance_limit >> 32) as u32; // extract bit 32
844        let instance_limit = instance_limit as u32; // truncate the limit to a u32
845
846        let dst_offset = dst_offset as u32; // max_storage_buffer_binding_size is a u32
847        let dst_offset = dst_offset / 4; // translate byte offset to offset in u32's
848
849        // `dst_offset` needs at most 30 bits,
850        // pack `vertex_or_index_limit_bit_32` in bit 30 of `dst_offset` and
851        // pack `instance_limit_bit_32` in bit 31 of `dst_offset`
852        let dst_offset =
853            dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
854
855        Self {
856            src_offset,
857            dst_offset,
858            vertex_or_index_limit,
859            instance_limit,
860        }
861    }
862}
863
864struct DrawIndirectValidationBatch {
865    src_buffer: Arc<crate::resource::Buffer>,
866    src_dynamic_offset: u64,
867    dst_resource_index: usize,
868    entries: Vec<MetadataEntry>,
869
870    staging_buffer_index: usize,
871    staging_buffer_offset: u64,
872    metadata_resource_index: usize,
873    metadata_buffer_offset: u64,
874}
875
876impl DrawIndirectValidationBatch {
877    /// Data to be written to the metadata buffer.
878    fn metadata(&self) -> &[u8] {
879        unsafe {
880            core::slice::from_raw_parts(
881                self.entries.as_ptr().cast::<u8>(),
882                self.entries.len() * size_of::<MetadataEntry>(),
883            )
884        }
885    }
886}
887
888/// Accumulates all needed data needed to validate indirect draws.
889pub(crate) struct DrawBatcher {
890    batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
891    current_dst_entry: Option<CurrentEntry>,
892}
893
894impl DrawBatcher {
895    pub(crate) fn new() -> Self {
896        Self {
897            batches: FastHashMap::default(),
898            current_dst_entry: None,
899        }
900    }
901
902    /// Add an indirect draw to be validated.
903    ///
904    /// Returns the index of the indirect buffer in `indirect_draw_validation_resources`
905    /// and the offset to be used for the draw.
906    pub(crate) fn add<'a>(
907        &mut self,
908        indirect_draw_validation_resources: &'a mut DrawResources,
909        device: &Device,
910        src_buffer: &Arc<crate::resource::Buffer>,
911        offset: u64,
912        indexed: bool,
913        vertex_or_index_limit: u64,
914        instance_limit: u64,
915    ) -> Result<(usize, u64), DeviceError> {
916        let stride = crate::command::get_stride_of_indirect_args(indexed);
917
918        let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
919            .get_dst_subrange(stride, &mut self.current_dst_entry)?;
920
921        let buffer_size = src_buffer.size;
922        let limits = device.adapter.limits();
923        let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
924
925        let src_buffer_tracker_index = src_buffer.tracker_index();
926
927        let entry = MetadataEntry::new(
928            indexed,
929            src_offset,
930            dst_offset,
931            vertex_or_index_limit,
932            instance_limit,
933        );
934
935        match self.batches.entry((
936            src_buffer_tracker_index,
937            src_dynamic_offset,
938            dst_resource_index,
939        )) {
940            hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
941                occupied_entry.get_mut().entries.push(entry)
942            }
943            hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
944                vacant_entry.insert(DrawIndirectValidationBatch {
945                    src_buffer: src_buffer.clone(),
946                    src_dynamic_offset,
947                    dst_resource_index,
948                    entries: vec![entry],
949
950                    // these will be initialized once we accumulated all entries for the batch
951                    staging_buffer_index: 0,
952                    staging_buffer_offset: 0,
953                    metadata_resource_index: 0,
954                    metadata_buffer_offset: 0,
955                });
956            }
957        }
958
959        Ok((dst_resource_index, dst_offset))
960    }
961}