1use core::{fmt::Display, num::NonZeroU32, ops};
2
3use crate::arena::{Handle, HandleVec};
4
5#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
7#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
8#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
9pub struct Alignment(NonZeroU32);
10
11impl Alignment {
12 pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
13 pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
14 pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
15 pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
16 pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
17
18 pub const MIN_UNIFORM: Self = Self::SIXTEEN;
19
20 pub const fn new(n: u32) -> Option<Self> {
21 if n.is_power_of_two() {
22 Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
24 } else {
25 None
26 }
27 }
28
29 pub fn from_width(width: u8) -> Self {
32 Self::new(width as u32).unwrap()
33 }
34
35 pub const fn is_aligned(&self, n: u32) -> bool {
37 n & (self.0.get() - 1) == 0
39 }
40
41 pub const fn round_up(&self, n: u32) -> u32 {
43 let mask = self.0.get() - 1;
49 (n + mask) & !mask
50 }
51}
52
53impl Display for Alignment {
54 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
55 self.0.get().fmt(f)
56 }
57}
58
59impl ops::Mul<u32> for Alignment {
60 type Output = u32;
61
62 fn mul(self, rhs: u32) -> Self::Output {
63 self.0.get() * rhs
64 }
65}
66
67impl ops::Mul for Alignment {
68 type Output = Alignment;
69
70 fn mul(self, rhs: Alignment) -> Self::Output {
71 Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
73 }
74}
75
76impl From<crate::VectorSize> for Alignment {
77 fn from(size: crate::VectorSize) -> Self {
78 match size {
79 crate::VectorSize::Bi => Alignment::TWO,
80 crate::VectorSize::Tri => Alignment::FOUR,
81 crate::VectorSize::Quad => Alignment::FOUR,
82 }
83 }
84}
85
86#[derive(Clone, Copy, Debug, Hash, PartialEq)]
88#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
89#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
90pub struct TypeLayout {
91 pub size: u32,
92 pub alignment: Alignment,
93}
94
95impl TypeLayout {
96 pub const fn to_stride(&self) -> u32 {
98 self.alignment.round_up(self.size)
99 }
100}
101
102#[derive(Debug, Default)]
112pub struct Layouter {
113 layouts: HandleVec<crate::Type, TypeLayout>,
115}
116
117impl ops::Index<Handle<crate::Type>> for Layouter {
118 type Output = TypeLayout;
119 fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
120 &self.layouts[handle]
121 }
122}
123
124#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
125pub enum LayoutErrorInner {
126 #[error("Array element type {0:?} doesn't exist")]
127 InvalidArrayElementType(Handle<crate::Type>),
128 #[error("Struct member[{0}] type {1:?} doesn't exist")]
129 InvalidStructMemberType(u32, Handle<crate::Type>),
130 #[error("Type width must be a power of two")]
131 NonPowerOfTwoWidth,
132}
133
134#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
135#[error("Error laying out type {ty:?}: {inner}")]
136pub struct LayoutError {
137 pub ty: Handle<crate::Type>,
138 pub inner: LayoutErrorInner,
139}
140
141impl LayoutErrorInner {
142 const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
143 LayoutError { ty, inner: self }
144 }
145}
146
147impl Layouter {
148 pub fn clear(&mut self) {
150 self.layouts.clear();
151 }
152
153 #[allow(clippy::or_fun_call)]
167 pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
168 use crate::TypeInner as Ti;
169
170 for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
171 let size = ty.inner.size(gctx);
172 let layout = match ty.inner {
173 Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
174 let alignment = Alignment::new(scalar.width as u32)
175 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
176 TypeLayout { size, alignment }
177 }
178 Ti::Vector {
179 size: vec_size,
180 scalar,
181 } => {
182 let alignment = Alignment::new(scalar.width as u32)
183 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
184 TypeLayout {
185 size,
186 alignment: Alignment::from(vec_size) * alignment,
187 }
188 }
189 Ti::Matrix {
190 columns: _,
191 rows,
192 scalar,
193 } => {
194 let alignment = Alignment::new(scalar.width as u32)
195 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
196 TypeLayout {
197 size,
198 alignment: Alignment::from(rows) * alignment,
199 }
200 }
201 Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
202 size,
203 alignment: Alignment::ONE,
204 },
205 Ti::Array {
206 base,
207 stride: _,
208 size: _,
209 } => TypeLayout {
210 size,
211 alignment: if base < ty_handle {
212 self[base].alignment
213 } else {
214 return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
215 },
216 },
217 Ti::Struct { span, ref members } => {
218 let mut alignment = Alignment::ONE;
219 for (index, member) in members.iter().enumerate() {
220 alignment = if member.ty < ty_handle {
221 alignment.max(self[member.ty].alignment)
222 } else {
223 return Err(LayoutErrorInner::InvalidStructMemberType(
224 index as u32,
225 member.ty,
226 )
227 .with(ty_handle));
228 };
229 }
230 TypeLayout {
231 size: span,
232 alignment,
233 }
234 }
235 Ti::Image { .. }
236 | Ti::Sampler { .. }
237 | Ti::AccelerationStructure { .. }
238 | Ti::RayQuery { .. }
239 | Ti::BindingArray { .. } => TypeLayout {
240 size,
241 alignment: Alignment::ONE,
242 },
243 };
244 debug_assert!(size <= layout.size);
245 self.layouts.insert(ty_handle, layout);
246 }
247
248 Ok(())
249 }
250}