ml_dsa/
hint.rs

1use crate::module_lattice::util::Truncate;
2use hybrid_array::{
3    Array,
4    typenum::{U256, Unsigned},
5};
6
7use crate::algebra::{AlgebraExt, BaseField, Decompose, Elem, Field, Polynomial, Vector};
8use crate::param::{EncodedHint, SignatureParams};
9
10fn make_hint<TwoGamma2: Unsigned>(z: Elem, r: Elem) -> bool {
11    let r1 = r.high_bits::<TwoGamma2>();
12    let v1 = (r + z).high_bits::<TwoGamma2>();
13    r1 != v1
14}
15
16// The method only deals with public data, so we don't need to worry that / and % are not
17// constant-time.
18#[allow(clippy::integer_division_remainder_used)]
19fn use_hint<TwoGamma2: Unsigned>(h: bool, r: Elem) -> Elem {
20    let m: u32 = (BaseField::Q - 1) / TwoGamma2::U32;
21    let (r1, r0) = r.decompose::<TwoGamma2>();
22    let gamma2 = TwoGamma2::U32 / 2;
23    if h && r0.0 <= gamma2 {
24        Elem::new((r1.0 + 1) % m)
25    } else if h && r0.0 >= BaseField::Q - gamma2 {
26        Elem::new((r1.0 + m - 1) % m)
27    } else if h {
28        // We use the Elem encoding even for signed integers.  Since r0 is computed
29        // mod+- 2*gamma2 (possibly minus 1), it is guaranteed to be in [-gamma2, gamma2].
30        unreachable!();
31    } else {
32        r1
33    }
34}
35
36#[derive(Clone, PartialEq, Debug)]
37pub struct Hint<P>(pub Array<Array<bool, U256>, P::K>)
38where
39    P: SignatureParams;
40
41impl<P> Default for Hint<P>
42where
43    P: SignatureParams,
44{
45    fn default() -> Self {
46        Self(Array::default())
47    }
48}
49
50impl<P> Hint<P>
51where
52    P: SignatureParams,
53{
54    pub fn new(z: &Vector<P::K>, r: &Vector<P::K>) -> Self {
55        let zi = z.0.iter();
56        let ri = r.0.iter();
57
58        Self(
59            zi.zip(ri)
60                .map(|(zv, rv)| {
61                    let zvi = zv.0.iter();
62                    let rvi = rv.0.iter();
63
64                    zvi.zip(rvi)
65                        .map(|(&z, &r)| make_hint::<P::TwoGamma2>(z, r))
66                        .collect()
67                })
68                .collect(),
69        )
70    }
71
72    pub fn hamming_weight(&self) -> usize {
73        self.0
74            .iter()
75            .map(|x| x.iter().filter(|x| **x).count())
76            .sum()
77    }
78
79    pub fn use_hint(&self, r: &Vector<P::K>) -> Vector<P::K> {
80        let hi = self.0.iter();
81        let ri = r.0.iter();
82
83        Vector::new(
84            hi.zip(ri)
85                .map(|(hv, rv)| {
86                    let hvi = hv.iter();
87                    let rvi = rv.0.iter();
88
89                    Polynomial::new(
90                        hvi.zip(rvi)
91                            .map(|(&h, &r)| use_hint::<P::TwoGamma2>(h, r))
92                            .collect(),
93                    )
94                })
95                .collect(),
96        )
97    }
98
99    pub fn bit_pack(&self) -> EncodedHint<P> {
100        let mut y: EncodedHint<P> = Array::default();
101        let mut index = 0;
102        let omega = P::Omega::USIZE;
103        for i in 0..P::K::U8 {
104            let i_usize: usize = i.into();
105            for j in 0..256 {
106                if self.0[i_usize][j] {
107                    y[index] = Truncate::truncate(j);
108                    index += 1;
109                }
110            }
111
112            y[omega + i_usize] = Truncate::truncate(index);
113        }
114
115        y
116    }
117
118    fn monotonic(a: &[usize]) -> bool {
119        a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] <= *x)
120    }
121
122    pub fn bit_unpack(y: &EncodedHint<P>) -> Option<Self> {
123        let (indices, cuts) = P::split_hint(y);
124        let cuts: Array<usize, P::K> = cuts.iter().map(|x| usize::from(*x)).collect();
125
126        let indices: Array<usize, P::Omega> = indices.iter().map(|x| usize::from(*x)).collect();
127        let max_cut: usize = cuts.iter().copied().max().unwrap();
128        if !Self::monotonic(&cuts)
129            || max_cut > indices.len()
130            || indices[max_cut..].iter().copied().max().unwrap_or(0) > 0
131        {
132            return None;
133        }
134
135        let mut h = Self::default();
136        let mut start = 0;
137        for (i, &end) in cuts.iter().enumerate() {
138            let indices = &indices[start..end];
139
140            if !Self::monotonic(indices) {
141                return None;
142            }
143
144            for &j in indices {
145                h.0[i][j] = true;
146            }
147
148            start = end;
149        }
150
151        Some(h)
152    }
153}