Skip to main content

ml_dsa/
algebra.rs

1use ctutils::{CtEq, CtGt, CtLt, CtSelect};
2use hybrid_array::{
3    ArraySize,
4    typenum::{Shleft, U1, U13, Unsigned},
5};
6use module_lattice::{Field, Truncate};
7
8module_lattice::define_field!(BaseField, u32, u64, u128, 8_380_417);
9
10pub(crate) type Int = <BaseField as Field>::Int;
11
12pub(crate) type Elem = module_lattice::Elem<BaseField>;
13pub(crate) type Polynomial = module_lattice::Polynomial<BaseField>;
14pub(crate) type Vector<K> = module_lattice::Vector<BaseField, K>;
15pub(crate) type NttPolynomial = module_lattice::NttPolynomial<BaseField>;
16pub(crate) type NttVector<K> = module_lattice::NttVector<BaseField, K>;
17pub(crate) type NttMatrix<K, L> = module_lattice::NttMatrix<BaseField, K, L>;
18
19// We require modular reduction for three moduli: q, 2^d, and 2 * gamma2.  All three of these are
20// greater than sqrt(q), which means that a number reduced mod q will always be less than M^2,
21// which means that barrett reduction will work.
22pub(crate) trait BarrettReduce: Unsigned {
23    const SHIFT: usize;
24    const MULTIPLIER: u64;
25
26    fn reduce(x: u32) -> u32 {
27        let m = Self::U64;
28        let x: u64 = x.into();
29        let quotient = (x * Self::MULTIPLIER) >> Self::SHIFT;
30        let remainder = x - quotient * m;
31
32        let r_small: u32 = Truncate::truncate(remainder);
33        let r_large: u32 = Truncate::truncate(remainder.wrapping_sub(m));
34        u32::ct_select(&r_large, &r_small, remainder.ct_lt(&m))
35    }
36}
37
38impl<M> BarrettReduce for M
39where
40    M: Unsigned,
41{
42    #[allow(clippy::as_conversions)]
43    const SHIFT: usize = 2 * (M::U64.ilog2() + 1) as usize;
44    #[allow(clippy::integer_division_remainder_used, reason = "constant")]
45    const MULTIPLIER: u64 = (1 << Self::SHIFT) / M::U64;
46}
47
48pub(crate) trait Decompose {
49    fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem);
50}
51
52/// Constant-time division by a compile-time constant divisor.
53///
54/// This trait provides a constant-time alternative to the hardware division
55/// instruction, which has variable timing based on operand values.
56/// Uses Barrett reduction to compute `x / M` where M is a compile-time constant.
57pub(crate) trait ConstantTimeDiv: Unsigned {
58    /// Bit shift for Barrett reduction, chosen to provide sufficient precision
59    const CT_DIV_SHIFT: usize;
60    /// Precomputed multiplier: ceil(2^SHIFT / M)
61    const CT_DIV_MULTIPLIER: u64;
62
63    /// Perform constant-time division of x by `Self::U32`
64    /// Requires: x < Q (the field modulus, ~2^23)
65    #[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
66    #[inline(always)]
67    fn ct_div(x: u32) -> u32 {
68        // Barrett reduction: q = (x * MULTIPLIER) >> SHIFT
69        // This gives us floor(x / M) for x < 2^SHIFT / MULTIPLIER * M
70        let x64 = u64::from(x);
71        let quotient = (x64 * Self::CT_DIV_MULTIPLIER) >> Self::CT_DIV_SHIFT;
72        // Quotient is guaranteed to fit in u32 because:
73        // - x < Q (~2^23), so quotient = x / M < x < 2^23 < 2^32
74        #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
75        let result = quotient as u32;
76        result
77    }
78}
79
80impl<M> ConstantTimeDiv for M
81where
82    M: Unsigned,
83{
84    // Use a shift that provides enough precision for the ML-DSA field (Q ~ 2^23)
85    // We need SHIFT > log2(Q) + log2(M) to ensure accuracy
86    // With Q < 2^24 and M < 2^20, SHIFT = 48 is sufficient
87    const CT_DIV_SHIFT: usize = 48;
88
89    // Precompute the multiplier at compile time
90    // We add (M-1) before dividing to get ceiling division, ensuring we never underestimate
91    #[allow(clippy::integer_division_remainder_used, reason = "constant")]
92    const CT_DIV_MULTIPLIER: u64 = (1u64 << Self::CT_DIV_SHIFT).div_ceil(M::U64);
93}
94
95impl Decompose for Elem {
96    // Algorithm 36 Decompose
97    //
98    // This implementation uses constant-time division to avoid timing side-channels.
99    // The original algorithm used hardware division which has variable timing based
100    // on operand values, potentially leaking secret information during signing.
101    fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem) {
102        let r_plus = self.clone();
103        let r0 = r_plus.mod_plus_minus::<TwoGamma2>();
104
105        let diff = r_plus - r0;
106        let is_edge = diff.0.ct_eq(&(BaseField::Q - 1));
107
108        // Compute both branches unconditionally
109        let edge = (Elem::new(0), r0 - Elem::new(1));
110        let r1 = Elem::new(TwoGamma2::ct_div(diff.0));
111        let normal = (r1, r0);
112
113        let r1_out = Elem::new(u32::ct_select(&normal.0.0, &edge.0.0, is_edge));
114        let r0_out = Elem::new(u32::ct_select(&normal.1.0, &edge.1.0, is_edge));
115        (r1_out, r0_out)
116    }
117}
118
119#[allow(clippy::module_name_repetitions)] // I can't think of a better name
120pub(crate) trait AlgebraExt: Sized {
121    fn mod_plus_minus<M: Unsigned>(&self) -> Self;
122    fn infinity_norm(&self) -> Int;
123    fn power2round(&self) -> (Self, Self);
124    fn high_bits<TwoGamma2: Unsigned>(&self) -> Self;
125    fn low_bits<TwoGamma2: Unsigned>(&self) -> Self;
126}
127
128impl AlgebraExt for Elem {
129    fn mod_plus_minus<M: Unsigned>(&self) -> Self {
130        let raw_mod = Elem::new(M::reduce(self.0));
131        let in_lower_half = !raw_mod.0.ct_gt(&(M::U32 >> 1));
132        Elem::new(u32::ct_select(
133            &(raw_mod - Elem::new(M::U32)).0,
134            &raw_mod.0,
135            in_lower_half,
136        ))
137    }
138
139    // FIPS 204 defines the infinity norm differently for signed vs. unsigned integers:
140    //
141    // * For w in Z, |w|_\infinity = |w|, the absolute value of w
142    // * For w in Z_q, |W|_infinity = |w mod^\pm q|
143    //
144    // Note that these two definitions are equivalent if |w| < q/2.  This property holds for all of
145    // the signed integers used in this crate, so we can safely use the unsigned version.  However,
146    // since mod_plus_minus is also unsigned, we need to unwrap the "negative" values.
147    fn infinity_norm(&self) -> u32 {
148        let in_lower_half = !self.0.ct_gt(&(BaseField::Q >> 1));
149        u32::ct_select(&(BaseField::Q - self.0), &self.0, in_lower_half)
150    }
151
152    // Algorithm 35 Power2Round
153    //
154    // In the specification, this function maps to signed integers rather than modular integers.
155    // To avoid the need for a whole separate type for signed integer polynomials, we represent
156    // these values using integers mod Q.  This is safe because Q is much larger than 2^13, so
157    // there's no risk of overlap between positive numbers (x) and negative numbers (Q-x).
158    fn power2round(&self) -> (Self, Self) {
159        type D = U13;
160        type Pow2D = Shleft<U1, D>;
161
162        let r_plus = self.clone();
163        let r0 = r_plus.mod_plus_minus::<Pow2D>();
164        let r1 = Elem::new((r_plus - r0).0 >> D::USIZE);
165
166        (r1, r0)
167    }
168
169    // Algorithm 37 HighBits
170    fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
171        self.decompose::<TwoGamma2>().0
172    }
173
174    // Algorithm 38 LowBits
175    fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
176        self.decompose::<TwoGamma2>().1
177    }
178}
179
180impl AlgebraExt for Polynomial {
181    fn mod_plus_minus<M: Unsigned>(&self) -> Self {
182        Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
183    }
184
185    fn infinity_norm(&self) -> u32 {
186        self.0
187            .iter()
188            .map(AlgebraExt::infinity_norm)
189            .max()
190            .expect("should have a maximum")
191    }
192
193    fn power2round(&self) -> (Self, Self) {
194        let mut r1 = Self::default();
195        let mut r0 = Self::default();
196
197        for (i, x) in self.0.iter().enumerate() {
198            (r1.0[i], r0.0[i]) = x.power2round();
199        }
200
201        (r1, r0)
202    }
203
204    fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
205        Self(
206            self.0
207                .iter()
208                .map(AlgebraExt::high_bits::<TwoGamma2>)
209                .collect(),
210        )
211    }
212
213    fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
214        Self(
215            self.0
216                .iter()
217                .map(AlgebraExt::low_bits::<TwoGamma2>)
218                .collect(),
219        )
220    }
221}
222
223impl<K: ArraySize> AlgebraExt for Vector<K> {
224    fn mod_plus_minus<M: Unsigned>(&self) -> Self {
225        Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
226    }
227
228    fn infinity_norm(&self) -> u32 {
229        self.0
230            .iter()
231            .map(AlgebraExt::infinity_norm)
232            .max()
233            .expect("should have a maximum")
234    }
235
236    fn power2round(&self) -> (Self, Self) {
237        let mut r1 = Self::default();
238        let mut r0 = Self::default();
239
240        for (i, x) in self.0.iter().enumerate() {
241            (r1.0[i], r0.0[i]) = x.power2round();
242        }
243
244        (r1, r0)
245    }
246
247    fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
248        Self(
249            self.0
250                .iter()
251                .map(AlgebraExt::high_bits::<TwoGamma2>)
252                .collect(),
253        )
254    }
255
256    fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
257        Self(
258            self.0
259                .iter()
260                .map(AlgebraExt::low_bits::<TwoGamma2>)
261                .collect(),
262        )
263    }
264}
265
266#[cfg(test)]
267#[allow(clippy::integer_division_remainder_used, reason = "tests")]
268mod test {
269    use super::*;
270
271    use crate::{MlDsa65, ParameterSet};
272
273    type Mod = <MlDsa65 as ParameterSet>::TwoGamma2;
274    const MOD: u32 = Mod::U32;
275    const MOD_ELEM: Elem = Elem::new(MOD);
276
277    #[test]
278    fn mod_plus_minus() {
279        for x in 0..MOD {
280            // BaseField::Q {
281            let x = Elem::new(x);
282            let x0 = x.mod_plus_minus::<Mod>();
283
284            // Outputs from mod+- should be in the half-open interval (-gamma2, gamma2]
285            let positive_bound = x0.0 <= MOD / 2;
286            let negative_bound = x0.0 > BaseField::Q - MOD / 2;
287            assert!(positive_bound || negative_bound);
288
289            // The output should be equivalent to the input, mod 2 * gamma2.  We add 2 * gamma2
290            // before comparing so that both values are "positive", avoiding interactions between
291            // the mod-Q and mod-M operations.
292            let xn = x + MOD_ELEM;
293            let x0n = x0 + MOD_ELEM;
294            assert_eq!(xn.0 % MOD, x0n.0 % MOD);
295        }
296    }
297
298    #[test]
299    fn decompose() {
300        for x in 0..MOD {
301            let x = Elem::new(x);
302            let (x1, x0) = x.decompose::<Mod>();
303
304            // The low-order output from decompose() is a mod+- output, optionally minus one.  So
305            // they should be in the closed interval [-gamma2, gamma2].
306            let positive_bound = x0.0 <= MOD / 2;
307            let negative_bound = x0.0 >= BaseField::Q - MOD / 2;
308            assert!(positive_bound || negative_bound);
309
310            // The low-order and high-order outputs should combine to form the input.
311            let xx = (MOD * x1.0 + x0.0) % BaseField::Q;
312            assert_eq!(xx, x.0);
313        }
314    }
315
316    #[test]
317    fn barrett_reduce_boundary() {
318        let m_minus_1 = Mod::U32 - 1;
319        assert_eq!(Mod::reduce(m_minus_1), m_minus_1);
320        assert_eq!(Mod::reduce(Mod::U32), 0);
321        assert_eq!(Mod::reduce(Mod::U32 + 1), 1);
322        assert_eq!(Mod::reduce(2 * Mod::U32 - 1), m_minus_1);
323        assert_eq!(Mod::reduce(2 * Mod::U32), 0);
324    }
325
326    #[test]
327    fn constant_time_div_accuracy() {
328        for x in 0..1000 {
329            assert_eq!(Mod::ct_div(x), x / Mod::U32);
330        }
331        for x in (BaseField::Q - 1000)..BaseField::Q {
332            assert_eq!(Mod::ct_div(x), x / Mod::U32);
333        }
334    }
335
336    #[test]
337    fn decompose_edge_case() {
338        let q_minus_1 = Elem::new(BaseField::Q - 1);
339        let (r1, r0) = q_minus_1.decompose::<Mod>();
340        let reconstructed = (MOD * r1.0 + r0.0) % BaseField::Q;
341        assert_eq!(reconstructed, q_minus_1.0);
342    }
343
344    #[test]
345    fn high_low_bits_consistency() {
346        for x in [0, 1, MOD / 2, MOD - 1, MOD, MOD + 1, BaseField::Q - 1] {
347            let elem = Elem::new(x);
348            let (decomp_high, decomp_low) = elem.decompose::<Mod>();
349            assert_eq!(elem.high_bits::<Mod>(), decomp_high);
350            assert_eq!(elem.low_bits::<Mod>(), decomp_low);
351        }
352    }
353}