Skip to main content

ml_dsa/
ntt.rs

1use crate::algebra::{BaseField, Elem, NttPolynomial, NttVector, Polynomial, Vector};
2use module_lattice::{ArraySize, Field, MultiplyNtt};
3
4// Since the powers of zeta used in the NTT and MultiplyNTTs are fixed, we use pre-computed tables
5// to avoid the need to compute the exponentiations at runtime.
6//
7//   ZETA_POW_BITREV[i] = zeta^{BitRev_8(i)}
8//
9// Note that the const environment here imposes some annoying conditions.  Because operator
10// overloading can't be const, we have to do all the reductions here manually.  Because `for` loops
11// are forbidden in `const` functions, we do them manually with `while` loops.
12//
13// The values computed here match those provided in Appendix B of FIPS 204.
14#[allow(clippy::cast_possible_truncation)]
15#[allow(clippy::as_conversions)]
16#[allow(clippy::integer_division_remainder_used, reason = "constant")]
17const ZETA_POW_BITREV: [Elem; 256] = {
18    const ZETA: u64 = 1753;
19    const fn bitrev8(x: usize) -> usize {
20        (x as u8).reverse_bits() as usize
21    }
22
23    // Compute the powers of zeta
24    let mut pow = [Elem::new(0); 256];
25    let mut i = 0;
26    let mut curr = 1u64;
27    while i < 256 {
28        pow[i] = Elem::new(curr as u32);
29        i += 1;
30        curr = (curr * ZETA) % BaseField::QL;
31    }
32
33    // Reorder the powers according to bitrev8
34    // Note that entry 0 is left as zero, in order to match the `zetas` array in the
35    // specification.
36    let mut pow_bitrev = [Elem::new(0); 256];
37    let mut i = 1;
38    while i < 256 {
39        pow_bitrev[i] = pow[bitrev8(i)];
40        i += 1;
41    }
42    pow_bitrev
43};
44
45pub(crate) trait Ntt {
46    type Output;
47    fn ntt(&self) -> Self::Output;
48}
49
50/// Constant-time NTT butterfly layer.
51///
52/// Uses const generics to ensure loop bounds are compile-time constants,
53/// avoiding UDIV instructions from runtime `step_by` calculations.
54#[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
55#[inline(always)]
56fn ntt_layer<const LEN: usize, const ITERATIONS: usize>(w: &mut [Elem; 256], m: &mut usize) {
57    for i in 0..ITERATIONS {
58        let start = i * 2 * LEN;
59        *m += 1;
60        let z = ZETA_POW_BITREV[*m];
61        for j in start..(start + LEN) {
62            let t = z * w[j + LEN];
63            w[j + LEN] = w[j] - t;
64            w[j] = w[j] + t;
65        }
66    }
67}
68
69impl Ntt for Polynomial {
70    type Output = NttPolynomial;
71
72    // Algorithm 41 NTT
73    //
74    // This implementation uses const-generic helper functions to ensure all loop
75    // bounds are compile-time constants, avoiding potential UDIV instructions.
76    fn ntt(&self) -> Self::Output {
77        let mut w: [Elem; 256] = self.0.clone().into();
78        let mut m = 0;
79
80        ntt_layer::<128, 1>(&mut w, &mut m);
81        ntt_layer::<64, 2>(&mut w, &mut m);
82        ntt_layer::<32, 4>(&mut w, &mut m);
83        ntt_layer::<16, 8>(&mut w, &mut m);
84        ntt_layer::<8, 16>(&mut w, &mut m);
85        ntt_layer::<4, 32>(&mut w, &mut m);
86        ntt_layer::<2, 64>(&mut w, &mut m);
87        ntt_layer::<1, 128>(&mut w, &mut m);
88
89        NttPolynomial::new(w.into())
90    }
91}
92
93impl<K: ArraySize> Ntt for Vector<K> {
94    type Output = NttVector<K>;
95
96    fn ntt(&self) -> Self::Output {
97        NttVector::new(self.0.iter().map(Polynomial::ntt).collect())
98    }
99}
100
101#[allow(clippy::module_name_repetitions)]
102pub(crate) trait NttInverse {
103    type Output;
104    fn ntt_inverse(&self) -> Self::Output;
105}
106
107/// Constant-time inverse NTT butterfly layer.
108///
109/// Uses const generics to ensure loop bounds are compile-time constants,
110/// avoiding UDIV instructions from runtime `step_by` calculations.
111#[allow(clippy::inline_always)] // Required for constant-time guarantees in crypto code
112#[inline(always)]
113fn ntt_inverse_layer<const LEN: usize, const ITERATIONS: usize>(
114    w: &mut [Elem; 256],
115    m: &mut usize,
116) {
117    for i in 0..ITERATIONS {
118        let start = i * 2 * LEN;
119        *m -= 1;
120        let z = -ZETA_POW_BITREV[*m];
121        for j in start..(start + LEN) {
122            let t = w[j];
123            w[j] = t + w[j + LEN];
124            w[j + LEN] = z * (t - w[j + LEN]);
125        }
126    }
127}
128
129impl NttInverse for NttPolynomial {
130    type Output = Polynomial;
131
132    // Algorithm 42 NTT^{−1}
133    //
134    // This implementation uses const-generic helper functions to ensure all loop
135    // bounds are compile-time constants, avoiding potential UDIV instructions.
136    fn ntt_inverse(&self) -> Self::Output {
137        const INVERSE_256: Elem = Elem::new(8_347_681);
138
139        let mut w: [Elem; 256] = self.0.clone().into();
140        let mut m = 256;
141
142        ntt_inverse_layer::<1, 128>(&mut w, &mut m);
143        ntt_inverse_layer::<2, 64>(&mut w, &mut m);
144        ntt_inverse_layer::<4, 32>(&mut w, &mut m);
145        ntt_inverse_layer::<8, 16>(&mut w, &mut m);
146        ntt_inverse_layer::<16, 8>(&mut w, &mut m);
147        ntt_inverse_layer::<32, 4>(&mut w, &mut m);
148        ntt_inverse_layer::<64, 2>(&mut w, &mut m);
149        ntt_inverse_layer::<128, 1>(&mut w, &mut m);
150
151        INVERSE_256 * &Polynomial::new(w.into())
152    }
153}
154
155impl<K: ArraySize> NttInverse for NttVector<K> {
156    type Output = Vector<K>;
157
158    fn ntt_inverse(&self) -> Self::Output {
159        Vector::new(self.0.iter().map(NttPolynomial::ntt_inverse).collect())
160    }
161}
162
163impl MultiplyNtt for BaseField {
164    // Algorithm 45 MultiplyNTT
165    fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial {
166        NttPolynomial::new(
167            lhs.0
168                .iter()
169                .zip(rhs.0.iter())
170                .map(|(&x, &y)| x * y)
171                .collect(),
172        )
173    }
174}
175
176#[cfg(test)]
177#[allow(clippy::as_conversions)]
178#[allow(clippy::cast_possible_truncation)]
179mod test {
180    use super::*;
181    use hybrid_array::{
182        Array,
183        typenum::{U2, U3},
184    };
185
186    use crate::algebra::*;
187
188    // Multiplication in R_q, modulo X^256 + 1
189    fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial {
190        let mut out = Polynomial::default();
191        for (i, x) in lhs.0.iter().enumerate() {
192            for (j, y) in rhs.0.iter().enumerate() {
193                let (sign, index) = if i + j < 256 {
194                    (Elem::new(1), i + j)
195                } else {
196                    (Elem::new(BaseField::Q - 1), i + j - 256)
197                };
198
199                out.0[index] = out.0[index] + (sign * *x * *y);
200            }
201        }
202        out
203    }
204
205    // A polynomial with only a scalar component, to make simple test cases
206    fn const_ntt(x: Int) -> NttPolynomial {
207        let mut p = Polynomial::default();
208        p.0[0] = Elem::new(x);
209        p.ntt()
210    }
211
212    #[test]
213    fn ntt() {
214        let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
215        let g = Polynomial::new(Array::from_fn(|i| Elem::new((2 * i) as Int)));
216        let f_hat = f.ntt();
217        let g_hat = g.ntt();
218
219        // Verify that NTT and NTT^-1 are actually inverses
220        let f_unhat = f_hat.ntt_inverse();
221        assert_eq!(f, f_unhat);
222
223        // Verify that NTT is a homomorphism with regard to addition
224        let fg = &f + &g;
225        let f_hat_g_hat = &f_hat + &g_hat;
226        let fg_unhat = f_hat_g_hat.ntt_inverse();
227        assert_eq!(fg, fg_unhat);
228
229        // Verify that NTT is a homomorphism with regard to multiplication
230        let fg = poly_mul(&f, &g);
231        let f_hat_g_hat = &f_hat * &g_hat;
232        let fg_unhat = f_hat_g_hat.ntt_inverse();
233        assert_eq!(fg, fg_unhat);
234    }
235
236    #[test]
237    fn ntt_vector() {
238        // Verify vector addition
239        let v1: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
240        let v2: NttVector<U3> = NttVector::new(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
241        let v3: NttVector<U3> = NttVector::new(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
242        assert_eq!((&v1 + &v2), v3);
243
244        // Verify dot product
245        assert_eq!((&v1 * &v2), const_ntt(6));
246        assert_eq!((&v1 * &v3), const_ntt(9));
247        assert_eq!((&v2 * &v3), const_ntt(18));
248    }
249
250    #[test]
251    fn ntt_matrix() {
252        // Verify matrix multiplication by a vector
253        let a: NttMatrix<U3, U2> = NttMatrix::new(Array([
254            NttVector::new(Array([const_ntt(1), const_ntt(2)])),
255            NttVector::new(Array([const_ntt(3), const_ntt(4)])),
256            NttVector::new(Array([const_ntt(5), const_ntt(6)])),
257        ]));
258        let v_in: NttVector<U2> = NttVector::new(Array([const_ntt(1), const_ntt(2)]));
259        let v_out: NttVector<U3> =
260            NttVector::new(Array([const_ntt(5), const_ntt(11), const_ntt(17)]));
261        assert_eq!(&a * &v_in, v_out);
262    }
263}