ml_kem/
algebra.rs

1use core::ops::{Add, Mul, Sub};
2use hybrid_array::{typenum::U256, Array};
3use sha3::digest::XofReader;
4
5use crate::crypto::{PrfOutput, PRF, XOF};
6use crate::encode::Encode;
7use crate::param::{ArraySize, CbdSamplingSize};
8use crate::util::{Truncate, B32};
9
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13pub type Integer = u16;
14
15/// An element of GF(q).  Although `q` is only 16 bits wide, we use a wider uint type to so that we
16/// can defer modular reductions.
17#[derive(Copy, Clone, Debug, Default, PartialEq)]
18pub struct FieldElement(pub Integer);
19
20#[cfg(feature = "zeroize")]
21impl Zeroize for FieldElement {
22    fn zeroize(&mut self) {
23        self.0.zeroize();
24    }
25}
26
27impl FieldElement {
28    pub const Q: Integer = 3329;
29    pub const Q32: u32 = Self::Q as u32;
30    pub const Q64: u64 = Self::Q as u64;
31    const BARRETT_SHIFT: usize = 24;
32    #[allow(clippy::integer_division_remainder_used)]
33    const BARRETT_MULTIPLIER: u64 = (1 << Self::BARRETT_SHIFT) / Self::Q64;
34
35    // A fast modular reduction for small numbers `x < 2*q`
36    fn small_reduce(x: u16) -> u16 {
37        if x < Self::Q {
38            x
39        } else {
40            x - Self::Q
41        }
42    }
43
44    fn barrett_reduce(x: u32) -> u16 {
45        let product = u64::from(x) * Self::BARRETT_MULTIPLIER;
46        let quotient = (product >> Self::BARRETT_SHIFT).truncate();
47        let remainder = x - quotient * Self::Q32;
48        Self::small_reduce(remainder.truncate())
49    }
50
51    // Algorithm 11. BaseCaseMultiply
52    //
53    // This is a hot loop.  We promote to u64 so that we can do the absolute minimum number of
54    // modular reductions, since these are the expensive operation.
55    fn base_case_multiply(a0: Self, a1: Self, b0: Self, b1: Self, i: usize) -> (Self, Self) {
56        let a0 = u32::from(a0.0);
57        let a1 = u32::from(a1.0);
58        let b0 = u32::from(b0.0);
59        let b1 = u32::from(b1.0);
60        let g = u32::from(GAMMA[i].0);
61
62        let b1g = u32::from(Self::barrett_reduce(b1 * g));
63
64        let c0 = Self::barrett_reduce(a0 * b0 + a1 * b1g);
65        let c1 = Self::barrett_reduce(a0 * b1 + a1 * b0);
66        (Self(c0), Self(c1))
67    }
68}
69
70impl Add<FieldElement> for FieldElement {
71    type Output = Self;
72
73    fn add(self, rhs: Self) -> Self {
74        Self(Self::small_reduce(self.0 + rhs.0))
75    }
76}
77
78impl Sub<FieldElement> for FieldElement {
79    type Output = Self;
80
81    fn sub(self, rhs: Self) -> Self {
82        // Guard against underflow if `rhs` is too large
83        Self(Self::small_reduce(self.0 + Self::Q - rhs.0))
84    }
85}
86
87impl Mul<FieldElement> for FieldElement {
88    type Output = FieldElement;
89
90    fn mul(self, rhs: FieldElement) -> FieldElement {
91        let x = u32::from(self.0);
92        let y = u32::from(rhs.0);
93        Self(Self::barrett_reduce(x * y))
94    }
95}
96
97/// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255
98#[derive(Clone, Copy, Default, Debug, PartialEq)]
99pub struct Polynomial(pub Array<FieldElement, U256>);
100
101impl Add<&Polynomial> for &Polynomial {
102    type Output = Polynomial;
103
104    fn add(self, rhs: &Polynomial) -> Polynomial {
105        Polynomial(
106            self.0
107                .iter()
108                .zip(rhs.0.iter())
109                .map(|(&x, &y)| x + y)
110                .collect(),
111        )
112    }
113}
114
115impl Sub<&Polynomial> for &Polynomial {
116    type Output = Polynomial;
117
118    fn sub(self, rhs: &Polynomial) -> Polynomial {
119        Polynomial(
120            self.0
121                .iter()
122                .zip(rhs.0.iter())
123                .map(|(&x, &y)| x - y)
124                .collect(),
125        )
126    }
127}
128
129impl Mul<&Polynomial> for FieldElement {
130    type Output = Polynomial;
131
132    fn mul(self, rhs: &Polynomial) -> Polynomial {
133        Polynomial(rhs.0.iter().map(|&x| self * x).collect())
134    }
135}
136
137impl Polynomial {
138    // Algorithm 7. SamplePolyCBD_eta(B)
139    //
140    // To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in
141    // ByteDecode.  We decode the PRF output into integers with eta bits, then use
142    // `count_ones` to perform the summation described in the algorithm.
143    pub fn sample_cbd<Eta>(B: &PrfOutput<Eta>) -> Self
144    where
145        Eta: CbdSamplingSize,
146    {
147        let vals: Polynomial = Encode::<Eta::SampleSize>::decode(B);
148        Self(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect())
149    }
150}
151
152/// A vector of polynomials of length `k`
153#[derive(Clone, Default, Debug, PartialEq)]
154pub struct PolynomialVector<K: ArraySize>(pub Array<Polynomial, K>);
155
156impl<K: ArraySize> Add<PolynomialVector<K>> for PolynomialVector<K> {
157    type Output = PolynomialVector<K>;
158
159    fn add(self, rhs: PolynomialVector<K>) -> PolynomialVector<K> {
160        PolynomialVector(
161            self.0
162                .iter()
163                .zip(rhs.0.iter())
164                .map(|(x, y)| x + y)
165                .collect(),
166        )
167    }
168}
169
170impl<K: ArraySize> PolynomialVector<K> {
171    pub fn sample_cbd<Eta>(sigma: &B32, start_n: u8) -> Self
172    where
173        Eta: CbdSamplingSize,
174    {
175        Self(Array::from_fn(|i| {
176            let N = start_n + i.truncate();
177            let prf_output = PRF::<Eta>(sigma, N);
178            Polynomial::sample_cbd::<Eta>(&prf_output)
179        }))
180    }
181}
182
183/// An element of the ring `T_q`, i.e., a tuple of 128 elements of the direct sum components of `T_q`.
184#[derive(Clone, Default, Debug, PartialEq)]
185pub struct NttPolynomial(pub Array<FieldElement, U256>);
186
187#[cfg(feature = "zeroize")]
188impl Zeroize for NttPolynomial {
189    fn zeroize(&mut self) {
190        for fe in self.0.iter_mut() {
191            fe.zeroize()
192        }
193    }
194}
195
196impl Add<&NttPolynomial> for &NttPolynomial {
197    type Output = NttPolynomial;
198
199    fn add(self, rhs: &NttPolynomial) -> NttPolynomial {
200        NttPolynomial(
201            self.0
202                .iter()
203                .zip(rhs.0.iter())
204                .map(|(&x, &y)| x + y)
205                .collect(),
206        )
207    }
208}
209
210// Algorithm 6. SampleNTT (lines 4-13)
211struct FieldElementReader<'a> {
212    xof: &'a mut dyn XofReader,
213    data: [u8; 96],
214    start: usize,
215    next: Option<Integer>,
216}
217
218impl<'a> FieldElementReader<'a> {
219    fn new(xof: &'a mut impl XofReader) -> Self {
220        let mut out = Self {
221            xof,
222            data: [0u8; 96],
223            start: 0,
224            next: None,
225        };
226
227        // Fill the buffer
228        out.xof.read(&mut out.data);
229
230        out
231    }
232
233    fn next(&mut self) -> FieldElement {
234        if let Some(val) = self.next {
235            self.next = None;
236            return FieldElement(val);
237        }
238
239        loop {
240            if self.start == self.data.len() {
241                self.xof.read(&mut self.data);
242                self.start = 0;
243            }
244
245            let end = self.start + 3;
246            let b = &self.data[self.start..end];
247            self.start = end;
248
249            let d1 = Integer::from(b[0]) + ((Integer::from(b[1]) & 0xf) << 8);
250            let d2 = (Integer::from(b[1]) >> 4) + ((Integer::from(b[2]) as Integer) << 4);
251
252            if d1 < FieldElement::Q {
253                if d2 < FieldElement::Q {
254                    self.next = Some(d2);
255                }
256                return FieldElement(d1);
257            }
258
259            if d2 < FieldElement::Q {
260                return FieldElement(d2);
261            }
262        }
263    }
264}
265
266impl NttPolynomial {
267    // Algorithm 6 SampleNTT(B)
268    pub fn sample_uniform(B: &mut impl XofReader) -> Self {
269        let mut reader = FieldElementReader::new(B);
270        Self(Array::from_fn(|_| reader.next()))
271    }
272}
273
274// Since the powers of zeta used in the NTT and MultiplyNTTs are fixed, we use pre-computed tables
275// to avoid the need to compute the exponetiations at runtime.
276//
277// * ZETA_POW_BITREV[i] = zeta^{BitRev_7(i)}
278// * GAMMA[i] = zeta^{2 BitRev_7(i) + 1}
279//
280// Note that the const environment here imposes some annoying conditions.  Because operator
281// overloading can't be const, we have to do all the reductions here manually.  Because `for` loops
282// are forbidden in `const` functions, we do them manually with `while` loops.
283//
284// The values computed here match those provided in Appendix A of FIPS 203.  ZETA_POW_BITREV
285// corresponds to the first table, and GAMMA to the second table.
286#[allow(clippy::cast_possible_truncation)]
287const ZETA_POW_BITREV: [FieldElement; 128] = {
288    const ZETA: u64 = 17;
289    #[allow(clippy::integer_division_remainder_used)]
290    const fn bitrev7(x: usize) -> usize {
291        ((x >> 6) % 2)
292            | (((x >> 5) % 2) << 1)
293            | (((x >> 4) % 2) << 2)
294            | (((x >> 3) % 2) << 3)
295            | (((x >> 2) % 2) << 4)
296            | (((x >> 1) % 2) << 5)
297            | ((x % 2) << 6)
298    }
299
300    // Compute the powers of zeta
301    let mut pow = [FieldElement(0); 128];
302    let mut i = 0;
303    let mut curr = 1u64;
304    #[allow(clippy::integer_division_remainder_used)]
305    while i < 128 {
306        pow[i] = FieldElement(curr as u16);
307        i += 1;
308        curr = (curr * ZETA) % FieldElement::Q64;
309    }
310
311    // Reorder the powers according to bitrev7
312    let mut pow_bitrev = [FieldElement(0); 128];
313    let mut i = 0;
314    while i < 128 {
315        pow_bitrev[i] = pow[bitrev7(i)];
316        i += 1;
317    }
318    pow_bitrev
319};
320
321#[allow(clippy::cast_possible_truncation)]
322const GAMMA: [FieldElement; 128] = {
323    const ZETA: u64 = 17;
324    let mut gamma = [FieldElement(0); 128];
325    let mut i = 0;
326    while i < 128 {
327        let zpr = ZETA_POW_BITREV[i].0 as u64;
328        #[allow(clippy::integer_division_remainder_used)]
329        let g = (zpr * zpr * ZETA) % FieldElement::Q64;
330        gamma[i] = FieldElement(g as u16);
331        i += 1;
332    }
333    gamma
334};
335
336// Algorithm 10. MuliplyNTTs
337impl Mul<&NttPolynomial> for &NttPolynomial {
338    type Output = NttPolynomial;
339
340    fn mul(self, rhs: &NttPolynomial) -> NttPolynomial {
341        let mut out = NttPolynomial(Array::default());
342
343        for i in 0..128 {
344            let (c0, c1) = FieldElement::base_case_multiply(
345                self.0[2 * i],
346                self.0[2 * i + 1],
347                rhs.0[2 * i],
348                rhs.0[2 * i + 1],
349                i,
350            );
351
352            out.0[2 * i] = c0;
353            out.0[2 * i + 1] = c1;
354        }
355
356        out
357    }
358}
359
360impl From<Array<FieldElement, U256>> for NttPolynomial {
361    fn from(f: Array<FieldElement, U256>) -> NttPolynomial {
362        NttPolynomial(f)
363    }
364}
365
366impl From<NttPolynomial> for Array<FieldElement, U256> {
367    fn from(f_hat: NttPolynomial) -> Array<FieldElement, U256> {
368        f_hat.0
369    }
370}
371
372// Algorithm 8. NTT
373impl Polynomial {
374    pub fn ntt(&self) -> NttPolynomial {
375        let mut k = 1;
376
377        let mut f = self.0;
378        for len in [128, 64, 32, 16, 8, 4, 2] {
379            for start in (0..256).step_by(2 * len) {
380                let zeta = ZETA_POW_BITREV[k];
381                k += 1;
382
383                for j in start..(start + len) {
384                    let t = zeta * f[j + len];
385                    f[j + len] = f[j] - t;
386                    f[j] = f[j] + t;
387                }
388            }
389        }
390
391        f.into()
392    }
393}
394
395// Algorithm 9. NTT^{-1}
396impl NttPolynomial {
397    pub fn ntt_inverse(&self) -> Polynomial {
398        let mut f: Array<FieldElement, U256> = self.0.clone();
399
400        let mut k = 127;
401        for len in [2, 4, 8, 16, 32, 64, 128] {
402            for start in (0..256).step_by(2 * len) {
403                let zeta = ZETA_POW_BITREV[k];
404                k -= 1;
405
406                for j in start..(start + len) {
407                    let t = f[j];
408                    f[j] = t + f[j + len];
409                    f[j + len] = zeta * (f[j + len] - t);
410                }
411            }
412        }
413
414        FieldElement(3303) * &Polynomial(f)
415    }
416}
417
418/// A vector of K NTT-domain polynomials
419#[derive(Clone, Default, Debug, PartialEq)]
420pub struct NttVector<K: ArraySize>(pub Array<NttPolynomial, K>);
421
422impl<K: ArraySize> NttVector<K> {
423    pub fn sample_uniform(rho: &B32, i: usize, transpose: bool) -> Self {
424        Self(Array::from_fn(|j| {
425            let (i, j) = if transpose { (j, i) } else { (i, j) };
426            let mut xof = XOF(rho, j.truncate(), i.truncate());
427            NttPolynomial::sample_uniform(&mut xof)
428        }))
429    }
430}
431
432#[cfg(feature = "zeroize")]
433impl<K> Zeroize for NttVector<K>
434where
435    K: ArraySize,
436{
437    fn zeroize(&mut self) {
438        for poly in self.0.iter_mut() {
439            poly.zeroize();
440        }
441    }
442}
443
444impl<K: ArraySize> Add<&NttVector<K>> for &NttVector<K> {
445    type Output = NttVector<K>;
446
447    fn add(self, rhs: &NttVector<K>) -> NttVector<K> {
448        NttVector(
449            self.0
450                .iter()
451                .zip(rhs.0.iter())
452                .map(|(x, y)| x + y)
453                .collect(),
454        )
455    }
456}
457
458impl<K: ArraySize> Mul<&NttVector<K>> for &NttVector<K> {
459    type Output = NttPolynomial;
460
461    fn mul(self, rhs: &NttVector<K>) -> NttPolynomial {
462        self.0
463            .iter()
464            .zip(rhs.0.iter())
465            .map(|(x, y)| x * y)
466            .fold(NttPolynomial::default(), |x, y| &x + &y)
467    }
468}
469
470impl<K: ArraySize> PolynomialVector<K> {
471    pub fn ntt(&self) -> NttVector<K> {
472        NttVector(self.0.iter().map(Polynomial::ntt).collect())
473    }
474}
475
476impl<K: ArraySize> NttVector<K> {
477    pub fn ntt_inverse(&self) -> PolynomialVector<K> {
478        PolynomialVector(self.0.iter().map(NttPolynomial::ntt_inverse).collect())
479    }
480}
481
482/// A K x K matrix of NTT-domain polynomials.  Each vector represents a row of the matrix, so that
483/// multiplying on the right just requires iteration.
484#[derive(Clone, Default, Debug, PartialEq)]
485pub struct NttMatrix<K: ArraySize>(Array<NttVector<K>, K>);
486
487impl<K: ArraySize> Mul<&NttVector<K>> for &NttMatrix<K> {
488    type Output = NttVector<K>;
489
490    fn mul(self, rhs: &NttVector<K>) -> NttVector<K> {
491        NttVector(self.0.iter().map(|x| x * rhs).collect())
492    }
493}
494
495impl<K: ArraySize> NttMatrix<K> {
496    pub fn sample_uniform(rho: &B32, transpose: bool) -> Self {
497        Self(Array::from_fn(|i| {
498            NttVector::sample_uniform(rho, i, transpose)
499        }))
500    }
501
502    pub fn transpose(&self) -> Self {
503        Self(Array::from_fn(|i| {
504            NttVector(Array::from_fn(|j| self.0[j].0[i].clone()))
505        }))
506    }
507}
508
509#[cfg(test)]
510mod test {
511    use super::*;
512    use crate::util::Flatten;
513    use hybrid_array::typenum::{U2, U3, U8};
514
515    // Multiplication in R_q, modulo X^256 + 1
516    impl Mul<&Polynomial> for &Polynomial {
517        type Output = Polynomial;
518
519        fn mul(self, rhs: &Polynomial) -> Self::Output {
520            let mut out = Self::Output::default();
521            for (i, x) in self.0.iter().enumerate() {
522                for (j, y) in rhs.0.iter().enumerate() {
523                    let (sign, index) = if i + j < 256 {
524                        (FieldElement(1), i + j)
525                    } else {
526                        (FieldElement(FieldElement::Q - 1), i + j - 256)
527                    };
528
529                    out.0[index] = out.0[index] + (sign * *x * *y);
530                }
531            }
532            out
533        }
534    }
535
536    // A polynomial with only a scalar component, to make simple test cases
537    fn const_ntt(x: Integer) -> NttPolynomial {
538        let mut p = Polynomial::default();
539        p.0[0] = FieldElement(x);
540        p.ntt()
541    }
542
543    #[test]
544    fn polynomial_ops() {
545        let f = Polynomial(Array::from_fn(|i| FieldElement(i as Integer)));
546        let g = Polynomial(Array::from_fn(|i| FieldElement(2 * i as Integer)));
547        let sum = Polynomial(Array::from_fn(|i| FieldElement(3 * i as Integer)));
548        assert_eq!((&f + &g), sum);
549        assert_eq!((&sum - &g), f);
550        assert_eq!(FieldElement(3) * &f, sum);
551    }
552
553    #[test]
554    fn ntt() {
555        let f = Polynomial(Array::from_fn(|i| FieldElement(i as Integer)));
556        let g = Polynomial(Array::from_fn(|i| FieldElement(2 * i as Integer)));
557        let f_hat = f.ntt();
558        let g_hat = g.ntt();
559
560        // Verify that NTT and NTT^-1 are actually inverses
561        let f_unhat = f_hat.ntt_inverse();
562        assert_eq!(f, f_unhat);
563
564        // Verify that NTT is a homomorphism with regard to addition
565        let fg = &f + &g;
566        let f_hat_g_hat = &f_hat + &g_hat;
567        let fg_unhat = f_hat_g_hat.ntt_inverse();
568        assert_eq!(fg, fg_unhat);
569
570        // Verify that NTT is a homomorphism with regard to multiplication
571        let fg = &f * &g;
572        let f_hat_g_hat = &f_hat * &g_hat;
573        let fg_unhat = f_hat_g_hat.ntt_inverse();
574        assert_eq!(fg, fg_unhat);
575    }
576
577    #[test]
578    fn ntt_vector() {
579        // Verify vector addition
580        let v1: NttVector<U3> = NttVector(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
581        let v2: NttVector<U3> = NttVector(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
582        let v3: NttVector<U3> = NttVector(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
583        assert_eq!((&v1 + &v2), v3);
584
585        // Verify dot product
586        assert_eq!((&v1 * &v2), const_ntt(6));
587        assert_eq!((&v1 * &v3), const_ntt(9));
588        assert_eq!((&v2 * &v3), const_ntt(18));
589    }
590
591    #[test]
592    fn ntt_matrix() {
593        // Verify matrix multiplication by a vector
594        let a: NttMatrix<U3> = NttMatrix(Array([
595            NttVector(Array([const_ntt(1), const_ntt(2), const_ntt(3)])),
596            NttVector(Array([const_ntt(4), const_ntt(5), const_ntt(6)])),
597            NttVector(Array([const_ntt(7), const_ntt(8), const_ntt(9)])),
598        ]));
599        let v_in: NttVector<U3> = NttVector(Array([const_ntt(1), const_ntt(2), const_ntt(3)]));
600        let v_out: NttVector<U3> = NttVector(Array([const_ntt(14), const_ntt(32), const_ntt(50)]));
601        assert_eq!(&a * &v_in, v_out);
602
603        // Verify transpose
604        let aT = NttMatrix(Array([
605            NttVector(Array([const_ntt(1), const_ntt(4), const_ntt(7)])),
606            NttVector(Array([const_ntt(2), const_ntt(5), const_ntt(8)])),
607            NttVector(Array([const_ntt(3), const_ntt(6), const_ntt(9)])),
608        ]));
609        assert_eq!(a.transpose(), aT);
610    }
611
612    // To verify the accuracy of sampling, we use a theorem related to the law of large numbers,
613    // which bounds the convergence of the Kullback-Liebler distance between the empirical
614    // distribution and the hypothesized distribution.
615    //
616    // Theorem (Cover & Thomas, 1991, Theorem 12.2.1): Let $X_1, \ldots, X_n$ be i.i.d. $~P(x)$.
617    // Then:
618    //
619    //   Pr{ D(P_{x^n} || P) > \epsilon } \leq 2^{ -n ( \epsilon - |X|^{ log(n+1) / n } ) }
620    //
621    // So if we test by computing D(P_{x^n} || P) and requiring the value to be below a threshold
622    // \epsilon, then an unbiased sampling should pass with overwhelming probability 1 - 2^{-k},
623    // for some k based on \epsilon, |X|, and n.
624    //
625    // If we take k = 256 and n = 256, then we can solve for the required threshold \epsilon:
626    //
627    //   \epsilon = 1 + |X|^{ 0.03125 }
628    //
629    // For the cases we're interested in here:
630    //
631    //   CBD(eta = 2) => |X| = 5   => epsilon ~= 2.0516
632    //   CBD(eta = 2) => |X| = 7   => epsilon ~= 2.0627
633    //   Uniform byte => |X| = 256 => epsilon ~= 2.1892
634    //
635    // Taking epsilon = 2.05 makes us conservative enough in all cases, without significantly
636    // increasing the probability of false negatives.
637    const KL_THRESHOLD: f64 = 2.05;
638
639    // The centered binomial distributions are calculated as:
640    //
641    //   bin_\eta(k) = (2\eta \choose k + \eta) 2^{-2\eta}
642    //
643    // for k in $-\eta, \ldots, \eta$.  The cases of interest here are \eta = 2, 3.
644    type Distribution = [f64; Q_SIZE];
645    const Q_SIZE: usize = FieldElement::Q as usize;
646    const CBD2: Distribution = {
647        let mut dist = [0.0; Q_SIZE];
648        dist[Q_SIZE - 2] = 1.0 / 16.0;
649        dist[Q_SIZE - 1] = 4.0 / 16.0;
650        dist[0] = 6.0 / 16.0;
651        dist[1] = 4.0 / 16.0;
652        dist[2] = 1.0 / 16.0;
653        dist
654    };
655    const CBD3: Distribution = {
656        let mut dist = [0.0; Q_SIZE];
657        dist[Q_SIZE - 3] = 1.0 / 64.0;
658        dist[Q_SIZE - 2] = 6.0 / 64.0;
659        dist[Q_SIZE - 1] = 15.0 / 64.0;
660        dist[0] = 20.0 / 64.0;
661        dist[1] = 15.0 / 64.0;
662        dist[2] = 6.0 / 64.0;
663        dist[3] = 1.0 / 64.0;
664        dist
665    };
666    const UNIFORM: Distribution = [1.0 / (FieldElement::Q as f64); Q_SIZE];
667
668    fn kl_divergence(p: &Distribution, q: &Distribution) -> f64 {
669        p.iter()
670            .zip(q.iter())
671            .map(|(p, q)| if *p == 0.0 { 0.0 } else { p * (p / q).log2() })
672            .sum()
673    }
674
675    fn test_sample(sample: &[FieldElement], ref_dist: &Distribution) {
676        // Verify data and compute the empirical distribution
677        let mut sample_dist: Distribution = [0.0; Q_SIZE];
678        let bump: f64 = 1.0 / (sample.len() as f64);
679        for x in sample {
680            assert!(x.0 < FieldElement::Q);
681            assert!(ref_dist[x.0 as usize] > 0.0);
682
683            sample_dist[x.0 as usize] += bump;
684        }
685
686        let d = kl_divergence(&sample_dist, ref_dist);
687        assert!(d < KL_THRESHOLD);
688    }
689
690    #[test]
691    fn sample_uniform() {
692        // We require roughly Q/2 samples to verify the uniform distribution.  This is because for
693        // M < N, the uniform distribution over a subset of M elements has KL distance:
694        //
695        //   M sum(p * log(q / p)) = log(q / p) = log(N / M)
696        //
697        // Since Q ~= 2^11 and 256 == 2^8, we need 2^3 == 8 runs of 256 to get out of the bad
698        // regime and get a meaningful measurement.
699        let rho = B32::default();
700        let sample: Array<Array<FieldElement, U256>, U8> = Array::from_fn(|i| {
701            let mut xof = XOF(&rho, 0, i as u8);
702            NttPolynomial::sample_uniform(&mut xof).into()
703        });
704
705        test_sample(&sample.flatten(), &UNIFORM);
706    }
707
708    #[test]
709    fn sample_cbd() {
710        // Eta = 2
711        let sigma = B32::default();
712        let prf_output = PRF::<U2>(&sigma, 0);
713        let sample = Polynomial::sample_cbd::<U2>(&prf_output).0;
714        test_sample(&sample, &CBD2);
715
716        // Eta = 3
717        let sigma = B32::default();
718        let prf_output = PRF::<U3>(&sigma, 0);
719        let sample = Polynomial::sample_cbd::<U3>(&prf_output).0;
720        test_sample(&sample, &CBD3);
721    }
722}