1use crate::module_lattice::util::Truncate;
2use hybrid_array::{
3 Array,
4 typenum::{U256, Unsigned},
5};
6
7use crate::algebra::{AlgebraExt, BaseField, Decompose, Elem, Field, Polynomial, Vector};
8use crate::param::{EncodedHint, SignatureParams};
9
10fn make_hint<TwoGamma2: Unsigned>(z: Elem, r: Elem) -> bool {
11 let r1 = r.high_bits::<TwoGamma2>();
12 let v1 = (r + z).high_bits::<TwoGamma2>();
13 r1 != v1
14}
15
16#[allow(clippy::integer_division_remainder_used)]
19fn use_hint<TwoGamma2: Unsigned>(h: bool, r: Elem) -> Elem {
20 let m: u32 = (BaseField::Q - 1) / TwoGamma2::U32;
21 let (r1, r0) = r.decompose::<TwoGamma2>();
22 let gamma2 = TwoGamma2::U32 / 2;
23 if h && r0.0 <= gamma2 {
24 Elem::new((r1.0 + 1) % m)
25 } else if h && r0.0 >= BaseField::Q - gamma2 {
26 Elem::new((r1.0 + m - 1) % m)
27 } else if h {
28 unreachable!();
31 } else {
32 r1
33 }
34}
35
36#[derive(Clone, PartialEq, Debug)]
37pub struct Hint<P>(pub Array<Array<bool, U256>, P::K>)
38where
39 P: SignatureParams;
40
41impl<P> Default for Hint<P>
42where
43 P: SignatureParams,
44{
45 fn default() -> Self {
46 Self(Array::default())
47 }
48}
49
50impl<P> Hint<P>
51where
52 P: SignatureParams,
53{
54 pub fn new(z: &Vector<P::K>, r: &Vector<P::K>) -> Self {
55 let zi = z.0.iter();
56 let ri = r.0.iter();
57
58 Self(
59 zi.zip(ri)
60 .map(|(zv, rv)| {
61 let zvi = zv.0.iter();
62 let rvi = rv.0.iter();
63
64 zvi.zip(rvi)
65 .map(|(&z, &r)| make_hint::<P::TwoGamma2>(z, r))
66 .collect()
67 })
68 .collect(),
69 )
70 }
71
72 pub fn hamming_weight(&self) -> usize {
73 self.0
74 .iter()
75 .map(|x| x.iter().filter(|x| **x).count())
76 .sum()
77 }
78
79 pub fn use_hint(&self, r: &Vector<P::K>) -> Vector<P::K> {
80 let hi = self.0.iter();
81 let ri = r.0.iter();
82
83 Vector::new(
84 hi.zip(ri)
85 .map(|(hv, rv)| {
86 let hvi = hv.iter();
87 let rvi = rv.0.iter();
88
89 Polynomial::new(
90 hvi.zip(rvi)
91 .map(|(&h, &r)| use_hint::<P::TwoGamma2>(h, r))
92 .collect(),
93 )
94 })
95 .collect(),
96 )
97 }
98
99 pub fn bit_pack(&self) -> EncodedHint<P> {
100 let mut y: EncodedHint<P> = Array::default();
101 let mut index = 0;
102 let omega = P::Omega::USIZE;
103 for i in 0..P::K::U8 {
104 let i_usize: usize = i.into();
105 for j in 0..256 {
106 if self.0[i_usize][j] {
107 y[index] = Truncate::truncate(j);
108 index += 1;
109 }
110 }
111
112 y[omega + i_usize] = Truncate::truncate(index);
113 }
114
115 y
116 }
117
118 fn monotonic(a: &[usize]) -> bool {
119 a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] <= *x)
120 }
121
122 pub fn bit_unpack(y: &EncodedHint<P>) -> Option<Self> {
123 let (indices, cuts) = P::split_hint(y);
124 let cuts: Array<usize, P::K> = cuts.iter().map(|x| usize::from(*x)).collect();
125
126 let indices: Array<usize, P::Omega> = indices.iter().map(|x| usize::from(*x)).collect();
127 let max_cut: usize = cuts.iter().copied().max().unwrap();
128 if !Self::monotonic(&cuts)
129 || max_cut > indices.len()
130 || indices[max_cut..].iter().copied().max().unwrap_or(0) > 0
131 {
132 return None;
133 }
134
135 let mut h = Self::default();
136 let mut start = 0;
137 for (i, &end) in cuts.iter().enumerate() {
138 let indices = &indices[start..end];
139
140 if !Self::monotonic(indices) {
141 return None;
142 }
143
144 for &j in indices {
145 h.0[i][j] = true;
146 }
147
148 start = end;
149 }
150
151 Some(h)
152 }
153}