1use crate::module_lattice::algebra::Field;
2use crate::module_lattice::encode::ArraySize;
3use core::ops::Mul;
4
5use crate::algebra::{BaseField, Elem, NttPolynomial, NttVector, Polynomial, Vector};
6
7#[allow(clippy::cast_possible_truncation)]
18#[allow(clippy::as_conversions)]
19#[allow(clippy::integer_division_remainder_used)]
20const ZETA_POW_BITREV: [Elem; 256] = {
21 const ZETA: u64 = 1753;
22 const fn bitrev8(x: usize) -> usize {
23 (x as u8).reverse_bits() as usize
24 }
25
26 let mut pow = [Elem::new(0); 256];
28 let mut i = 0;
29 let mut curr = 1u64;
30 while i < 256 {
31 pow[i] = Elem::new(curr as u32);
32 i += 1;
33 curr = (curr * ZETA) % BaseField::QL;
34 }
35
36 let mut pow_bitrev = [Elem::new(0); 256];
40 let mut i = 1;
41 while i < 256 {
42 pow_bitrev[i] = pow[bitrev8(i)];
43 i += 1;
44 }
45 pow_bitrev
46};
47
48pub trait Ntt {
49 type Output;
50 fn ntt(&self) -> Self::Output;
51}
52
53impl Ntt for Polynomial {
54 type Output = NttPolynomial;
55
56 fn ntt(&self) -> Self::Output {
58 let mut w = self.0.clone();
59
60 let mut m = 0;
61 for len in [128, 64, 32, 16, 8, 4, 2, 1] {
62 for start in (0..256).step_by(2 * len) {
63 m += 1;
64 let z = ZETA_POW_BITREV[m];
65
66 for j in start..(start + len) {
67 let t = z * w[j + len];
68 w[j + len] = w[j] - t;
69 w[j] = w[j] + t;
70 }
71 }
72 }
73
74 NttPolynomial::new(w)
75 }
76}
77
78impl<K: ArraySize> Ntt for Vector<K> {
79 type Output = NttVector<K>;
80
81 fn ntt(&self) -> Self::Output {
82 NttVector::new(self.0.iter().map(Polynomial::ntt).collect())
83 }
84}
85
86#[allow(clippy::module_name_repetitions)]
87pub trait NttInverse {
88 type Output;
89 fn ntt_inverse(&self) -> Self::Output;
90}
91
92impl NttInverse for NttPolynomial {
93 type Output = Polynomial;
94
95 fn ntt_inverse(&self) -> Self::Output {
97 const INVERSE_256: Elem = Elem::new(8_347_681);
98
99 let mut w = self.0.clone();
100
101 let mut m = 256;
102 for len in [1, 2, 4, 8, 16, 32, 64, 128] {
103 for start in (0..256).step_by(2 * len) {
104 m -= 1;
105 let z = -ZETA_POW_BITREV[m];
106
107 for j in start..(start + len) {
108 let t = w[j];
109 w[j] = t + w[j + len];
110 w[j + len] = z * (t - w[j + len]);
111 }
112 }
113 }
114
115 INVERSE_256 * &Polynomial::new(w)
116 }
117}
118
119impl<K: ArraySize> NttInverse for NttVector<K> {
120 type Output = Vector<K>;
121
122 fn ntt_inverse(&self) -> Self::Output {
123 Vector::new(self.0.iter().map(NttPolynomial::ntt_inverse).collect())
124 }
125}
126
127impl Mul<&NttPolynomial> for &NttPolynomial {
128 type Output = NttPolynomial;
129
130 fn mul(self, rhs: &NttPolynomial) -> NttPolynomial {
132 NttPolynomial::new(
133 self.0
134 .iter()
135 .zip(rhs.0.iter())
136 .map(|(&x, &y)| x * y)
137 .collect(),
138 )
139 }
140}
141
142#[cfg(test)]
143#[allow(clippy::as_conversions)]
144#[allow(clippy::cast_possible_truncation)]
145mod test {
146 use super::*;
147 use hybrid_array::{
148 Array,
149 typenum::{U2, U3},
150 };
151
152 use crate::algebra::*;
153
154 impl Mul<&Polynomial> for &Polynomial {
156 type Output = Polynomial;
157
158 fn mul(self, rhs: &Polynomial) -> Self::Output {
159 let mut out = Self::Output::default();
160 for (i, x) in self.0.iter().enumerate() {
161 for (j, y) in rhs.0.iter().enumerate() {
162 let (sign, index) = if i + j < 256 {
163 (Elem::new(1), i + j)
164 } else {
165 (Elem::new(BaseField::Q - 1), i + j - 256)
166 };
167
168 out.0[index] = out.0[index] + (sign * *x * *y);
169 }
170 }
171 out
172 }
173 }
174
175 fn const_ntt(x: Int) -> NttPolynomial {
177 let mut p = Polynomial::default();
178 p.0[0] = Elem::new(x);
179 p.ntt()
180 }
181
182 #[test]
183 fn ntt() {
184 let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
185 let g = Polynomial::new(Array::from_fn(|i| Elem::new((2 * i) as Int)));
186 let f_hat = f.ntt();
187 let g_hat = g.ntt();
188
189 let f_unhat = f_hat.ntt_inverse();
191 assert_eq!(f, f_unhat);
192
193 let fg = &f + &g;
195 let f_hat_g_hat = &f_hat + &g_hat;
196 let fg_unhat = f_hat_g_hat.ntt_inverse();
197 assert_eq!(fg, fg_unhat);
198
199 let fg = &f * &g;
201 let f_hat_g_hat = &f_hat * &g_hat;
202 let fg_unhat = f_hat_g_hat.ntt_inverse();
203 assert_eq!(fg, fg_unhat);
204 }
205
206 #[test]
207 fn ntt_vector() {
208 let v1: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
210 let v2: NttVector<U3> = NttVector::new(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
211 let v3: NttVector<U3> = NttVector::new(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
212 assert_eq!((&v1 + &v2), v3);
213
214 assert_eq!((&v1 * &v2), const_ntt(6));
216 assert_eq!((&v1 * &v3), const_ntt(9));
217 assert_eq!((&v2 * &v3), const_ntt(18));
218 }
219
220 #[test]
221 fn ntt_matrix() {
222 let a: NttMatrix<U3, U2> = NttMatrix::new(Array([
224 NttVector::new(Array([const_ntt(1), const_ntt(2)])),
225 NttVector::new(Array([const_ntt(3), const_ntt(4)])),
226 NttVector::new(Array([const_ntt(5), const_ntt(6)])),
227 ]));
228 let v_in: NttVector<U2> = NttVector::new(Array([const_ntt(1), const_ntt(2)]));
229 let v_out: NttVector<U3> =
230 NttVector::new(Array([const_ntt(5), const_ntt(11), const_ntt(17)]));
231 assert_eq!(&a * &v_in, v_out);
232 }
233}