Skip to main content

ml_dsa/
hint.rs

1use crate::{
2    algebra::{AlgebraExt, BaseField, Decompose, Elem, Polynomial, Vector},
3    param::{EncodedHint, SignatureParams},
4};
5use ctutils::{Choice, CtEq, CtGt, CtSelect};
6use hybrid_array::{
7    Array,
8    typenum::{U256, Unsigned},
9};
10use module_lattice::{Field, Truncate};
11
12/// Algorithm 39 `MakeHint`: computes hint bit indicating whether adding `z` to `r` alters the high
13/// bits of `r`.
14fn make_hint<TwoGamma2: Unsigned>(z: Elem, r: Elem) -> bool {
15    let r1 = r.high_bits::<TwoGamma2>();
16    let v1 = (r + z).high_bits::<TwoGamma2>();
17    r1 != v1
18}
19
20/// Algorithm 40 `UseHint`: returns the high bits of `r` adjusted according to hint `h`.
21///
22/// All branches are replaced with constant-time selection to avoid
23/// leaking information about `r0` through branch timing.
24#[allow(clippy::integer_division_remainder_used, reason = "params are public")]
25fn use_hint<TwoGamma2: Unsigned>(h: bool, r: Elem) -> Elem {
26    let m: u32 = (BaseField::Q - 1) / TwoGamma2::U32;
27    let (r1, r0) = r.decompose::<TwoGamma2>();
28    let gamma2 = TwoGamma2::U32 / 2;
29
30    // Compute both possible hint-adjusted results unconditionally
31    let r1_inc = Elem::new((r1.0 + 1) % m);
32    let r1_dec = Elem::new((r1.0 + m - 1) % m);
33
34    // r0 is "positive" when r0 > 0 AND r0 <= gamma2
35    let r0_positive = !r0.0.ct_eq(&0) & !r0.0.ct_gt(&gamma2);
36    let hinted = Elem::new(u32::ct_select(&r1_dec.0, &r1_inc.0, r0_positive));
37
38    // Apply hint only when h is set
39    Elem::new(u32::ct_select(
40        &r1.0,
41        &hinted.0,
42        Choice::from_u8_lsb(u8::from(h)),
43    ))
44}
45
46#[derive(Clone, PartialEq, Debug)]
47pub(crate) struct Hint<P>(pub Array<Array<bool, U256>, P::K>)
48where
49    P: SignatureParams;
50
51impl<P> Default for Hint<P>
52where
53    P: SignatureParams,
54{
55    fn default() -> Self {
56        Self(Array::default())
57    }
58}
59
60impl<P> Hint<P>
61where
62    P: SignatureParams,
63{
64    pub(crate) fn new(z: &Vector<P::K>, r: &Vector<P::K>) -> Self {
65        let zi = z.0.iter();
66        let ri = r.0.iter();
67
68        Self(
69            zi.zip(ri)
70                .map(|(zv, rv)| {
71                    let zvi = zv.0.iter();
72                    let rvi = rv.0.iter();
73
74                    zvi.zip(rvi)
75                        .map(|(&z, &r)| make_hint::<P::TwoGamma2>(z, r))
76                        .collect()
77                })
78                .collect(),
79        )
80    }
81
82    pub(crate) fn hamming_weight(&self) -> usize {
83        self.0
84            .iter()
85            .map(|x| x.iter().filter(|x| **x).count())
86            .sum()
87    }
88
89    pub(crate) fn use_hint(&self, r: &Vector<P::K>) -> Vector<P::K> {
90        let hi = self.0.iter();
91        let ri = r.0.iter();
92
93        Vector::new(
94            hi.zip(ri)
95                .map(|(hv, rv)| {
96                    let hvi = hv.iter();
97                    let rvi = rv.0.iter();
98
99                    Polynomial::new(
100                        hvi.zip(rvi)
101                            .map(|(&h, &r)| use_hint::<P::TwoGamma2>(h, r))
102                            .collect(),
103                    )
104                })
105                .collect(),
106        )
107    }
108
109    pub(crate) fn bit_pack(&self) -> EncodedHint<P> {
110        let mut y: EncodedHint<P> = Array::default();
111        let mut index = 0;
112        let omega = P::Omega::USIZE;
113        for i in 0..P::K::U8 {
114            let i_usize: usize = i.into();
115            for j in 0..256 {
116                if self.0[i_usize][j] {
117                    y[index] = Truncate::truncate(j);
118                    index += 1;
119                }
120            }
121
122            y[omega + i_usize] = Truncate::truncate(index);
123        }
124
125        y
126    }
127
128    pub(crate) fn bit_unpack(y: &EncodedHint<P>) -> Option<Self> {
129        let (indices, cuts) = P::split_hint(y);
130        let cuts: Array<usize, P::K> = cuts.iter().map(|x| usize::from(*x)).collect();
131
132        let indices: Array<usize, P::Omega> = indices.iter().map(|x| usize::from(*x)).collect();
133        let max_cut: usize = cuts.iter().copied().max().expect("should have a maximum");
134
135        // cuts must be monotonic but can repeat
136        if !cuts.windows(2).all(|w| w[0] <= w[1])
137            || max_cut > indices.len()
138            || indices[max_cut..].iter().copied().max().unwrap_or(0) > 0
139        {
140            return None;
141        }
142
143        let mut h = Self::default();
144        let mut start = 0;
145        for (i, &end) in cuts.iter().enumerate() {
146            let indices = &indices[start..end];
147
148            // indices must be strictly increasing
149            if !indices.windows(2).all(|w| w[0] < w[1]) {
150                return None;
151            }
152
153            for &j in indices {
154                h.0[i][j] = true;
155            }
156
157            start = end;
158        }
159
160        Some(h)
161    }
162}
163
164#[cfg(test)]
165#[allow(clippy::integer_division_remainder_used, reason = "tests")]
166mod test {
167    use super::*;
168    use crate::{MlDsa44, MlDsa65, ParameterSet};
169
170    #[test]
171    fn use_hint_arithmetic() {
172        type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
173        let gamma2 = TwoGamma2::U32 / 2;
174        let m = (BaseField::Q - 1) / TwoGamma2::U32;
175
176        // h=false returns r1 unchanged
177        let r = Elem::new(1000);
178        let (expected_r1, _) = r.decompose::<TwoGamma2>();
179        assert_eq!(use_hint::<TwoGamma2>(false, r), expected_r1);
180
181        // h=true with positive r0: increment r1 mod m
182        for test_r in 1..TwoGamma2::U32 {
183            let r = Elem::new(test_r);
184            let (r1, r0) = r.decompose::<TwoGamma2>();
185            if r0.0 > 0 && r0.0 <= gamma2 {
186                let result = use_hint::<TwoGamma2>(true, r);
187                assert_eq!(result, Elem::new((r1.0 + 1) % m));
188                break;
189            }
190        }
191
192        // h=true with negative r0: decrement r1
193        for test_r in (BaseField::Q - TwoGamma2::U32)..BaseField::Q {
194            let r = Elem::new(test_r);
195            let (r1, r0) = r.decompose::<TwoGamma2>();
196            if r0.0 >= BaseField::Q - gamma2 {
197                let result = use_hint::<TwoGamma2>(true, r);
198                assert_eq!(result, Elem::new((r1.0 + m - 1) % m));
199                break;
200            }
201        }
202
203        // Test modular wrapping at m-1
204        let r_at_max = Elem::new(TwoGamma2::U32 * (m - 1) + 1);
205        let (r1_max, r0_max) = r_at_max.decompose::<TwoGamma2>();
206        if r1_max.0 == m - 1 && r0_max.0 > 0 && r0_max.0 <= gamma2 {
207            assert_eq!(use_hint::<TwoGamma2>(true, r_at_max).0, 0);
208        }
209
210        // Test with r=1
211        let r_one = Elem::new(1);
212        let (r1_one, _) = r_one.decompose::<TwoGamma2>();
213        assert_eq!(use_hint::<TwoGamma2>(true, r_one).0, (r1_one.0 + 1) % m);
214
215        // Test with r=Q-1
216        let r_qm1 = Elem::new(BaseField::Q - 1);
217        let (r1_qm1, r0_qm1) = r_qm1.decompose::<TwoGamma2>();
218        if r0_qm1.0 >= BaseField::Q - gamma2 {
219            assert_eq!(use_hint::<TwoGamma2>(true, r_qm1).0, (r1_qm1.0 + m - 1) % m);
220        }
221    }
222
223    #[test]
224    fn use_hint_m_wraparound() {
225        type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
226        let m = (BaseField::Q - 1) / TwoGamma2::U32;
227
228        let r_base = TwoGamma2::U32 * (m - 1);
229        for offset in 1..100 {
230            let r = Elem::new(r_base + offset);
231            let (r1, r0) = r.decompose::<TwoGamma2>();
232            if r1.0 == m - 1 && r0.0 > 0 && r0.0 <= TwoGamma2::U32 / 2 {
233                assert_eq!(use_hint::<TwoGamma2>(true, r).0, 0);
234                return;
235            }
236        }
237        panic!("Could not find suitable test value");
238    }
239
240    #[test]
241    fn use_hint_r0_is_zero() {
242        type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
243        let m = (BaseField::Q - 1) / TwoGamma2::U32;
244        let r = Elem::new(0);
245        let (r1, r0) = r.decompose::<TwoGamma2>();
246        assert_eq!(r0.0, 0);
247
248        let result = use_hint::<TwoGamma2>(true, r);
249        assert_eq!(result, Elem::new((r1.0 + m - 1) % m));
250    }
251
252    #[test]
253    fn use_hint_threshold() {
254        type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
255        let gamma2 = TwoGamma2::U32 / 2;
256        let m = (BaseField::Q - 1) / TwoGamma2::U32;
257
258        let threshold = BaseField::Q - gamma2;
259        for test_r in (threshold - 100)..(threshold + 100) {
260            if test_r >= BaseField::Q {
261                continue;
262            }
263            let r = Elem::new(test_r);
264            let (r1, r0) = r.decompose::<TwoGamma2>();
265            if r0.0 == threshold {
266                let expected = (r1.0 + m - 1) % m;
267                assert_eq!(use_hint::<TwoGamma2>(true, r).0, expected);
268                return;
269            }
270        }
271    }
272
273    #[test]
274    fn decompose_produces_valid_r0() {
275        type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
276        let gamma2 = TwoGamma2::U32 / 2;
277
278        for test_r in [
279            0,
280            1000,
281            BaseField::Q / 2,
282            BaseField::Q - 1000,
283            BaseField::Q - 1,
284        ] {
285            let r = Elem::new(test_r);
286            let (r1, r0) = r.decompose::<TwoGamma2>();
287
288            let in_positive_range = r0.0 <= gamma2;
289            let in_negative_range = r0.0 >= BaseField::Q - gamma2;
290            assert!(in_positive_range || in_negative_range);
291
292            let reconstructed = TwoGamma2::U32 * r1.0 + r0.0;
293            assert_eq!(reconstructed % BaseField::Q, test_r % BaseField::Q);
294        }
295    }
296
297    #[test]
298    fn make_hint_correctness() {
299        type TwoGamma2 = <MlDsa65 as ParameterSet>::TwoGamma2;
300
301        for test_r in [0, 1000, BaseField::Q / 2, BaseField::Q - 1] {
302            let r = Elem::new(test_r);
303            let r1 = r.high_bits::<TwoGamma2>();
304
305            assert!(!make_hint::<TwoGamma2>(Elem::new(0), r));
306
307            for test_z in [0, 1, TwoGamma2::U32 / 2, TwoGamma2::U32] {
308                let z = Elem::new(test_z);
309                let h = make_hint::<TwoGamma2>(z, r);
310                let v1 = (r + z).high_bits::<TwoGamma2>();
311                assert_eq!(h, r1 != v1);
312            }
313        }
314    }
315
316    #[test]
317    fn hint_round_trip() {
318        fn test<P: SignatureParams + PartialEq + core::fmt::Debug>() {
319            let mut h = Hint::<P>::default();
320            for i in 0..P::K::USIZE {
321                if i < h.0.len() {
322                    h.0[i][0] = true;
323                    h.0[i][10] = true;
324                    if i > 0 {
325                        h.0[i][i * 5] = true;
326                    }
327                }
328            }
329            let packed = h.bit_pack();
330            let unpacked = Hint::<P>::bit_unpack(&packed).unwrap();
331            assert_eq!(h, unpacked);
332        }
333        test::<MlDsa44>();
334        test::<MlDsa65>();
335    }
336}