Skip to main content

ml_kem/
algebra.rs

1use crate::{
2    B32,
3    crypto::{PRF, PrfOutput, XOF},
4    param::CbdSamplingSize,
5};
6use array::{Array, ArraySize, typenum::U256};
7use module_lattice::{Encode, Field, MultiplyNtt, Truncate};
8use sha3::digest::XofReader;
9
10module_lattice::define_field!(BaseField, u16, u32, u64, 3329);
11
12pub(crate) type Int = <BaseField as Field>::Int;
13
14/// An element of GF(q).
15pub(crate) type Elem = module_lattice::Elem<BaseField>;
16
17/// An element of the ring `R_q`, i.e., a polynomial over `Z_q` of degree 255
18pub(crate) type Polynomial = module_lattice::Polynomial<BaseField>;
19
20/// A vector of polynomials of length `K`.
21pub(crate) type Vector<K> = module_lattice::Vector<BaseField, K>;
22
23/// An element of the ring `T_q` i.e. a tuple of 128 elements of the direct sum components of `T_q`.
24pub(crate) type NttPolynomial = module_lattice::NttPolynomial<BaseField>;
25
26/// A vector of K NTT-domain polynomials.
27pub(crate) type NttVector<K> = module_lattice::NttVector<BaseField, K>;
28
29/// A K x K matrix of NTT-domain polynomials.  Each vector represents a row of the matrix, so that
30/// multiplying on the right just requires iteration.
31pub(crate) type NttMatrix<K> = module_lattice::NttMatrix<BaseField, K, K>;
32
33/// Algorithm 7: `SampleNTT(B)`
34pub(crate) fn sample_ntt(B: &mut impl XofReader) -> NttPolynomial {
35    struct FieldElementReader<'a> {
36        xof: &'a mut dyn XofReader,
37        data: [u8; 96],
38        start: usize,
39        next: Option<Int>,
40    }
41
42    impl<'a> FieldElementReader<'a> {
43        fn new(xof: &'a mut impl XofReader) -> Self {
44            let mut out = Self {
45                xof,
46                data: [0u8; 96],
47                start: 0,
48                next: None,
49            };
50
51            // Fill the buffer
52            out.xof.read(&mut out.data);
53
54            out
55        }
56
57        fn next(&mut self) -> Elem {
58            if let Some(val) = self.next {
59                self.next = None;
60                return Elem::new(val);
61            }
62
63            loop {
64                if self.start == self.data.len() {
65                    self.xof.read(&mut self.data);
66                    self.start = 0;
67                }
68
69                let end = self.start + 3;
70                let b = &self.data[self.start..end];
71                self.start = end;
72
73                let d1 = Int::from(b[0]) + ((Int::from(b[1]) & 0xf) << 8);
74                let d2 = (Int::from(b[1]) >> 4) + (Int::from(b[2]) << 4);
75
76                if d1 < BaseField::Q {
77                    if d2 < BaseField::Q {
78                        self.next = Some(d2);
79                    }
80                    return Elem::new(d1);
81                }
82
83                if d2 < BaseField::Q {
84                    return Elem::new(d2);
85                }
86            }
87        }
88    }
89
90    let mut reader = FieldElementReader::new(B);
91    NttPolynomial::new(Array::from_fn(|_| reader.next()))
92}
93
94pub(crate) fn matrix_sample_ntt<K: ArraySize>(rho: &B32, transpose: bool) -> NttMatrix<K> {
95    NttMatrix::new(Array::from_fn(|i| {
96        NttVector::new(Array::from_fn(|j| {
97            let (i, j) = if transpose { (j, i) } else { (i, j) };
98            let mut xof = XOF(rho, Truncate::truncate(j), Truncate::truncate(i));
99            sample_ntt(&mut xof)
100        }))
101    }))
102}
103
104/// Algorithm 8: `SamplePolyCBD_eta(B)`
105///
106/// To avoid all the bitwise manipulation in the algorithm as written, we reuse the logic in
107/// `ByteDecode`.  We decode the PRF output into integers with eta bits, then use
108/// `count_ones` to perform the summation described in the algorithm.
109pub(crate) fn sample_poly_cbd<Eta>(B: &PrfOutput<Eta>) -> Polynomial
110where
111    Eta: CbdSamplingSize,
112{
113    let vals: Polynomial = Encode::<Eta::SampleSize>::decode(B);
114    Polynomial::new(vals.0.iter().map(|val| Eta::ONES[val.0 as usize]).collect())
115}
116
117pub(crate) fn sample_poly_vec_cbd<Eta, K>(sigma: &B32, start_n: u8) -> Vector<K>
118where
119    Eta: CbdSamplingSize,
120    K: ArraySize,
121{
122    Vector::new(Array::from_fn(|i| {
123        let N = start_n + u8::truncate(i);
124        let prf_output = PRF::<Eta>(sigma, N);
125        sample_poly_cbd::<Eta>(&prf_output)
126    }))
127}
128
129/// The Number Theoretic Transform (NTT) is a variant of the Discrete Fourier Transform (DFT)
130/// defined over a finite field that turns costly polynomial multiplications into simple
131/// coefficient-wise multiplications modulo a fixed prime.
132pub(crate) trait Ntt {
133    type Output;
134    fn ntt(&self) -> Self::Output;
135}
136
137/// One layer of the forward NTT butterfly.
138///
139/// `LEN` is the butterfly half-length and `ITERATIONS = 128 / LEN` is the number of
140/// butterfly groups in the layer. Making both compile-time constants lets the compiler
141/// eliminate the iterator length calculation (`256 / (2 * LEN)`) that `step_by` would
142/// otherwise compute with a `UDIV` instruction.
143#[inline(always)]
144fn ntt_layer<const LEN: usize, const ITERATIONS: usize>(f: &mut Array<Elem, U256>, k: &mut usize) {
145    for i in 0..ITERATIONS {
146        let start = i * 2 * LEN;
147        let zeta = ZETA_POW_BITREV[*k];
148        *k += 1;
149
150        for j in start..(start + LEN) {
151            let t = zeta * f[j + LEN];
152            f[j + LEN] = f[j] - t;
153            f[j] = f[j] + t;
154        }
155    }
156}
157
158/// Algorithm 9: `NTT`
159impl Ntt for Polynomial {
160    type Output = NttPolynomial;
161
162    fn ntt(&self) -> NttPolynomial {
163        let mut k = 1;
164        let mut f = self.0;
165
166        ntt_layer::<128, 1>(&mut f, &mut k);
167        ntt_layer::<64, 2>(&mut f, &mut k);
168        ntt_layer::<32, 4>(&mut f, &mut k);
169        ntt_layer::<16, 8>(&mut f, &mut k);
170        ntt_layer::<8, 16>(&mut f, &mut k);
171        ntt_layer::<4, 32>(&mut f, &mut k);
172        ntt_layer::<2, 64>(&mut f, &mut k);
173
174        f.into()
175    }
176}
177
178impl<K: ArraySize> Ntt for Vector<K> {
179    type Output = NttVector<K>;
180
181    fn ntt(&self) -> NttVector<K> {
182        NttVector::new(self.0.iter().map(Ntt::ntt).collect())
183    }
184}
185
186/// The inverse NTT is the reverse of the Number Theoretic Transform, converting coefficient-wise
187/// products back into standard polynomial form while preserving correctness modulo the same prime.
188#[allow(clippy::module_name_repetitions)]
189pub(crate) trait NttInverse {
190    type Output;
191    fn ntt_inverse(&self) -> Self::Output;
192}
193
194/// One layer of the inverse NTT butterfly.
195///
196/// See [`ntt_layer`] for the rationale behind the const generics.
197#[inline(always)]
198fn ntt_inverse_layer<const LEN: usize, const ITERATIONS: usize>(
199    f: &mut Array<Elem, U256>,
200    k: &mut usize,
201) {
202    for i in 0..ITERATIONS {
203        let start = i * 2 * LEN;
204        let zeta = ZETA_POW_BITREV[*k];
205        *k -= 1;
206
207        for j in start..(start + LEN) {
208            let t = f[j];
209            f[j] = t + f[j + LEN];
210            f[j + LEN] = zeta * (f[j + LEN] - t);
211        }
212    }
213}
214
215/// Algorithm 10: `NTT^{-1}`
216impl NttInverse for NttPolynomial {
217    type Output = Polynomial;
218
219    fn ntt_inverse(&self) -> Polynomial {
220        let mut f: Array<Elem, U256> = self.0.clone();
221        let mut k = 127;
222
223        ntt_inverse_layer::<2, 64>(&mut f, &mut k);
224        ntt_inverse_layer::<4, 32>(&mut f, &mut k);
225        ntt_inverse_layer::<8, 16>(&mut f, &mut k);
226        ntt_inverse_layer::<16, 8>(&mut f, &mut k);
227        ntt_inverse_layer::<32, 4>(&mut f, &mut k);
228        ntt_inverse_layer::<64, 2>(&mut f, &mut k);
229        ntt_inverse_layer::<128, 1>(&mut f, &mut k);
230
231        Elem::new(3303) * &Polynomial::new(f)
232    }
233}
234
235impl<K: ArraySize> NttInverse for NttVector<K> {
236    type Output = Vector<K>;
237
238    fn ntt_inverse(&self) -> Vector<K> {
239        Vector::new(self.0.iter().map(NttInverse::ntt_inverse).collect())
240    }
241}
242
243/// Algorithm 11: `MultiplyNTTs`
244impl MultiplyNtt for BaseField {
245    fn multiply_ntt(lhs: &NttPolynomial, rhs: &NttPolynomial) -> NttPolynomial {
246        let mut out = NttPolynomial::new(Array::default());
247
248        for i in 0..128 {
249            let (c0, c1) = base_case_multiply(
250                lhs.0[2 * i],
251                lhs.0[2 * i + 1],
252                rhs.0[2 * i],
253                rhs.0[2 * i + 1],
254                i,
255            );
256
257            out.0[2 * i] = c0;
258            out.0[2 * i + 1] = c1;
259        }
260
261        out
262    }
263}
264
265/// Algorithm 12: `BaseCaseMultiply`
266///
267/// This is a hot loop.  We promote to u64 so that we can do the absolute minimum number of
268/// modular reductions, since these are the expensive operation.
269#[inline]
270fn base_case_multiply(a0: Elem, a1: Elem, b0: Elem, b1: Elem, i: usize) -> (Elem, Elem) {
271    let a0 = u32::from(a0.0);
272    let a1 = u32::from(a1.0);
273    let b0 = u32::from(b0.0);
274    let b1 = u32::from(b1.0);
275    let g = u32::from(GAMMA[i].0);
276
277    let b1g = u32::from(BaseField::barrett_reduce(b1 * g));
278
279    let c0 = BaseField::barrett_reduce(a0 * b0 + a1 * b1g);
280    let c1 = BaseField::barrett_reduce(a0 * b1 + a1 * b0);
281    (Elem::new(c0), Elem::new(c1))
282}
283
284/// Since the powers of zeta used in the `NTT` and `MultiplyNTTs` are fixed, we use pre-computed
285/// tables to avoid the need to compute the exponentiations at runtime.
286///
287/// * `ZETA_POW_BITREV[i] = zeta^{BitRev_7(i)}`
288/// * `GAMMA[i] = zeta^{2 BitRev_7(i) + 1}`
289///
290/// Note that the const environment here imposes some annoying conditions.  Because operator
291/// overloading can't be const, we have to do all the reductions here manually.  Because `for` loops
292/// are forbidden in `const` functions, we do them manually with `while` loops.
293///
294/// The values computed here match those provided in Appendix A of FIPS 203.
295/// `ZETA_POW_BITREV` corresponds to the first table, and `GAMMA` to the second table.
296#[allow(clippy::integer_division_remainder_used, reason = "constant")]
297const ZETA_POW_BITREV: [Elem; 128] = {
298    const ZETA: u64 = 17;
299
300    const fn bitrev7(x: usize) -> usize {
301        ((x >> 6) % 2)
302            | (((x >> 5) % 2) << 1)
303            | (((x >> 4) % 2) << 2)
304            | (((x >> 3) % 2) << 3)
305            | (((x >> 2) % 2) << 4)
306            | (((x >> 1) % 2) << 5)
307            | ((x % 2) << 6)
308    }
309
310    // Compute the powers of zeta
311    let mut pow = [Elem::new(0); 128];
312    let mut i = 0;
313    let mut curr = 1u64;
314
315    while i < 128 {
316        pow[i] = Elem::new((curr & 0xFFFF) as u16);
317        i += 1;
318        curr = (curr * ZETA) % BaseField::QLL;
319    }
320
321    // Reorder the powers according to bitrev7
322    let mut pow_bitrev = [Elem::new(0); 128];
323    let mut i = 0;
324    while i < 128 {
325        pow_bitrev[i] = pow[bitrev7(i)];
326        i += 1;
327    }
328    pow_bitrev
329};
330
331#[allow(clippy::integer_division_remainder_used, reason = "constant")]
332const GAMMA: [Elem; 128] = {
333    const ZETA: u64 = 17;
334    let mut gamma = [Elem::new(0); 128];
335    let mut i = 0;
336    while i < 128 {
337        let zpr = ZETA_POW_BITREV[i].0 as u64;
338        let g = (zpr * zpr * ZETA) % BaseField::QLL;
339        gamma[i] = Elem::new((g & 0xFFFF) as u16);
340        i += 1;
341    }
342    gamma
343};
344
345#[cfg(test)]
346mod test {
347    use super::{
348        Array, B32, BaseField, Elem, Field, Int, Ntt, NttInverse, NttMatrix, NttPolynomial,
349        NttVector, PRF, Polynomial, U256, XOF,
350    };
351    use array::{
352        ArraySize, Flatten,
353        typenum::{U2, U3, U8},
354    };
355
356    /// A polynomial with only a scalar component, to make simple test cases
357    fn const_ntt(x: Int) -> NttPolynomial {
358        let mut p = Polynomial::default();
359        p.0[0] = Elem::new(x);
360        p.ntt()
361    }
362
363    /// Multiplication in `R_q`, modulo X^256 + 1
364    fn poly_mul(lhs: &Polynomial, rhs: &Polynomial) -> Polynomial {
365        let mut out = Polynomial::default();
366        for (i, x) in lhs.0.iter().enumerate() {
367            for (j, y) in rhs.0.iter().enumerate() {
368                let (sign, index) = if i + j < 256 {
369                    (Elem::new(1), i + j)
370                } else {
371                    (Elem::new(BaseField::Q - 1), i + j - 256)
372                };
373
374                out.0[index] = out.0[index] + (sign * *x * *y);
375            }
376        }
377        out
378    }
379
380    /// Transpose `NttMatrix`
381    fn matrix_transpose<K: ArraySize>(matrix: &NttMatrix<K>) -> NttMatrix<K> {
382        NttMatrix::new(Array::from_fn(|i| {
383            NttVector::new(Array::from_fn(|j| matrix.0[j].0[i].clone()))
384        }))
385    }
386
387    #[test]
388    #[allow(clippy::cast_possible_truncation)]
389    fn polynomial_ops() {
390        let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
391        let g = Polynomial::new(Array::from_fn(|i| Elem::new(2 * i as Int)));
392        let sum = Polynomial::new(Array::from_fn(|i| Elem::new(3 * i as Int)));
393        assert_eq!((&f + &g), sum);
394        assert_eq!((&sum - &g), f);
395        assert_eq!(Elem::new(3) * &f, sum);
396    }
397
398    #[test]
399    #[allow(clippy::cast_possible_truncation, clippy::similar_names)]
400    fn ntt() {
401        let f = Polynomial::new(Array::from_fn(|i| Elem::new(i as Int)));
402        let g = Polynomial::new(Array::from_fn(|i| Elem::new(2 * i as Int)));
403        let f_hat = f.ntt();
404        let g_hat = g.ntt();
405
406        // Verify that NTT and NTT^-1 are actually inverses
407        let f_unhat = f_hat.ntt_inverse();
408        assert_eq!(f, f_unhat);
409
410        // Verify that NTT is a homomorphism with regard to addition
411        let fg = &f + &g;
412        let f_hat_g_hat = &f_hat + &g_hat;
413        let fg_unhat = f_hat_g_hat.ntt_inverse();
414        assert_eq!(fg, fg_unhat);
415
416        // Verify that NTT is a homomorphism with regard to multiplication
417        let fg = poly_mul(&f, &g);
418        let f_hat_g_hat = &f_hat * &g_hat;
419        let fg_unhat = f_hat_g_hat.ntt_inverse();
420        assert_eq!(fg, fg_unhat);
421    }
422
423    #[test]
424    fn ntt_vector() {
425        // Verify vector addition
426        let v1: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(1), const_ntt(1)]));
427        let v2: NttVector<U3> = NttVector::new(Array([const_ntt(2), const_ntt(2), const_ntt(2)]));
428        let v3: NttVector<U3> = NttVector::new(Array([const_ntt(3), const_ntt(3), const_ntt(3)]));
429        assert_eq!((&v1 + &v2), v3);
430
431        // Verify dot product
432        assert_eq!((&v1 * &v2), const_ntt(6));
433        assert_eq!((&v1 * &v3), const_ntt(9));
434        assert_eq!((&v2 * &v3), const_ntt(18));
435
436        // Verify inequality (catches PartialEq mutation that returns true unconditionally)
437        assert_ne!(v1, v2);
438        assert_ne!(v1, v3);
439        assert_ne!(v2, v3);
440    }
441
442    #[test]
443    fn ntt_matrix() {
444        // Verify matrix multiplication by a vector
445        let a: NttMatrix<U3> = NttMatrix::new(Array([
446            NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)])),
447            NttVector::new(Array([const_ntt(4), const_ntt(5), const_ntt(6)])),
448            NttVector::new(Array([const_ntt(7), const_ntt(8), const_ntt(9)])),
449        ]));
450        let v_in: NttVector<U3> = NttVector::new(Array([const_ntt(1), const_ntt(2), const_ntt(3)]));
451        let v_out: NttVector<U3> =
452            NttVector::new(Array([const_ntt(14), const_ntt(32), const_ntt(50)]));
453        assert_eq!(&a * &v_in, v_out);
454
455        // Verify transpose
456        let aT = NttMatrix::new(Array([
457            NttVector::new(Array([const_ntt(1), const_ntt(4), const_ntt(7)])),
458            NttVector::new(Array([const_ntt(2), const_ntt(5), const_ntt(8)])),
459            NttVector::new(Array([const_ntt(3), const_ntt(6), const_ntt(9)])),
460        ]));
461        assert_eq!(matrix_transpose(&a), aT);
462    }
463
464    // To verify the accuracy of sampling, we use a theorem related to the law of large numbers,
465    // which bounds the convergence of the Kullback-Liebler distance between the empirical
466    // distribution and the hypothesized distribution.
467    //
468    // Theorem (Cover & Thomas, 1991, Theorem 12.2.1): Let $X_1, \ldots, X_n$ be i.i.d. $~P(x)$.
469    // Then:
470    //
471    //   Pr{ D(P_{x^n} || P) > \epsilon } \leq 2^{ -n ( \epsilon - |X|^{ log(n+1) / n } ) }
472    //
473    // So if we test by computing D(P_{x^n} || P) and requiring the value to be below a threshold
474    // \epsilon, then an unbiased sampling should pass with overwhelming probability 1 - 2^{-k},
475    // for some k based on \epsilon, |X|, and n.
476    //
477    // If we take k = 256 and n = 256, then we can solve for the required threshold \epsilon:
478    //
479    //   \epsilon = 1 + |X|^{ 0.03125 }
480    //
481    // For the cases we're interested in here:
482    //
483    //   CBD(eta = 2) => |X| = 5   => epsilon ~= 2.0516
484    //   CBD(eta = 2) => |X| = 7   => epsilon ~= 2.0627
485    //   Uniform byte => |X| = 256 => epsilon ~= 2.1892
486    //
487    // Taking epsilon = 2.05 makes us conservative enough in all cases, without significantly
488    // increasing the probability of false negatives.
489    const KL_THRESHOLD: f64 = 2.05;
490
491    // The centered binomial distributions are calculated as:
492    //
493    //   bin_\eta(k) = (2\eta \choose k + \eta) 2^{-2\eta}
494    //
495    // for k in $-\eta, \ldots, \eta$.  The cases of interest here are \eta = 2, 3.
496    type Distribution = [f64; Q_SIZE];
497    const Q_SIZE: usize = BaseField::Q as usize;
498    static CBD2: Distribution = {
499        let mut dist = [0.0; Q_SIZE];
500        dist[Q_SIZE - 2] = 1.0 / 16.0;
501        dist[Q_SIZE - 1] = 4.0 / 16.0;
502        dist[0] = 6.0 / 16.0;
503        dist[1] = 4.0 / 16.0;
504        dist[2] = 1.0 / 16.0;
505        dist
506    };
507    static CBD3: Distribution = {
508        let mut dist = [0.0; Q_SIZE];
509        dist[Q_SIZE - 3] = 1.0 / 64.0;
510        dist[Q_SIZE - 2] = 6.0 / 64.0;
511        dist[Q_SIZE - 1] = 15.0 / 64.0;
512        dist[0] = 20.0 / 64.0;
513        dist[1] = 15.0 / 64.0;
514        dist[2] = 6.0 / 64.0;
515        dist[3] = 1.0 / 64.0;
516        dist
517    };
518    static UNIFORM: Distribution = [1.0 / (BaseField::Q as f64); Q_SIZE];
519
520    fn kl_divergence(p: &Distribution, q: &Distribution) -> f64 {
521        p.iter()
522            .zip(q.iter())
523            .map(|(p, q)| if *p == 0.0 { 0.0 } else { p * (p / q).log2() })
524            .sum()
525    }
526
527    #[allow(clippy::cast_precision_loss, clippy::large_stack_arrays)]
528    fn test_sample(sample: &[Elem], ref_dist: &Distribution) {
529        // Verify data and compute the empirical distribution
530        let mut sample_dist: Distribution = [0.0; Q_SIZE];
531        let bump: f64 = 1.0 / (sample.len() as f64);
532        for x in sample {
533            assert!(x.0 < BaseField::Q);
534            assert!(ref_dist[x.0 as usize] > 0.0);
535
536            sample_dist[x.0 as usize] += bump;
537        }
538
539        let d = kl_divergence(&sample_dist, ref_dist);
540        assert!(d < KL_THRESHOLD);
541    }
542
543    #[test]
544    #[allow(clippy::cast_possible_truncation)]
545    fn sample_uniform() {
546        // We require roughly Q/2 samples to verify the uniform distribution.  This is because for
547        // M < N, the uniform distribution over a subset of M elements has KL distance:
548        //
549        //   M sum(p * log(q / p)) = log(q / p) = log(N / M)
550        //
551        // Since Q ~= 2^11 and 256 == 2^8, we need 2^3 == 8 runs of 256 to get out of the bad
552        // regime and get a meaningful measurement.
553        let rho = B32::default();
554        let sample: Array<Array<Elem, U256>, U8> = Array::from_fn(|i| {
555            let mut xof = XOF(&rho, 0, i as u8);
556            super::sample_ntt(&mut xof).into()
557        });
558
559        test_sample(&sample.flatten(), &UNIFORM);
560    }
561
562    #[test]
563    fn sample_poly_cbd() {
564        // Eta = 2
565        let sigma = B32::default();
566        let prf_output = PRF::<U2>(&sigma, 0);
567        let sample = super::sample_poly_cbd::<U2>(&prf_output).0;
568        test_sample(&sample, &CBD2);
569
570        // Eta = 3
571        let sigma = B32::default();
572        let prf_output = PRF::<U3>(&sigma, 0);
573        let sample = super::sample_poly_cbd::<U3>(&prf_output).0;
574        test_sample(&sample, &CBD3);
575    }
576}