1use crate::{
2 B32,
3 crypto::{PRF, PrfOutput, XOF},
4 param::CbdSamplingSize,
5};
6use array::{Array, ArraySize, typenum::U256};
7use module_lattice::{Encode, Field, MultiplyNtt, Truncate};
8use sha3::digest::XofReader;
9
10module_lattice::define_field!(BaseField, u16, u32, u64, 3329);
11
12pub(crate) type Int = <BaseField as Field>::Int;
13
14pub(crate) type Elem = module_lattice::Elem<BaseField>;
16
17pub(crate) type Polynomial = module_lattice::Polynomial<BaseField>;
19
20pub(crate) type Vector<K> = module_lattice::Vector<BaseField, K>;
22
23pub(crate) type NttPolynomial = module_lattice::NttPolynomial<BaseField>;
25
26pub(crate) type NttVector<K> = module_lattice::NttVector<BaseField, K>;
28
29pub(crate) type NttMatrix<K> = module_lattice::NttMatrix<BaseField, K, K>;
32
33pub(crate) fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
35 struct FieldElementReader<'a> {
36 xof: &'a mut dyn XofReader,
37 data: [u8; 96],
38 start: usize,
39 next: Option<Int>,
40 }
41
42 impl<'a> FieldElementReader<'a> {
43 fn new(xof: &'a mut impl XofReader) -> Self {
44 let mut out = Self {
45 xof,
46 data: [0u8; 96],
47 start: 0,
48 next: None,
49 };
50
51 out.xof.read(&mut out.data);
53
54 out
55 }
56
57 fn next(&mut self) -> Elem {
58 if let Some(val) = self.next {
59 self.next = None;
60 return Elem::new(val);
61 }
62
63 loop {
64 if self.start == self.data.len() {
65 self.xof.read(&mut self.data);
66 self.start = 0;
67 }
68
69 let end = self.start + 3;
70 let b = &self.data[self.start..end];
71 self.start = end;
72
73 let d1 = Int::from(b[0]) + ((Int::from(b[1]) & 0xf) << 8);
74 let d2 = (Int::from(b[1]) >> 4) + (Int::from(b[2]) << 4);
75
76 if d1 < BaseField::Q {
77 if d2 < BaseField::Q {
78 self.next = Some(d2);
79 }
80 return Elem::new(d1);
81 }
82
83 if d2 < BaseField::Q {
84 return Elem::new(d2);
85 }
86 }
87 }
88 }
89
90 let mut reader = FieldElementReader::new(B);
91 NttPolynomial::new(Array::from_fn(|_| reader.next()))
92}
93
94pub(crate) fn matrix_sample_ntt<K: ArraySize>(rho: &B32, transpose: bool) -> NttMatrix<K> {
95 NttMatrix::new(Array::from_fn(|i| {
96 NttVector::new(Array::from_fn(|j| {
97 let (i, j) = if transpose { (j, i) } else { (i, j) };
98 let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i));
99 sample_ntt(&mut xof)
100 }))
101 }))
102}
103
104pub(crate) fn sample_poly_cbd<Eta>(B: &PrfOutput<Eta>) -> Polynomial
110where
111 Eta: CbdSamplingSize,
112{
113 let vals: Polynomial = Encode::<Eta::SampleSize>::decode(B);
114 Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect())
115}
116
117pub(crate) fn sample_poly_vec_cbd<Eta, K>(sigma: &B32, start_n: u8) -> Vector<K>
118where
119 Eta: CbdSamplingSize,
120 K: ArraySize,
121{
122 Vector::new(Array::from_fn(|i| {
123 let N = start_n + u8::truncate(i);
124 let prf_output = PRF::<Eta>(sigma, N);
125 sample_poly_cbd::<Eta>(&prf_output)
126 }))
127}
128
129pub(crate) trait Ntt {
133 type Output;
134 fn ntt(&self) -> Self::Output;
135}
136
137#[inline(always)]
144fn ntt_layer<const LEN: usize, const ITERATIONS: usize>(f: &mut Array<Elem, U256>, k: &mut usize) {
145 for i in 0..ITERATIONS {
146 let start = i * 2 * LEN;
147 let zeta = ZETA_POW_BITREV[*k];
148 *k += 1;
149
150 for j in start..(start + LEN) {
151 let t = zeta * f[j + LEN];
152 f[j + LEN] = f[j] - t;
153 f[j] = f[j] + t;
154 }
155 }
156}
157
158impl Ntt for Polynomial {
160 type Output = NttPolynomial;
161
162 fn ntt(&self) -> NttPolynomial {
163 let mut k = 1;
164 let mut f = self.0;
165
166 ntt_layer::<128, 1>(&mut f, &mut k);
167 ntt_layer::<64, 2>(&mut f, &mut k);
168 ntt_layer::<32, 4>(&mut f, &mut k);
169 ntt_layer::<16, 8>(&mut f, &mut k);
170 ntt_layer::<8, 16>(&mut f, &mut k);
171 ntt_layer::<4, 32>(&mut f, &mut k);
172 ntt_layer::<2, 64>(&mut f, &mut k);
173
174 f.into()
175 }
176}
177
178impl<K: ArraySize> Ntt for Vector<K> {
179 type Output = NttVector<K>;
180
181 fn ntt(&self) -> NttVector<K> {
182 NttVector::new(self.0.iter().map(Ntt::ntt).collect())
183 }
184}
185
186#[allow(clippy::module_name_repetitions)]
189pub(crate) trait NttInverse {
190 type Output;
191 fn ntt_inverse(&self) -> Self::Output;
192}
193
194#[inline(always)]
198fn ntt_inverse_layer<const LEN: usize, const ITERATIONS: usize>(
199 f: &mut Array<Elem, U256>,
200 k: &mut usize,
201) {
202 for i in 0..ITERATIONS {
203 let start = i * 2 * LEN;
204 let zeta = ZETA_POW_BITREV[*k];
205 *k -= 1;
206
207 for j in start..(start + LEN) {
208 let t = f[j];
209 f[j] = t + f[j + LEN];
210 f[j + LEN] = zeta * (f[j + LEN] - t);
211 }
212 }
213}
214
215impl NttInverse for NttPolynomial {
217 type Output = Polynomial;
218
219 fn ntt_inverse(&self) -> Polynomial {
220 let mut f: Array<Elem, U256> = self.0.clone();
221 let mut k = 127;
222
223 ntt_inverse_layer::<2, 64>(&mut f, &mut k);
224 ntt_inverse_layer::<4, 32>(&mut f, &mut k);
225 ntt_inverse_layer::<8, 16>(&mut f, &mut k);
226 ntt_inverse_layer::<16, 8>(&mut f, &mut k);
227 ntt_inverse_layer::<32, 4>(&mut f, &mut k);
228 ntt_inverse_layer::<64, 2>(&mut f, &mut k);
229 ntt_inverse_layer::<128, 1>(&mut f, &mut k);
230
231 Elem::new(3303) * &Polynomial::new(f)
232 }
233}
234
235impl<K: ArraySize> NttInverse for NttVector<K> {
236 type Output = Vector<K>;
237
238 fn ntt_inverse(&self) -> Vector<K> {
239 Vector::new(self.0.iter().map(NttInverse::ntt_inverse).collect())
240 }
241}
242
243impl MultiplyNtt for BaseField {
245 fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial {
246 let mut out = NttPolynomial::new(Array::default());
247
248 for i in 0..128 {
249 let (c0, c1) = base_case_multiply(
250 lhs.0[2 * i],
251 lhs.0[2 * i + 1],
252 rhs.0[2 * i],
253 rhs.0[2 * i + 1],
254 i,
255 );
256
257 out.0[2 * i] = c0;
258 out.0[2 * i + 1] = c1;
259 }
260
261 out
262 }
263}
264
265#[inline]
270fn base_case_multiply(a0: Elem, a1: Elem, b0: Elem, b1: Elem, i: usize) -> (Elem, Elem) {
271 let a0 = u32::from(a0.0);
272 let a1 = u32::from(a1.0);
273 let b0 = u32::from(b0.0);
274 let b1 = u32::from(b1.0);
275 let g = u32::from(GAMMA[i].0);
276
277 let b1g = u32::from(BaseField::barrett_reduce(b1 * g));
278
279 let c0 = BaseField::barrett_reduce(a0 * b0 + a1 * b1g);
280 let c1 = BaseField::barrett_reduce(a0 * b1 + a1 * b0);
281 (Elem::new(c0), Elem::new(c1))
282}
283
284#[allow(clippy::integer_division_remainder_used, reason = "constant")]
297const ZETA_POW_BITREV: [Elem; 128] = {
298 const ZETA: u64 = 17;
299
300 const fn bitrev7(x: usize) -> usize {
301 ((x >> 6) % 2)
302 | (((x >> 5) % 2) << 1)
303 | (((x >> 4) % 2) << 2)
304 | (((x >> 3) % 2) << 3)
305 | (((x >> 2) % 2) << 4)
306 | (((x >> 1) % 2) << 5)
307 | ((x % 2) << 6)
308 }
309
310 let mut pow = [Elem::new(0); 128];
312 let mut i = 0;
313 let mut curr = 1u64;
314
315 while i < 128 {
316 pow[i] = Elem::new((curr & 0xFFFF) as u16);
317 i += 1;
318 curr = (curr * ZETA) % BaseField::QLL;
319 }
320
321 let mut pow_bitrev = [Elem::new(0); 128];
323 let mut i = 0;
324 while i < 128 {
325 pow_bitrev[i] = pow[bitrev7(i)];
326 i += 1;
327 }
328 pow_bitrev
329};
330
331#[allow(clippy::integer_division_remainder_used, reason = "constant")]
332const GAMMA: [Elem; 128] = {
333 const ZETA: u64 = 17;
334 let mut gamma = [Elem::new(0); 128];
335 let mut i = 0;
336 while i < 128 {
337 let zpr = ZETA_POW_BITREV[i].0 as u64;
338 let g = (zpr * zpr * ZETA) % BaseField::QLL;
339 gamma[i] = Elem::new((g & 0xFFFF) as u16);
340 i += 1;
341 }
342 gamma
343};
344
345#[cfg(test)]
346mod test {
347 use super::{
348 Array, B32, BaseField, Elem, Field, Int, Ntt, NttInverse, NttMatrix, NttPolynomial,
349 NttVector, PRF, Polynomial, U256, XOF,
350 };
351 use array::{
352 ArraySize, Flatten,
353 typenum::{U2, U3, U8},
354 };
355
356 fn const_ntt(x: Int) -> NttPolynomial {
358 let mut p = Polynomial::default();
359 p.0[0] = Elem::new(x);
360 p.ntt()
361 }
362
363 fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial {
365 let mut out = Polynomial::default();
366 for (i, x) in lhs.0.iter().enumerate() {
367 for (j, y) in rhs.0.iter().enumerate() {
368 let (sign, index) = if i + j < 256 {
369 (Elem::new(1), i + j)
370 } else {
371 (Elem::new(BaseField::Q - 1), i + j - 256)
372 };
373
374 out.0[index] = out.0[index] + (sign * *x * *y);
375 }
376 }
377 out
378 }
379
380 fn matrix_transpose<K: ArraySize>(matrix: &NttMatrix<K>) -> NttMatrix<K> {
382 NttMatrix::new(Array::from_fn(|i| {
383 NttVector::new(Array::from_fn(|j| matrix.0[j].0[i].clone()))
384 }))
385 }
386
387 #[test]
388 #[allow(clippy::cast_possible_truncation)]
389 fn polynomial_ops() {
390 let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
391 let g = Polynomial::new(Array::from_fn(|i| Elem::new(2 * i as Int)));
392 let sum = Polynomial::new(Array::from_fn(|i| Elem::new(3 * i as Int)));
393 assert_eq!((&f + &g), sum);
394 assert_eq!((&sum - &g), f);
395 assert_eq!(Elem::new(3) * &f, sum);
396 }
397
398 #[test]
399 #[allow(clippy::cast_possible_truncation, clippy::similar_names)]
400 fn ntt() {
401 let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
402 let g = Polynomial::new(Array::from_fn(|i| Elem::new(2 * i as Int)));
403 let f_hat = f.ntt();
404 let g_hat = g.ntt();
405
406 let f_unhat = f_hat.ntt_inverse();
408 assert_eq!(f, f_unhat);
409
410 let fg = &f + &g;
412 let f_hat_g_hat = &f_hat + &g_hat;
413 let fg_unhat = f_hat_g_hat.ntt_inverse();
414 assert_eq!(fg, fg_unhat);
415
416 let fg = poly_mul(&f, &g);
418 let f_hat_g_hat = &f_hat * &g_hat;
419 let fg_unhat = f_hat_g_hat.ntt_inverse();
420 assert_eq!(fg, fg_unhat);
421 }
422
423 #[test]
424 fn ntt_vector() {
425 let v1: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
427 let v2: NttVector<U3> = NttVector::new(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
428 let v3: NttVector<U3> = NttVector::new(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
429 assert_eq!((&v1 + &v2), v3);
430
431 assert_eq!((&v1 * &v2), const_ntt(6));
433 assert_eq!((&v1 * &v3), const_ntt(9));
434 assert_eq!((&v2 * &v3), const_ntt(18));
435
436 assert_ne!(v1, v2);
438 assert_ne!(v1, v3);
439 assert_ne!(v2, v3);
440 }
441
442 #[test]
443 fn ntt_matrix() {
444 let a: NttMatrix<U3> = NttMatrix::new(Array([
446 NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)])),
447 NttVector::new(Array([const_ntt(4), const_ntt(5), const_ntt(6)])),
448 NttVector::new(Array([const_ntt(7), const_ntt(8), const_ntt(9)])),
449 ]));
450 let v_in: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)]));
451 let v_out: NttVector<U3> =
452 NttVector::new(Array([const_ntt(14), const_ntt(32), const_ntt(50)]));
453 assert_eq!(&a * &v_in, v_out);
454
455 let aT = NttMatrix::new(Array([
457 NttVector::new(Array([const_ntt(1), const_ntt(4), const_ntt(7)])),
458 NttVector::new(Array([const_ntt(2), const_ntt(5), const_ntt(8)])),
459 NttVector::new(Array([const_ntt(3), const_ntt(6), const_ntt(9)])),
460 ]));
461 assert_eq!(matrix_transpose(&a), aT);
462 }
463
464 const KL_THRESHOLD: f64 = 2.05;
490
491 type Distribution = [f64; Q_SIZE];
497 const Q_SIZE: usize = BaseField::Q as usize;
498 static CBD2: Distribution = {
499 let mut dist = [0.0; Q_SIZE];
500 dist[Q_SIZE - 2] = 1.0 / 16.0;
501 dist[Q_SIZE - 1] = 4.0 / 16.0;
502 dist[0] = 6.0 / 16.0;
503 dist[1] = 4.0 / 16.0;
504 dist[2] = 1.0 / 16.0;
505 dist
506 };
507 static CBD3: Distribution = {
508 let mut dist = [0.0; Q_SIZE];
509 dist[Q_SIZE - 3] = 1.0 / 64.0;
510 dist[Q_SIZE - 2] = 6.0 / 64.0;
511 dist[Q_SIZE - 1] = 15.0 / 64.0;
512 dist[0] = 20.0 / 64.0;
513 dist[1] = 15.0 / 64.0;
514 dist[2] = 6.0 / 64.0;
515 dist[3] = 1.0 / 64.0;
516 dist
517 };
518 static UNIFORM: Distribution = [1.0 / (BaseField::Q as f64); Q_SIZE];
519
520 fn kl_divergence(p: &Distribution, q: &Distribution) -> f64 {
521 p.iter()
522 .zip(q.iter())
523 .map(|(p, q)| if *p == 0.0 { 0.0 } else { p * (p / q).log2() })
524 .sum()
525 }
526
527 #[allow(clippy::cast_precision_loss, clippy::large_stack_arrays)]
528 fn test_sample(sample: &[Elem], ref_dist: &Distribution) {
529 let mut sample_dist: Distribution = [0.0; Q_SIZE];
531 let bump: f64 = 1.0 / (sample.len() as f64);
532 for x in sample {
533 assert!(x.0 < BaseField::Q);
534 assert!(ref_dist[x.0 as usize] > 0.0);
535
536 sample_dist[x.0 as usize] += bump;
537 }
538
539 let d = kl_divergence(&sample_dist, ref_dist);
540 assert!(d < KL_THRESHOLD);
541 }
542
543 #[test]
544 #[allow(clippy::cast_possible_truncation)]
545 fn sample_uniform() {
546 let rho = B32::default();
554 let sample: Array<Array<Elem, U256>, U8> = Array::from_fn(|i| {
555 let mut xof = XOF(&rho, 0, i as u8);
556 super::sample_ntt(&mut xof).into()
557 });
558
559 test_sample(&sample.flatten(), &UNIFORM);
560 }
561
562 #[test]
563 fn sample_poly_cbd() {
564 let sigma = B32::default();
566 let prf_output = PRF::<U2>(&sigma, 0);
567 let sample = super::sample_poly_cbd::<U2>(&prf_output).0;
568 test_sample(&sample, &CBD2);
569
570 let sigma = B32::default();
572 let prf_output = PRF::<U3>(&sigma, 0);
573 let sample = super::sample_poly_cbd::<U3>(&prf_output).0;
574 test_sample(&sample, &CBD3);
575 }
576}