1use crate::algebra::{BaseField, Elem, NttPolynomial, NttVector, Polynomial, Vector};
2use module_lattice::{ArraySize, Field, MultiplyNtt};
3
4#[allow(clippy::cast_possible_truncation)]
15#[allow(clippy::as_conversions)]
16#[allow(clippy::integer_division_remainder_used, reason = "constant")]
17const ZETA_POW_BITREV: [Elem; 256] = {
18 const ZETA: u64 = 1753;
19 const fn bitrev8(x: usize) -> usize {
20 (x as u8).reverse_bits() as usize
21 }
22
23 let mut pow = [Elem::new(0); 256];
25 let mut i = 0;
26 let mut curr = 1u64;
27 while i < 256 {
28 pow[i] = Elem::new(curr as u32);
29 i += 1;
30 curr = (curr * ZETA) % BaseField::QL;
31 }
32
33 let mut pow_bitrev = [Elem::new(0); 256];
37 let mut i = 1;
38 while i < 256 {
39 pow_bitrev[i] = pow[bitrev8(i)];
40 i += 1;
41 }
42 pow_bitrev
43};
44
45pub(crate) trait Ntt {
46 type Output;
47 fn ntt(&self) -> Self::Output;
48}
49
50#[allow(clippy::inline_always)] #[inline(always)]
56fn ntt_layer<const LEN: usize, const ITERATIONS: usize>(w: &mut [Elem; 256], m: &mut usize) {
57 for i in 0..ITERATIONS {
58 let start = i * 2 * LEN;
59 *m += 1;
60 let z = ZETA_POW_BITREV[*m];
61 for j in start..(start + LEN) {
62 let t = z * w[j + LEN];
63 w[j + LEN] = w[j] - t;
64 w[j] = w[j] + t;
65 }
66 }
67}
68
69impl Ntt for Polynomial {
70 type Output = NttPolynomial;
71
72 fn ntt(&self) -> Self::Output {
77 let mut w: [Elem; 256] = self.0.clone().into();
78 let mut m = 0;
79
80 ntt_layer::<128, 1>(&mut w, &mut m);
81 ntt_layer::<64, 2>(&mut w, &mut m);
82 ntt_layer::<32, 4>(&mut w, &mut m);
83 ntt_layer::<16, 8>(&mut w, &mut m);
84 ntt_layer::<8, 16>(&mut w, &mut m);
85 ntt_layer::<4, 32>(&mut w, &mut m);
86 ntt_layer::<2, 64>(&mut w, &mut m);
87 ntt_layer::<1, 128>(&mut w, &mut m);
88
89 NttPolynomial::new(w.into())
90 }
91}
92
93impl<K: ArraySize> Ntt for Vector<K> {
94 type Output = NttVector<K>;
95
96 fn ntt(&self) -> Self::Output {
97 NttVector::new(self.0.iter().map(Polynomial::ntt).collect())
98 }
99}
100
101#[allow(clippy::module_name_repetitions)]
102pub(crate) trait NttInverse {
103 type Output;
104 fn ntt_inverse(&self) -> Self::Output;
105}
106
107#[allow(clippy::inline_always)] #[inline(always)]
113fn ntt_inverse_layer<const LEN: usize, const ITERATIONS: usize>(
114 w: &mut [Elem; 256],
115 m: &mut usize,
116) {
117 for i in 0..ITERATIONS {
118 let start = i * 2 * LEN;
119 *m -= 1;
120 let z = -ZETA_POW_BITREV[*m];
121 for j in start..(start + LEN) {
122 let t = w[j];
123 w[j] = t + w[j + LEN];
124 w[j + LEN] = z * (t - w[j + LEN]);
125 }
126 }
127}
128
129impl NttInverse for NttPolynomial {
130 type Output = Polynomial;
131
132 fn ntt_inverse(&self) -> Self::Output {
137 const INVERSE_256: Elem = Elem::new(8_347_681);
138
139 let mut w: [Elem; 256] = self.0.clone().into();
140 let mut m = 256;
141
142 ntt_inverse_layer::<1, 128>(&mut w, &mut m);
143 ntt_inverse_layer::<2, 64>(&mut w, &mut m);
144 ntt_inverse_layer::<4, 32>(&mut w, &mut m);
145 ntt_inverse_layer::<8, 16>(&mut w, &mut m);
146 ntt_inverse_layer::<16, 8>(&mut w, &mut m);
147 ntt_inverse_layer::<32, 4>(&mut w, &mut m);
148 ntt_inverse_layer::<64, 2>(&mut w, &mut m);
149 ntt_inverse_layer::<128, 1>(&mut w, &mut m);
150
151 INVERSE_256 * &Polynomial::new(w.into())
152 }
153}
154
155impl<K: ArraySize> NttInverse for NttVector<K> {
156 type Output = Vector<K>;
157
158 fn ntt_inverse(&self) -> Self::Output {
159 Vector::new(self.0.iter().map(NttPolynomial::ntt_inverse).collect())
160 }
161}
162
163impl MultiplyNtt for BaseField {
164 fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial {
166 NttPolynomial::new(
167 lhs.0
168 .iter()
169 .zip(rhs.0.iter())
170 .map(|(&x, &y)| x * y)
171 .collect(),
172 )
173 }
174}
175
176#[cfg(test)]
177#[allow(clippy::as_conversions)]
178#[allow(clippy::cast_possible_truncation)]
179mod test {
180 use super::*;
181 use hybrid_array::{
182 Array,
183 typenum::{U2, U3},
184 };
185
186 use crate::algebra::*;
187
188 fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial {
190 let mut out = Polynomial::default();
191 for (i, x) in lhs.0.iter().enumerate() {
192 for (j, y) in rhs.0.iter().enumerate() {
193 let (sign, index) = if i + j < 256 {
194 (Elem::new(1), i + j)
195 } else {
196 (Elem::new(BaseField::Q - 1), i + j - 256)
197 };
198
199 out.0[index] = out.0[index] + (sign * *x * *y);
200 }
201 }
202 out
203 }
204
205 fn const_ntt(x: Int) -> NttPolynomial {
207 let mut p = Polynomial::default();
208 p.0[0] = Elem::new(x);
209 p.ntt()
210 }
211
212 #[test]
213 fn ntt() {
214 let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
215 let g = Polynomial::new(Array::from_fn(|i| Elem::new((2 * i) as Int)));
216 let f_hat = f.ntt();
217 let g_hat = g.ntt();
218
219 let f_unhat = f_hat.ntt_inverse();
221 assert_eq!(f, f_unhat);
222
223 let fg = &f + &g;
225 let f_hat_g_hat = &f_hat + &g_hat;
226 let fg_unhat = f_hat_g_hat.ntt_inverse();
227 assert_eq!(fg, fg_unhat);
228
229 let fg = poly_mul(&f, &g);
231 let f_hat_g_hat = &f_hat * &g_hat;
232 let fg_unhat = f_hat_g_hat.ntt_inverse();
233 assert_eq!(fg, fg_unhat);
234 }
235
236 #[test]
237 fn ntt_vector() {
238 let v1: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
240 let v2: NttVector<U3> = NttVector::new(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
241 let v3: NttVector<U3> = NttVector::new(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
242 assert_eq!((&v1 + &v2), v3);
243
244 assert_eq!((&v1 * &v2), const_ntt(6));
246 assert_eq!((&v1 * &v3), const_ntt(9));
247 assert_eq!((&v2 * &v3), const_ntt(18));
248 }
249
250 #[test]
251 fn ntt_matrix() {
252 let a: NttMatrix<U3, U2> = NttMatrix::new(Array([
254 NttVector::new(Array([const_ntt(1), const_ntt(2)])),
255 NttVector::new(Array([const_ntt(3), const_ntt(4)])),
256 NttVector::new(Array([const_ntt(5), const_ntt(6)])),
257 ]));
258 let v_in: NttVector<U2> = NttVector::new(Array([const_ntt(1), const_ntt(2)]));
259 let v_out: NttVector<U3> =
260 NttVector::new(Array([const_ntt(5), const_ntt(11), const_ntt(17)]));
261 assert_eq!(&a * &v_in, v_out);
262 }
263}