Skip to main content

ml_dsa/
sampling.rs

1use crate::{
2    algebra::{BaseField, Elem, Int, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector},
3    crypto::{G, H},
4    param::{Eta, MaskSamplingSize},
5};
6use hybrid_array::Array;
7use module_lattice::{ArraySize, Field, Truncate};
8#[cfg(feature = "zeroize")]
9use zeroize::Zeroize;
10
11// Algorithm 13 BytesToBits
12fn bit_set(z: &[u8], i: usize) -> bool {
13    let bit_index = i & 0x07;
14    let byte_index = i >> 3;
15    z[byte_index] & (1 << bit_index) != 0
16}
17
18// Algorithm 14 CoeffFromThreeBytes
19fn coeff_from_three_bytes(b: [u8; 3]) -> Option<Elem> {
20    let b0: Int = b[0].into();
21    let b1: Int = b[1].into();
22    let b2: Int = b[2].into();
23
24    let b2p = if b2 > 127 { b2 - 128 } else { b2 };
25
26    let z = (b2p << 16) + (b1 << 8) + b0;
27    (z < BaseField::Q).then_some(Elem::new(z))
28}
29
30// Algorithm 15 CoeffFromHalfByte
31fn coeff_from_half_byte(b: u8, eta: Eta) -> Option<Elem> {
32    match eta {
33        Eta::Two if b < 15 => {
34            let b = Int::from(match b {
35                b if b < 5 => b,
36                b if b < 10 => b - 5,
37                _ => b - 10,
38            });
39
40            if b <= 2 {
41                Some(Elem::new(2 - b))
42            } else {
43                Some(-Elem::new(b - 2))
44            }
45        }
46        Eta::Four if b < 9 => {
47            let b = Int::from(b);
48            if b <= 4 {
49                Some(Elem::new(4 - b))
50            } else {
51                Some(-Elem::new(b - 4))
52            }
53        }
54        _ => None,
55    }
56}
57
58fn coeffs_from_byte(z: u8, eta: Eta) -> (Option<Elem>, Option<Elem>) {
59    (
60        coeff_from_half_byte(z & 0x0F, eta),
61        coeff_from_half_byte(z >> 4, eta),
62    )
63}
64
65// Algorithm 29 SampleInBall
66pub(crate) fn sample_in_ball(rho: &[u8], tau: usize) -> Polynomial {
67    const ONE: Elem = Elem::new(1);
68    const MINUS_ONE: Elem = Elem::new(BaseField::Q - 1);
69
70    let mut c = Polynomial::default();
71    let mut ctx = H::default().absorb(rho);
72
73    let mut s = [0u8; 8];
74    ctx.squeeze(&mut s);
75
76    // h = bytes_to_bits(s)
77    let mut j = [0u8];
78    for i in (256 - tau)..256 {
79        ctx.squeeze(&mut j);
80        while usize::from(j[0]) > i {
81            ctx.squeeze(&mut j);
82        }
83
84        let j = usize::from(j[0]);
85        c.0[i] = c.0[j];
86        c.0[j] = if bit_set(&s, i + tau - 256) {
87            MINUS_ONE
88        } else {
89            ONE
90        };
91    }
92
93    c
94}
95
96// Algorithm 30 RejNTTPoly
97fn rej_ntt_poly(rho: &[u8], r: u8, s: u8) -> NttPolynomial {
98    let mut j = 0;
99    let mut ctx = G::default().absorb(rho).absorb(&[s]).absorb(&[r]);
100    let mut a = NttPolynomial::default();
101
102    // Squeeze 840 bytes (5 SHAKE128 blocks) in a single call rather than 3 bytes
103    // at a time.  The rejection probability is ~0.098%, so 280 candidates are
104    // almost always sufficient while requiring the same 5 Keccak-f permutations.
105    let mut buf = [0u8; 840];
106    ctx.squeeze(&mut buf);
107
108    for chunk in buf.chunks_exact(3) {
109        if let Some(x) = coeff_from_three_bytes([chunk[0], chunk[1], chunk[2]]) {
110            a.0[j] = x;
111            j += 1;
112            if j == 256 {
113                break;
114            }
115        }
116    }
117
118    // Fallback: astronomically unlikely (~10^-44), but required for correctness.
119    let mut tmp = [0u8; 3];
120    while j < 256 {
121        ctx.squeeze(&mut tmp);
122        if let Some(x) = coeff_from_three_bytes(tmp) {
123            a.0[j] = x;
124            j += 1;
125        }
126    }
127    #[cfg(feature = "zeroize")]
128    {
129        buf.zeroize();
130        tmp.zeroize();
131    }
132    a
133}
134
135// Algorithm 31 RejBoundedPoly
136fn rej_bounded_poly(rho: &[u8], eta: Eta, r: u16) -> Polynomial {
137    let mut j = 0;
138    let mut ctx = H::default().absorb(rho).absorb(&r.to_le_bytes());
139    let mut a = Polynomial::default();
140
141    // The reference implementation uses 136 bytes (1 SHAKE256 block) for eta=2 and 272 bytes (2 blocks) for eta=4.
142    let mut buf = [0u8; 272];
143    ctx.squeeze(&mut buf);
144
145    for &byte in &buf {
146        let (z0, z1) = coeffs_from_byte(byte, eta);
147        if let Some(x) = z0 {
148            a.0[j] = x;
149            j += 1;
150            if j == 256 {
151                break;
152            }
153        }
154        if let Some(x) = z1 {
155            a.0[j] = x;
156            j += 1;
157            if j == 256 {
158                break;
159            }
160        }
161    }
162
163    // Fallback: astronomically unlikely, but required for correctness.
164    let mut tmp = [0u8; 1];
165    while j < 256 {
166        ctx.squeeze(&mut tmp);
167        let (z0, z1) = coeffs_from_byte(tmp[0], eta);
168        if let Some(x) = z0 {
169            a.0[j] = x;
170            j += 1;
171        }
172        if j < 256 {
173            if let Some(x) = z1 {
174                a.0[j] = x;
175                j += 1;
176            }
177        }
178    }
179    #[cfg(feature = "zeroize")]
180    {
181        buf.zeroize();
182        tmp.zeroize();
183    }
184    a
185}
186
187// Algorithm 32 ExpandA
188pub(crate) fn expand_a<K: ArraySize, L: ArraySize>(rho: &[u8]) -> NttMatrix<K, L> {
189    NttMatrix::new(Array::from_fn(|r| {
190        NttVector::new(Array::from_fn(|s| {
191            rej_ntt_poly(rho, Truncate::truncate(r), Truncate::truncate(s))
192        }))
193    }))
194}
195
196// Algorithm 33 ExpandS
197//
198// We only do half of the algorithm here, because it's inconvenient to return two vectors of
199// different sizes.  So the caller has to call twice:
200//
201//    let s1 = Vector::<K>::expand_s(rho, 0);
202//    let s2 = Vector::<L>::expand_s(rho, L::USIZE);
203pub(crate) fn expand_s<K: ArraySize>(rho: &[u8], eta: Eta, base: usize) -> Vector<K> {
204    Vector::new(Array::from_fn(|r| {
205        let r = Truncate::truncate(r + base);
206        rej_bounded_poly(rho, eta, r)
207    }))
208}
209
210// Algorithm 34 ExpandMask
211pub(crate) fn expand_mask<K, Gamma1>(rho: &[u8], mu: u16) -> Vector<K>
212where
213    K: ArraySize,
214    Gamma1: MaskSamplingSize,
215{
216    Vector::new(Array::from_fn(|r| {
217        let r: u16 = Truncate::truncate(r);
218        let v = H::default()
219            .absorb(rho)
220            .absorb(&(mu + r).to_le_bytes())
221            .squeeze_new::<Gamma1::SampleSize>();
222
223        Gamma1::unpack(&v)
224    }))
225}
226
227#[cfg(test)]
228#[allow(clippy::as_conversions)]
229#[allow(clippy::cast_possible_truncation)]
230mod test {
231    use super::*;
232    use hybrid_array::typenum::{U16, U256};
233
234    fn max_abs_1(p: &Polynomial) -> bool {
235        p.0.iter()
236            .all(|x| x.0 == 0 || x.0 == 1 || x.0 == BaseField::Q - 1)
237    }
238
239    fn hamming_weight(p: &Polynomial) -> usize {
240        p.0.iter().filter(|x| x.0 != 0).count()
241    }
242
243    // Verify that SampleInBall returns a polynomial with the following properties:
244    //   a. All coefficients are from {-1, 0, 1}
245    //   b. Hamming weight is exactly tau
246    //
247    // We test 256 samples for each value of
248    #[test]
249    fn test_sample_in_ball() {
250        for tau in 1..65 {
251            for seed in 0_usize..255 {
252                let rho = ((tau as u16) << 8) + (seed as u16);
253                let p = sample_in_ball(&rho.to_be_bytes(), tau);
254                assert_eq!(hamming_weight(&p), tau);
255                assert!(max_abs_1(&p));
256            }
257        }
258    }
259
260    // Verify that RejNTTPoly produces samples that are in the proper range, and roughly uniform.
261    // For the "roughly uniform" criterion,
262    #[test]
263    fn test_rej_ntt_poly() {
264        let sample: Array<Array<Elem, U256>, U16> = Array::from_fn(|i| {
265            let i = i as u8;
266            let rho = [i; 32];
267            rej_ntt_poly(&rho, i, i + 1).0
268        });
269
270        let sample = sample.as_flattened();
271
272        let all_in_range = sample.iter().all(|x| x.0 < BaseField::Q);
273        assert!(all_in_range);
274
275        // TODO measure uniformity
276    }
277
278    #[test]
279    fn test_sample_cbd() {
280        let rho = [0; 32];
281
282        // Eta = 2
283        let sample = rej_bounded_poly(&rho, Eta::Two, 0).0;
284        let all_in_range = sample.iter().map(|x| *x + Elem::new(2)).all(|x| x.0 < 5);
285        assert!(all_in_range);
286        // TODO measure uniformity
287
288        // Eta = 4
289        let sample = rej_bounded_poly(&rho, Eta::Four, 0).0;
290        let all_in_range = sample.iter().map(|x| *x + Elem::new(4)).all(|x| x.0 < 9);
291        assert!(all_in_range);
292        // TODO measure uniformity
293    }
294}