wgpu_core/timestamp_normalization/
mod.rs1use core::num::NonZeroU64;
31
32use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
33
34use hashbrown::HashMap;
35use wgt::PushConstantRange;
36
37use crate::{
38 device::{Device, DeviceError},
39 pipeline::{CreateComputePipelineError, CreateShaderModuleError},
40 resource::Buffer,
41 snatch::SnatchGuard,
42 track::BufferTracker,
43};
44
45pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
46 wgt::BufferUses::STORAGE_READ_WRITE;
47
48struct InternalState {
49 temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
50 pipeline_layout: Box<dyn hal::DynPipelineLayout>,
51 pipeline: Box<dyn hal::DynComputePipeline>,
52}
53
54#[derive(Debug, Clone, thiserror::Error)]
55pub enum TimestampNormalizerInitError {
56 #[error("Failed to initialize bind group layout")]
57 BindGroupLayout(#[source] DeviceError),
58 #[cfg(feature = "wgsl")]
59 #[error("Failed to parse shader")]
60 ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
61 #[error("Failed to validate shader module")]
62 ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
63 #[error("Failed to create shader module")]
64 CreateShaderModule(#[from] CreateShaderModuleError),
65 #[error("Failed to create pipeline layout")]
66 PipelineLayout(#[source] DeviceError),
67 #[error("Failed to create compute pipeline")]
68 ComputePipeline(#[from] CreateComputePipelineError),
69}
70
71pub struct TimestampNormalizer {
74 state: Option<InternalState>,
75}
76
77impl TimestampNormalizer {
78 pub fn new(
87 device: &Device,
88 timestamp_period: f32,
89 ) -> Result<Self, TimestampNormalizerInitError> {
90 unsafe {
91 if !device
92 .instance_flags
93 .contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
94 {
95 return Ok(Self { state: None });
96 }
97
98 if !device
99 .downlevel
100 .flags
101 .contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
102 {
103 log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
104 return Ok(Self { state: None });
105 }
106
107 if timestamp_period == 1.0 {
108 return Ok(Self { state: None });
110 }
111
112 let temporary_bind_group_layout = device
113 .raw()
114 .create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
115 label: Some("Timestamp Normalization Bind Group Layout"),
116 flags: hal::BindGroupLayoutFlags::empty(),
117 entries: &[wgt::BindGroupLayoutEntry {
118 binding: 0,
119 visibility: wgt::ShaderStages::COMPUTE,
120 ty: wgt::BindingType::Buffer {
121 ty: wgt::BufferBindingType::Storage { read_only: false },
122 has_dynamic_offset: false,
123 min_binding_size: Some(NonZeroU64::new(8).unwrap()),
124 },
125 count: None,
126 }],
127 })
128 .map_err(|e| {
129 TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
130 })?;
131
132 let common_src = include_str!("common.wgsl");
133 let src = include_str!("timestamp_normalization.wgsl");
134
135 let preprocessed_src = alloc::format!("{common_src}\n{src}");
136
137 #[cfg(feature = "wgsl")]
138 let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
139 TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
140 source: preprocessed_src.clone(),
141 label: None,
142 inner: Box::new(inner),
143 })
144 })?;
145 #[cfg(not(feature = "wgsl"))]
146 #[allow(clippy::diverging_sub_expression)]
147 let module =
148 panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
149
150 let info = crate::device::create_validator(
151 wgt::Features::PUSH_CONSTANTS,
152 wgt::DownlevelFlags::empty(),
153 naga::valid::ValidationFlags::all(),
154 )
155 .validate(&module)
156 .map_err(|inner| {
157 TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
158 source: preprocessed_src.clone(),
159 label: None,
160 inner: Box::new(inner),
161 })
162 })?;
163 let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
164 module: alloc::borrow::Cow::Owned(module),
165 info,
166 debug_source: None,
167 });
168 let hal_desc = hal::ShaderModuleDescriptor {
169 label: None,
170 runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
171 };
172 let module = device
173 .raw()
174 .create_shader_module(&hal_desc, hal_shader)
175 .map_err(|error| match error {
176 hal::ShaderError::Device(error) => {
177 CreateShaderModuleError::Device(device.handle_hal_error(error))
178 }
179 hal::ShaderError::Compilation(ref msg) => {
180 log::error!("Shader error: {}", msg);
181 CreateShaderModuleError::Generation
182 }
183 })?;
184
185 let pipeline_layout = device
186 .raw()
187 .create_pipeline_layout(&hal::PipelineLayoutDescriptor {
188 label: None,
189 bind_group_layouts: &[temporary_bind_group_layout.as_ref()],
190 push_constant_ranges: &[PushConstantRange {
191 stages: wgt::ShaderStages::COMPUTE,
192 range: 0..8,
193 }],
194 flags: hal::PipelineLayoutFlags::empty(),
195 })
196 .map_err(|e| {
197 TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
198 })?;
199
200 let (multiplier, shift) = compute_timestamp_period(timestamp_period);
201
202 let mut constants = HashMap::with_capacity(2);
203 constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
204 constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
205
206 let pipeline_desc = hal::ComputePipelineDescriptor {
207 label: None,
208 layout: pipeline_layout.as_ref(),
209 stage: hal::ProgrammableStage {
210 module: module.as_ref(),
211 entry_point: "main",
212 constants: &constants,
213 zero_initialize_workgroup_memory: false,
214 },
215 cache: None,
216 };
217 let pipeline = device
218 .raw()
219 .create_compute_pipeline(&pipeline_desc)
220 .map_err(|err| match err {
221 hal::PipelineError::Device(error) => {
222 CreateComputePipelineError::Device(device.handle_hal_error(error))
223 }
224 hal::PipelineError::Linkage(_stages, msg) => {
225 CreateComputePipelineError::Internal(msg)
226 }
227 hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
228 crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
229 ),
230 hal::PipelineError::PipelineConstants(_, error) => {
231 CreateComputePipelineError::PipelineConstants(error)
232 }
233 })?;
234
235 Ok(Self {
236 state: Some(InternalState {
237 temporary_bind_group_layout,
238 pipeline_layout,
239 pipeline,
240 }),
241 })
242 }
243 }
244
245 pub fn create_normalization_bind_group(
246 &self,
247 device: &Device,
248 buffer: &dyn hal::DynBuffer,
249 buffer_label: Option<&str>,
250 buffer_size: u64,
251 buffer_usages: wgt::BufferUsages,
252 ) -> Result<TimestampNormalizationBindGroup, DeviceError> {
253 unsafe {
254 let Some(ref state) = &self.state else {
255 return Ok(TimestampNormalizationBindGroup { raw: None });
256 };
257
258 if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
259 return Ok(TimestampNormalizationBindGroup { raw: None });
260 }
261
262 if buffer_size > device.adapter.limits().max_storage_buffer_binding_size as u64 {
267 return Err(DeviceError::OutOfMemory);
268 }
269
270 let bg_label_alloc;
271 let label = match buffer_label {
272 Some(label) => {
273 bg_label_alloc =
274 alloc::format!("Timestamp normalization bind group ({})", label);
275 &*bg_label_alloc
276 }
277 None => "Timestamp normalization bind group",
278 };
279
280 let bg = device
281 .raw()
282 .create_bind_group(&hal::BindGroupDescriptor {
283 label: Some(label),
284 layout: &*state.temporary_bind_group_layout,
285 buffers: &[hal::BufferBinding {
286 buffer,
287 offset: 0,
288 size: None,
289 }],
290 samplers: &[],
291 textures: &[],
292 acceleration_structures: &[],
293 entries: &[hal::BindGroupEntry {
294 binding: 0,
295 resource_index: 0,
296 count: 1,
297 }],
298 })
299 .map_err(|e| device.handle_hal_error(e))?;
300
301 Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
302 }
303 }
304
305 pub fn normalize(
306 &self,
307 snatch_guard: &SnatchGuard<'_>,
308 encoder: &mut dyn hal::DynCommandEncoder,
309 tracker: &mut BufferTracker,
310 bind_group: &TimestampNormalizationBindGroup,
311 buffer: &Arc<Buffer>,
312 buffer_offset_bytes: u64,
313 total_timestamps: u32,
314 ) {
315 let Some(ref state) = &self.state else {
316 return;
317 };
318
319 let Some(bind_group) = bind_group.raw.as_deref() else {
320 return;
321 };
322
323 let buffer_offset_timestamps: u32 = (buffer_offset_bytes / 8).try_into().unwrap(); let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
326
327 let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
328
329 let needed_workgroups = total_timestamps.div_ceil(64);
330
331 unsafe {
332 encoder.transition_buffers(barrier.as_slice());
333 encoder.begin_compute_pass(&hal::ComputePassDescriptor {
334 label: Some("Timestamp normalization pass"),
335 timestamp_writes: None,
336 });
337 encoder.set_compute_pipeline(&*state.pipeline);
338 encoder.set_bind_group(&*state.pipeline_layout, 0, Some(bind_group), &[]);
339 encoder.set_push_constants(
340 &*state.pipeline_layout,
341 wgt::ShaderStages::COMPUTE,
342 0,
343 &[buffer_offset_timestamps, total_timestamps],
344 );
345 encoder.dispatch([needed_workgroups, 1, 1]);
346 encoder.end_compute_pass();
347 }
348 }
349
350 pub fn dispose(self, device: &dyn hal::DynDevice) {
351 unsafe {
352 let Some(state) = self.state else {
353 return;
354 };
355
356 device.destroy_compute_pipeline(state.pipeline);
357 device.destroy_pipeline_layout(state.pipeline_layout);
358 device.destroy_bind_group_layout(state.temporary_bind_group_layout);
359 }
360 }
361
362 pub fn enabled(&self) -> bool {
363 self.state.is_some()
364 }
365}
366
367#[derive(Debug)]
368pub struct TimestampNormalizationBindGroup {
369 raw: Option<Box<dyn hal::DynBindGroup>>,
370}
371
372impl TimestampNormalizationBindGroup {
373 pub fn dispose(self, device: &dyn hal::DynDevice) {
374 unsafe {
375 if let Some(raw) = self.raw {
376 device.destroy_bind_group(raw);
377 }
378 }
379 }
380}
381
382fn compute_timestamp_period(input: f32) -> (u32, u32) {
383 let pow2 = input.log2().ceil() as i32;
384 let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
385 let shift = 32 - clamped_pow2;
386
387 let denominator = (1u64 << shift) as f64;
388
389 let multiplier = (input as f64 * denominator).round() as u32;
391
392 (multiplier, shift)
393}
394
395#[cfg(test)]
396mod tests {
397 use core::f64;
398
399 fn assert_timestamp_case(input: f32) {
400 let (multiplier, shift) = super::compute_timestamp_period(input);
401
402 let output = multiplier as f64 / (1u64 << shift) as f64;
403
404 assert!((input as f64 - output).abs() < 0.0000001);
405 }
406
407 #[test]
408 fn compute_timestamp_period() {
409 assert_timestamp_case(0.01);
410 assert_timestamp_case(0.5);
411 assert_timestamp_case(1.0);
412 assert_timestamp_case(2.0);
413 assert_timestamp_case(2.7);
414 assert_timestamp_case(1000.7);
415 }
416}