1use super::{
2 utils::{BufferBarrierScratch, BufferBarriers, UniqueIndexExt as _, UniqueIndexScratch},
3 CreateIndirectValidationPipelineError,
4};
5use crate::{
6 device::{queue::TempResource, Device, DeviceError},
7 lock::{rank, Mutex},
8 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
9 resource::{StagingBuffer, Trackable},
10 snatch::SnatchGuard,
11 track::TrackerIndex,
12 FastHashMap,
13};
14use alloc::{boxed::Box, string::ToString, sync::Arc, vec, vec::Vec};
15use core::{
16 mem::{size_of, size_of_val},
17 num::NonZeroU64,
18};
19use wgt::Limits;
20
21const BUFFER_SIZE: wgt::BufferSize = unsafe { wgt::BufferSize::new_unchecked(1_048_560) };
36
37#[derive(Debug)]
49pub(crate) struct Draw {
50 module: Box<dyn hal::DynShaderModule>,
51 metadata_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
52 src_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
53 dst_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
54 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
55 pipeline: Box<dyn hal::DynComputePipeline>,
56
57 free_indirect_entries: Mutex<Vec<BufferPoolEntry>>,
58 free_metadata_entries: Mutex<Vec<BufferPoolEntry>>,
59}
60
61impl Draw {
62 pub(super) fn new(
63 device: &dyn hal::DynDevice,
64 required_features: &wgt::Features,
65 backend: wgt::Backend,
66 ) -> Result<Self, CreateIndirectValidationPipelineError> {
67 let module = create_validation_module(device)?;
68
69 let metadata_bind_group_layout =
70 create_bind_group_layout(device, true, false, BUFFER_SIZE)?;
71 let src_bind_group_layout =
72 create_bind_group_layout(device, true, true, wgt::BufferSize::new(4 * 4).unwrap())?;
73 let dst_bind_group_layout = create_bind_group_layout(device, false, false, BUFFER_SIZE)?;
74
75 let pipeline_layout_desc = hal::PipelineLayoutDescriptor {
76 label: None,
77 flags: hal::PipelineLayoutFlags::empty(),
78 bind_group_layouts: &[
79 metadata_bind_group_layout.as_ref(),
80 src_bind_group_layout.as_ref(),
81 dst_bind_group_layout.as_ref(),
82 ],
83 push_constant_ranges: &[wgt::PushConstantRange {
84 stages: wgt::ShaderStages::COMPUTE,
85 range: 0..8,
86 }],
87 };
88 let pipeline_layout = unsafe {
89 device
90 .create_pipeline_layout(&pipeline_layout_desc)
91 .map_err(DeviceError::from_hal)?
92 };
93
94 let supports_indirect_first_instance =
95 required_features.contains(wgt::Features::INDIRECT_FIRST_INSTANCE);
96 let write_d3d12_special_constants = backend == wgt::Backend::Dx12;
97 let pipeline = create_validation_pipeline(
98 device,
99 module.as_ref(),
100 pipeline_layout.as_ref(),
101 supports_indirect_first_instance,
102 write_d3d12_special_constants,
103 )?;
104
105 Ok(Self {
106 module,
107 metadata_bind_group_layout,
108 src_bind_group_layout,
109 dst_bind_group_layout,
110 pipeline_layout,
111 pipeline,
112
113 free_indirect_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
114 free_metadata_entries: Mutex::new(rank::BUFFER_POOL, Vec::new()),
115 })
116 }
117
118 pub(super) fn create_src_bind_group(
120 &self,
121 device: &dyn hal::DynDevice,
122 limits: &Limits,
123 buffer_size: u64,
124 buffer: &dyn hal::DynBuffer,
125 ) -> Result<Option<Box<dyn hal::DynBindGroup>>, DeviceError> {
126 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
127 let Some(binding_size) = NonZeroU64::new(binding_size) else {
128 return Ok(None);
129 };
130 let hal_desc = hal::BindGroupDescriptor {
131 label: None,
132 layout: self.src_bind_group_layout.as_ref(),
133 entries: &[hal::BindGroupEntry {
134 binding: 0,
135 resource_index: 0,
136 count: 1,
137 }],
138 buffers: &[hal::BufferBinding {
139 buffer,
140 offset: 0,
141 size: Some(binding_size),
142 }],
143 samplers: &[],
144 textures: &[],
145 acceleration_structures: &[],
146 };
147 unsafe {
148 device
149 .create_bind_group(&hal_desc)
150 .map(Some)
151 .map_err(DeviceError::from_hal)
152 }
153 }
154
155 fn acquire_dst_entry(
156 &self,
157 device: &dyn hal::DynDevice,
158 ) -> Result<BufferPoolEntry, hal::DeviceError> {
159 let mut free_buffers = self.free_indirect_entries.lock();
160 match free_buffers.pop() {
161 Some(buffer) => Ok(buffer),
162 None => {
163 let usage = wgt::BufferUses::INDIRECT | wgt::BufferUses::STORAGE_READ_WRITE;
164 create_buffer_and_bind_group(device, usage, self.dst_bind_group_layout.as_ref())
165 }
166 }
167 }
168
169 fn release_dst_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
170 self.free_indirect_entries.lock().extend(entries);
171 }
172
173 fn acquire_metadata_entry(
174 &self,
175 device: &dyn hal::DynDevice,
176 ) -> Result<BufferPoolEntry, hal::DeviceError> {
177 let mut free_buffers = self.free_metadata_entries.lock();
178 match free_buffers.pop() {
179 Some(buffer) => Ok(buffer),
180 None => {
181 let usage = wgt::BufferUses::COPY_DST | wgt::BufferUses::STORAGE_READ_ONLY;
182 create_buffer_and_bind_group(
183 device,
184 usage,
185 self.metadata_bind_group_layout.as_ref(),
186 )
187 }
188 }
189 }
190
191 fn release_metadata_entries(&self, entries: impl Iterator<Item = BufferPoolEntry>) {
192 self.free_metadata_entries.lock().extend(entries);
193 }
194
195 pub(crate) fn inject_validation_pass(
197 &self,
198 device: &Arc<Device>,
199 snatch_guard: &SnatchGuard,
200 resources: &mut DrawResources,
201 temp_resources: &mut Vec<TempResource>,
202 encoder: &mut dyn hal::DynCommandEncoder,
203 batcher: DrawBatcher,
204 ) -> Result<(), DeviceError> {
205 let mut batches = batcher.batches;
206
207 if batches.is_empty() {
208 return Ok(());
209 }
210
211 let max_staging_buffer_size = 1 << 26; let mut staging_buffers = Vec::new();
214
215 let mut current_size = 0;
216 for batch in batches.values_mut() {
217 let data = batch.metadata();
218 let offset = if current_size + data.len() > max_staging_buffer_size {
219 let staging_buffer =
220 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
221 staging_buffers.push(staging_buffer);
222 current_size = data.len();
223 0
224 } else {
225 let offset = current_size;
226 current_size += data.len();
227 offset as u64
228 };
229 batch.staging_buffer_index = staging_buffers.len();
230 batch.staging_buffer_offset = offset;
231 }
232 if current_size != 0 {
233 let staging_buffer =
234 StagingBuffer::new(device, NonZeroU64::new(current_size as u64).unwrap())?;
235 staging_buffers.push(staging_buffer);
236 }
237
238 for batch in batches.values() {
239 let data = batch.metadata();
240 let staging_buffer = &mut staging_buffers[batch.staging_buffer_index];
241 unsafe {
242 staging_buffer.write_with_offset(
243 data,
244 0,
245 batch.staging_buffer_offset as isize,
246 data.len(),
247 )
248 };
249 }
250
251 let staging_buffers: Vec<_> = staging_buffers
252 .into_iter()
253 .map(|buffer| buffer.flush())
254 .collect();
255
256 let mut current_metadata_entry = None;
257 for batch in batches.values_mut() {
258 let data = batch.metadata();
259 let (metadata_resource_index, metadata_buffer_offset) =
260 resources.get_metadata_subrange(data.len() as u64, &mut current_metadata_entry)?;
261 batch.metadata_resource_index = metadata_resource_index;
262 batch.metadata_buffer_offset = metadata_buffer_offset;
263 }
264
265 let buffer_barrier_scratch = &mut BufferBarrierScratch::new();
266 let unique_index_scratch = &mut UniqueIndexScratch::new();
267
268 BufferBarriers::new(buffer_barrier_scratch)
269 .extend(
270 batches
271 .values()
272 .map(|batch| batch.staging_buffer_index)
273 .unique(unique_index_scratch)
274 .map(|index| hal::BufferBarrier {
275 buffer: staging_buffers[index].raw(),
276 usage: hal::StateTransition {
277 from: wgt::BufferUses::MAP_WRITE,
278 to: wgt::BufferUses::COPY_SRC,
279 },
280 }),
281 )
282 .extend(
283 batches
284 .values()
285 .map(|batch| batch.metadata_resource_index)
286 .unique(unique_index_scratch)
287 .map(|index| hal::BufferBarrier {
288 buffer: resources.get_metadata_buffer(index),
289 usage: hal::StateTransition {
290 from: wgt::BufferUses::STORAGE_READ_ONLY,
291 to: wgt::BufferUses::COPY_DST,
292 },
293 }),
294 )
295 .encode(encoder);
296
297 for batch in batches.values() {
298 let data = batch.metadata();
299 let data_size = NonZeroU64::new(data.len() as u64).unwrap();
300
301 let staging_buffer = &staging_buffers[batch.staging_buffer_index];
302
303 let metadata_buffer = resources.get_metadata_buffer(batch.metadata_resource_index);
304
305 unsafe {
306 encoder.copy_buffer_to_buffer(
307 staging_buffer.raw(),
308 metadata_buffer,
309 &[hal::BufferCopy {
310 src_offset: batch.staging_buffer_offset,
311 dst_offset: batch.metadata_buffer_offset,
312 size: data_size,
313 }],
314 );
315 }
316 }
317
318 for staging_buffer in staging_buffers {
319 temp_resources.push(TempResource::StagingBuffer(staging_buffer));
320 }
321
322 BufferBarriers::new(buffer_barrier_scratch)
323 .extend(
324 batches
325 .values()
326 .map(|batch| batch.metadata_resource_index)
327 .unique(unique_index_scratch)
328 .map(|index| hal::BufferBarrier {
329 buffer: resources.get_metadata_buffer(index),
330 usage: hal::StateTransition {
331 from: wgt::BufferUses::COPY_DST,
332 to: wgt::BufferUses::STORAGE_READ_ONLY,
333 },
334 }),
335 )
336 .extend(
337 batches
338 .values()
339 .map(|batch| batch.dst_resource_index)
340 .unique(unique_index_scratch)
341 .map(|index| hal::BufferBarrier {
342 buffer: resources.get_dst_buffer(index),
343 usage: hal::StateTransition {
344 from: wgt::BufferUses::INDIRECT,
345 to: wgt::BufferUses::STORAGE_READ_WRITE,
346 },
347 }),
348 )
349 .encode(encoder);
350
351 let desc = hal::ComputePassDescriptor {
352 label: None,
353 timestamp_writes: None,
354 };
355 unsafe {
356 encoder.begin_compute_pass(&desc);
357 }
358 unsafe {
359 encoder.set_compute_pipeline(self.pipeline.as_ref());
360 }
361
362 for batch in batches.values() {
363 let pipeline_layout = self.pipeline_layout.as_ref();
364
365 let metadata_start =
366 (batch.metadata_buffer_offset / size_of::<MetadataEntry>() as u64) as u32;
367 let metadata_count = batch.entries.len() as u32;
368 unsafe {
369 encoder.set_push_constants(
370 pipeline_layout,
371 wgt::ShaderStages::COMPUTE,
372 0,
373 &[metadata_start, metadata_count],
374 );
375 }
376
377 let metadata_bind_group =
378 resources.get_metadata_bind_group(batch.metadata_resource_index);
379 unsafe {
380 encoder.set_bind_group(pipeline_layout, 0, Some(metadata_bind_group), &[]);
381 }
382
383 let src_bind_group = batch
384 .src_buffer
385 .indirect_validation_bind_groups
386 .get(snatch_guard)
387 .unwrap()
388 .draw
389 .as_ref();
390 unsafe {
391 encoder.set_bind_group(
392 pipeline_layout,
393 1,
394 Some(src_bind_group),
395 &[batch.src_dynamic_offset as u32],
396 );
397 }
398
399 let dst_bind_group = resources.get_dst_bind_group(batch.dst_resource_index);
400 unsafe {
401 encoder.set_bind_group(pipeline_layout, 2, Some(dst_bind_group), &[]);
402 }
403
404 unsafe {
405 encoder.dispatch([(batch.entries.len() as u32).div_ceil(64), 1, 1]);
406 }
407 }
408
409 unsafe {
410 encoder.end_compute_pass();
411 }
412
413 BufferBarriers::new(buffer_barrier_scratch)
414 .extend(
415 batches
416 .values()
417 .map(|batch| batch.dst_resource_index)
418 .unique(unique_index_scratch)
419 .map(|index| hal::BufferBarrier {
420 buffer: resources.get_dst_buffer(index),
421 usage: hal::StateTransition {
422 from: wgt::BufferUses::STORAGE_READ_WRITE,
423 to: wgt::BufferUses::INDIRECT,
424 },
425 }),
426 )
427 .encode(encoder);
428
429 Ok(())
430 }
431
432 pub(super) fn dispose(self, device: &dyn hal::DynDevice) {
433 let Draw {
434 module,
435 metadata_bind_group_layout,
436 src_bind_group_layout,
437 dst_bind_group_layout,
438 pipeline_layout,
439 pipeline,
440
441 free_indirect_entries,
442 free_metadata_entries,
443 } = self;
444
445 for entry in free_indirect_entries.into_inner().drain(..) {
446 unsafe {
447 device.destroy_bind_group(entry.bind_group);
448 device.destroy_buffer(entry.buffer);
449 }
450 }
451
452 for entry in free_metadata_entries.into_inner().drain(..) {
453 unsafe {
454 device.destroy_bind_group(entry.bind_group);
455 device.destroy_buffer(entry.buffer);
456 }
457 }
458
459 unsafe {
460 device.destroy_compute_pipeline(pipeline);
461 device.destroy_pipeline_layout(pipeline_layout);
462 device.destroy_bind_group_layout(metadata_bind_group_layout);
463 device.destroy_bind_group_layout(src_bind_group_layout);
464 device.destroy_bind_group_layout(dst_bind_group_layout);
465 device.destroy_shader_module(module);
466 }
467 }
468}
469
470fn create_validation_module(
471 device: &dyn hal::DynDevice,
472) -> Result<Box<dyn hal::DynShaderModule>, CreateIndirectValidationPipelineError> {
473 let src = include_str!("./validate_draw.wgsl");
474
475 #[cfg(feature = "wgsl")]
476 let module = naga::front::wgsl::parse_str(src).map_err(|inner| {
477 CreateShaderModuleError::Parsing(naga::error::ShaderError {
478 source: src.to_string(),
479 label: None,
480 inner: Box::new(inner),
481 })
482 })?;
483 #[cfg(not(feature = "wgsl"))]
484 #[allow(clippy::diverging_sub_expression)]
485 let module = panic!("Indirect validation requires the wgsl feature flag to be enabled!");
486
487 let info = crate::device::create_validator(
488 wgt::Features::PUSH_CONSTANTS,
489 wgt::DownlevelFlags::empty(),
490 naga::valid::ValidationFlags::all(),
491 )
492 .validate(&module)
493 .map_err(|inner| {
494 CreateShaderModuleError::Validation(naga::error::ShaderError {
495 source: src.to_string(),
496 label: None,
497 inner: Box::new(inner),
498 })
499 })?;
500 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
501 module: alloc::borrow::Cow::Owned(module),
502 info,
503 debug_source: None,
504 });
505 let hal_desc = hal::ShaderModuleDescriptor {
506 label: None,
507 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
508 };
509 let module = unsafe { device.create_shader_module(&hal_desc, hal_shader) }.map_err(
510 |error| match error {
511 hal::ShaderError::Device(error) => {
512 CreateShaderModuleError::Device(DeviceError::from_hal(error))
513 }
514 hal::ShaderError::Compilation(ref msg) => {
515 log::error!("Shader error: {}", msg);
516 CreateShaderModuleError::Generation
517 }
518 },
519 )?;
520
521 Ok(module)
522}
523
524fn create_validation_pipeline(
525 device: &dyn hal::DynDevice,
526 module: &dyn hal::DynShaderModule,
527 pipeline_layout: &dyn hal::DynPipelineLayout,
528 supports_indirect_first_instance: bool,
529 write_d3d12_special_constants: bool,
530) -> Result<Box<dyn hal::DynComputePipeline>, CreateIndirectValidationPipelineError> {
531 let pipeline_desc = hal::ComputePipelineDescriptor {
532 label: None,
533 layout: pipeline_layout,
534 stage: hal::ProgrammableStage {
535 module,
536 entry_point: "main",
537 constants: &hashbrown::HashMap::from([
538 (
539 "supports_indirect_first_instance".to_string(),
540 f64::from(supports_indirect_first_instance),
541 ),
542 (
543 "write_d3d12_special_constants".to_string(),
544 f64::from(write_d3d12_special_constants),
545 ),
546 ]),
547 zero_initialize_workgroup_memory: false,
548 },
549 cache: None,
550 };
551 let pipeline =
552 unsafe { device.create_compute_pipeline(&pipeline_desc) }.map_err(|err| match err {
553 hal::PipelineError::Device(error) => {
554 CreateComputePipelineError::Device(DeviceError::from_hal(error))
555 }
556 hal::PipelineError::Linkage(_stages, msg) => CreateComputePipelineError::Internal(msg),
557 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
558 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
559 ),
560 hal::PipelineError::PipelineConstants(_, error) => {
561 CreateComputePipelineError::PipelineConstants(error)
562 }
563 })?;
564
565 Ok(pipeline)
566}
567
568fn create_bind_group_layout(
569 device: &dyn hal::DynDevice,
570 read_only: bool,
571 has_dynamic_offset: bool,
572 min_binding_size: wgt::BufferSize,
573) -> Result<Box<dyn hal::DynBindGroupLayout>, CreateIndirectValidationPipelineError> {
574 let bind_group_layout_desc = hal::BindGroupLayoutDescriptor {
575 label: None,
576 flags: hal::BindGroupLayoutFlags::empty(),
577 entries: &[wgt::BindGroupLayoutEntry {
578 binding: 0,
579 visibility: wgt::ShaderStages::COMPUTE,
580 ty: wgt::BindingType::Buffer {
581 ty: wgt::BufferBindingType::Storage { read_only },
582 has_dynamic_offset,
583 min_binding_size: Some(min_binding_size),
584 },
585 count: None,
586 }],
587 };
588 let bind_group_layout = unsafe {
589 device
590 .create_bind_group_layout(&bind_group_layout_desc)
591 .map_err(DeviceError::from_hal)?
592 };
593
594 Ok(bind_group_layout)
595}
596
597fn calculate_src_buffer_binding_size(buffer_size: u64, limits: &Limits) -> u64 {
599 let max_storage_buffer_binding_size = limits.max_storage_buffer_binding_size as u64;
600 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
601
602 if buffer_size <= max_storage_buffer_binding_size {
603 buffer_size
604 } else {
605 let buffer_rem = buffer_size % min_storage_buffer_offset_alignment;
606 let binding_rem = max_storage_buffer_binding_size % min_storage_buffer_offset_alignment;
607
608 if buffer_rem <= binding_rem {
611 max_storage_buffer_binding_size - binding_rem + buffer_rem
612 }
613 else {
615 max_storage_buffer_binding_size - binding_rem - min_storage_buffer_offset_alignment
616 + buffer_rem
617 }
618 }
619}
620
621fn calculate_src_offsets(buffer_size: u64, limits: &Limits, offset: u64) -> (u64, u64) {
623 let binding_size = calculate_src_buffer_binding_size(buffer_size, limits);
624
625 let min_storage_buffer_offset_alignment = limits.min_storage_buffer_offset_alignment as u64;
626
627 let chunk_adjustment = match min_storage_buffer_offset_alignment {
628 4 => 0,
630 8 => 2,
634 16.. => 1,
638 _ => unreachable!(),
639 };
640
641 let chunks = binding_size / min_storage_buffer_offset_alignment;
642 let dynamic_offset_stride =
643 chunks.saturating_sub(chunk_adjustment) * min_storage_buffer_offset_alignment;
644
645 if dynamic_offset_stride == 0 {
646 return (0, offset);
647 }
648
649 let max_dynamic_offset = buffer_size - binding_size;
650 let max_dynamic_offset_index = max_dynamic_offset / dynamic_offset_stride;
651
652 let src_dynamic_offset_index = offset / dynamic_offset_stride;
653
654 let src_dynamic_offset =
655 src_dynamic_offset_index.min(max_dynamic_offset_index) * dynamic_offset_stride;
656 let src_offset = offset - src_dynamic_offset;
657
658 (src_dynamic_offset, src_offset)
659}
660
661#[derive(Debug)]
662struct BufferPoolEntry {
663 buffer: Box<dyn hal::DynBuffer>,
664 bind_group: Box<dyn hal::DynBindGroup>,
665}
666
667fn create_buffer_and_bind_group(
668 device: &dyn hal::DynDevice,
669 usage: wgt::BufferUses,
670 bind_group_layout: &dyn hal::DynBindGroupLayout,
671) -> Result<BufferPoolEntry, hal::DeviceError> {
672 let buffer_desc = hal::BufferDescriptor {
673 label: None,
674 size: BUFFER_SIZE.get(),
675 usage,
676 memory_flags: hal::MemoryFlags::empty(),
677 };
678 let buffer = unsafe { device.create_buffer(&buffer_desc) }?;
679 let bind_group_desc = hal::BindGroupDescriptor {
680 label: None,
681 layout: bind_group_layout,
682 entries: &[hal::BindGroupEntry {
683 binding: 0,
684 resource_index: 0,
685 count: 1,
686 }],
687 buffers: &[hal::BufferBinding {
688 buffer: buffer.as_ref(),
689 offset: 0,
690 size: Some(BUFFER_SIZE),
691 }],
692 samplers: &[],
693 textures: &[],
694 acceleration_structures: &[],
695 };
696 let bind_group = unsafe { device.create_bind_group(&bind_group_desc) }?;
697 Ok(BufferPoolEntry { buffer, bind_group })
698}
699
700#[derive(Clone)]
701struct CurrentEntry {
702 index: usize,
703 offset: u64,
704}
705
706pub(crate) struct DrawResources {
708 device: Arc<Device>,
709 dst_entries: Vec<BufferPoolEntry>,
710 metadata_entries: Vec<BufferPoolEntry>,
711}
712
713impl Drop for DrawResources {
714 fn drop(&mut self) {
715 if let Some(ref indirect_validation) = self.device.indirect_validation {
716 let indirect_draw_validation = &indirect_validation.draw;
717 indirect_draw_validation.release_dst_entries(self.dst_entries.drain(..));
718 indirect_draw_validation.release_metadata_entries(self.metadata_entries.drain(..));
719 }
720 }
721}
722
723impl DrawResources {
724 pub(crate) fn new(device: Arc<Device>) -> Self {
725 DrawResources {
726 device,
727 dst_entries: Vec::new(),
728 metadata_entries: Vec::new(),
729 }
730 }
731
732 pub(crate) fn get_dst_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
733 self.dst_entries.get(index).unwrap().buffer.as_ref()
734 }
735
736 fn get_dst_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
737 self.dst_entries.get(index).unwrap().bind_group.as_ref()
738 }
739
740 fn get_metadata_buffer(&self, index: usize) -> &dyn hal::DynBuffer {
741 self.metadata_entries.get(index).unwrap().buffer.as_ref()
742 }
743
744 fn get_metadata_bind_group(&self, index: usize) -> &dyn hal::DynBindGroup {
745 self.metadata_entries
746 .get(index)
747 .unwrap()
748 .bind_group
749 .as_ref()
750 }
751
752 fn get_dst_subrange(
753 &mut self,
754 size: u64,
755 current_entry: &mut Option<CurrentEntry>,
756 ) -> Result<(usize, u64), DeviceError> {
757 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
758 let ensure_entry = |index: usize| {
759 if self.dst_entries.len() <= index {
760 let entry = indirect_draw_validation.acquire_dst_entry(self.device.raw())?;
761 self.dst_entries.push(entry);
762 }
763 Ok(())
764 };
765 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
766 Ok((entry_data.index, entry_data.offset))
767 }
768
769 fn get_metadata_subrange(
770 &mut self,
771 size: u64,
772 current_entry: &mut Option<CurrentEntry>,
773 ) -> Result<(usize, u64), DeviceError> {
774 let indirect_draw_validation = &self.device.indirect_validation.as_ref().unwrap().draw;
775 let ensure_entry = |index: usize| {
776 if self.metadata_entries.len() <= index {
777 let entry = indirect_draw_validation.acquire_metadata_entry(self.device.raw())?;
778 self.metadata_entries.push(entry);
779 }
780 Ok(())
781 };
782 let entry_data = Self::get_subrange_impl(ensure_entry, current_entry, size)?;
783 Ok((entry_data.index, entry_data.offset))
784 }
785
786 fn get_subrange_impl(
787 ensure_entry: impl FnOnce(usize) -> Result<(), hal::DeviceError>,
788 current_entry: &mut Option<CurrentEntry>,
789 size: u64,
790 ) -> Result<CurrentEntry, DeviceError> {
791 let index = if let Some(current_entry) = current_entry.as_mut() {
792 if current_entry.offset + size <= BUFFER_SIZE.get() {
793 let entry_data = current_entry.clone();
794 current_entry.offset += size;
795 return Ok(entry_data);
796 } else {
797 current_entry.index + 1
798 }
799 } else {
800 0
801 };
802
803 ensure_entry(index).map_err(DeviceError::from_hal)?;
804
805 let entry_data = CurrentEntry { index, offset: 0 };
806
807 *current_entry = Some(CurrentEntry {
808 index,
809 offset: size,
810 });
811
812 Ok(entry_data)
813 }
814}
815
816#[repr(C)]
818struct MetadataEntry {
819 src_offset: u32,
820 dst_offset: u32,
821 vertex_or_index_limit: u32,
822 instance_limit: u32,
823}
824
825impl MetadataEntry {
826 fn new(
827 indexed: bool,
828 src_offset: u64,
829 dst_offset: u64,
830 vertex_or_index_limit: u64,
831 instance_limit: u64,
832 ) -> Self {
833 debug_assert_eq!(
834 4,
835 size_of_val(&Limits::default().max_storage_buffer_binding_size)
836 );
837
838 let src_offset = src_offset as u32; let src_offset = src_offset / 4; let src_offset = src_offset | ((indexed as u32) << 31);
844
845 let max_limit = u32::MAX as u64 + u32::MAX as u64; let vertex_or_index_limit = vertex_or_index_limit.min(max_limit);
849 let vertex_or_index_limit_bit_32 = (vertex_or_index_limit >> 32) as u32; let vertex_or_index_limit = vertex_or_index_limit as u32; let instance_limit = instance_limit.min(max_limit);
853 let instance_limit_bit_32 = (instance_limit >> 32) as u32; let instance_limit = instance_limit as u32; let dst_offset = dst_offset as u32; let dst_offset = dst_offset / 4; let dst_offset =
863 dst_offset | (vertex_or_index_limit_bit_32 << 30) | (instance_limit_bit_32 << 31);
864
865 Self {
866 src_offset,
867 dst_offset,
868 vertex_or_index_limit,
869 instance_limit,
870 }
871 }
872}
873
874struct DrawIndirectValidationBatch {
875 src_buffer: Arc<crate::resource::Buffer>,
876 src_dynamic_offset: u64,
877 dst_resource_index: usize,
878 entries: Vec<MetadataEntry>,
879
880 staging_buffer_index: usize,
881 staging_buffer_offset: u64,
882 metadata_resource_index: usize,
883 metadata_buffer_offset: u64,
884}
885
886impl DrawIndirectValidationBatch {
887 fn metadata(&self) -> &[u8] {
889 unsafe {
890 core::slice::from_raw_parts(
891 self.entries.as_ptr().cast::<u8>(),
892 self.entries.len() * size_of::<MetadataEntry>(),
893 )
894 }
895 }
896}
897
898pub(crate) struct DrawBatcher {
900 batches: FastHashMap<(TrackerIndex, u64, usize), DrawIndirectValidationBatch>,
901 current_dst_entry: Option<CurrentEntry>,
902}
903
904impl DrawBatcher {
905 pub(crate) fn new() -> Self {
906 Self {
907 batches: FastHashMap::default(),
908 current_dst_entry: None,
909 }
910 }
911
912 pub(crate) fn add<'a>(
917 &mut self,
918 indirect_draw_validation_resources: &'a mut DrawResources,
919 device: &Device,
920 src_buffer: &Arc<crate::resource::Buffer>,
921 offset: u64,
922 indexed: bool,
923 vertex_or_index_limit: u64,
924 instance_limit: u64,
925 ) -> Result<(usize, u64), DeviceError> {
926 let extra = if device.backend() == wgt::Backend::Dx12 {
928 3 * size_of::<u32>() as u64
929 } else {
930 0
931 };
932 let stride = extra + crate::command::get_stride_of_indirect_args(indexed);
933
934 let (dst_resource_index, dst_offset) = indirect_draw_validation_resources
935 .get_dst_subrange(stride, &mut self.current_dst_entry)?;
936
937 let buffer_size = src_buffer.size;
938 let limits = device.adapter.limits();
939 let (src_dynamic_offset, src_offset) = calculate_src_offsets(buffer_size, &limits, offset);
940
941 let src_buffer_tracker_index = src_buffer.tracker_index();
942
943 let entry = MetadataEntry::new(
944 indexed,
945 src_offset,
946 dst_offset,
947 vertex_or_index_limit,
948 instance_limit,
949 );
950
951 match self.batches.entry((
952 src_buffer_tracker_index,
953 src_dynamic_offset,
954 dst_resource_index,
955 )) {
956 hashbrown::hash_map::Entry::Occupied(mut occupied_entry) => {
957 occupied_entry.get_mut().entries.push(entry)
958 }
959 hashbrown::hash_map::Entry::Vacant(vacant_entry) => {
960 vacant_entry.insert(DrawIndirectValidationBatch {
961 src_buffer: src_buffer.clone(),
962 src_dynamic_offset,
963 dst_resource_index,
964 entries: vec![entry],
965
966 staging_buffer_index: 0,
968 staging_buffer_offset: 0,
969 metadata_resource_index: 0,
970 metadata_buffer_offset: 0,
971 });
972 }
973 }
974
975 Ok((dst_resource_index, dst_offset))
976 }
977}