1use crate::module_lattice::encode::ArraySize;
2use crate::module_lattice::util::Truncate;
3use hybrid_array::Array;
4
5use crate::algebra::{
6 BaseField, Elem, Field, Int, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector,
7};
8use crate::crypto::{G, H};
9use crate::param::{Eta, MaskSamplingSize};
10
11fn bit_set(z: &[u8], i: usize) -> bool {
13 let bit_index = i & 0x07;
14 let byte_index = i >> 3;
15 z[byte_index] & (1 << bit_index) != 0
16}
17
18fn coeff_from_three_bytes(b: [u8; 3]) -> Option<Elem> {
20 let b0: Int = b[0].into();
21 let b1: Int = b[1].into();
22 let b2: Int = b[2].into();
23
24 let b2p = if b2 > 127 { b2 - 128 } else { b2 };
25
26 let z = (b2p << 16) + (b1 << 8) + b0;
27 (z < BaseField::Q).then_some(Elem::new(z))
28}
29
30fn coeff_from_half_byte(b: u8, eta: Eta) -> Option<Elem> {
32 match eta {
33 Eta::Two if b < 15 => {
34 let b = Int::from(match b {
35 b if b < 5 => b,
36 b if b < 10 => b - 5,
37 _ => b - 10,
38 });
39
40 if b <= 2 {
41 Some(Elem::new(2 - b))
42 } else {
43 Some(-Elem::new(b - 2))
44 }
45 }
46 Eta::Four if b < 9 => {
47 let b = Int::from(b);
48 if b <= 4 {
49 Some(Elem::new(4 - b))
50 } else {
51 Some(-Elem::new(b - 4))
52 }
53 }
54 _ => None,
55 }
56}
57
58fn coeffs_from_byte(z: u8, eta: Eta) -> (Option<Elem>, Option<Elem>) {
59 (
60 coeff_from_half_byte(z & 0x0F, eta),
61 coeff_from_half_byte(z >> 4, eta),
62 )
63}
64
65pub fn sample_in_ball(rho: &[u8], tau: usize) -> Polynomial {
67 const ONE: Elem = Elem::new(1);
68 const MINUS_ONE: Elem = Elem::new(BaseField::Q - 1);
69
70 let mut c = Polynomial::default();
71 let mut ctx = H::default().absorb(rho);
72
73 let mut s = [0u8; 8];
74 ctx.squeeze(&mut s);
75
76 let mut j = [0u8];
78 for i in (256 - tau)..256 {
79 ctx.squeeze(&mut j);
80 while usize::from(j[0]) > i {
81 ctx.squeeze(&mut j);
82 }
83
84 let j = usize::from(j[0]);
85 c.0[i] = c.0[j];
86 c.0[j] = if bit_set(&s, i + tau - 256) {
87 MINUS_ONE
88 } else {
89 ONE
90 };
91 }
92
93 c
94}
95
96fn rej_ntt_poly(rho: &[u8], r: u8, s: u8) -> NttPolynomial {
98 let mut j = 0;
99 let mut ctx = G::default().absorb(rho).absorb(&[s]).absorb(&[r]);
100
101 let mut a = NttPolynomial::default();
102 let mut s = [0u8; 3];
103 while j < 256 {
104 ctx.squeeze(&mut s);
105 if let Some(x) = coeff_from_three_bytes(s) {
106 a.0[j] = x;
107 j += 1;
108 }
109 }
110
111 a
112}
113
114fn rej_bounded_poly(rho: &[u8], eta: Eta, r: u16) -> Polynomial {
116 let mut j = 0;
117 let mut ctx = H::default().absorb(rho).absorb(&r.to_le_bytes());
118
119 let mut a = Polynomial::default();
120 let mut z = [0u8];
121 while j < 256 {
122 ctx.squeeze(&mut z);
123 let (z0, z1) = coeffs_from_byte(z[0], eta);
124
125 if let Some(z) = z0 {
126 a.0[j] = z;
127 j += 1;
128 }
129
130 if j == 256 {
131 break;
132 }
133
134 if let Some(z) = z1 {
135 a.0[j] = z;
136 j += 1;
137 }
138 }
139
140 a
141}
142
143pub fn expand_a<K: ArraySize, L: ArraySize>(rho: &[u8]) -> NttMatrix<K, L> {
145 NttMatrix::new(Array::from_fn(|r| {
146 NttVector::new(Array::from_fn(|s| {
147 rej_ntt_poly(rho, Truncate::truncate(r), Truncate::truncate(s))
148 }))
149 }))
150}
151
152pub fn expand_s<K: ArraySize>(rho: &[u8], eta: Eta, base: usize) -> Vector<K> {
160 Vector::new(Array::from_fn(|r| {
161 let r = Truncate::truncate(r + base);
162 rej_bounded_poly(rho, eta, r)
163 }))
164}
165
166pub fn expand_mask<K, Gamma1>(rho: &[u8], mu: u16) -> Vector<K>
168where
169 K: ArraySize,
170 Gamma1: MaskSamplingSize,
171{
172 Vector::new(Array::from_fn(|r| {
173 let r: u16 = Truncate::truncate(r);
174 let v = H::default()
175 .absorb(rho)
176 .absorb(&(mu + r).to_le_bytes())
177 .squeeze_new::<Gamma1::SampleSize>();
178
179 Gamma1::unpack(&v)
180 }))
181}
182
183#[cfg(test)]
184#[allow(clippy::as_conversions)]
185#[allow(clippy::cast_possible_truncation)]
186mod test {
187 use super::*;
188 use hybrid_array::typenum::{U16, U256};
189
190 fn max_abs_1(p: &Polynomial) -> bool {
191 p.0.iter()
192 .all(|x| x.0 == 0 || x.0 == 1 || x.0 == BaseField::Q - 1)
193 }
194
195 fn hamming_weight(p: &Polynomial) -> usize {
196 p.0.iter().filter(|x| x.0 != 0).count()
197 }
198
199 #[test]
205 fn test_sample_in_ball() {
206 for tau in 1..65 {
207 for seed in 0_usize..255 {
208 let rho = ((tau as u16) << 8) + (seed as u16);
209 let p = sample_in_ball(&rho.to_be_bytes(), tau);
210 assert_eq!(hamming_weight(&p), tau);
211 assert!(max_abs_1(&p));
212 }
213 }
214 }
215
216 #[test]
219 fn test_rej_ntt_poly() {
220 let sample: Array<Array<Elem, U256>, U16> = Array::from_fn(|i| {
221 let i = i as u8;
222 let rho = [i; 32];
223 rej_ntt_poly(&rho, i, i + 1).0
224 });
225
226 let sample = sample.as_flattened();
227
228 let all_in_range = sample.iter().all(|x| x.0 < BaseField::Q);
229 assert!(all_in_range);
230
231 }
233
234 #[test]
235 fn test_sample_cbd() {
236 let rho = [0; 32];
237
238 let sample = rej_bounded_poly(&rho, Eta::Two, 0).0;
240 let all_in_range = sample.iter().map(|x| *x + Elem::new(2)).all(|x| x.0 < 5);
241 assert!(all_in_range);
242 let sample = rej_bounded_poly(&rho, Eta::Four, 0).0;
246 let all_in_range = sample.iter().map(|x| *x + Elem::new(4)).all(|x| x.0 < 9);
247 assert!(all_in_range);
248 }
250}