wgpu_core/timestamp_normalization/
mod.rs

1//! Utility for normalizing GPU timestamp queries to have a consistent
2//! 1GHz period. This uses a compute shader to do the normalization,
3//! so the timestamps exist in their correct format on the GPU, as
4//! is required by the WebGPU specification.
5//!
6//! ## Algorithm
7//!
8//! The fundamental operation is multiplying a u64 timestamp by an f32
9//! value. We have neither f64s nor u64s in shaders, so we need to do
10//! something more complicated.
11//!
12//! We first decompose the f32 into a u32 fraction where the denominator
13//! is a power of two. We do the computation with f64 for ease of computation,
14//! as those can store u32s losslessly.
15//!
16//! Because the denominator is a power of two, this means the shader can evaluate
17//! this divide by using a shift. Additionally, we always choose the largest denominator
18//! we can, so that the fraction is as precise as possible.
19//!
20//! To evaluate this function, we have two helper operations (both in common.wgsl).
21//!
22//! 1. `u64_mul_u32` multiplies a u64 by a u32 and returns a u96.
23//! 2. `shift_right_u96` shifts a u96 right by a given amount, returning a u96.
24//!
25//! See their implementations for more details.
26//!
27//! We then multiply the timestamp by the numerator, and shift it right by the
28//! denominator. This gives us the normalized timestamp.
29
30use core::num::NonZeroU64;
31
32use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
33
34use hashbrown::HashMap;
35use wgt::PushConstantRange;
36
37use crate::{
38    device::{Device, DeviceError},
39    pipeline::{CreateComputePipelineError, CreateShaderModuleError},
40    resource::Buffer,
41    snatch::SnatchGuard,
42    track::BufferTracker,
43};
44
45pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
46    wgt::BufferUses::STORAGE_READ_WRITE;
47
48struct InternalState {
49    temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
50    pipeline_layout: Box<dyn hal::DynPipelineLayout>,
51    pipeline: Box<dyn hal::DynComputePipeline>,
52}
53
54#[derive(Debug, Clone, thiserror::Error)]
55pub enum TimestampNormalizerInitError {
56    #[error("Failed to initialize bind group layout")]
57    BindGroupLayout(#[source] DeviceError),
58    #[cfg(feature = "wgsl")]
59    #[error("Failed to parse shader")]
60    ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
61    #[error("Failed to validate shader module")]
62    ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
63    #[error("Failed to create shader module")]
64    CreateShaderModule(#[from] CreateShaderModuleError),
65    #[error("Failed to create pipeline layout")]
66    PipelineLayout(#[source] DeviceError),
67    #[error("Failed to create compute pipeline")]
68    ComputePipeline(#[from] CreateComputePipelineError),
69}
70
71/// Normalizes GPU timestamps to have a consistent 1GHz period.
72/// See module documentation for more information.
73pub struct TimestampNormalizer {
74    state: Option<InternalState>,
75}
76
77impl TimestampNormalizer {
78    /// Creates a new timestamp normalizer.
79    ///
80    /// If the device cannot support automatic timestamp normalization,
81    /// this will return a normalizer that does nothing.
82    ///
83    /// # Errors
84    ///
85    /// If any resources are invalid, this will return an error.
86    pub fn new(
87        device: &Device,
88        timestamp_period: f32,
89    ) -> Result<Self, TimestampNormalizerInitError> {
90        unsafe {
91            if !device
92                .instance_flags
93                .contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
94            {
95                return Ok(Self { state: None });
96            }
97
98            if !device
99                .downlevel
100                .flags
101                .contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
102            {
103                log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
104                return Ok(Self { state: None });
105            }
106
107            if timestamp_period == 1.0 {
108                // If the period is 1, we don't need to do anything to them.
109                return Ok(Self { state: None });
110            }
111
112            let temporary_bind_group_layout = device
113                .raw()
114                .create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
115                    label: Some("Timestamp Normalization Bind Group Layout"),
116                    flags: hal::BindGroupLayoutFlags::empty(),
117                    entries: &[wgt::BindGroupLayoutEntry {
118                        binding: 0,
119                        visibility: wgt::ShaderStages::COMPUTE,
120                        ty: wgt::BindingType::Buffer {
121                            ty: wgt::BufferBindingType::Storage { read_only: false },
122                            has_dynamic_offset: false,
123                            min_binding_size: Some(NonZeroU64::new(8).unwrap()),
124                        },
125                        count: None,
126                    }],
127                })
128                .map_err(|e| {
129                    TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
130                })?;
131
132            let common_src = include_str!("common.wgsl");
133            let src = include_str!("timestamp_normalization.wgsl");
134
135            let preprocessed_src = alloc::format!("{common_src}\n{src}");
136
137            #[cfg(feature = "wgsl")]
138            let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
139                TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
140                    source: preprocessed_src.clone(),
141                    label: None,
142                    inner: Box::new(inner),
143                })
144            })?;
145            #[cfg(not(feature = "wgsl"))]
146            #[allow(clippy::diverging_sub_expression)]
147            let module =
148                panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
149
150            let info = crate::device::create_validator(
151                wgt::Features::PUSH_CONSTANTS,
152                wgt::DownlevelFlags::empty(),
153                naga::valid::ValidationFlags::all(),
154            )
155            .validate(&module)
156            .map_err(|inner| {
157                TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
158                    source: preprocessed_src.clone(),
159                    label: None,
160                    inner: Box::new(inner),
161                })
162            })?;
163            let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
164                module: alloc::borrow::Cow::Owned(module),
165                info,
166                debug_source: None,
167            });
168            let hal_desc = hal::ShaderModuleDescriptor {
169                label: None,
170                runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
171            };
172            let module = device
173                .raw()
174                .create_shader_module(&hal_desc, hal_shader)
175                .map_err(|error| match error {
176                    hal::ShaderError::Device(error) => {
177                        CreateShaderModuleError::Device(device.handle_hal_error(error))
178                    }
179                    hal::ShaderError::Compilation(ref msg) => {
180                        log::error!("Shader error: {}", msg);
181                        CreateShaderModuleError::Generation
182                    }
183                })?;
184
185            let pipeline_layout = device
186                .raw()
187                .create_pipeline_layout(&hal::PipelineLayoutDescriptor {
188                    label: None,
189                    bind_group_layouts: &[temporary_bind_group_layout.as_ref()],
190                    push_constant_ranges: &[PushConstantRange {
191                        stages: wgt::ShaderStages::COMPUTE,
192                        range: 0..8,
193                    }],
194                    flags: hal::PipelineLayoutFlags::empty(),
195                })
196                .map_err(|e| {
197                    TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
198                })?;
199
200            let (multiplier, shift) = compute_timestamp_period(timestamp_period);
201
202            let mut constants = HashMap::with_capacity(2);
203            constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
204            constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
205
206            let pipeline_desc = hal::ComputePipelineDescriptor {
207                label: None,
208                layout: pipeline_layout.as_ref(),
209                stage: hal::ProgrammableStage {
210                    module: module.as_ref(),
211                    entry_point: "main",
212                    constants: &constants,
213                    zero_initialize_workgroup_memory: false,
214                },
215                cache: None,
216            };
217            let pipeline = device
218                .raw()
219                .create_compute_pipeline(&pipeline_desc)
220                .map_err(|err| match err {
221                    hal::PipelineError::Device(error) => {
222                        CreateComputePipelineError::Device(device.handle_hal_error(error))
223                    }
224                    hal::PipelineError::Linkage(_stages, msg) => {
225                        CreateComputePipelineError::Internal(msg)
226                    }
227                    hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
228                        crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
229                    ),
230                    hal::PipelineError::PipelineConstants(_, error) => {
231                        CreateComputePipelineError::PipelineConstants(error)
232                    }
233                })?;
234
235            Ok(Self {
236                state: Some(InternalState {
237                    temporary_bind_group_layout,
238                    pipeline_layout,
239                    pipeline,
240                }),
241            })
242        }
243    }
244
245    pub fn create_normalization_bind_group(
246        &self,
247        device: &Device,
248        buffer: &dyn hal::DynBuffer,
249        buffer_label: Option<&str>,
250        buffer_size: u64,
251        buffer_usages: wgt::BufferUsages,
252    ) -> Result<TimestampNormalizationBindGroup, DeviceError> {
253        unsafe {
254            let Some(ref state) = &self.state else {
255                return Ok(TimestampNormalizationBindGroup { raw: None });
256            };
257
258            if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
259                return Ok(TimestampNormalizationBindGroup { raw: None });
260            }
261
262            // If this buffer is large enough that we wouldn't be able to bind the entire thing
263            // at once to normalize the timestamps, we can't use it. We force the buffer to fail
264            // to allocate. The lowest max binding size is 128MB, and query sets must be small
265            // (no more than 4096), so this should never be hit in practice by sane programs.
266            if buffer_size > device.adapter.limits().max_storage_buffer_binding_size as u64 {
267                return Err(DeviceError::OutOfMemory);
268            }
269
270            let bg_label_alloc;
271            let label = match buffer_label {
272                Some(label) => {
273                    bg_label_alloc =
274                        alloc::format!("Timestamp normalization bind group ({})", label);
275                    &*bg_label_alloc
276                }
277                None => "Timestamp normalization bind group",
278            };
279
280            let bg = device
281                .raw()
282                .create_bind_group(&hal::BindGroupDescriptor {
283                    label: Some(label),
284                    layout: &*state.temporary_bind_group_layout,
285                    buffers: &[hal::BufferBinding {
286                        buffer,
287                        offset: 0,
288                        size: None,
289                    }],
290                    samplers: &[],
291                    textures: &[],
292                    acceleration_structures: &[],
293                    entries: &[hal::BindGroupEntry {
294                        binding: 0,
295                        resource_index: 0,
296                        count: 1,
297                    }],
298                })
299                .map_err(|e| device.handle_hal_error(e))?;
300
301            Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
302        }
303    }
304
305    pub fn normalize(
306        &self,
307        snatch_guard: &SnatchGuard<'_>,
308        encoder: &mut dyn hal::DynCommandEncoder,
309        tracker: &mut BufferTracker,
310        bind_group: &TimestampNormalizationBindGroup,
311        buffer: &Arc<Buffer>,
312        buffer_offset_bytes: u64,
313        total_timestamps: u32,
314    ) {
315        let Some(ref state) = &self.state else {
316            return;
317        };
318
319        let Some(bind_group) = bind_group.raw.as_deref() else {
320            return;
321        };
322
323        let buffer_offset_timestamps: u32 = (buffer_offset_bytes / 8).try_into().unwrap(); // Unreachable as MAX_QUERIES is way less than u32::MAX
324
325        let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
326
327        let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
328
329        let needed_workgroups = total_timestamps.div_ceil(64);
330
331        unsafe {
332            encoder.transition_buffers(barrier.as_slice());
333            encoder.begin_compute_pass(&hal::ComputePassDescriptor {
334                label: Some("Timestamp normalization pass"),
335                timestamp_writes: None,
336            });
337            encoder.set_compute_pipeline(&*state.pipeline);
338            encoder.set_bind_group(&*state.pipeline_layout, 0, Some(bind_group), &[]);
339            encoder.set_push_constants(
340                &*state.pipeline_layout,
341                wgt::ShaderStages::COMPUTE,
342                0,
343                &[buffer_offset_timestamps, total_timestamps],
344            );
345            encoder.dispatch([needed_workgroups, 1, 1]);
346            encoder.end_compute_pass();
347        }
348    }
349
350    pub fn dispose(self, device: &dyn hal::DynDevice) {
351        unsafe {
352            let Some(state) = self.state else {
353                return;
354            };
355
356            device.destroy_compute_pipeline(state.pipeline);
357            device.destroy_pipeline_layout(state.pipeline_layout);
358            device.destroy_bind_group_layout(state.temporary_bind_group_layout);
359        }
360    }
361
362    pub fn enabled(&self) -> bool {
363        self.state.is_some()
364    }
365}
366
367#[derive(Debug)]
368pub struct TimestampNormalizationBindGroup {
369    raw: Option<Box<dyn hal::DynBindGroup>>,
370}
371
372impl TimestampNormalizationBindGroup {
373    pub fn dispose(self, device: &dyn hal::DynDevice) {
374        unsafe {
375            if let Some(raw) = self.raw {
376                device.destroy_bind_group(raw);
377            }
378        }
379    }
380}
381
382fn compute_timestamp_period(input: f32) -> (u32, u32) {
383    let pow2 = input.log2().ceil() as i32;
384    let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
385    let shift = 32 - clamped_pow2;
386
387    let denominator = (1u64 << shift) as f64;
388
389    // float -> int conversions are defined to saturate.
390    let multiplier = (input as f64 * denominator).round() as u32;
391
392    (multiplier, shift)
393}
394
395#[cfg(test)]
396mod tests {
397    use core::f64;
398
399    fn assert_timestamp_case(input: f32) {
400        let (multiplier, shift) = super::compute_timestamp_period(input);
401
402        let output = multiplier as f64 / (1u64 << shift) as f64;
403
404        assert!((input as f64 - output).abs() < 0.0000001);
405    }
406
407    #[test]
408    fn compute_timestamp_period() {
409        assert_timestamp_case(0.01);
410        assert_timestamp_case(0.5);
411        assert_timestamp_case(1.0);
412        assert_timestamp_case(2.0);
413        assert_timestamp_case(2.7);
414        assert_timestamp_case(1000.7);
415    }
416}