ml_dsa/
algebra.rs

1pub use crate::module_lattice::algebra::Field;
2pub use crate::module_lattice::util::Truncate;
3use hybrid_array::{
4    ArraySize,
5    typenum::{Shleft, U1, U13, Unsigned},
6};
7
8use crate::define_field;
9use crate::module_lattice::algebra;
10
11define_field!(BaseField, u32, u64, u128, 8_380_417);
12
13pub type Int = <BaseField as Field>::Int;
14
15pub type Elem = algebra::Elem<BaseField>;
16pub type Polynomial = algebra::Polynomial<BaseField>;
17pub type Vector<K> = algebra::Vector<BaseField, K>;
18pub type NttPolynomial = algebra::NttPolynomial<BaseField>;
19pub type NttVector<K> = algebra::NttVector<BaseField, K>;
20pub type NttMatrix<K, L> = algebra::NttMatrix<BaseField, K, L>;
21
22// We require modular reduction for three moduli: q, 2^d, and 2 * gamma2.  All three of these are
23// greater than sqrt(q), which means that a number reduced mod q will always be less than M^2,
24// which means that barrett reduction will work.
25pub trait BarrettReduce: Unsigned {
26    const SHIFT: usize;
27    const MULTIPLIER: u64;
28
29    fn reduce(x: u32) -> u32 {
30        let m = Self::U64;
31        let x: u64 = x.into();
32        let quotient = (x * Self::MULTIPLIER) >> Self::SHIFT;
33        let remainder = x - quotient * m;
34
35        if remainder < m {
36            Truncate::truncate(remainder)
37        } else {
38            Truncate::truncate(remainder - m)
39        }
40    }
41}
42
43impl<M> BarrettReduce for M
44where
45    M: Unsigned,
46{
47    #[allow(clippy::as_conversions)]
48    const SHIFT: usize = 2 * (M::U64.ilog2() + 1) as usize;
49    #[allow(clippy::integer_division_remainder_used)]
50    const MULTIPLIER: u64 = (1 << Self::SHIFT) / M::U64;
51}
52
53pub trait Decompose {
54    fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem);
55}
56
57impl Decompose for Elem {
58    // Algorithm 36 Decompose
59    fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem) {
60        let r_plus = self.clone();
61        let r0 = r_plus.mod_plus_minus::<TwoGamma2>();
62
63        if r_plus - r0 == Elem::new(BaseField::Q - 1) {
64            (Elem::new(0), r0 - Elem::new(1))
65        } else {
66            let mut r1 = r_plus - r0;
67            r1.0 /= TwoGamma2::U32;
68            (r1, r0)
69        }
70    }
71}
72
73#[allow(clippy::module_name_repetitions)] // I can't think of a better name
74pub trait AlgebraExt: Sized {
75    fn mod_plus_minus<M: Unsigned>(&self) -> Self;
76    fn infinity_norm(&self) -> Int;
77    fn power2round(&self) -> (Self, Self);
78    fn high_bits<TwoGamma2: Unsigned>(&self) -> Self;
79    fn low_bits<TwoGamma2: Unsigned>(&self) -> Self;
80}
81
82impl AlgebraExt for Elem {
83    fn mod_plus_minus<M: Unsigned>(&self) -> Self {
84        let raw_mod = Elem::new(M::reduce(self.0));
85        if raw_mod.0 <= M::U32 >> 1 {
86            raw_mod
87        } else {
88            raw_mod - Elem::new(M::U32)
89        }
90    }
91
92    // FIPS 204 defines the infinity norm differently for signed vs. unsigned integers:
93    //
94    // * For w in Z, |w|_\infinity = |w|, the absolute value of w
95    // * For w in Z_q, |W|_infinity = |w mod^\pm q|
96    //
97    // Note that these two definitions are equivalent if |w| < q/2.  This property holds for all of
98    // the signed integers used in this crate, so we can safely use the unsigned version.  However,
99    // since mod_plus_minus is also unsigned, we need to unwrap the "negative" values.
100    fn infinity_norm(&self) -> u32 {
101        if self.0 <= BaseField::Q >> 1 {
102            self.0
103        } else {
104            BaseField::Q - self.0
105        }
106    }
107
108    // Algorithm 35 Power2Round
109    //
110    // In the specification, this function maps to signed integers rather than modular integers.
111    // To avoid the need for a whole separate type for signed integer polynomials, we represent
112    // these values using integers mod Q.  This is safe because Q is much larger than 2^13, so
113    // there's no risk of overlap between positive numbers (x) and negative numbers (Q-x).
114    fn power2round(&self) -> (Self, Self) {
115        type D = U13;
116        type Pow2D = Shleft<U1, D>;
117
118        let r_plus = self.clone();
119        let r0 = r_plus.mod_plus_minus::<Pow2D>();
120        let r1 = Elem::new((r_plus - r0).0 >> D::USIZE);
121
122        (r1, r0)
123    }
124
125    // Algorithm 37 HighBits
126    fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
127        self.decompose::<TwoGamma2>().0
128    }
129
130    // Algorithm 38 LowBits
131    fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
132        self.decompose::<TwoGamma2>().1
133    }
134}
135
136impl AlgebraExt for Polynomial {
137    fn mod_plus_minus<M: Unsigned>(&self) -> Self {
138        Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
139    }
140
141    fn infinity_norm(&self) -> u32 {
142        self.0.iter().map(AlgebraExt::infinity_norm).max().unwrap()
143    }
144
145    fn power2round(&self) -> (Self, Self) {
146        let mut r1 = Self::default();
147        let mut r0 = Self::default();
148
149        for (i, x) in self.0.iter().enumerate() {
150            (r1.0[i], r0.0[i]) = x.power2round();
151        }
152
153        (r1, r0)
154    }
155
156    fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
157        Self(
158            self.0
159                .iter()
160                .map(AlgebraExt::high_bits::<TwoGamma2>)
161                .collect(),
162        )
163    }
164
165    fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
166        Self(
167            self.0
168                .iter()
169                .map(AlgebraExt::low_bits::<TwoGamma2>)
170                .collect(),
171        )
172    }
173}
174
175impl<K: ArraySize> AlgebraExt for Vector<K> {
176    fn mod_plus_minus<M: Unsigned>(&self) -> Self {
177        Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
178    }
179
180    fn infinity_norm(&self) -> u32 {
181        self.0.iter().map(AlgebraExt::infinity_norm).max().unwrap()
182    }
183
184    fn power2round(&self) -> (Self, Self) {
185        let mut r1 = Self::default();
186        let mut r0 = Self::default();
187
188        for (i, x) in self.0.iter().enumerate() {
189            (r1.0[i], r0.0[i]) = x.power2round();
190        }
191
192        (r1, r0)
193    }
194
195    fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
196        Self(
197            self.0
198                .iter()
199                .map(AlgebraExt::high_bits::<TwoGamma2>)
200                .collect(),
201        )
202    }
203
204    fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
205        Self(
206            self.0
207                .iter()
208                .map(AlgebraExt::low_bits::<TwoGamma2>)
209                .collect(),
210        )
211    }
212}
213
214#[cfg(test)]
215mod test {
216    use super::*;
217
218    use crate::{MlDsa65, ParameterSet};
219
220    type Mod = <MlDsa65 as ParameterSet>::TwoGamma2;
221    const MOD: u32 = Mod::U32;
222    const MOD_ELEM: Elem = Elem::new(MOD);
223
224    #[test]
225    fn mod_plus_minus() {
226        for x in 0..MOD {
227            // BaseField::Q {
228            let x = Elem::new(x);
229            let x0 = x.mod_plus_minus::<Mod>();
230
231            // Outputs from mod+- should be in the half-open interval (-gamma2, gamma2]
232            let positive_bound = x0.0 <= MOD / 2;
233            let negative_bound = x0.0 > BaseField::Q - MOD / 2;
234            assert!(positive_bound || negative_bound);
235
236            // The output should be equivalent to the input, mod 2 * gamma2.  We add 2 * gamma2
237            // before comparing so that both values are "positive", avoiding interactions between
238            // the mod-Q and mod-M operations.
239            let xn = x + MOD_ELEM;
240            let x0n = x0 + MOD_ELEM;
241            assert_eq!(xn.0 % MOD, x0n.0 % MOD);
242        }
243    }
244
245    #[test]
246    fn decompose() {
247        for x in 0..MOD {
248            let x = Elem::new(x);
249            let (x1, x0) = x.decompose::<Mod>();
250
251            // The low-order output from decompose() is a mod+- output, optionally minus one.  So
252            // they should be in the closed interval [-gamma2, gamma2].
253            let positive_bound = x0.0 <= MOD / 2;
254            let negative_bound = x0.0 >= BaseField::Q - MOD / 2;
255            assert!(positive_bound || negative_bound);
256
257            // The low-order and high-order outputs should combine to form the input.
258            let xx = (MOD * x1.0 + x0.0) % BaseField::Q;
259            assert_eq!(xx, x.0);
260        }
261    }
262}