1use core::ops::{Add, Mul, Sub};
2use hybrid_array::{typenum::U256, Array};
3use sha3::digest::XofReader;
4
5use crate::crypto::{PrfOutput, PRF, XOF};
6use crate::encode::Encode;
7use crate::param::{ArraySize, CbdSamplingSize};
8use crate::util::{Truncate, B32};
9
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13pub type Integer = u16;
14
15#[derive(Copy, Clone, Debug, Default, PartialEq)]
18pub struct FieldElement(pub Integer);
19
20#[cfg(feature = "zeroize")]
21impl Zeroize for FieldElement {
22 fn zeroize(&mut self) {
23 self.0.zeroize();
24 }
25}
26
27impl FieldElement {
28 pub const Q: Integer = 3329;
29 pub const Q32: u32 = Self::Q as u32;
30 pub const Q64: u64 = Self::Q as u64;
31 const BARRETT_SHIFT: usize = 24;
32 #[allow(clippy::integer_division_remainder_used)]
33 const BARRETT_MULTIPLIER: u64 = (1 << Self::BARRETT_SHIFT) / Self::Q64;
34
35 fn small_reduce(x: u16) -> u16 {
37 if x < Self::Q {
38 x
39 } else {
40 x - Self::Q
41 }
42 }
43
44 fn barrett_reduce(x: u32) -> u16 {
45 let product = u64::from(x) * Self::BARRETT_MULTIPLIER;
46 let quotient = (product >> Self::BARRETT_SHIFT).truncate();
47 let remainder = x - quotient * Self::Q32;
48 Self::small_reduce(remainder.truncate())
49 }
50
51 fn base_case_multiply(a0: Self, a1: Self, b0: Self, b1: Self, i: usize) -> (Self, Self) {
56 let a0 = u32::from(a0.0);
57 let a1 = u32::from(a1.0);
58 let b0 = u32::from(b0.0);
59 let b1 = u32::from(b1.0);
60 let g = u32::from(GAMMA[i].0);
61
62 let b1g = u32::from(Self::barrett_reduce(b1 * g));
63
64 let c0 = Self::barrett_reduce(a0 * b0 + a1 * b1g);
65 let c1 = Self::barrett_reduce(a0 * b1 + a1 * b0);
66 (Self(c0), Self(c1))
67 }
68}
69
70impl Add<FieldElement> for FieldElement {
71 type Output = Self;
72
73 fn add(self, rhs: Self) -> Self {
74 Self(Self::small_reduce(self.0 + rhs.0))
75 }
76}
77
78impl Sub<FieldElement> for FieldElement {
79 type Output = Self;
80
81 fn sub(self, rhs: Self) -> Self {
82 Self(Self::small_reduce(self.0 + Self::Q - rhs.0))
84 }
85}
86
87impl Mul<FieldElement> for FieldElement {
88 type Output = FieldElement;
89
90 fn mul(self, rhs: FieldElement) -> FieldElement {
91 let x = u32::from(self.0);
92 let y = u32::from(rhs.0);
93 Self(Self::barrett_reduce(x * y))
94 }
95}
96
97#[derive(Clone, Copy, Default, Debug, PartialEq)]
99pub struct Polynomial(pub Array<FieldElement, U256>);
100
101impl Add<&Polynomial> for &Polynomial {
102 type Output = Polynomial;
103
104 fn add(self, rhs: &Polynomial) -> Polynomial {
105 Polynomial(
106 self.0
107 .iter()
108 .zip(rhs.0.iter())
109 .map(|(&x, &y)| x + y)
110 .collect(),
111 )
112 }
113}
114
115impl Sub<&Polynomial> for &Polynomial {
116 type Output = Polynomial;
117
118 fn sub(self, rhs: &Polynomial) -> Polynomial {
119 Polynomial(
120 self.0
121 .iter()
122 .zip(rhs.0.iter())
123 .map(|(&x, &y)| x - y)
124 .collect(),
125 )
126 }
127}
128
129impl Mul<&Polynomial> for FieldElement {
130 type Output = Polynomial;
131
132 fn mul(self, rhs: &Polynomial) -> Polynomial {
133 Polynomial(rhs.0.iter().map(|&x| self * x).collect())
134 }
135}
136
137impl Polynomial {
138 pub fn sample_cbd<Eta>(B: &PrfOutput<Eta>) -> Self
144 where
145 Eta: CbdSamplingSize,
146 {
147 let vals: Polynomial = Encode::<Eta::SampleSize>::decode(B);
148 Self(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect())
149 }
150}
151
152#[derive(Clone, Default, Debug, PartialEq)]
154pub struct PolynomialVector<K: ArraySize>(pub Array<Polynomial, K>);
155
156impl<K: ArraySize> Add<PolynomialVector<K>> for PolynomialVector<K> {
157 type Output = PolynomialVector<K>;
158
159 fn add(self, rhs: PolynomialVector<K>) -> PolynomialVector<K> {
160 PolynomialVector(
161 self.0
162 .iter()
163 .zip(rhs.0.iter())
164 .map(|(x, y)| x + y)
165 .collect(),
166 )
167 }
168}
169
170impl<K: ArraySize> PolynomialVector<K> {
171 pub fn sample_cbd<Eta>(sigma: &B32, start_n: u8) -> Self
172 where
173 Eta: CbdSamplingSize,
174 {
175 Self(Array::from_fn(|i| {
176 let N = start_n + i.truncate();
177 let prf_output = PRF::<Eta>(sigma, N);
178 Polynomial::sample_cbd::<Eta>(&prf_output)
179 }))
180 }
181}
182
183#[derive(Clone, Default, Debug, PartialEq)]
185pub struct NttPolynomial(pub Array<FieldElement, U256>);
186
187#[cfg(feature = "zeroize")]
188impl Zeroize for NttPolynomial {
189 fn zeroize(&mut self) {
190 for fe in self.0.iter_mut() {
191 fe.zeroize()
192 }
193 }
194}
195
196impl Add<&NttPolynomial> for &NttPolynomial {
197 type Output = NttPolynomial;
198
199 fn add(self, rhs: &NttPolynomial) -> NttPolynomial {
200 NttPolynomial(
201 self.0
202 .iter()
203 .zip(rhs.0.iter())
204 .map(|(&x, &y)| x + y)
205 .collect(),
206 )
207 }
208}
209
210struct FieldElementReader<'a> {
212 xof: &'a mut dyn XofReader,
213 data: [u8; 96],
214 start: usize,
215 next: Option<Integer>,
216}
217
218impl<'a> FieldElementReader<'a> {
219 fn new(xof: &'a mut impl XofReader) -> Self {
220 let mut out = Self {
221 xof,
222 data: [0u8; 96],
223 start: 0,
224 next: None,
225 };
226
227 out.xof.read(&mut out.data);
229
230 out
231 }
232
233 fn next(&mut self) -> FieldElement {
234 if let Some(val) = self.next {
235 self.next = None;
236 return FieldElement(val);
237 }
238
239 loop {
240 if self.start == self.data.len() {
241 self.xof.read(&mut self.data);
242 self.start = 0;
243 }
244
245 let end = self.start + 3;
246 let b = &self.data[self.start..end];
247 self.start = end;
248
249 let d1 = Integer::from(b[0]) + ((Integer::from(b[1]) & 0xf) << 8);
250 let d2 = (Integer::from(b[1]) >> 4) + ((Integer::from(b[2]) as Integer) << 4);
251
252 if d1 < FieldElement::Q {
253 if d2 < FieldElement::Q {
254 self.next = Some(d2);
255 }
256 return FieldElement(d1);
257 }
258
259 if d2 < FieldElement::Q {
260 return FieldElement(d2);
261 }
262 }
263 }
264}
265
266impl NttPolynomial {
267 pub fn sample_uniform(B: &mut impl XofReader) -> Self {
269 let mut reader = FieldElementReader::new(B);
270 Self(Array::from_fn(|_| reader.next()))
271 }
272}
273
274#[allow(clippy::cast_possible_truncation)]
287const ZETA_POW_BITREV: [FieldElement; 128] = {
288 const ZETA: u64 = 17;
289 #[allow(clippy::integer_division_remainder_used)]
290 const fn bitrev7(x: usize) -> usize {
291 ((x >> 6) % 2)
292 | (((x >> 5) % 2) << 1)
293 | (((x >> 4) % 2) << 2)
294 | (((x >> 3) % 2) << 3)
295 | (((x >> 2) % 2) << 4)
296 | (((x >> 1) % 2) << 5)
297 | ((x % 2) << 6)
298 }
299
300 let mut pow = [FieldElement(0); 128];
302 let mut i = 0;
303 let mut curr = 1u64;
304 #[allow(clippy::integer_division_remainder_used)]
305 while i < 128 {
306 pow[i] = FieldElement(curr as u16);
307 i += 1;
308 curr = (curr * ZETA) % FieldElement::Q64;
309 }
310
311 let mut pow_bitrev = [FieldElement(0); 128];
313 let mut i = 0;
314 while i < 128 {
315 pow_bitrev[i] = pow[bitrev7(i)];
316 i += 1;
317 }
318 pow_bitrev
319};
320
321#[allow(clippy::cast_possible_truncation)]
322const GAMMA: [FieldElement; 128] = {
323 const ZETA: u64 = 17;
324 let mut gamma = [FieldElement(0); 128];
325 let mut i = 0;
326 while i < 128 {
327 let zpr = ZETA_POW_BITREV[i].0 as u64;
328 #[allow(clippy::integer_division_remainder_used)]
329 let g = (zpr * zpr * ZETA) % FieldElement::Q64;
330 gamma[i] = FieldElement(g as u16);
331 i += 1;
332 }
333 gamma
334};
335
336impl Mul<&NttPolynomial> for &NttPolynomial {
338 type Output = NttPolynomial;
339
340 fn mul(self, rhs: &NttPolynomial) -> NttPolynomial {
341 let mut out = NttPolynomial(Array::default());
342
343 for i in 0..128 {
344 let (c0, c1) = FieldElement::base_case_multiply(
345 self.0[2 * i],
346 self.0[2 * i + 1],
347 rhs.0[2 * i],
348 rhs.0[2 * i + 1],
349 i,
350 );
351
352 out.0[2 * i] = c0;
353 out.0[2 * i + 1] = c1;
354 }
355
356 out
357 }
358}
359
360impl From<Array<FieldElement, U256>> for NttPolynomial {
361 fn from(f: Array<FieldElement, U256>) -> NttPolynomial {
362 NttPolynomial(f)
363 }
364}
365
366impl From<NttPolynomial> for Array<FieldElement, U256> {
367 fn from(f_hat: NttPolynomial) -> Array<FieldElement, U256> {
368 f_hat.0
369 }
370}
371
372impl Polynomial {
374 pub fn ntt(&self) -> NttPolynomial {
375 let mut k = 1;
376
377 let mut f = self.0;
378 for len in [128, 64, 32, 16, 8, 4, 2] {
379 for start in (0..256).step_by(2 * len) {
380 let zeta = ZETA_POW_BITREV[k];
381 k += 1;
382
383 for j in start..(start + len) {
384 let t = zeta * f[j + len];
385 f[j + len] = f[j] - t;
386 f[j] = f[j] + t;
387 }
388 }
389 }
390
391 f.into()
392 }
393}
394
395impl NttPolynomial {
397 pub fn ntt_inverse(&self) -> Polynomial {
398 let mut f: Array<FieldElement, U256> = self.0.clone();
399
400 let mut k = 127;
401 for len in [2, 4, 8, 16, 32, 64, 128] {
402 for start in (0..256).step_by(2 * len) {
403 let zeta = ZETA_POW_BITREV[k];
404 k -= 1;
405
406 for j in start..(start + len) {
407 let t = f[j];
408 f[j] = t + f[j + len];
409 f[j + len] = zeta * (f[j + len] - t);
410 }
411 }
412 }
413
414 FieldElement(3303) * &Polynomial(f)
415 }
416}
417
418#[derive(Clone, Default, Debug, PartialEq)]
420pub struct NttVector<K: ArraySize>(pub Array<NttPolynomial, K>);
421
422impl<K: ArraySize> NttVector<K> {
423 pub fn sample_uniform(rho: &B32, i: usize, transpose: bool) -> Self {
424 Self(Array::from_fn(|j| {
425 let (i, j) = if transpose { (j, i) } else { (i, j) };
426 let mut xof = XOF(rho, j.truncate(), i.truncate());
427 NttPolynomial::sample_uniform(&mut xof)
428 }))
429 }
430}
431
432#[cfg(feature = "zeroize")]
433impl<K> Zeroize for NttVector<K>
434where
435 K: ArraySize,
436{
437 fn zeroize(&mut self) {
438 for poly in self.0.iter_mut() {
439 poly.zeroize();
440 }
441 }
442}
443
444impl<K: ArraySize> Add<&NttVector<K>> for &NttVector<K> {
445 type Output = NttVector<K>;
446
447 fn add(self, rhs: &NttVector<K>) -> NttVector<K> {
448 NttVector(
449 self.0
450 .iter()
451 .zip(rhs.0.iter())
452 .map(|(x, y)| x + y)
453 .collect(),
454 )
455 }
456}
457
458impl<K: ArraySize> Mul<&NttVector<K>> for &NttVector<K> {
459 type Output = NttPolynomial;
460
461 fn mul(self, rhs: &NttVector<K>) -> NttPolynomial {
462 self.0
463 .iter()
464 .zip(rhs.0.iter())
465 .map(|(x, y)| x * y)
466 .fold(NttPolynomial::default(), |x, y| &x + &y)
467 }
468}
469
470impl<K: ArraySize> PolynomialVector<K> {
471 pub fn ntt(&self) -> NttVector<K> {
472 NttVector(self.0.iter().map(Polynomial::ntt).collect())
473 }
474}
475
476impl<K: ArraySize> NttVector<K> {
477 pub fn ntt_inverse(&self) -> PolynomialVector<K> {
478 PolynomialVector(self.0.iter().map(NttPolynomial::ntt_inverse).collect())
479 }
480}
481
482#[derive(Clone, Default, Debug, PartialEq)]
485pub struct NttMatrix<K: ArraySize>(Array<NttVector<K>, K>);
486
487impl<K: ArraySize> Mul<&NttVector<K>> for &NttMatrix<K> {
488 type Output = NttVector<K>;
489
490 fn mul(self, rhs: &NttVector<K>) -> NttVector<K> {
491 NttVector(self.0.iter().map(|x| x * rhs).collect())
492 }
493}
494
495impl<K: ArraySize> NttMatrix<K> {
496 pub fn sample_uniform(rho: &B32, transpose: bool) -> Self {
497 Self(Array::from_fn(|i| {
498 NttVector::sample_uniform(rho, i, transpose)
499 }))
500 }
501
502 pub fn transpose(&self) -> Self {
503 Self(Array::from_fn(|i| {
504 NttVector(Array::from_fn(|j| self.0[j].0[i].clone()))
505 }))
506 }
507}
508
509#[cfg(test)]
510mod test {
511 use super::*;
512 use crate::util::Flatten;
513 use hybrid_array::typenum::{U2, U3, U8};
514
515 impl Mul<&Polynomial> for &Polynomial {
517 type Output = Polynomial;
518
519 fn mul(self, rhs: &Polynomial) -> Self::Output {
520 let mut out = Self::Output::default();
521 for (i, x) in self.0.iter().enumerate() {
522 for (j, y) in rhs.0.iter().enumerate() {
523 let (sign, index) = if i + j < 256 {
524 (FieldElement(1), i + j)
525 } else {
526 (FieldElement(FieldElement::Q - 1), i + j - 256)
527 };
528
529 out.0[index] = out.0[index] + (sign * *x * *y);
530 }
531 }
532 out
533 }
534 }
535
536 fn const_ntt(x: Integer) -> NttPolynomial {
538 let mut p = Polynomial::default();
539 p.0[0] = FieldElement(x);
540 p.ntt()
541 }
542
543 #[test]
544 fn polynomial_ops() {
545 let f = Polynomial(Array::from_fn(|i| FieldElement(i as Integer)));
546 let g = Polynomial(Array::from_fn(|i| FieldElement(2 * i as Integer)));
547 let sum = Polynomial(Array::from_fn(|i| FieldElement(3 * i as Integer)));
548 assert_eq!((&f + &g), sum);
549 assert_eq!((&sum - &g), f);
550 assert_eq!(FieldElement(3) * &f, sum);
551 }
552
553 #[test]
554 fn ntt() {
555 let f = Polynomial(Array::from_fn(|i| FieldElement(i as Integer)));
556 let g = Polynomial(Array::from_fn(|i| FieldElement(2 * i as Integer)));
557 let f_hat = f.ntt();
558 let g_hat = g.ntt();
559
560 let f_unhat = f_hat.ntt_inverse();
562 assert_eq!(f, f_unhat);
563
564 let fg = &f + &g;
566 let f_hat_g_hat = &f_hat + &g_hat;
567 let fg_unhat = f_hat_g_hat.ntt_inverse();
568 assert_eq!(fg, fg_unhat);
569
570 let fg = &f * &g;
572 let f_hat_g_hat = &f_hat * &g_hat;
573 let fg_unhat = f_hat_g_hat.ntt_inverse();
574 assert_eq!(fg, fg_unhat);
575 }
576
577 #[test]
578 fn ntt_vector() {
579 let v1: NttVector<U3> = NttVector(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
581 let v2: NttVector<U3> = NttVector(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
582 let v3: NttVector<U3> = NttVector(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
583 assert_eq!((&v1 + &v2), v3);
584
585 assert_eq!((&v1 * &v2), const_ntt(6));
587 assert_eq!((&v1 * &v3), const_ntt(9));
588 assert_eq!((&v2 * &v3), const_ntt(18));
589 }
590
591 #[test]
592 fn ntt_matrix() {
593 let a: NttMatrix<U3> = NttMatrix(Array([
595 NttVector(Array([const_ntt(1), const_ntt(2), const_ntt(3)])),
596 NttVector(Array([const_ntt(4), const_ntt(5), const_ntt(6)])),
597 NttVector(Array([const_ntt(7), const_ntt(8), const_ntt(9)])),
598 ]));
599 let v_in: NttVector<U3> = NttVector(Array([const_ntt(1), const_ntt(2), const_ntt(3)]));
600 let v_out: NttVector<U3> = NttVector(Array([const_ntt(14), const_ntt(32), const_ntt(50)]));
601 assert_eq!(&a * &v_in, v_out);
602
603 let aT = NttMatrix(Array([
605 NttVector(Array([const_ntt(1), const_ntt(4), const_ntt(7)])),
606 NttVector(Array([const_ntt(2), const_ntt(5), const_ntt(8)])),
607 NttVector(Array([const_ntt(3), const_ntt(6), const_ntt(9)])),
608 ]));
609 assert_eq!(a.transpose(), aT);
610 }
611
612 const KL_THRESHOLD: f64 = 2.05;
638
639 type Distribution = [f64; Q_SIZE];
645 const Q_SIZE: usize = FieldElement::Q as usize;
646 const CBD2: Distribution = {
647 let mut dist = [0.0; Q_SIZE];
648 dist[Q_SIZE - 2] = 1.0 / 16.0;
649 dist[Q_SIZE - 1] = 4.0 / 16.0;
650 dist[0] = 6.0 / 16.0;
651 dist[1] = 4.0 / 16.0;
652 dist[2] = 1.0 / 16.0;
653 dist
654 };
655 const CBD3: Distribution = {
656 let mut dist = [0.0; Q_SIZE];
657 dist[Q_SIZE - 3] = 1.0 / 64.0;
658 dist[Q_SIZE - 2] = 6.0 / 64.0;
659 dist[Q_SIZE - 1] = 15.0 / 64.0;
660 dist[0] = 20.0 / 64.0;
661 dist[1] = 15.0 / 64.0;
662 dist[2] = 6.0 / 64.0;
663 dist[3] = 1.0 / 64.0;
664 dist
665 };
666 const UNIFORM: Distribution = [1.0 / (FieldElement::Q as f64); Q_SIZE];
667
668 fn kl_divergence(p: &Distribution, q: &Distribution) -> f64 {
669 p.iter()
670 .zip(q.iter())
671 .map(|(p, q)| if *p == 0.0 { 0.0 } else { p * (p / q).log2() })
672 .sum()
673 }
674
675 fn test_sample(sample: &[FieldElement], ref_dist: &Distribution) {
676 let mut sample_dist: Distribution = [0.0; Q_SIZE];
678 let bump: f64 = 1.0 / (sample.len() as f64);
679 for x in sample {
680 assert!(x.0 < FieldElement::Q);
681 assert!(ref_dist[x.0 as usize] > 0.0);
682
683 sample_dist[x.0 as usize] += bump;
684 }
685
686 let d = kl_divergence(&sample_dist, ref_dist);
687 assert!(d < KL_THRESHOLD);
688 }
689
690 #[test]
691 fn sample_uniform() {
692 let rho = B32::default();
700 let sample: Array<Array<FieldElement, U256>, U8> = Array::from_fn(|i| {
701 let mut xof = XOF(&rho, 0, i as u8);
702 NttPolynomial::sample_uniform(&mut xof).into()
703 });
704
705 test_sample(&sample.flatten(), &UNIFORM);
706 }
707
708 #[test]
709 fn sample_cbd() {
710 let sigma = B32::default();
712 let prf_output = PRF::<U2>(&sigma, 0);
713 let sample = Polynomial::sample_cbd::<U2>(&prf_output).0;
714 test_sample(&sample, &CBD2);
715
716 let sigma = B32::default();
718 let prf_output = PRF::<U3>(&sigma, 0);
719 let sample = Polynomial::sample_cbd::<U3>(&prf_output).0;
720 test_sample(&sample, &CBD3);
721 }
722}