naga/proc/
layouter.rs

1use core::{fmt::Display, num::NonZeroU32, ops};
2
3use crate::arena::{Handle, HandleVec};
4
5/// A newtype struct where its only valid values are powers of 2
6#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
7#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
8#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
9pub struct Alignment(NonZeroU32);
10
11impl Alignment {
12    pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
13    pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
14    pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
15    pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
16    pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
17
18    pub const MIN_UNIFORM: Self = Self::SIXTEEN;
19
20    pub const fn new(n: u32) -> Option<Self> {
21        if n.is_power_of_two() {
22            // SAFETY: value can't be 0 since we just checked if it's a power of 2
23            Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
24        } else {
25            None
26        }
27    }
28
29    /// # Panics
30    /// If `width` is not a power of 2
31    pub fn from_width(width: u8) -> Self {
32        Self::new(width as u32).unwrap()
33    }
34
35    /// Returns whether or not `n` is a multiple of this alignment.
36    pub const fn is_aligned(&self, n: u32) -> bool {
37        // equivalent to: `n % self.0.get() == 0` but much faster
38        n & (self.0.get() - 1) == 0
39    }
40
41    /// Round `n` up to the nearest alignment boundary.
42    pub const fn round_up(&self, n: u32) -> u32 {
43        // equivalent to:
44        // match n % self.0.get() {
45        //     0 => n,
46        //     rem => n + (self.0.get() - rem),
47        // }
48        let mask = self.0.get() - 1;
49        (n + mask) & !mask
50    }
51}
52
53impl Display for Alignment {
54    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
55        self.0.get().fmt(f)
56    }
57}
58
59impl ops::Mul<u32> for Alignment {
60    type Output = u32;
61
62    fn mul(self, rhs: u32) -> Self::Output {
63        self.0.get() * rhs
64    }
65}
66
67impl ops::Mul for Alignment {
68    type Output = Alignment;
69
70    fn mul(self, rhs: Alignment) -> Self::Output {
71        // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2
72        Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
73    }
74}
75
76impl From<crate::VectorSize> for Alignment {
77    fn from(size: crate::VectorSize) -> Self {
78        match size {
79            crate::VectorSize::Bi => Alignment::TWO,
80            crate::VectorSize::Tri => Alignment::FOUR,
81            crate::VectorSize::Quad => Alignment::FOUR,
82        }
83    }
84}
85
86/// Size and alignment information for a type.
87#[derive(Clone, Copy, Debug, Hash, PartialEq)]
88#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
89#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
90pub struct TypeLayout {
91    pub size: u32,
92    pub alignment: Alignment,
93}
94
95impl TypeLayout {
96    /// Produce the stride as if this type is a base of an array.
97    pub const fn to_stride(&self) -> u32 {
98        self.alignment.round_up(self.size)
99    }
100}
101
102/// Helper processor that derives the sizes of all types.
103///
104/// `Layouter` uses the default layout algorithm/table, described in
105/// [WGSL §4.3.7, "Memory Layout"]
106///
107/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the
108/// layout of the type whose handle is `handle`.
109///
110/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts)
111#[derive(Debug, Default)]
112pub struct Layouter {
113    /// Layouts for types in an arena.
114    layouts: HandleVec<crate::Type, TypeLayout>,
115}
116
117impl ops::Index<Handle<crate::Type>> for Layouter {
118    type Output = TypeLayout;
119    fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
120        &self.layouts[handle]
121    }
122}
123
124#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
125pub enum LayoutErrorInner {
126    #[error("Array element type {0:?} doesn't exist")]
127    InvalidArrayElementType(Handle<crate::Type>),
128    #[error("Struct member[{0}] type {1:?} doesn't exist")]
129    InvalidStructMemberType(u32, Handle<crate::Type>),
130    #[error("Type width must be a power of two")]
131    NonPowerOfTwoWidth,
132}
133
134#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
135#[error("Error laying out type {ty:?}: {inner}")]
136pub struct LayoutError {
137    pub ty: Handle<crate::Type>,
138    pub inner: LayoutErrorInner,
139}
140
141impl LayoutErrorInner {
142    const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
143        LayoutError { ty, inner: self }
144    }
145}
146
147impl Layouter {
148    /// Remove all entries from this `Layouter`, retaining storage.
149    pub fn clear(&mut self) {
150        self.layouts.clear();
151    }
152
153    /// Extend this `Layouter` with layouts for any new entries in `gctx.types`.
154    ///
155    /// Ensure that every type in `gctx.types` has a corresponding [TypeLayout]
156    /// in [`self.layouts`].
157    ///
158    /// Some front ends need to be able to compute layouts for existing types
159    /// while module construction is still in progress and new types are still
160    /// being added. This function assumes that the `TypeLayout` values already
161    /// present in `self.layouts` cover their corresponding entries in `types`,
162    /// and extends `self.layouts` as needed to cover the rest. Thus, a front
163    /// end can call this function at any time, passing its current type and
164    /// constant arenas, and then assume that layouts are available for all
165    /// types.
166    #[allow(clippy::or_fun_call)]
167    pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
168        use crate::TypeInner as Ti;
169
170        for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
171            let size = ty.inner.size(gctx);
172            let layout = match ty.inner {
173                Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
174                    let alignment = Alignment::new(scalar.width as u32)
175                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
176                    TypeLayout { size, alignment }
177                }
178                Ti::Vector {
179                    size: vec_size,
180                    scalar,
181                } => {
182                    let alignment = Alignment::new(scalar.width as u32)
183                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
184                    TypeLayout {
185                        size,
186                        alignment: Alignment::from(vec_size) * alignment,
187                    }
188                }
189                Ti::Matrix {
190                    columns: _,
191                    rows,
192                    scalar,
193                } => {
194                    let alignment = Alignment::new(scalar.width as u32)
195                        .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
196                    TypeLayout {
197                        size,
198                        alignment: Alignment::from(rows) * alignment,
199                    }
200                }
201                Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
202                    size,
203                    alignment: Alignment::ONE,
204                },
205                Ti::Array {
206                    base,
207                    stride: _,
208                    size: _,
209                } => TypeLayout {
210                    size,
211                    alignment: if base < ty_handle {
212                        self[base].alignment
213                    } else {
214                        return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
215                    },
216                },
217                Ti::Struct { span, ref members } => {
218                    let mut alignment = Alignment::ONE;
219                    for (index, member) in members.iter().enumerate() {
220                        alignment = if member.ty < ty_handle {
221                            alignment.max(self[member.ty].alignment)
222                        } else {
223                            return Err(LayoutErrorInner::InvalidStructMemberType(
224                                index as u32,
225                                member.ty,
226                            )
227                            .with(ty_handle));
228                        };
229                    }
230                    TypeLayout {
231                        size: span,
232                        alignment,
233                    }
234                }
235                Ti::Image { .. }
236                | Ti::Sampler { .. }
237                | Ti::AccelerationStructure { .. }
238                | Ti::RayQuery { .. }
239                | Ti::BindingArray { .. } => TypeLayout {
240                    size,
241                    alignment: Alignment::ONE,
242                },
243            };
244            debug_assert!(size <= layout.size);
245            self.layouts.insert(ty_handle, layout);
246        }
247
248        Ok(())
249    }
250}