ml_dsa/module_lattice/
util.rs

1use core::mem::ManuallyDrop;
2use core::ops::{Div, Mul, Rem};
3use core::ptr;
4use hybrid_array::{
5    Array, ArraySize,
6    typenum::{Prod, Quot, U0, Unsigned},
7};
8
9/// Safely truncate an unsigned integer value to shorter representation
10pub trait Truncate<T> {
11    fn truncate(x: T) -> Self;
12}
13
14macro_rules! define_truncate {
15    ($from:ident, $to:ident) => {
16        impl Truncate<$from> for $to {
17            fn truncate(x: $from) -> $to {
18                // This line is marked unsafe because the `unwrap_unchecked` call is UB when its
19                // `self` argument is `Err`.  It never will be, because we explicitly zeroize the
20                // high-order bits before converting.  We could have used `unwrap()`, but chose to
21                // avoid the possibility of panic.
22                unsafe { (x & $from::from($to::MAX)).try_into().unwrap_unchecked() }
23            }
24        }
25    };
26}
27
28define_truncate!(u128, u32);
29define_truncate!(u64, u32);
30define_truncate!(usize, u8);
31define_truncate!(usize, u16);
32
33/// Defines a sequence of sequences that can be merged into a bigger overall seequence
34pub trait Flatten<T, M: ArraySize> {
35    type OutputSize: ArraySize;
36
37    fn flatten(self) -> Array<T, Self::OutputSize>;
38}
39
40impl<T, N, M> Flatten<T, Prod<M, N>> for Array<Array<T, M>, N>
41where
42    N: ArraySize,
43    M: ArraySize + Mul<N>,
44    Prod<M, N>: ArraySize,
45{
46    type OutputSize = Prod<M, N>;
47
48    // This is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed to be
49    // safe by the Rust memory layout of these types.
50    fn flatten(self) -> Array<T, Self::OutputSize> {
51        let whole = ManuallyDrop::new(self);
52        unsafe { ptr::read(whole.as_ptr().cast()) }
53    }
54}
55
56/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size
57pub trait Unflatten<M>
58where
59    M: ArraySize,
60{
61    type Part;
62
63    fn unflatten(self) -> Array<Self::Part, M>;
64}
65
66impl<T, N, M> Unflatten<M> for Array<T, N>
67where
68    T: Default,
69    N: ArraySize + Div<M> + Rem<M, Output = U0>,
70    M: ArraySize,
71    Quot<N, M>: ArraySize,
72{
73    type Part = Array<T, Quot<N, M>>;
74
75    // This requires some unsafeness, but it is the same as what is done in Array::split.
76    // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
77    // be safe by the Rust memory layout of these types.
78    fn unflatten(self) -> Array<Self::Part, M> {
79        let part_size = Quot::<N, M>::USIZE;
80        let whole = ManuallyDrop::new(self);
81        Array::from_fn(|i| unsafe { ptr::read(whole.as_ptr().add(i * part_size).cast()) })
82    }
83}
84
85impl<'a, T, N, M> Unflatten<M> for &'a Array<T, N>
86where
87    T: Default,
88    N: ArraySize + Div<M> + Rem<M, Output = U0>,
89    M: ArraySize,
90    Quot<N, M>: ArraySize,
91{
92    type Part = &'a Array<T, Quot<N, M>>;
93
94    // This requires some unsafeness, but it is the same as what is done in Array::split.
95    // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
96    // be safe by the Rust memory layout of these types.
97    fn unflatten(self) -> Array<Self::Part, M> {
98        let part_size = Quot::<N, M>::USIZE;
99        let mut ptr: *const T = self.as_ptr();
100        Array::from_fn(|_i| unsafe {
101            let part = &*(ptr.cast());
102            ptr = ptr.add(part_size);
103            part
104        })
105    }
106}
107
108#[cfg(test)]
109mod test {
110    use super::*;
111    use hybrid_array::{
112        Array,
113        typenum::{U2, U5},
114    };
115
116    #[test]
117    fn flatten() {
118        let flat: Array<u8, _> = Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
119        let unflat2: Array<Array<u8, _>, _> = Array([
120            Array([1, 2]),
121            Array([3, 4]),
122            Array([5, 6]),
123            Array([7, 8]),
124            Array([9, 10]),
125        ]);
126        let unflat5: Array<Array<u8, _>, _> =
127            Array([Array([1, 2, 3, 4, 5]), Array([6, 7, 8, 9, 10])]);
128
129        // Flatten
130        let actual = unflat2.flatten();
131        assert_eq!(flat, actual);
132
133        let actual = unflat5.flatten();
134        assert_eq!(flat, actual);
135
136        // Unflatten
137        let actual: Array<Array<u8, U2>, U5> = flat.unflatten();
138        assert_eq!(unflat2, actual);
139
140        let actual: Array<Array<u8, U5>, U2> = flat.unflatten();
141        assert_eq!(unflat5, actual);
142
143        // Unflatten on references
144        let actual: Array<&Array<u8, U2>, U5> = (&flat).unflatten();
145        for (i, part) in actual.iter().enumerate() {
146            assert_eq!(&unflat2[i], *part);
147        }
148
149        let actual: Array<&Array<u8, U5>, U2> = (&flat).unflatten();
150        for (i, part) in actual.iter().enumerate() {
151            assert_eq!(&unflat5[i], *part);
152        }
153    }
154}