1use thiserror::Error;
2use wgt::{
3    error::{ErrorType, WebGpuError},
4    AdapterInfo,
5};
6
7pub const HEADER_LENGTH: usize = size_of::<PipelineCacheHeader>();
8
9#[derive(Debug, PartialEq, Eq, Clone, Error)]
10#[non_exhaustive]
11pub enum PipelineCacheValidationError {
12    #[error("The pipeline cache data was truncated")]
13    Truncated,
14    #[error("The pipeline cache data was longer than recorded")]
15    Extended,
17    #[error("The pipeline cache data was corrupted (e.g. the hash didn't match)")]
18    Corrupted,
19    #[error("The pipeline cacha data was out of date and so cannot be safely used")]
20    Outdated,
21    #[error("The cache data was created for a different device")]
22    DeviceMismatch,
23    #[error("Pipeline cacha data was created for a future version of wgpu")]
24    Unsupported,
25}
26
27impl PipelineCacheValidationError {
28    pub fn was_avoidable(&self) -> bool {
31        match self {
32            PipelineCacheValidationError::DeviceMismatch => true,
33            PipelineCacheValidationError::Truncated
34            | PipelineCacheValidationError::Unsupported
35            | PipelineCacheValidationError::Extended
36            | PipelineCacheValidationError::Outdated
38            | PipelineCacheValidationError::Corrupted => false,
39        }
40    }
41}
42
43impl WebGpuError for PipelineCacheValidationError {
44    fn webgpu_error_type(&self) -> ErrorType {
45        ErrorType::Validation
46    }
47}
48
49pub fn validate_pipeline_cache<'d>(
51    cache_data: &'d [u8],
52    adapter: &AdapterInfo,
53    validation_key: [u8; 16],
54) -> Result<&'d [u8], PipelineCacheValidationError> {
55    let adapter_key = adapter_key(adapter)?;
56    let Some((header, remaining_data)) = PipelineCacheHeader::read(cache_data) else {
57        return Err(PipelineCacheValidationError::Truncated);
58    };
59    if header.magic != MAGIC {
60        return Err(PipelineCacheValidationError::Corrupted);
61    }
62    if header.header_version != HEADER_VERSION {
63        return Err(PipelineCacheValidationError::Outdated);
64    }
65    if header.cache_abi != ABI {
66        return Err(PipelineCacheValidationError::Outdated);
67    }
68    if header.backend != adapter.backend as u8 {
69        return Err(PipelineCacheValidationError::DeviceMismatch);
70    }
71    if header.adapter_key != adapter_key {
72        return Err(PipelineCacheValidationError::DeviceMismatch);
73    }
74    if header.validation_key != validation_key {
75        return Err(PipelineCacheValidationError::Outdated);
79    }
80    let data_size: usize = header
81        .data_size
82        .try_into()
83        .map_err(|_| PipelineCacheValidationError::Corrupted)?;
86    if remaining_data.len() < data_size {
87        return Err(PipelineCacheValidationError::Truncated);
88    }
89    if remaining_data.len() > data_size {
90        return Err(PipelineCacheValidationError::Extended);
91    }
92    if header.hash_space != HASH_SPACE_VALUE {
93        return Err(PipelineCacheValidationError::Corrupted);
94    }
95    Ok(remaining_data)
96}
97
98pub fn add_cache_header(
99    in_region: &mut [u8],
100    data: &[u8],
101    adapter: &AdapterInfo,
102    validation_key: [u8; 16],
103) {
104    assert_eq!(in_region.len(), HEADER_LENGTH);
105    let header = PipelineCacheHeader {
106        adapter_key: adapter_key(adapter)
107            .expect("Called add_cache_header for an adapter which doesn't support cache data. This is a wgpu internal bug"),
108        backend: adapter.backend as u8,
109        cache_abi: ABI,
110        magic: MAGIC,
111        header_version: HEADER_VERSION,
112        validation_key,
113        hash_space: HASH_SPACE_VALUE,
114        data_size: data
115            .len()
116            .try_into()
117            .expect("Cache larger than u64::MAX bytes"),
118    };
119    header.write(in_region);
120}
121
122const MAGIC: [u8; 8] = *b"WGPUPLCH";
123const HEADER_VERSION: u32 = 1;
124const ABI: u32 = size_of::<*const ()>() as u32;
125
126const HASH_SPACE_VALUE: u64 = 0xFEDCBA9_876543210;
136
137#[repr(C)]
138#[derive(PartialEq, Eq)]
139struct PipelineCacheHeader {
140    magic: [u8; 8],
143    header_version: u32,
150    cache_abi: u32,
154    backend: u8,
156    adapter_key: [u8; 15],
161    validation_key: [u8; 16],
165    data_size: u64,
167    hash_space: u64,
174}
175
176impl PipelineCacheHeader {
177    fn read(data: &[u8]) -> Option<(PipelineCacheHeader, &[u8])> {
178        let mut reader = Reader {
179            data,
180            total_read: 0,
181        };
182        let magic = reader.read_array()?;
183        let header_version = reader.read_u32()?;
184        let cache_abi = reader.read_u32()?;
185        let backend = reader.read_byte()?;
186        let adapter_key = reader.read_array()?;
187        let validation_key = reader.read_array()?;
188        let data_size = reader.read_u64()?;
189        let data_hash = reader.read_u64()?;
190
191        assert_eq!(reader.total_read, size_of::<PipelineCacheHeader>());
192
193        Some((
194            PipelineCacheHeader {
195                magic,
196                header_version,
197                cache_abi,
198                backend,
199                adapter_key,
200                validation_key,
201                data_size,
202                hash_space: data_hash,
203            },
204            reader.data,
205        ))
206    }
207
208    fn write(&self, into: &mut [u8]) -> Option<()> {
209        let mut writer = Writer { data: into };
210        writer.write_array(&self.magic)?;
211        writer.write_u32(self.header_version)?;
212        writer.write_u32(self.cache_abi)?;
213        writer.write_byte(self.backend)?;
214        writer.write_array(&self.adapter_key)?;
215        writer.write_array(&self.validation_key)?;
216        writer.write_u64(self.data_size)?;
217        writer.write_u64(self.hash_space)?;
218
219        assert_eq!(writer.data.len(), 0);
220        Some(())
221    }
222}
223
224fn adapter_key(adapter: &AdapterInfo) -> Result<[u8; 15], PipelineCacheValidationError> {
225    match adapter.backend {
226        wgt::Backend::Vulkan => {
227            let v: [u8; 4] = adapter.vendor.to_be_bytes();
230            let d: [u8; 4] = adapter.device.to_be_bytes();
231            let adapter = [
232                255, 255, 255, v[0], v[1], v[2], v[3], d[0], d[1], d[2], d[3], 255, 255, 255, 255,
233            ];
234            Ok(adapter)
235        }
236        _ => Err(PipelineCacheValidationError::Unsupported),
237    }
238}
239
240struct Reader<'a> {
241    data: &'a [u8],
242    total_read: usize,
243}
244
245impl<'a> Reader<'a> {
246    fn read_byte(&mut self) -> Option<u8> {
247        let res = *self.data.first()?;
248        self.total_read += 1;
249        self.data = &self.data[1..];
250        Some(res)
251    }
252    fn read_array<const N: usize>(&mut self) -> Option<[u8; N]> {
253        if N > self.data.len() {
255            return None;
256        }
257        let (start, data) = self.data.split_at(N);
258        self.total_read += N;
259        self.data = data;
260        Some(start.try_into().expect("off-by-one-error in array size"))
261    }
262
263    fn read_u32(&mut self) -> Option<u32> {
267        self.read_array().map(u32::from_be_bytes)
268    }
269    fn read_u64(&mut self) -> Option<u64> {
270        self.read_array().map(u64::from_be_bytes)
271    }
272}
273
274struct Writer<'a> {
275    data: &'a mut [u8],
276}
277
278impl<'a> Writer<'a> {
279    fn write_byte(&mut self, byte: u8) -> Option<()> {
280        self.write_array(&[byte])
281    }
282    fn write_array<const N: usize>(&mut self, array: &[u8; N]) -> Option<()> {
283        if N > self.data.len() {
285            return None;
286        }
287        let data = core::mem::take(&mut self.data);
288        let (start, data) = data.split_at_mut(N);
289        self.data = data;
290        start.copy_from_slice(array);
291        Some(())
292    }
293
294    fn write_u32(&mut self, value: u32) -> Option<()> {
298        self.write_array(&value.to_be_bytes())
299    }
300    fn write_u64(&mut self, value: u64) -> Option<()> {
301        self.write_array(&value.to_be_bytes())
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use alloc::{string::String, vec::Vec};
308    use wgt::AdapterInfo;
309
310    use crate::pipeline_cache::{PipelineCacheValidationError as E, HEADER_LENGTH};
311
312    use super::ABI;
313
314    const _: [(); HEADER_LENGTH] = [(); 64];
316
317    const ADAPTER: AdapterInfo = AdapterInfo {
318        name: String::new(),
319        vendor: 0x0002_FEED,
320        device: 0xFEFE_FEFE,
321        device_type: wgt::DeviceType::Other,
322        driver: String::new(),
323        driver_info: String::new(),
324        backend: wgt::Backend::Vulkan,
325    };
326
327    const VALIDATION_KEY: [u8; 16] = u128::to_be_bytes(0xFFFFFFFF_FFFFFFFF_88888888_88888888);
329    #[test]
330    fn written_header() {
331        let mut result = [0; HEADER_LENGTH];
332        super::add_cache_header(&mut result, &[], &ADAPTER, VALIDATION_KEY);
333        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
334            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x0u64.to_be_bytes(),                         0xFEDCBA9_876543210u64.to_be_bytes(),         ];
343        let expected = cache.into_iter().flatten().collect::<Vec<u8>>();
344
345        assert_eq!(result.as_slice(), expected.as_slice());
346    }
347
348    #[test]
349    fn valid_data() {
350        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
351            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x0u64.to_be_bytes(),                         0xFEDCBA9_876543210u64.to_be_bytes(),         ];
360        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
361        let expected: &[u8] = &[];
362        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
363        assert_eq!(validation_result, Ok(expected));
364    }
365    #[test]
366    fn invalid_magic() {
367        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
368            *b"NOT_WGPU",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x0u64.to_be_bytes(),                         0xFEDCBA9_876543210u64.to_be_bytes(),         ];
377        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
378        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
379        assert_eq!(validation_result, Err(E::Corrupted));
380    }
381
382    #[test]
383    fn wrong_version() {
384        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
385            *b"WGPUPLCH",                                 [0, 0, 0, 2, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x0u64.to_be_bytes(),                         0xFEDCBA9_876543210u64.to_be_bytes(),         ];
394        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
395        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
396        assert_eq!(validation_result, Err(E::Outdated));
397    }
398    #[test]
399    fn wrong_abi() {
400        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
401            *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, 14],            [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(),                 0xFEDCBA9_876543210u64.to_be_bytes(), ];
411        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
412        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
413        assert_eq!(validation_result, Err(E::Outdated));
414    }
415
416    #[test]
417    fn wrong_backend() {
418        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
419            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [2, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x0u64.to_be_bytes(),                         0xFEDCBA9_876543210u64.to_be_bytes(),         ];
428        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
429        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
430        assert_eq!(validation_result, Err(E::DeviceMismatch));
431    }
432    #[test]
433    fn wrong_adapter() {
434        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
435            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0x00],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x0u64.to_be_bytes(),                         0xFEDCBA9_876543210u64.to_be_bytes(),         ];
444        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
445        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
446        assert_eq!(validation_result, Err(E::DeviceMismatch));
447    }
448    #[test]
449    fn wrong_validation() {
450        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
451            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_00000000u64.to_be_bytes(),         0x0u64.to_be_bytes(),                         0xFEDCBA9_876543210u64.to_be_bytes(),         ];
460        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
461        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
462        assert_eq!(validation_result, Err(E::Outdated));
463    }
464    #[test]
465    fn too_little_data() {
466        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
467            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x064u64.to_be_bytes(),                       0xFEDCBA9_876543210u64.to_be_bytes(),         ];
476        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
477        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
478        assert_eq!(validation_result, Err(E::Truncated));
479    }
480    #[test]
481    fn not_no_data() {
482        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
483            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         100u64.to_be_bytes(),                         0xFEDCBA9_876543210u64.to_be_bytes(),         ];
492        let cache = cache
493            .into_iter()
494            .flatten()
495            .chain(core::iter::repeat_n(0u8, 100))
496            .collect::<Vec<u8>>();
497        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
498        let expected: &[u8] = &[0; 100];
499        assert_eq!(validation_result, Ok(expected));
500    }
501    #[test]
502    fn too_much_data() {
503        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
504            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x064u64.to_be_bytes(),                       0xFEDCBA9_876543210u64.to_be_bytes(),         ];
513        let cache = cache
514            .into_iter()
515            .flatten()
516            .chain(core::iter::repeat_n(0u8, 200))
517            .collect::<Vec<u8>>();
518        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
519        assert_eq!(validation_result, Err(E::Extended));
520    }
521    #[test]
522    fn wrong_hash() {
523        let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
524            *b"WGPUPLCH",                                 [0, 0, 0, 1, 0, 0, 0, ABI as u8],             [1, 255, 255, 255, 0, 2, 0xFE, 0xED],         [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(),         0x88888888_88888888u64.to_be_bytes(),         0x0u64.to_be_bytes(),                         0x00000000_00000000u64.to_be_bytes(),         ];
533        let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
534        let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
535        assert_eq!(validation_result, Err(E::Corrupted));
536    }
537}