ml_dsa/
ntt.rs

1use crate::module_lattice::algebra::Field;
2use crate::module_lattice::encode::ArraySize;
3use core::ops::Mul;
4
5use crate::algebra::{BaseField, Elem, NttPolynomial, NttVector, Polynomial, Vector};
6
7// Since the powers of zeta used in the NTT and MultiplyNTTs are fixed, we use pre-computed tables
8// to avoid the need to compute the exponetiations at runtime.
9//
10//   ZETA_POW_BITREV[i] = zeta^{BitRev_8(i)}
11//
12// Note that the const environment here imposes some annoying conditions.  Because operator
13// overloading can't be const, we have to do all the reductions here manually.  Because `for` loops
14// are forbidden in `const` functions, we do them manually with `while` loops.
15//
16// The values computed here match those provided in Appendix B of FIPS 204.
17#[allow(clippy::cast_possible_truncation)]
18#[allow(clippy::as_conversions)]
19#[allow(clippy::integer_division_remainder_used)]
20const ZETA_POW_BITREV: [Elem; 256] = {
21    const ZETA: u64 = 1753;
22    const fn bitrev8(x: usize) -> usize {
23        (x as u8).reverse_bits() as usize
24    }
25
26    // Compute the powers of zeta
27    let mut pow = [Elem::new(0); 256];
28    let mut i = 0;
29    let mut curr = 1u64;
30    while i < 256 {
31        pow[i] = Elem::new(curr as u32);
32        i += 1;
33        curr = (curr * ZETA) % BaseField::QL;
34    }
35
36    // Reorder the powers according to bitrev8
37    // Note that entry 0 is left as zero, in order to match the `zetas` array in the
38    // specification.
39    let mut pow_bitrev = [Elem::new(0); 256];
40    let mut i = 1;
41    while i < 256 {
42        pow_bitrev[i] = pow[bitrev8(i)];
43        i += 1;
44    }
45    pow_bitrev
46};
47
48pub trait Ntt {
49    type Output;
50    fn ntt(&self) -> Self::Output;
51}
52
53impl Ntt for Polynomial {
54    type Output = NttPolynomial;
55
56    // Algorithm 41 NTT
57    fn ntt(&self) -> Self::Output {
58        let mut w = self.0.clone();
59
60        let mut m = 0;
61        for len in [128, 64, 32, 16, 8, 4, 2, 1] {
62            for start in (0..256).step_by(2 * len) {
63                m += 1;
64                let z = ZETA_POW_BITREV[m];
65
66                for j in start..(start + len) {
67                    let t = z * w[j + len];
68                    w[j + len] = w[j] - t;
69                    w[j] = w[j] + t;
70                }
71            }
72        }
73
74        NttPolynomial::new(w)
75    }
76}
77
78impl<K: ArraySize> Ntt for Vector<K> {
79    type Output = NttVector<K>;
80
81    fn ntt(&self) -> Self::Output {
82        NttVector::new(self.0.iter().map(Polynomial::ntt).collect())
83    }
84}
85
86#[allow(clippy::module_name_repetitions)]
87pub trait NttInverse {
88    type Output;
89    fn ntt_inverse(&self) -> Self::Output;
90}
91
92impl NttInverse for NttPolynomial {
93    type Output = Polynomial;
94
95    // Algorithm 42 NTT^{−1}
96    fn ntt_inverse(&self) -> Self::Output {
97        const INVERSE_256: Elem = Elem::new(8_347_681);
98
99        let mut w = self.0.clone();
100
101        let mut m = 256;
102        for len in [1, 2, 4, 8, 16, 32, 64, 128] {
103            for start in (0..256).step_by(2 * len) {
104                m -= 1;
105                let z = -ZETA_POW_BITREV[m];
106
107                for j in start..(start + len) {
108                    let t = w[j];
109                    w[j] = t + w[j + len];
110                    w[j + len] = z * (t - w[j + len]);
111                }
112            }
113        }
114
115        INVERSE_256 * &Polynomial::new(w)
116    }
117}
118
119impl<K: ArraySize> NttInverse for NttVector<K> {
120    type Output = Vector<K>;
121
122    fn ntt_inverse(&self) -> Self::Output {
123        Vector::new(self.0.iter().map(NttPolynomial::ntt_inverse).collect())
124    }
125}
126
127impl Mul<&NttPolynomial> for &NttPolynomial {
128    type Output = NttPolynomial;
129
130    // Algorithm 45 MultiplyNTT
131    fn mul(self, rhs: &NttPolynomial) -> NttPolynomial {
132        NttPolynomial::new(
133            self.0
134                .iter()
135                .zip(rhs.0.iter())
136                .map(|(&x, &y)| x * y)
137                .collect(),
138        )
139    }
140}
141
142#[cfg(test)]
143#[allow(clippy::as_conversions)]
144#[allow(clippy::cast_possible_truncation)]
145mod test {
146    use super::*;
147    use hybrid_array::{
148        Array,
149        typenum::{U2, U3},
150    };
151
152    use crate::algebra::*;
153
154    // Multiplication in R_q, modulo X^256 + 1
155    impl Mul<&Polynomial> for &Polynomial {
156        type Output = Polynomial;
157
158        fn mul(self, rhs: &Polynomial) -> Self::Output {
159            let mut out = Self::Output::default();
160            for (i, x) in self.0.iter().enumerate() {
161                for (j, y) in rhs.0.iter().enumerate() {
162                    let (sign, index) = if i + j < 256 {
163                        (Elem::new(1), i + j)
164                    } else {
165                        (Elem::new(BaseField::Q - 1), i + j - 256)
166                    };
167
168                    out.0[index] = out.0[index] + (sign * *x * *y);
169                }
170            }
171            out
172        }
173    }
174
175    // A polynomial with only a scalar component, to make simple test cases
176    fn const_ntt(x: Int) -> NttPolynomial {
177        let mut p = Polynomial::default();
178        p.0[0] = Elem::new(x);
179        p.ntt()
180    }
181
182    #[test]
183    fn ntt() {
184        let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
185        let g = Polynomial::new(Array::from_fn(|i| Elem::new((2 * i) as Int)));
186        let f_hat = f.ntt();
187        let g_hat = g.ntt();
188
189        // Verify that NTT and NTT^-1 are actually inverses
190        let f_unhat = f_hat.ntt_inverse();
191        assert_eq!(f, f_unhat);
192
193        // Verify that NTT is a homomorphism with regard to addition
194        let fg = &f + &g;
195        let f_hat_g_hat = &f_hat + &g_hat;
196        let fg_unhat = f_hat_g_hat.ntt_inverse();
197        assert_eq!(fg, fg_unhat);
198
199        // Verify that NTT is a homomorphism with regard to multiplication
200        let fg = &f * &g;
201        let f_hat_g_hat = &f_hat * &g_hat;
202        let fg_unhat = f_hat_g_hat.ntt_inverse();
203        assert_eq!(fg, fg_unhat);
204    }
205
206    #[test]
207    fn ntt_vector() {
208        // Verify vector addition
209        let v1: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
210        let v2: NttVector<U3> = NttVector::new(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
211        let v3: NttVector<U3> = NttVector::new(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
212        assert_eq!((&v1 + &v2), v3);
213
214        // Verify dot product
215        assert_eq!((&v1 * &v2), const_ntt(6));
216        assert_eq!((&v1 * &v3), const_ntt(9));
217        assert_eq!((&v2 * &v3), const_ntt(18));
218    }
219
220    #[test]
221    fn ntt_matrix() {
222        // Verify matrix multiplication by a vector
223        let a: NttMatrix<U3, U2> = NttMatrix::new(Array([
224            NttVector::new(Array([const_ntt(1), const_ntt(2)])),
225            NttVector::new(Array([const_ntt(3), const_ntt(4)])),
226            NttVector::new(Array([const_ntt(5), const_ntt(6)])),
227        ]));
228        let v_in: NttVector<U2> = NttVector::new(Array([const_ntt(1), const_ntt(2)]));
229        let v_out: NttVector<U3> =
230            NttVector::new(Array([const_ntt(5), const_ntt(11), const_ntt(17)]));
231        assert_eq!(&a * &v_in, v_out);
232    }
233}