1use thiserror::Error;
2use wgt::AdapterInfo;
3
4pub const HEADER_LENGTH: usize = size_of::<PipelineCacheHeader>();
5
6#[derive(Debug, PartialEq, Eq, Clone, Error)]
7#[non_exhaustive]
8pub enum PipelineCacheValidationError {
9 #[error("The pipeline cache data was truncated")]
10 Truncated,
11 #[error("The pipeline cache data was longer than recorded")]
12 Extended,
14 #[error("The pipeline cache data was corrupted (e.g. the hash didn't match)")]
15 Corrupted,
16 #[error("The pipeline cacha data was out of date and so cannot be safely used")]
17 Outdated,
18 #[error("The cache data was created for a different device")]
19 DeviceMismatch,
20 #[error("Pipeline cacha data was created for a future version of wgpu")]
21 Unsupported,
22}
23
24impl PipelineCacheValidationError {
25 pub fn was_avoidable(&self) -> bool {
28 match self {
29 PipelineCacheValidationError::DeviceMismatch => true,
30 PipelineCacheValidationError::Truncated
31 | PipelineCacheValidationError::Unsupported
32 | PipelineCacheValidationError::Extended
33 | PipelineCacheValidationError::Outdated
35 | PipelineCacheValidationError::Corrupted => false,
36 }
37 }
38}
39
40pub fn validate_pipeline_cache<'d>(
42 cache_data: &'d [u8],
43 adapter: &AdapterInfo,
44 validation_key: [u8; 16],
45) -> Result<&'d [u8], PipelineCacheValidationError> {
46 let adapter_key = adapter_key(adapter)?;
47 let Some((header, remaining_data)) = PipelineCacheHeader::read(cache_data) else {
48 return Err(PipelineCacheValidationError::Truncated);
49 };
50 if header.magic != MAGIC {
51 return Err(PipelineCacheValidationError::Corrupted);
52 }
53 if header.header_version != HEADER_VERSION {
54 return Err(PipelineCacheValidationError::Outdated);
55 }
56 if header.cache_abi != ABI {
57 return Err(PipelineCacheValidationError::Outdated);
58 }
59 if header.backend != adapter.backend as u8 {
60 return Err(PipelineCacheValidationError::DeviceMismatch);
61 }
62 if header.adapter_key != adapter_key {
63 return Err(PipelineCacheValidationError::DeviceMismatch);
64 }
65 if header.validation_key != validation_key {
66 return Err(PipelineCacheValidationError::Outdated);
70 }
71 let data_size: usize = header
72 .data_size
73 .try_into()
74 .map_err(|_| PipelineCacheValidationError::Corrupted)?;
77 if remaining_data.len() < data_size {
78 return Err(PipelineCacheValidationError::Truncated);
79 }
80 if remaining_data.len() > data_size {
81 return Err(PipelineCacheValidationError::Extended);
82 }
83 if header.hash_space != HASH_SPACE_VALUE {
84 return Err(PipelineCacheValidationError::Corrupted);
85 }
86 Ok(remaining_data)
87}
88
89pub fn add_cache_header(
90 in_region: &mut [u8],
91 data: &[u8],
92 adapter: &AdapterInfo,
93 validation_key: [u8; 16],
94) {
95 assert_eq!(in_region.len(), HEADER_LENGTH);
96 let header = PipelineCacheHeader {
97 adapter_key: adapter_key(adapter)
98 .expect("Called add_cache_header for an adapter which doesn't support cache data. This is a wgpu internal bug"),
99 backend: adapter.backend as u8,
100 cache_abi: ABI,
101 magic: MAGIC,
102 header_version: HEADER_VERSION,
103 validation_key,
104 hash_space: HASH_SPACE_VALUE,
105 data_size: data
106 .len()
107 .try_into()
108 .expect("Cache larger than u64::MAX bytes"),
109 };
110 header.write(in_region);
111}
112
113const MAGIC: [u8; 8] = *b"WGPUPLCH";
114const HEADER_VERSION: u32 = 1;
115const ABI: u32 = size_of::<*const ()>() as u32;
116
117const HASH_SPACE_VALUE: u64 = 0xFEDCBA9_876543210;
127
128#[repr(C)]
129#[derive(PartialEq, Eq)]
130struct PipelineCacheHeader {
131 magic: [u8; 8],
134 header_version: u32,
141 cache_abi: u32,
145 backend: u8,
147 adapter_key: [u8; 15],
152 validation_key: [u8; 16],
156 data_size: u64,
158 hash_space: u64,
165}
166
167impl PipelineCacheHeader {
168 fn read(data: &[u8]) -> Option<(PipelineCacheHeader, &[u8])> {
169 let mut reader = Reader {
170 data,
171 total_read: 0,
172 };
173 let magic = reader.read_array()?;
174 let header_version = reader.read_u32()?;
175 let cache_abi = reader.read_u32()?;
176 let backend = reader.read_byte()?;
177 let adapter_key = reader.read_array()?;
178 let validation_key = reader.read_array()?;
179 let data_size = reader.read_u64()?;
180 let data_hash = reader.read_u64()?;
181
182 assert_eq!(reader.total_read, size_of::<PipelineCacheHeader>());
183
184 Some((
185 PipelineCacheHeader {
186 magic,
187 header_version,
188 cache_abi,
189 backend,
190 adapter_key,
191 validation_key,
192 data_size,
193 hash_space: data_hash,
194 },
195 reader.data,
196 ))
197 }
198
199 fn write(&self, into: &mut [u8]) -> Option<()> {
200 let mut writer = Writer { data: into };
201 writer.write_array(&self.magic)?;
202 writer.write_u32(self.header_version)?;
203 writer.write_u32(self.cache_abi)?;
204 writer.write_byte(self.backend)?;
205 writer.write_array(&self.adapter_key)?;
206 writer.write_array(&self.validation_key)?;
207 writer.write_u64(self.data_size)?;
208 writer.write_u64(self.hash_space)?;
209
210 assert_eq!(writer.data.len(), 0);
211 Some(())
212 }
213}
214
215fn adapter_key(adapter: &AdapterInfo) -> Result<[u8; 15], PipelineCacheValidationError> {
216 match adapter.backend {
217 wgt::Backend::Vulkan => {
218 let v: [u8; 4] = adapter.vendor.to_be_bytes();
221 let d: [u8; 4] = adapter.device.to_be_bytes();
222 let adapter = [
223 255, 255, 255, v[0], v[1], v[2], v[3], d[0], d[1], d[2], d[3], 255, 255, 255, 255,
224 ];
225 Ok(adapter)
226 }
227 _ => Err(PipelineCacheValidationError::Unsupported),
228 }
229}
230
231struct Reader<'a> {
232 data: &'a [u8],
233 total_read: usize,
234}
235
236impl<'a> Reader<'a> {
237 fn read_byte(&mut self) -> Option<u8> {
238 let res = *self.data.first()?;
239 self.total_read += 1;
240 self.data = &self.data[1..];
241 Some(res)
242 }
243 fn read_array<const N: usize>(&mut self) -> Option<[u8; N]> {
244 if N > self.data.len() {
246 return None;
247 }
248 let (start, data) = self.data.split_at(N);
249 self.total_read += N;
250 self.data = data;
251 Some(start.try_into().expect("off-by-one-error in array size"))
252 }
253
254 fn read_u32(&mut self) -> Option<u32> {
258 self.read_array().map(u32::from_be_bytes)
259 }
260 fn read_u64(&mut self) -> Option<u64> {
261 self.read_array().map(u64::from_be_bytes)
262 }
263}
264
265struct Writer<'a> {
266 data: &'a mut [u8],
267}
268
269impl<'a> Writer<'a> {
270 fn write_byte(&mut self, byte: u8) -> Option<()> {
271 self.write_array(&[byte])
272 }
273 fn write_array<const N: usize>(&mut self, array: &[u8; N]) -> Option<()> {
274 if N > self.data.len() {
276 return None;
277 }
278 let data = core::mem::take(&mut self.data);
279 let (start, data) = data.split_at_mut(N);
280 self.data = data;
281 start.copy_from_slice(array);
282 Some(())
283 }
284
285 fn write_u32(&mut self, value: u32) -> Option<()> {
289 self.write_array(&value.to_be_bytes())
290 }
291 fn write_u64(&mut self, value: u64) -> Option<()> {
292 self.write_array(&value.to_be_bytes())
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use alloc::{string::String, vec::Vec};
299 use wgt::AdapterInfo;
300
301 use crate::pipeline_cache::{PipelineCacheValidationError as E, HEADER_LENGTH};
302
303 use super::ABI;
304
305 const _: [(); HEADER_LENGTH] = [(); 64];
307
308 const ADAPTER: AdapterInfo = AdapterInfo {
309 name: String::new(),
310 vendor: 0x0002_FEED,
311 device: 0xFEFE_FEFE,
312 device_type: wgt::DeviceType::Other,
313 driver: String::new(),
314 driver_info: String::new(),
315 backend: wgt::Backend::Vulkan,
316 };
317
318 const VALIDATION_KEY: [u8; 16] = u128::to_be_bytes(0xFFFFFFFF_FFFFFFFF_88888888_88888888);
320 #[test]
321 fn written_header() {
322 let mut result = [0; HEADER_LENGTH];
323 super::add_cache_header(&mut result, &[], &ADAPTER, VALIDATION_KEY);
324 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
325 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
334 let expected = cache.into_iter().flatten().collect::<Vec<u8>>();
335
336 assert_eq!(result.as_slice(), expected.as_slice());
337 }
338
339 #[test]
340 fn valid_data() {
341 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
342 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
351 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
352 let expected: &[u8] = &[];
353 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
354 assert_eq!(validation_result, Ok(expected));
355 }
356 #[test]
357 fn invalid_magic() {
358 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
359 *b"NOT_WGPU", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
368 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
369 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
370 assert_eq!(validation_result, Err(E::Corrupted));
371 }
372
373 #[test]
374 fn wrong_version() {
375 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
376 *b"WGPUPLCH", [0, 0, 0, 2, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
385 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
386 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
387 assert_eq!(validation_result, Err(E::Outdated));
388 }
389 #[test]
390 fn wrong_abi() {
391 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
392 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, 14], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
402 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
403 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
404 assert_eq!(validation_result, Err(E::Outdated));
405 }
406
407 #[test]
408 fn wrong_backend() {
409 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
410 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [2, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
419 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
420 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
421 assert_eq!(validation_result, Err(E::DeviceMismatch));
422 }
423 #[test]
424 fn wrong_adapter() {
425 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
426 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0x00], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
435 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
436 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
437 assert_eq!(validation_result, Err(E::DeviceMismatch));
438 }
439 #[test]
440 fn wrong_validation() {
441 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
442 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_00000000u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
451 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
452 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
453 assert_eq!(validation_result, Err(E::Outdated));
454 }
455 #[test]
456 fn too_little_data() {
457 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
458 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x064u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
467 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
468 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
469 assert_eq!(validation_result, Err(E::Truncated));
470 }
471 #[test]
472 fn not_no_data() {
473 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
474 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 100u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
483 let cache = cache
484 .into_iter()
485 .flatten()
486 .chain(core::iter::repeat_n(0u8, 100))
487 .collect::<Vec<u8>>();
488 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
489 let expected: &[u8] = &[0; 100];
490 assert_eq!(validation_result, Ok(expected));
491 }
492 #[test]
493 fn too_much_data() {
494 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
495 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x064u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
504 let cache = cache
505 .into_iter()
506 .flatten()
507 .chain(core::iter::repeat_n(0u8, 200))
508 .collect::<Vec<u8>>();
509 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
510 assert_eq!(validation_result, Err(E::Extended));
511 }
512 #[test]
513 fn wrong_hash() {
514 let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
515 *b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0x00000000_00000000u64.to_be_bytes(), ];
524 let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
525 let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
526 assert_eq!(validation_result, Err(E::Corrupted));
527 }
528}