ml_dsa/
sampling.rs

1use crate::module_lattice::encode::ArraySize;
2use crate::module_lattice::util::Truncate;
3use hybrid_array::Array;
4
5use crate::algebra::{
6    BaseField, Elem, Field, Int, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector,
7};
8use crate::crypto::{G, H};
9use crate::param::{Eta, MaskSamplingSize};
10
11// Algorithm 13 BytesToBits
12fn bit_set(z: &[u8], i: usize) -> bool {
13    let bit_index = i & 0x07;
14    let byte_index = i >> 3;
15    z[byte_index] & (1 << bit_index) != 0
16}
17
18// Algorithm 14 CoeffFromThreeBytes
19fn coeff_from_three_bytes(b: [u8; 3]) -> Option<Elem> {
20    let b0: Int = b[0].into();
21    let b1: Int = b[1].into();
22    let b2: Int = b[2].into();
23
24    let b2p = if b2 > 127 { b2 - 128 } else { b2 };
25
26    let z = (b2p << 16) + (b1 << 8) + b0;
27    (z < BaseField::Q).then_some(Elem::new(z))
28}
29
30// Algorithm 15 CoeffFromHalfByte
31fn coeff_from_half_byte(b: u8, eta: Eta) -> Option<Elem> {
32    match eta {
33        Eta::Two if b < 15 => {
34            let b = Int::from(match b {
35                b if b < 5 => b,
36                b if b < 10 => b - 5,
37                _ => b - 10,
38            });
39
40            if b <= 2 {
41                Some(Elem::new(2 - b))
42            } else {
43                Some(-Elem::new(b - 2))
44            }
45        }
46        Eta::Four if b < 9 => {
47            let b = Int::from(b);
48            if b <= 4 {
49                Some(Elem::new(4 - b))
50            } else {
51                Some(-Elem::new(b - 4))
52            }
53        }
54        _ => None,
55    }
56}
57
58fn coeffs_from_byte(z: u8, eta: Eta) -> (Option<Elem>, Option<Elem>) {
59    (
60        coeff_from_half_byte(z & 0x0F, eta),
61        coeff_from_half_byte(z >> 4, eta),
62    )
63}
64
65// Algorithm 29 SampleInBall
66pub fn sample_in_ball(rho: &[u8], tau: usize) -> Polynomial {
67    const ONE: Elem = Elem::new(1);
68    const MINUS_ONE: Elem = Elem::new(BaseField::Q - 1);
69
70    let mut c = Polynomial::default();
71    let mut ctx = H::default().absorb(rho);
72
73    let mut s = [0u8; 8];
74    ctx.squeeze(&mut s);
75
76    // h = bytes_to_bits(s)
77    let mut j = [0u8];
78    for i in (256 - tau)..256 {
79        ctx.squeeze(&mut j);
80        while usize::from(j[0]) > i {
81            ctx.squeeze(&mut j);
82        }
83
84        let j = usize::from(j[0]);
85        c.0[i] = c.0[j];
86        c.0[j] = if bit_set(&s, i + tau - 256) {
87            MINUS_ONE
88        } else {
89            ONE
90        };
91    }
92
93    c
94}
95
96// Algorithm 30 RejNTTPoly
97fn rej_ntt_poly(rho: &[u8], r: u8, s: u8) -> NttPolynomial {
98    let mut j = 0;
99    let mut ctx = G::default().absorb(rho).absorb(&[s]).absorb(&[r]);
100
101    let mut a = NttPolynomial::default();
102    let mut s = [0u8; 3];
103    while j < 256 {
104        ctx.squeeze(&mut s);
105        if let Some(x) = coeff_from_three_bytes(s) {
106            a.0[j] = x;
107            j += 1;
108        }
109    }
110
111    a
112}
113
114// Algorithm 31 RejBoundedPoly
115fn rej_bounded_poly(rho: &[u8], eta: Eta, r: u16) -> Polynomial {
116    let mut j = 0;
117    let mut ctx = H::default().absorb(rho).absorb(&r.to_le_bytes());
118
119    let mut a = Polynomial::default();
120    let mut z = [0u8];
121    while j < 256 {
122        ctx.squeeze(&mut z);
123        let (z0, z1) = coeffs_from_byte(z[0], eta);
124
125        if let Some(z) = z0 {
126            a.0[j] = z;
127            j += 1;
128        }
129
130        if j == 256 {
131            break;
132        }
133
134        if let Some(z) = z1 {
135            a.0[j] = z;
136            j += 1;
137        }
138    }
139
140    a
141}
142
143// Algorithm 32 ExpandA
144pub fn expand_a<K: ArraySize, L: ArraySize>(rho: &[u8]) -> NttMatrix<K, L> {
145    NttMatrix::new(Array::from_fn(|r| {
146        NttVector::new(Array::from_fn(|s| {
147            rej_ntt_poly(rho, Truncate::truncate(r), Truncate::truncate(s))
148        }))
149    }))
150}
151
152// Algorithm 33 ExpandS
153//
154// We only do half of the algorithm here, because it's inconvenient to return two vectors of
155// different sizes.  So the caller has to call twice:
156//
157//    let s1 = Vector::<K>::expand_s(rho, 0);
158//    let s2 = Vector::<L>::expand_s(rho, L::USIZE);
159pub fn expand_s<K: ArraySize>(rho: &[u8], eta: Eta, base: usize) -> Vector<K> {
160    Vector::new(Array::from_fn(|r| {
161        let r = Truncate::truncate(r + base);
162        rej_bounded_poly(rho, eta, r)
163    }))
164}
165
166// Algorithm 34 ExpandMask
167pub fn expand_mask<K, Gamma1>(rho: &[u8], mu: u16) -> Vector<K>
168where
169    K: ArraySize,
170    Gamma1: MaskSamplingSize,
171{
172    Vector::new(Array::from_fn(|r| {
173        let r: u16 = Truncate::truncate(r);
174        let v = H::default()
175            .absorb(rho)
176            .absorb(&(mu + r).to_le_bytes())
177            .squeeze_new::<Gamma1::SampleSize>();
178
179        Gamma1::unpack(&v)
180    }))
181}
182
183#[cfg(test)]
184#[allow(clippy::as_conversions)]
185#[allow(clippy::cast_possible_truncation)]
186mod test {
187    use super::*;
188    use hybrid_array::typenum::{U16, U256};
189
190    fn max_abs_1(p: &Polynomial) -> bool {
191        p.0.iter()
192            .all(|x| x.0 == 0 || x.0 == 1 || x.0 == BaseField::Q - 1)
193    }
194
195    fn hamming_weight(p: &Polynomial) -> usize {
196        p.0.iter().filter(|x| x.0 != 0).count()
197    }
198
199    // Verify that SampleInBall returns a polynomial with the following properties:
200    //   a. All coefficients are from {-1, 0, 1}
201    //   b. Hamming weight is exactly tau
202    //
203    // We test 256 samples for each value of
204    #[test]
205    fn test_sample_in_ball() {
206        for tau in 1..65 {
207            for seed in 0_usize..255 {
208                let rho = ((tau as u16) << 8) + (seed as u16);
209                let p = sample_in_ball(&rho.to_be_bytes(), tau);
210                assert_eq!(hamming_weight(&p), tau);
211                assert!(max_abs_1(&p));
212            }
213        }
214    }
215
216    // Verify that RejNTTPoly produces samples that are in the proper range, and roughly uniform.
217    // For the "roughly unform" criterion,
218    #[test]
219    fn test_rej_ntt_poly() {
220        let sample: Array<Array<Elem, U256>, U16> = Array::from_fn(|i| {
221            let i = i as u8;
222            let rho = [i; 32];
223            rej_ntt_poly(&rho, i, i + 1).0
224        });
225
226        let sample = sample.as_flattened();
227
228        let all_in_range = sample.iter().all(|x| x.0 < BaseField::Q);
229        assert!(all_in_range);
230
231        // TODO measure uniformity
232    }
233
234    #[test]
235    fn test_sample_cbd() {
236        let rho = [0; 32];
237
238        // Eta = 2
239        let sample = rej_bounded_poly(&rho, Eta::Two, 0).0;
240        let all_in_range = sample.iter().map(|x| *x + Elem::new(2)).all(|x| x.0 < 5);
241        assert!(all_in_range);
242        // TODO measure uniformity
243
244        // Eta = 4
245        let sample = rej_bounded_poly(&rho, Eta::Four, 0).0;
246        let all_in_range = sample.iter().map(|x| *x + Elem::new(4)).all(|x| x.0 < 9);
247        assert!(all_in_range);
248        // TODO measure uniformity
249    }
250}