1use crate::{
2 algebra::{AlgebraExt, BaseField, Decompose, Elem, Polynomial, Vector},
3 param::{EncodedHint, SignatureParams},
4};
5use ctutils::{Choice, CtEq, CtGt, CtSelect};
6use hybrid_array::{
7 Array,
8 typenum::{U256, Unsigned},
9};
10use module_lattice::{Field, Truncate};
11
12fn make_hint<TwoGamma2: Unsigned>(z: Elem, r: Elem) -> bool {
15 let r1 = r.high_bits::<TwoGamma2>();
16 let v1 = (r + z).high_bits::<TwoGamma2>();
17 r1 != v1
18}
19
20#[allow(clippy::integer_division_remainder_used, reason = "params are public")]
25fn use_hint<TwoGamma2: Unsigned>(h: bool, r: Elem) -> Elem {
26 let m: u32 = (BaseField::Q - 1) / TwoGamma2::U32;
27 let (r1, r0) = r.decompose::<TwoGamma2>();
28 let gamma2 = TwoGamma2::U32 / 2;
29
30 let r1_inc = Elem::new((r1.0 + 1) % m);
32 let r1_dec = Elem::new((r1.0 + m - 1) % m);
33
34 let r0_positive = !r0.0.ct_eq(&0) & !r0.0.ct_gt(&gamma2);
36 let hinted = Elem::new(u32::ct_select(&r1_dec.0, &r1_inc.0, r0_positive));
37
38 Elem::new(u32::ct_select(
40 &r1.0,
41 &hinted.0,
42 Choice::from_u8_lsb(u8::from(h)),
43 ))
44}
45
46#[derive(Clone, PartialEq, Debug)]
47pub(crate) struct Hint<P>(pub Array<Array<bool, U256>, P::K>)
48where
49 P: SignatureParams;
50
51impl<P> Default for Hint<P>
52where
53 P: SignatureParams,
54{
55 fn default() -> Self {
56 Self(Array::default())
57 }
58}
59
60impl<P> Hint<P>
61where
62 P: SignatureParams,
63{
64 pub(crate) fn new(z: &Vector<P::K>, r: &Vector<P::K>) -> Self {
65 let zi = z.0.iter();
66 let ri = r.0.iter();
67
68 Self(
69 zi.zip(ri)
70 .map(|(zv, rv)| {
71 let zvi = zv.0.iter();
72 let rvi = rv.0.iter();
73
74 zvi.zip(rvi)
75 .map(|(&z, &r)| make_hint::<P::TwoGamma2>(z, r))
76 .collect()
77 })
78 .collect(),
79 )
80 }
81
82 pub(crate) fn hamming_weight(&self) -> usize {
83 self.0
84 .iter()
85 .map(|x| x.iter().filter(|x| **x).count())
86 .sum()
87 }
88
89 pub(crate) fn use_hint(&self, r: &Vector<P::K>) -> Vector<P::K> {
90 let hi = self.0.iter();
91 let ri = r.0.iter();
92
93 Vector::new(
94 hi.zip(ri)
95 .map(|(hv, rv)| {
96 let hvi = hv.iter();
97 let rvi = rv.0.iter();
98
99 Polynomial::new(
100 hvi.zip(rvi)
101 .map(|(&h, &r)| use_hint::<P::TwoGamma2>(h, r))
102 .collect(),
103 )
104 })
105 .collect(),
106 )
107 }
108
109 pub(crate) fn bit_pack(&self) -> EncodedHint<P> {
110 let mut y: EncodedHint<P> = Array::default();
111 let mut index = 0;
112 let omega = P::Omega::USIZE;
113 for i in 0..P::K::U8 {
114 let i_usize: usize = i.into();
115 for j in 0..256 {
116 if self.0[i_usize][j] {
117 y[index] = Truncate::truncate(j);
118 index += 1;
119 }
120 }
121
122 y[omega + i_usize] = Truncate::truncate(index);
123 }
124
125 y
126 }
127
128 pub(crate) fn bit_unpack(y: &EncodedHint<P>) -> Option<Self> {
129 let (indices, cuts) = P::split_hint(y);
130 let cuts: Array<usize, P::K> = cuts.iter().map(|x| usize::from(*x)).collect();
131
132 let indices: Array<usize, P::Omega> = indices.iter().map(|x| usize::from(*x)).collect();
133 let max_cut: usize = cuts.iter().copied().max().expect("should have a maximum");
134
135 if !cuts.windows(2).all(|w| w[0] <= w[1])
137 || max_cut > indices.len()
138 || indices[max_cut..].iter().copied().max().unwrap_or(0) > 0
139 {
140 return None;
141 }
142
143 let mut h = Self::default();
144 let mut start = 0;
145 for (i, &end) in cuts.iter().enumerate() {
146 let indices = &indices[start..end];
147
148 if !indices.windows(2).all(|w| w[0] < w[1]) {
150 return None;
151 }
152
153 for &j in indices {
154 h.0[i][j] = true;
155 }
156
157 start = end;
158 }
159
160 Some(h)
161 }
162}
163
164#[cfg(test)]
165#[allow(clippy::integer_division_remainder_used, reason = "tests")]
166mod test {
167 use super::*;
168 use crate::{MlDsa44, MlDsa65, ParameterSet};
169
170 #[test]
171 fn use_hint_arithmetic() {
172 type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
173 let gamma2 = TwoGamma2::U32 / 2;
174 let m = (BaseField::Q - 1) / TwoGamma2::U32;
175
176 let r = Elem::new(1000);
178 let (expected_r1, _) = r.decompose::<TwoGamma2>();
179 assert_eq!(use_hint::<TwoGamma2>(false, r), expected_r1);
180
181 for test_r in 1..TwoGamma2::U32 {
183 let r = Elem::new(test_r);
184 let (r1, r0) = r.decompose::<TwoGamma2>();
185 if r0.0 > 0 && r0.0 <= gamma2 {
186 let result = use_hint::<TwoGamma2>(true, r);
187 assert_eq!(result, Elem::new((r1.0 + 1) % m));
188 break;
189 }
190 }
191
192 for test_r in (BaseField::Q - TwoGamma2::U32)..BaseField::Q {
194 let r = Elem::new(test_r);
195 let (r1, r0) = r.decompose::<TwoGamma2>();
196 if r0.0 >= BaseField::Q - gamma2 {
197 let result = use_hint::<TwoGamma2>(true, r);
198 assert_eq!(result, Elem::new((r1.0 + m - 1) % m));
199 break;
200 }
201 }
202
203 let r_at_max = Elem::new(TwoGamma2::U32 * (m - 1) + 1);
205 let (r1_max, r0_max) = r_at_max.decompose::<TwoGamma2>();
206 if r1_max.0 == m - 1 && r0_max.0 > 0 && r0_max.0 <= gamma2 {
207 assert_eq!(use_hint::<TwoGamma2>(true, r_at_max).0, 0);
208 }
209
210 let r_one = Elem::new(1);
212 let (r1_one, _) = r_one.decompose::<TwoGamma2>();
213 assert_eq!(use_hint::<TwoGamma2>(true, r_one).0, (r1_one.0 + 1) % m);
214
215 let r_qm1 = Elem::new(BaseField::Q - 1);
217 let (r1_qm1, r0_qm1) = r_qm1.decompose::<TwoGamma2>();
218 if r0_qm1.0 >= BaseField::Q - gamma2 {
219 assert_eq!(use_hint::<TwoGamma2>(true, r_qm1).0, (r1_qm1.0 + m - 1) % m);
220 }
221 }
222
223 #[test]
224 fn use_hint_m_wraparound() {
225 type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
226 let m = (BaseField::Q - 1) / TwoGamma2::U32;
227
228 let r_base = TwoGamma2::U32 * (m - 1);
229 for offset in 1..100 {
230 let r = Elem::new(r_base + offset);
231 let (r1, r0) = r.decompose::<TwoGamma2>();
232 if r1.0 == m - 1 && r0.0 > 0 && r0.0 <= TwoGamma2::U32 / 2 {
233 assert_eq!(use_hint::<TwoGamma2>(true, r).0, 0);
234 return;
235 }
236 }
237 panic!("Could not find suitable test value");
238 }
239
240 #[test]
241 fn use_hint_r0_is_zero() {
242 type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
243 let m = (BaseField::Q - 1) / TwoGamma2::U32;
244 let r = Elem::new(0);
245 let (r1, r0) = r.decompose::<TwoGamma2>();
246 assert_eq!(r0.0, 0);
247
248 let result = use_hint::<TwoGamma2>(true, r);
249 assert_eq!(result, Elem::new((r1.0 + m - 1) % m));
250 }
251
252 #[test]
253 fn use_hint_threshold() {
254 type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
255 let gamma2 = TwoGamma2::U32 / 2;
256 let m = (BaseField::Q - 1) / TwoGamma2::U32;
257
258 let threshold = BaseField::Q - gamma2;
259 for test_r in (threshold - 100)..(threshold + 100) {
260 if test_r >= BaseField::Q {
261 continue;
262 }
263 let r = Elem::new(test_r);
264 let (r1, r0) = r.decompose::<TwoGamma2>();
265 if r0.0 == threshold {
266 let expected = (r1.0 + m - 1) % m;
267 assert_eq!(use_hint::<TwoGamma2>(true, r).0, expected);
268 return;
269 }
270 }
271 }
272
273 #[test]
274 fn decompose_produces_valid_r0() {
275 type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
276 let gamma2 = TwoGamma2::U32 / 2;
277
278 for test_r in [
279 0,
280 1000,
281 BaseField::Q / 2,
282 BaseField::Q - 1000,
283 BaseField::Q - 1,
284 ] {
285 let r = Elem::new(test_r);
286 let (r1, r0) = r.decompose::<TwoGamma2>();
287
288 let in_positive_range = r0.0 <= gamma2;
289 let in_negative_range = r0.0 >= BaseField::Q - gamma2;
290 assert!(in_positive_range || in_negative_range);
291
292 let reconstructed = TwoGamma2::U32 * r1.0 + r0.0;
293 assert_eq!(reconstructed % BaseField::Q, test_r % BaseField::Q);
294 }
295 }
296
297 #[test]
298 fn make_hint_correctness() {
299 type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
300
301 for test_r in [0, 1000, BaseField::Q / 2, BaseField::Q - 1] {
302 let r = Elem::new(test_r);
303 let r1 = r.high_bits::<TwoGamma2>();
304
305 assert!(!make_hint::<TwoGamma2>(Elem::new(0), r));
306
307 for test_z in [0, 1, TwoGamma2::U32 / 2, TwoGamma2::U32] {
308 let z = Elem::new(test_z);
309 let h = make_hint::<TwoGamma2>(z, r);
310 let v1 = (r + z).high_bits::<TwoGamma2>();
311 assert_eq!(h, r1 != v1);
312 }
313 }
314 }
315
316 #[test]
317 fn hint_round_trip() {
318 fn test<P: SignatureParams + PartialEq + core::fmt::Debug>() {
319 let mut h = Hint::<P>::default();
320 for i in 0..P::K::USIZE {
321 if i < h.0.len() {
322 h.0[i][0] = true;
323 h.0[i][10] = true;
324 if i > 0 {
325 h.0[i][i * 5] = true;
326 }
327 }
328 }
329 let packed = h.bit_pack();
330 let unpacked = Hint::<P>::bit_unpack(&packed).unwrap();
331 assert_eq!(h, unpacked);
332 }
333 test::<MlDsa44>();
334 test::<MlDsa65>();
335 }
336}