1use super::{
2 utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3 CreateIndirectValidationPipelineError,
4};
5use crate::{
6 device::{queue::TempResource, Device, DeviceError},
7 lock::{rank, Mutex},
8 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
9 resource::{StagingBuffer, Trackable},
10 snatch::SnatchGuard,
11 track::TrackerIndex,
12 FastHashMap,
13};
14use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
15use core::{
16 mem::{size_of, size_of_val},
17 num::NonZeroU64,
18};
19use wgt::Limits;
20
21const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
36
37#[derive(Debug)]
49pub(crate) struct Draw {
50 module: Box<dyn hal::DynShaderModule>,
51 metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
55 pipeline: Box<dyn hal::DynComputePipeline>,
56
57 free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
58 free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
59}
60
61impl Draw {
62 pub(super) fn new(
63 device: &dyn hal::DynDevice,
64 required_features: &wgt::Features,
65 ) -> Result<Self, CreateIndirectValidationPipelineError> {
66 let module = create_validation_module(device)?;
67
68 let metadata_bind_group_layout =
69 create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
70 let src_bind_group_layout =
71 create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
72 let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
73
74 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
75 label: None,
76 flags: hal::PipelineLayoutFlags::empty(),
77 bind_group_layouts: &[
78 metadata_bind_group_layout.as_ref(),
79 src_bind_group_layout.as_ref(),
80 dst_bind_group_layout.as_ref(),
81 ],
82 push_constant_ranges: &[wgt::PushConstantRange {
83 stages: wgt::ShaderStages::COMPUTE,
84 range: 0..8,
85 }],
86 };
87 let pipeline_layout = unsafe {
88 device
89 .create_pipeline_layout(&pipeline_layout_desc)
90 .map_err(DeviceError::from_hal)?
91 };
92
93 let supports_indirect_first_instance =
94 required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
95 let pipeline = create_validation_pipeline(
96 device,
97 module.as_ref(),
98 pipeline_layout.as_ref(),
99 supports_indirect_first_instance,
100 )?;
101
102 Ok(Self {
103 module,
104 metadata_bind_group_layout,
105 src_bind_group_layout,
106 dst_bind_group_layout,
107 pipeline_layout,
108 pipeline,
109
110 free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
111 free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
112 })
113 }
114
115 pub(super) fn create_src_bind_group(
117 &self,
118 device: &dyn hal::DynDevice,
119 limits: &Limits,
120 buffer_size: u64,
121 buffer: &dyn hal::DynBuffer,
122 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
123 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
124 let Some(binding_size) = NonZeroU64::new(binding_size) else {
125 return Ok(None);
126 };
127 let hal_desc = hal::BindGroupDescriptor {
128 label: None,
129 layout: self.src_bind_group_layout.as_ref(),
130 entries: &[hal::BindGroupEntry {
131 binding: 0,
132 resource_index: 0,
133 count: 1,
134 }],
135 buffers: &[hal::BufferBinding {
136 buffer,
137 offset: 0,
138 size: Some(binding_size),
139 }],
140 samplers: &[],
141 textures: &[],
142 acceleration_structures: &[],
143 };
144 unsafe {
145 device
146 .create_bind_group(&hal_desc)
147 .map(Some)
148 .map_err(DeviceError::from_hal)
149 }
150 }
151
152 fn acquire_dst_entry(
153 &self,
154 device: &dyn hal::DynDevice,
155 ) -> Result<BufferPoolEntry, hal::DeviceError> {
156 let mut free_buffers = self.free_indirect_entries.lock();
157 match free_buffers.pop() {
158 Some(buffer) => Ok(buffer),
159 None => {
160 let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
161 create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
162 }
163 }
164 }
165
166 fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
167 self.free_indirect_entries.lock().extend(entries);
168 }
169
170 fn acquire_metadata_entry(
171 &self,
172 device: &dyn hal::DynDevice,
173 ) -> Result<BufferPoolEntry, hal::DeviceError> {
174 let mut free_buffers = self.free_metadata_entries.lock();
175 match free_buffers.pop() {
176 Some(buffer) => Ok(buffer),
177 None => {
178 let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
179 create_buffer_and_bind_group(
180 device,
181 usage,
182 self.metadata_bind_group_layout.as_ref(),
183 )
184 }
185 }
186 }
187
188 fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
189 self.free_metadata_entries.lock().extend(entries);
190 }
191
192 pub(crate) fn inject_validation_pass(
194 &self,
195 device: &Arc<Device>,
196 snatch_guard: &SnatchGuard,
197 resources: &mut DrawResources,
198 temp_resources: &mut Vec<TempResource>,
199 encoder: &mut dyn hal::DynCommandEncoder,
200 batcher: DrawBatcher,
201 ) -> Result<(), DeviceError> {
202 let mut batches = batcher.batches;
203
204 if batches.is_empty() {
205 return Ok(());
206 }
207
208 let max_staging_buffer_size = 1 << 26; let mut staging_buffers = Vec::new();
211
212 let mut current_size = 0;
213 for batch in batches.values_mut() {
214 let data = batch.metadata();
215 let offset = if current_size + data.len() > max_staging_buffer_size {
216 let staging_buffer =
217 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
218 staging_buffers.push(staging_buffer);
219 current_size = data.len();
220 0
221 } else {
222 let offset = current_size;
223 current_size += data.len();
224 offset as u64
225 };
226 batch.staging_buffer_index = staging_buffers.len();
227 batch.staging_buffer_offset = offset;
228 }
229 if current_size != 0 {
230 let staging_buffer =
231 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
232 staging_buffers.push(staging_buffer);
233 }
234
235 for batch in batches.values() {
236 let data = batch.metadata();
237 let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
238 unsafe {
239 staging_buffer.write_with_offset(
240 data,
241 0,
242 batch.staging_buffer_offset as isize,
243 data.len(),
244 )
245 };
246 }
247
248 let staging_buffers: Vec<_> = staging_buffers
249 .into_iter()
250 .map(|buffer| buffer.flush())
251 .collect();
252
253 let mut current_metadata_entry = None;
254 for batch in batches.values_mut() {
255 let data = batch.metadata();
256 let (metadata_resource_index, metadata_buffer_offset) =
257 resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
258 batch.metadata_resource_index = metadata_resource_index;
259 batch.metadata_buffer_offset = metadata_buffer_offset;
260 }
261
262 let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
263 let unique_index_scratch = &mut UniqueIndexScratch::new();
264
265 BufferBarriers::new(buffer_barrier_scratch)
266 .extend(
267 batches
268 .values()
269 .map(|batch| batch.staging_buffer_index)
270 .unique(unique_index_scratch)
271 .map(|index| hal::BufferBarrier {
272 buffer: staging_buffers[index].raw(),
273 usage: hal::StateTransition {
274 from: wgt::BufferUses::MAP_WRITE,
275 to: wgt::BufferUses::COPY_SRC,
276 },
277 }),
278 )
279 .extend(
280 batches
281 .values()
282 .map(|batch| batch.metadata_resource_index)
283 .unique(unique_index_scratch)
284 .map(|index| hal::BufferBarrier {
285 buffer: resources.get_metadata_buffer(index),
286 usage: hal::StateTransition {
287 from: wgt::BufferUses::STORAGE_READ_ONLY,
288 to: wgt::BufferUses::COPY_DST,
289 },
290 }),
291 )
292 .encode(encoder);
293
294 for batch in batches.values() {
295 let data = batch.metadata();
296 let data_size = NonZeroU64::new(data.len() as u64).unwrap();
297
298 let staging_buffer = &staging_buffers[batch.staging_buffer_index];
299
300 let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
301
302 unsafe {
303 encoder.copy_buffer_to_buffer(
304 staging_buffer.raw(),
305 metadata_buffer,
306 &[hal::BufferCopy {
307 src_offset: batch.staging_buffer_offset,
308 dst_offset: batch.metadata_buffer_offset,
309 size: data_size,
310 }],
311 );
312 }
313 }
314
315 for staging_buffer in staging_buffers {
316 temp_resources.push(TempResource::StagingBuffer(staging_buffer));
317 }
318
319 BufferBarriers::new(buffer_barrier_scratch)
320 .extend(
321 batches
322 .values()
323 .map(|batch| batch.metadata_resource_index)
324 .unique(unique_index_scratch)
325 .map(|index| hal::BufferBarrier {
326 buffer: resources.get_metadata_buffer(index),
327 usage: hal::StateTransition {
328 from: wgt::BufferUses::COPY_DST,
329 to: wgt::BufferUses::STORAGE_READ_ONLY,
330 },
331 }),
332 )
333 .extend(
334 batches
335 .values()
336 .map(|batch| batch.dst_resource_index)
337 .unique(unique_index_scratch)
338 .map(|index| hal::BufferBarrier {
339 buffer: resources.get_dst_buffer(index),
340 usage: hal::StateTransition {
341 from: wgt::BufferUses::INDIRECT,
342 to: wgt::BufferUses::STORAGE_READ_WRITE,
343 },
344 }),
345 )
346 .encode(encoder);
347
348 let desc = hal::ComputePassDescriptor {
349 label: None,
350 timestamp_writes: None,
351 };
352 unsafe {
353 encoder.begin_compute_pass(&desc);
354 }
355 unsafe {
356 encoder.set_compute_pipeline(self.pipeline.as_ref());
357 }
358
359 for batch in batches.values() {
360 let pipeline_layout = self.pipeline_layout.as_ref();
361
362 let metadata_start =
363 (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
364 let metadata_count = batch.entries.len() as u32;
365 unsafe {
366 encoder.set_push_constants(
367 pipeline_layout,
368 wgt::ShaderStages::COMPUTE,
369 0,
370 &[metadata_start, metadata_count],
371 );
372 }
373
374 let metadata_bind_group =
375 resources.get_metadata_bind_group(batch.metadata_resource_index);
376 unsafe {
377 encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
378 }
379
380 let src_bind_group = batch
381 .src_buffer
382 .indirect_validation_bind_groups
383 .get(snatch_guard)
384 .unwrap()
385 .draw
386 .as_ref();
387 unsafe {
388 encoder.set_bind_group(
389 pipeline_layout,
390 1,
391 Some(src_bind_group),
392 &[batch.src_dynamic_offset as u32],
393 );
394 }
395
396 let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
397 unsafe {
398 encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
399 }
400
401 unsafe {
402 encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
403 }
404 }
405
406 unsafe {
407 encoder.end_compute_pass();
408 }
409
410 BufferBarriers::new(buffer_barrier_scratch)
411 .extend(
412 batches
413 .values()
414 .map(|batch| batch.dst_resource_index)
415 .unique(unique_index_scratch)
416 .map(|index| hal::BufferBarrier {
417 buffer: resources.get_dst_buffer(index),
418 usage: hal::StateTransition {
419 from: wgt::BufferUses::STORAGE_READ_WRITE,
420 to: wgt::BufferUses::INDIRECT,
421 },
422 }),
423 )
424 .encode(encoder);
425
426 Ok(())
427 }
428
429 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
430 let Draw {
431 module,
432 metadata_bind_group_layout,
433 src_bind_group_layout,
434 dst_bind_group_layout,
435 pipeline_layout,
436 pipeline,
437
438 free_indirect_entries,
439 free_metadata_entries,
440 } = self;
441
442 for entry in free_indirect_entries.into_inner().drain(..) {
443 unsafe {
444 device.destroy_bind_group(entry.bind_group);
445 device.destroy_buffer(entry.buffer);
446 }
447 }
448
449 for entry in free_metadata_entries.into_inner().drain(..) {
450 unsafe {
451 device.destroy_bind_group(entry.bind_group);
452 device.destroy_buffer(entry.buffer);
453 }
454 }
455
456 unsafe {
457 device.destroy_compute_pipeline(pipeline);
458 device.destroy_pipeline_layout(pipeline_layout);
459 device.destroy_bind_group_layout(metadata_bind_group_layout);
460 device.destroy_bind_group_layout(src_bind_group_layout);
461 device.destroy_bind_group_layout(dst_bind_group_layout);
462 device.destroy_shader_module(module);
463 }
464 }
465}
466
467fn create_validation_module(
468 device: &dyn hal::DynDevice,
469) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
470 let src = include_str!("./validate_draw.wgsl");
471
472 #[cfg(feature = "wgsl")]
473 let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
474 CreateShaderModuleError::Parsing(naga::error::ShaderError {
475 source: src.to_string(),
476 label: None,
477 inner: Box::new(inner),
478 })
479 })?;
480 #[cfg(not(feature = "wgsl"))]
481 #[allow(clippy::diverging_sub_expression)]
482 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
483
484 let info = crate::device::create_validator(
485 wgt::Features::PUSH_CONSTANTS,
486 wgt::DownlevelFlags::empty(),
487 naga::valid::ValidationFlags::all(),
488 )
489 .validate(&module)
490 .map_err(|inner| {
491 CreateShaderModuleError::Validation(naga::error::ShaderError {
492 source: src.to_string(),
493 label: None,
494 inner: Box::new(inner),
495 })
496 })?;
497 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
498 module: alloc::borrow::Cow::Owned(module),
499 info,
500 debug_source: None,
501 });
502 let hal_desc = hal::ShaderModuleDescriptor {
503 label: None,
504 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
505 };
506 let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
507 |error| match error {
508 hal::ShaderError::Device(error) => {
509 CreateShaderModuleError::Device(DeviceError::from_hal(error))
510 }
511 hal::ShaderError::Compilation(ref msg) => {
512 log::error!("Shader error: {}", msg);
513 CreateShaderModuleError::Generation
514 }
515 },
516 )?;
517
518 Ok(module)
519}
520
521fn create_validation_pipeline(
522 device: &dyn hal::DynDevice,
523 module: &dyn hal::DynShaderModule,
524 pipeline_layout: &dyn hal::DynPipelineLayout,
525 supports_indirect_first_instance: bool,
526) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
527 let pipeline_desc = hal::ComputePipelineDescriptor {
528 label: None,
529 layout: pipeline_layout,
530 stage: hal::ProgrammableStage {
531 module,
532 entry_point: "main",
533 constants: &hashbrown::HashMap::from([(
534 "supports_indirect_first_instance".to_string(),
535 f64::from(supports_indirect_first_instance),
536 )]),
537 zero_initialize_workgroup_memory: false,
538 },
539 cache: None,
540 };
541 let pipeline =
542 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
543 hal::PipelineError::Device(error) => {
544 CreateComputePipelineError::Device(DeviceError::from_hal(error))
545 }
546 hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
547 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
548 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
549 ),
550 hal::PipelineError::PipelineConstants(_, error) => {
551 CreateComputePipelineError::PipelineConstants(error)
552 }
553 })?;
554
555 Ok(pipeline)
556}
557
558fn create_bind_group_layout(
559 device: &dyn hal::DynDevice,
560 read_only: bool,
561 has_dynamic_offset: bool,
562 min_binding_size: wgt::BufferSize,
563) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
564 let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
565 label: None,
566 flags: hal::BindGroupLayoutFlags::empty(),
567 entries: &[wgt::BindGroupLayoutEntry {
568 binding: 0,
569 visibility: wgt::ShaderStages::COMPUTE,
570 ty: wgt::BindingType::Buffer {
571 ty: wgt::BufferBindingType::Storage { read_only },
572 has_dynamic_offset,
573 min_binding_size: Some(min_binding_size),
574 },
575 count: None,
576 }],
577 };
578 let bind_group_layout = unsafe {
579 device
580 .create_bind_group_layout(&bind_group_layout_desc)
581 .map_err(DeviceError::from_hal)?
582 };
583
584 Ok(bind_group_layout)
585}
586
587fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
589 let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
590 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
591
592 if buffer_size <= max_storage_buffer_binding_size {
593 buffer_size
594 } else {
595 let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
596 let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
597
598 if buffer_rem <= binding_rem {
601 max_storage_buffer_binding_size - binding_rem + buffer_rem
602 }
603 else {
605 max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
606 + buffer_rem
607 }
608 }
609}
610
611fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
613 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
614
615 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
616
617 let chunk_adjustment = match min_storage_buffer_offset_alignment {
618 4 => 0,
620 8 => 2,
624 16.. => 1,
628 _ => unreachable!(),
629 };
630
631 let chunks = binding_size / min_storage_buffer_offset_alignment;
632 let dynamic_offset_stride =
633 chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
634
635 if dynamic_offset_stride == 0 {
636 return (0, offset);
637 }
638
639 let max_dynamic_offset = buffer_size - binding_size;
640 let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
641
642 let src_dynamic_offset_index = offset / dynamic_offset_stride;
643
644 let src_dynamic_offset =
645 src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
646 let src_offset = offset - src_dynamic_offset;
647
648 (src_dynamic_offset, src_offset)
649}
650
651#[derive(Debug)]
652struct BufferPoolEntry {
653 buffer: Box<dyn hal::DynBuffer>,
654 bind_group: Box<dyn hal::DynBindGroup>,
655}
656
657fn create_buffer_and_bind_group(
658 device: &dyn hal::DynDevice,
659 usage: wgt::BufferUses,
660 bind_group_layout: &dyn hal::DynBindGroupLayout,
661) -> Result<BufferPoolEntry, hal::DeviceError> {
662 let buffer_desc = hal::BufferDescriptor {
663 label: None,
664 size: BUFFER_SIZE.get(),
665 usage,
666 memory_flags: hal::MemoryFlags::empty(),
667 };
668 let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
669 let bind_group_desc = hal::BindGroupDescriptor {
670 label: None,
671 layout: bind_group_layout,
672 entries: &[hal::BindGroupEntry {
673 binding: 0,
674 resource_index: 0,
675 count: 1,
676 }],
677 buffers: &[hal::BufferBinding {
678 buffer: buffer.as_ref(),
679 offset: 0,
680 size: Some(BUFFER_SIZE),
681 }],
682 samplers: &[],
683 textures: &[],
684 acceleration_structures: &[],
685 };
686 let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
687 Ok(BufferPoolEntry { buffer, bind_group })
688}
689
690#[derive(Clone)]
691struct CurrentEntry {
692 index: usize,
693 offset: u64,
694}
695
696pub(crate) struct DrawResources {
698 device: Arc<Device>,
699 dst_entries: Vec<BufferPoolEntry>,
700 metadata_entries: Vec<BufferPoolEntry>,
701}
702
703impl Drop for DrawResources {
704 fn drop(&mut self) {
705 if let Some(ref indirect_validation) = self.device.indirect_validation {
706 let indirect_draw_validation = &indirect_validation.draw;
707 indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
708 indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
709 }
710 }
711}
712
713impl DrawResources {
714 pub(crate) fn new(device: Arc<Device>) -> Self {
715 DrawResources {
716 device,
717 dst_entries: Vec::new(),
718 metadata_entries: Vec::new(),
719 }
720 }
721
722 pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
723 self.dst_entries.get(index).unwrap().buffer.as_ref()
724 }
725
726 fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
727 self.dst_entries.get(index).unwrap().bind_group.as_ref()
728 }
729
730 fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
731 self.metadata_entries.get(index).unwrap().buffer.as_ref()
732 }
733
734 fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
735 self.metadata_entries
736 .get(index)
737 .unwrap()
738 .bind_group
739 .as_ref()
740 }
741
742 fn get_dst_subrange(
743 &mut self,
744 size: u64,
745 current_entry: &mut Option<CurrentEntry>,
746 ) -> Result<(usize, u64), DeviceError> {
747 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
748 let ensure_entry = |index: usize| {
749 if self.dst_entries.len() <= index {
750 let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
751 self.dst_entries.push(entry);
752 }
753 Ok(())
754 };
755 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
756 Ok((entry_data.index, entry_data.offset))
757 }
758
759 fn get_metadata_subrange(
760 &mut self,
761 size: u64,
762 current_entry: &mut Option<CurrentEntry>,
763 ) -> Result<(usize, u64), DeviceError> {
764 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
765 let ensure_entry = |index: usize| {
766 if self.metadata_entries.len() <= index {
767 let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
768 self.metadata_entries.push(entry);
769 }
770 Ok(())
771 };
772 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
773 Ok((entry_data.index, entry_data.offset))
774 }
775
776 fn get_subrange_impl(
777 ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
778 current_entry: &mut Option<CurrentEntry>,
779 size: u64,
780 ) -> Result<CurrentEntry, DeviceError> {
781 let index = if let Some(current_entry) = current_entry.as_mut() {
782 if current_entry.offset + size <= BUFFER_SIZE.get() {
783 let entry_data = current_entry.clone();
784 current_entry.offset += size;
785 return Ok(entry_data);
786 } else {
787 current_entry.index + 1
788 }
789 } else {
790 0
791 };
792
793 ensure_entry(index).map_err(DeviceError::from_hal)?;
794
795 let entry_data = CurrentEntry { index, offset: 0 };
796
797 *current_entry = Some(CurrentEntry {
798 index,
799 offset: size,
800 });
801
802 Ok(entry_data)
803 }
804}
805
806#[repr(C)]
808struct MetadataEntry {
809 src_offset: u32,
810 dst_offset: u32,
811 vertex_or_index_limit: u32,
812 instance_limit: u32,
813}
814
815impl MetadataEntry {
816 fn new(
817 indexed: bool,
818 src_offset: u64,
819 dst_offset: u64,
820 vertex_or_index_limit: u64,
821 instance_limit: u64,
822 ) -> Self {
823 debug_assert_eq!(
824 4,
825 size_of_val(&Limits::default().max_storage_buffer_binding_size)
826 );
827
828 let src_offset = src_offset as u32; let src_offset = src_offset / 4; let src_offset = src_offset | ((indexed as u32) << 31);
834
835 let max_limit = u32::MAX as u64 + u32::MAX as u64; let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
839 let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; let vertex_or_index_limit = vertex_or_index_limit as u32; let instance_limit = instance_limit.min(max_limit);
843 let instance_limit_bit_32 = (instance_limit >> 32) as u32; let instance_limit = instance_limit as u32; let dst_offset = dst_offset as u32; let dst_offset = dst_offset / 4; let dst_offset =
853 dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
854
855 Self {
856 src_offset,
857 dst_offset,
858 vertex_or_index_limit,
859 instance_limit,
860 }
861 }
862}
863
864struct DrawIndirectValidationBatch {
865 src_buffer: Arc<crate::resource::Buffer>,
866 src_dynamic_offset: u64,
867 dst_resource_index: usize,
868 entries: Vec<MetadataEntry>,
869
870 staging_buffer_index: usize,
871 staging_buffer_offset: u64,
872 metadata_resource_index: usize,
873 metadata_buffer_offset: u64,
874}
875
876impl DrawIndirectValidationBatch {
877 fn metadata(&self) -> &[u8] {
879 unsafe {
880 core::slice::from_raw_parts(
881 self.entries.as_ptr().cast::<u8>(),
882 self.entries.len() * size_of::<MetadataEntry>(),
883 )
884 }
885 }
886}
887
888pub(crate) struct DrawBatcher {
890 batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
891 current_dst_entry: Option<CurrentEntry>,
892}
893
894impl DrawBatcher {
895 pub(crate) fn new() -> Self {
896 Self {
897 batches: FastHashMap::default(),
898 current_dst_entry: None,
899 }
900 }
901
902 pub(crate) fn add<'a>(
907 &mut self,
908 indirect_draw_validation_resources: &'a mut DrawResources,
909 device: &Device,
910 src_buffer: &Arc<crate::resource::Buffer>,
911 offset: u64,
912 indexed: bool,
913 vertex_or_index_limit: u64,
914 instance_limit: u64,
915 ) -> Result<(usize, u64), DeviceError> {
916 let stride = crate::command::get_stride_of_indirect_args(indexed);
917
918 let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
919 .get_dst_subrange(stride, &mut self.current_dst_entry)?;
920
921 let buffer_size = src_buffer.size;
922 let limits = device.adapter.limits();
923 let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
924
925 let src_buffer_tracker_index = src_buffer.tracker_index();
926
927 let entry = MetadataEntry::new(
928 indexed,
929 src_offset,
930 dst_offset,
931 vertex_or_index_limit,
932 instance_limit,
933 );
934
935 match self.batches.entry((
936 src_buffer_tracker_index,
937 src_dynamic_offset,
938 dst_resource_index,
939 )) {
940 hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
941 occupied_entry.get_mut().entries.push(entry)
942 }
943 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
944 vacant_entry.insert(DrawIndirectValidationBatch {
945 src_buffer: src_buffer.clone(),
946 src_dynamic_offset,
947 dst_resource_index,
948 entries: vec![entry],
949
950 staging_buffer_index: 0,
952 staging_buffer_offset: 0,
953 metadata_resource_index: 0,
954 metadata_buffer_offset: 0,
955 });
956 }
957 }
958
959 Ok((dst_resource_index, dst_offset))
960 }
961}