1use crate::{
2 algebra::{BaseField, Elem, Int, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector},
3 crypto::{G, H},
4 param::{Eta, MaskSamplingSize},
5};
6use hybrid_array::Array;
7use module_lattice::{ArraySize, Field, Truncate};
8#[cfg(feature = "zeroize")]
9use zeroize::Zeroize;
10
11fn bit_set(z: &[u8], i: usize) -> bool {
13 let bit_index = i & 0x07;
14 let byte_index = i >> 3;
15 z[byte_index] & (1 << bit_index) != 0
16}
17
18fn coeff_from_three_bytes(b: [u8; 3]) -> Option<Elem> {
20 let b0: Int = b[0].into();
21 let b1: Int = b[1].into();
22 let b2: Int = b[2].into();
23
24 let b2p = if b2 > 127 { b2 - 128 } else { b2 };
25
26 let z = (b2p << 16) + (b1 << 8) + b0;
27 (z < BaseField::Q).then_some(Elem::new(z))
28}
29
30fn coeff_from_half_byte(b: u8, eta: Eta) -> Option<Elem> {
32 match eta {
33 Eta::Two if b < 15 => {
34 let b = Int::from(match b {
35 b if b < 5 => b,
36 b if b < 10 => b - 5,
37 _ => b - 10,
38 });
39
40 if b <= 2 {
41 Some(Elem::new(2 - b))
42 } else {
43 Some(-Elem::new(b - 2))
44 }
45 }
46 Eta::Four if b < 9 => {
47 let b = Int::from(b);
48 if b <= 4 {
49 Some(Elem::new(4 - b))
50 } else {
51 Some(-Elem::new(b - 4))
52 }
53 }
54 _ => None,
55 }
56}
57
58fn coeffs_from_byte(z: u8, eta: Eta) -> (Option<Elem>, Option<Elem>) {
59 (
60 coeff_from_half_byte(z & 0x0F, eta),
61 coeff_from_half_byte(z >> 4, eta),
62 )
63}
64
65pub(crate) fn sample_in_ball(rho: &[u8], tau: usize) -> Polynomial {
67 const ONE: Elem = Elem::new(1);
68 const MINUS_ONE: Elem = Elem::new(BaseField::Q - 1);
69
70 let mut c = Polynomial::default();
71 let mut ctx = H::default().absorb(rho);
72
73 let mut s = [0u8; 8];
74 ctx.squeeze(&mut s);
75
76 let mut j = [0u8];
78 for i in (256 - tau)..256 {
79 ctx.squeeze(&mut j);
80 while usize::from(j[0]) > i {
81 ctx.squeeze(&mut j);
82 }
83
84 let j = usize::from(j[0]);
85 c.0[i] = c.0[j];
86 c.0[j] = if bit_set(&s, i + tau - 256) {
87 MINUS_ONE
88 } else {
89 ONE
90 };
91 }
92
93 c
94}
95
96fn rej_ntt_poly(rho: &[u8], r: u8, s: u8) -> NttPolynomial {
98 let mut j = 0;
99 let mut ctx = G::default().absorb(rho).absorb(&[s]).absorb(&[r]);
100 let mut a = NttPolynomial::default();
101
102 let mut buf = [0u8; 840];
106 ctx.squeeze(&mut buf);
107
108 for chunk in buf.chunks_exact(3) {
109 if let Some(x) = coeff_from_three_bytes([chunk[0], chunk[1], chunk[2]]) {
110 a.0[j] = x;
111 j += 1;
112 if j == 256 {
113 break;
114 }
115 }
116 }
117
118 let mut tmp = [0u8; 3];
120 while j < 256 {
121 ctx.squeeze(&mut tmp);
122 if let Some(x) = coeff_from_three_bytes(tmp) {
123 a.0[j] = x;
124 j += 1;
125 }
126 }
127 #[cfg(feature = "zeroize")]
128 {
129 buf.zeroize();
130 tmp.zeroize();
131 }
132 a
133}
134
135fn rej_bounded_poly(rho: &[u8], eta: Eta, r: u16) -> Polynomial {
137 let mut j = 0;
138 let mut ctx = H::default().absorb(rho).absorb(&r.to_le_bytes());
139 let mut a = Polynomial::default();
140
141 let mut buf = [0u8; 272];
143 ctx.squeeze(&mut buf);
144
145 for &byte in &buf {
146 let (z0, z1) = coeffs_from_byte(byte, eta);
147 if let Some(x) = z0 {
148 a.0[j] = x;
149 j += 1;
150 if j == 256 {
151 break;
152 }
153 }
154 if let Some(x) = z1 {
155 a.0[j] = x;
156 j += 1;
157 if j == 256 {
158 break;
159 }
160 }
161 }
162
163 let mut tmp = [0u8; 1];
165 while j < 256 {
166 ctx.squeeze(&mut tmp);
167 let (z0, z1) = coeffs_from_byte(tmp[0], eta);
168 if let Some(x) = z0 {
169 a.0[j] = x;
170 j += 1;
171 }
172 if j < 256 {
173 if let Some(x) = z1 {
174 a.0[j] = x;
175 j += 1;
176 }
177 }
178 }
179 #[cfg(feature = "zeroize")]
180 {
181 buf.zeroize();
182 tmp.zeroize();
183 }
184 a
185}
186
187pub(crate) fn expand_a<K: ArraySize, L: ArraySize>(rho: &[u8]) -> NttMatrix<K, L> {
189 NttMatrix::new(Array::from_fn(|r| {
190 NttVector::new(Array::from_fn(|s| {
191 rej_ntt_poly(rho, Truncate::truncate(r), Truncate::truncate(s))
192 }))
193 }))
194}
195
196pub(crate) fn expand_s<K: ArraySize>(rho: &[u8], eta: Eta, base: usize) -> Vector<K> {
204 Vector::new(Array::from_fn(|r| {
205 let r = Truncate::truncate(r + base);
206 rej_bounded_poly(rho, eta, r)
207 }))
208}
209
210pub(crate) fn expand_mask<K, Gamma1>(rho: &[u8], mu: u16) -> Vector<K>
212where
213 K: ArraySize,
214 Gamma1: MaskSamplingSize,
215{
216 Vector::new(Array::from_fn(|r| {
217 let r: u16 = Truncate::truncate(r);
218 let v = H::default()
219 .absorb(rho)
220 .absorb(&(mu + r).to_le_bytes())
221 .squeeze_new::<Gamma1::SampleSize>();
222
223 Gamma1::unpack(&v)
224 }))
225}
226
227#[cfg(test)]
228#[allow(clippy::as_conversions)]
229#[allow(clippy::cast_possible_truncation)]
230mod test {
231 use super::*;
232 use hybrid_array::typenum::{U16, U256};
233
234 fn max_abs_1(p: &Polynomial) -> bool {
235 p.0.iter()
236 .all(|x| x.0 == 0 || x.0 == 1 || x.0 == BaseField::Q - 1)
237 }
238
239 fn hamming_weight(p: &Polynomial) -> usize {
240 p.0.iter().filter(|x| x.0 != 0).count()
241 }
242
243 #[test]
249 fn test_sample_in_ball() {
250 for tau in 1..65 {
251 for seed in 0_usize..255 {
252 let rho = ((tau as u16) << 8) + (seed as u16);
253 let p = sample_in_ball(&rho.to_be_bytes(), tau);
254 assert_eq!(hamming_weight(&p), tau);
255 assert!(max_abs_1(&p));
256 }
257 }
258 }
259
260 #[test]
263 fn test_rej_ntt_poly() {
264 let sample: Array<Array<Elem, U256>, U16> = Array::from_fn(|i| {
265 let i = i as u8;
266 let rho = [i; 32];
267 rej_ntt_poly(&rho, i, i + 1).0
268 });
269
270 let sample = sample.as_flattened();
271
272 let all_in_range = sample.iter().all(|x| x.0 < BaseField::Q);
273 assert!(all_in_range);
274
275 }
277
278 #[test]
279 fn test_sample_cbd() {
280 let rho = [0; 32];
281
282 let sample = rej_bounded_poly(&rho, Eta::Two, 0).0;
284 let all_in_range = sample.iter().map(|x| *x + Elem::new(2)).all(|x| x.0 < 5);
285 assert!(all_in_range);
286 let sample = rej_bounded_poly(&rho, Eta::Four, 0).0;
290 let all_in_range = sample.iter().map(|x| *x + Elem::new(4)).all(|x| x.0 < 9);
291 assert!(all_in_range);
292 }
294}