use std::mem::size_of;
use thiserror::Error;
use wgt::AdapterInfo;
pub const HEADER_LENGTH: usize = size_of::<PipelineCacheHeader>();
#[derive(Debug, PartialEq, Eq, Clone, Error)]
#[non_exhaustive]
pub enum PipelineCacheValidationError {
#[error("The pipeline cache data was truncated")]
Truncated,
#[error("The pipeline cache data was longer than recorded")]
Extended,
#[error("The pipeline cache data was corrupted (e.g. the hash didn't match)")]
Corrupted,
#[error("The pipeline cacha data was out of date and so cannot be safely used")]
Outdated,
#[error("The cache data was created for a different device")]
DeviceMismatch,
#[error("Pipeline cacha data was created for a future version of wgpu")]
Unsupported,
}
impl PipelineCacheValidationError {
pub fn was_avoidable(&self) -> bool {
match self {
PipelineCacheValidationError::DeviceMismatch => true,
PipelineCacheValidationError::Truncated
| PipelineCacheValidationError::Unsupported
| PipelineCacheValidationError::Extended
| PipelineCacheValidationError::Outdated
| PipelineCacheValidationError::Corrupted => false,
}
}
}
pub fn validate_pipeline_cache<'d>(
cache_data: &'d [u8],
adapter: &AdapterInfo,
validation_key: [u8; 16],
) -> Result<&'d [u8], PipelineCacheValidationError> {
let adapter_key = adapter_key(adapter)?;
let Some((header, remaining_data)) = PipelineCacheHeader::read(cache_data) else {
return Err(PipelineCacheValidationError::Truncated);
};
if header.magic != MAGIC {
return Err(PipelineCacheValidationError::Corrupted);
}
if header.header_version != HEADER_VERSION {
return Err(PipelineCacheValidationError::Outdated);
}
if header.cache_abi != ABI {
return Err(PipelineCacheValidationError::Outdated);
}
if header.backend != adapter.backend as u8 {
return Err(PipelineCacheValidationError::DeviceMismatch);
}
if header.adapter_key != adapter_key {
return Err(PipelineCacheValidationError::DeviceMismatch);
}
if header.validation_key != validation_key {
return Err(PipelineCacheValidationError::Outdated);
}
let data_size: usize = header
.data_size
.try_into()
.map_err(|_| PipelineCacheValidationError::Corrupted)?;
if remaining_data.len() < data_size {
return Err(PipelineCacheValidationError::Truncated);
}
if remaining_data.len() > data_size {
return Err(PipelineCacheValidationError::Extended);
}
if header.hash_space != HASH_SPACE_VALUE {
return Err(PipelineCacheValidationError::Corrupted);
}
Ok(remaining_data)
}
pub fn add_cache_header(
in_region: &mut [u8],
data: &[u8],
adapter: &AdapterInfo,
validation_key: [u8; 16],
) {
assert_eq!(in_region.len(), HEADER_LENGTH);
let header = PipelineCacheHeader {
adapter_key: adapter_key(adapter)
.expect("Called add_cache_header for an adapter which doesn't support cache data. This is a wgpu internal bug"),
backend: adapter.backend as u8,
cache_abi: ABI,
magic: MAGIC,
header_version: HEADER_VERSION,
validation_key,
hash_space: HASH_SPACE_VALUE,
data_size: data
.len()
.try_into()
.expect("Cache larger than u64::MAX bytes"),
};
header.write(in_region);
}
const MAGIC: [u8; 8] = *b"WGPUPLCH";
const HEADER_VERSION: u32 = 1;
const ABI: u32 = size_of::<*const ()>() as u32;
const HASH_SPACE_VALUE: u64 = 0xFEDCBA9_876543210;
#[repr(C)]
#[derive(PartialEq, Eq)]
struct PipelineCacheHeader {
magic: [u8; 8],
header_version: u32,
cache_abi: u32,
backend: u8,
adapter_key: [u8; 15],
validation_key: [u8; 16],
data_size: u64,
hash_space: u64,
}
impl PipelineCacheHeader {
fn read(data: &[u8]) -> Option<(PipelineCacheHeader, &[u8])> {
let mut reader = Reader {
data,
total_read: 0,
};
let magic = reader.read_array()?;
let header_version = reader.read_u32()?;
let cache_abi = reader.read_u32()?;
let backend = reader.read_byte()?;
let adapter_key = reader.read_array()?;
let validation_key = reader.read_array()?;
let data_size = reader.read_u64()?;
let data_hash = reader.read_u64()?;
assert_eq!(reader.total_read, size_of::<PipelineCacheHeader>());
Some((
PipelineCacheHeader {
magic,
header_version,
cache_abi,
backend,
adapter_key,
validation_key,
data_size,
hash_space: data_hash,
},
reader.data,
))
}
fn write(&self, into: &mut [u8]) -> Option<()> {
let mut writer = Writer { data: into };
writer.write_array(&self.magic)?;
writer.write_u32(self.header_version)?;
writer.write_u32(self.cache_abi)?;
writer.write_byte(self.backend)?;
writer.write_array(&self.adapter_key)?;
writer.write_array(&self.validation_key)?;
writer.write_u64(self.data_size)?;
writer.write_u64(self.hash_space)?;
assert_eq!(writer.data.len(), 0);
Some(())
}
}
fn adapter_key(adapter: &AdapterInfo) -> Result<[u8; 15], PipelineCacheValidationError> {
match adapter.backend {
wgt::Backend::Vulkan => {
let v: [u8; 4] = adapter.vendor.to_be_bytes();
let d: [u8; 4] = adapter.device.to_be_bytes();
let adapter = [
255, 255, 255, v[0], v[1], v[2], v[3], d[0], d[1], d[2], d[3], 255, 255, 255, 255,
];
Ok(adapter)
}
_ => Err(PipelineCacheValidationError::Unsupported),
}
}
struct Reader<'a> {
data: &'a [u8],
total_read: usize,
}
impl<'a> Reader<'a> {
fn read_byte(&mut self) -> Option<u8> {
let res = *self.data.first()?;
self.total_read += 1;
self.data = &self.data[1..];
Some(res)
}
fn read_array<const N: usize>(&mut self) -> Option<[u8; N]> {
if N > self.data.len() {
return None;
}
let (start, data) = self.data.split_at(N);
self.total_read += N;
self.data = data;
Some(start.try_into().expect("off-by-one-error in array size"))
}
fn read_u32(&mut self) -> Option<u32> {
self.read_array().map(u32::from_be_bytes)
}
fn read_u64(&mut self) -> Option<u64> {
self.read_array().map(u64::from_be_bytes)
}
}
struct Writer<'a> {
data: &'a mut [u8],
}
impl<'a> Writer<'a> {
fn write_byte(&mut self, byte: u8) -> Option<()> {
self.write_array(&[byte])
}
fn write_array<const N: usize>(&mut self, array: &[u8; N]) -> Option<()> {
if N > self.data.len() {
return None;
}
let data = std::mem::take(&mut self.data);
let (start, data) = data.split_at_mut(N);
self.data = data;
start.copy_from_slice(array);
Some(())
}
fn write_u32(&mut self, value: u32) -> Option<()> {
self.write_array(&value.to_be_bytes())
}
fn write_u64(&mut self, value: u64) -> Option<()> {
self.write_array(&value.to_be_bytes())
}
}
#[cfg(test)]
mod tests {
use wgt::AdapterInfo;
use crate::pipeline_cache::{PipelineCacheValidationError as E, HEADER_LENGTH};
use super::ABI;
const _: [(); HEADER_LENGTH] = [(); 64];
const ADAPTER: AdapterInfo = AdapterInfo {
name: String::new(),
vendor: 0x0002_FEED,
device: 0xFEFE_FEFE,
device_type: wgt::DeviceType::Other,
driver: String::new(),
driver_info: String::new(),
backend: wgt::Backend::Vulkan,
};
const VALIDATION_KEY: [u8; 16] = u128::to_be_bytes(0xFFFFFFFF_FFFFFFFF_88888888_88888888);
#[test]
fn written_header() {
let mut result = [0; HEADER_LENGTH];
super::add_cache_header(&mut result, &[], &ADAPTER, VALIDATION_KEY);
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let expected = cache.into_iter().flatten().collect::<Vec<u8>>();
assert_eq!(result.as_slice(), expected.as_slice());
}
#[test]
fn valid_data() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let expected: &[u8] = &[];
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Ok(expected));
}
#[test]
fn invalid_magic() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"NOT_WGPU", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Corrupted));
}
#[test]
fn wrong_version() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 2, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Outdated));
}
#[test]
fn wrong_abi() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, 14], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Outdated));
}
#[test]
fn wrong_backend() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [2, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::DeviceMismatch));
}
#[test]
fn wrong_adapter() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0x00], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::DeviceMismatch));
}
#[test]
fn wrong_validation() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_00000000u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Outdated));
}
#[test]
fn too_little_data() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x064u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Truncated));
}
#[test]
fn not_no_data() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 100u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache
.into_iter()
.flatten()
.chain(std::iter::repeat(0u8).take(100))
.collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
let expected: &[u8] = &[0; 100];
assert_eq!(validation_result, Ok(expected));
}
#[test]
fn too_much_data() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x064u64.to_be_bytes(), 0xFEDCBA9_876543210u64.to_be_bytes(), ];
let cache = cache
.into_iter()
.flatten()
.chain(std::iter::repeat(0u8).take(200))
.collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Extended));
}
#[test]
fn wrong_hash() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", [0, 0, 0, 1, 0, 0, 0, ABI as u8], [1, 255, 255, 255, 0, 2, 0xFE, 0xED], [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), 0x88888888_88888888u64.to_be_bytes(), 0x0u64.to_be_bytes(), 0x00000000_00000000u64.to_be_bytes(), ];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Corrupted));
}
}