1use ctutils::{CtEq, CtGt, CtLt, CtSelect};
2use hybrid_array::{
3 ArraySize,
4 typenum::{Shleft, U1, U13, Unsigned},
5};
6use module_lattice::{Field, Truncate};
7
8module_lattice::define_field!(BaseField, u32, u64, u128, 8_380_417);
9
10pub(crate) type Int = <BaseField as Field>::Int;
11
12pub(crate) type Elem = module_lattice::Elem<BaseField>;
13pub(crate) type Polynomial = module_lattice::Polynomial<BaseField>;
14pub(crate) type Vector<K> = module_lattice::Vector<BaseField, K>;
15pub(crate) type NttPolynomial = module_lattice::NttPolynomial<BaseField>;
16pub(crate) type NttVector<K> = module_lattice::NttVector<BaseField, K>;
17pub(crate) type NttMatrix<K, L> = module_lattice::NttMatrix<BaseField, K, L>;
18
19pub(crate) trait BarrettReduce: Unsigned {
23 const SHIFT: usize;
24 const MULTIPLIER: u64;
25
26 fn reduce(x: u32) -> u32 {
27 let m = Self::U64;
28 let x: u64 = x.into();
29 let quotient = (x * Self::MULTIPLIER) >> Self::SHIFT;
30 let remainder = x - quotient * m;
31
32 let r_small: u32 = Truncate::truncate(remainder);
33 let r_large: u32 = Truncate::truncate(remainder.wrapping_sub(m));
34 u32::ct_select(&r_large, &r_small, remainder.ct_lt(&m))
35 }
36}
37
38impl<M> BarrettReduce for M
39where
40 M: Unsigned,
41{
42 #[allow(clippy::as_conversions)]
43 const SHIFT: usize = 2 * (M::U64.ilog2() + 1) as usize;
44 #[allow(clippy::integer_division_remainder_used, reason = "constant")]
45 const MULTIPLIER: u64 = (1 << Self::SHIFT) / M::U64;
46}
47
48pub(crate) trait Decompose {
49 fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem);
50}
51
52pub(crate) trait ConstantTimeDiv: Unsigned {
58 const CT_DIV_SHIFT: usize;
60 const CT_DIV_MULTIPLIER: u64;
62
63 #[allow(clippy::inline_always)] #[inline(always)]
67 fn ct_div(x: u32) -> u32 {
68 let x64 = u64::from(x);
71 let quotient = (x64 * Self::CT_DIV_MULTIPLIER) >> Self::CT_DIV_SHIFT;
72 #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
75 let result = quotient as u32;
76 result
77 }
78}
79
80impl<M> ConstantTimeDiv for M
81where
82 M: Unsigned,
83{
84 const CT_DIV_SHIFT: usize = 48;
88
89 #[allow(clippy::integer_division_remainder_used, reason = "constant")]
92 const CT_DIV_MULTIPLIER: u64 = (1u64 << Self::CT_DIV_SHIFT).div_ceil(M::U64);
93}
94
95impl Decompose for Elem {
96 fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem) {
102 let r_plus = self.clone();
103 let r0 = r_plus.mod_plus_minus::<TwoGamma2>();
104
105 let diff = r_plus - r0;
106 let is_edge = diff.0.ct_eq(&(BaseField::Q - 1));
107
108 let edge = (Elem::new(0), r0 - Elem::new(1));
110 let r1 = Elem::new(TwoGamma2::ct_div(diff.0));
111 let normal = (r1, r0);
112
113 let r1_out = Elem::new(u32::ct_select(&normal.0.0, &edge.0.0, is_edge));
114 let r0_out = Elem::new(u32::ct_select(&normal.1.0, &edge.1.0, is_edge));
115 (r1_out, r0_out)
116 }
117}
118
119#[allow(clippy::module_name_repetitions)] pub(crate) trait AlgebraExt: Sized {
121 fn mod_plus_minus<M: Unsigned>(&self) -> Self;
122 fn infinity_norm(&self) -> Int;
123 fn power2round(&self) -> (Self, Self);
124 fn high_bits<TwoGamma2: Unsigned>(&self) -> Self;
125 fn low_bits<TwoGamma2: Unsigned>(&self) -> Self;
126}
127
128impl AlgebraExt for Elem {
129 fn mod_plus_minus<M: Unsigned>(&self) -> Self {
130 let raw_mod = Elem::new(M::reduce(self.0));
131 let in_lower_half = !raw_mod.0.ct_gt(&(M::U32 >> 1));
132 Elem::new(u32::ct_select(
133 &(raw_mod - Elem::new(M::U32)).0,
134 &raw_mod.0,
135 in_lower_half,
136 ))
137 }
138
139 fn infinity_norm(&self) -> u32 {
148 let in_lower_half = !self.0.ct_gt(&(BaseField::Q >> 1));
149 u32::ct_select(&(BaseField::Q - self.0), &self.0, in_lower_half)
150 }
151
152 fn power2round(&self) -> (Self, Self) {
159 type D = U13;
160 type Pow2D = Shleft<U1, D>;
161
162 let r_plus = self.clone();
163 let r0 = r_plus.mod_plus_minus::<Pow2D>();
164 let r1 = Elem::new((r_plus - r0).0 >> D::USIZE);
165
166 (r1, r0)
167 }
168
169 fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
171 self.decompose::<TwoGamma2>().0
172 }
173
174 fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
176 self.decompose::<TwoGamma2>().1
177 }
178}
179
180impl AlgebraExt for Polynomial {
181 fn mod_plus_minus<M: Unsigned>(&self) -> Self {
182 Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
183 }
184
185 fn infinity_norm(&self) -> u32 {
186 self.0
187 .iter()
188 .map(AlgebraExt::infinity_norm)
189 .max()
190 .expect("should have a maximum")
191 }
192
193 fn power2round(&self) -> (Self, Self) {
194 let mut r1 = Self::default();
195 let mut r0 = Self::default();
196
197 for (i, x) in self.0.iter().enumerate() {
198 (r1.0[i], r0.0[i]) = x.power2round();
199 }
200
201 (r1, r0)
202 }
203
204 fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
205 Self(
206 self.0
207 .iter()
208 .map(AlgebraExt::high_bits::<TwoGamma2>)
209 .collect(),
210 )
211 }
212
213 fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
214 Self(
215 self.0
216 .iter()
217 .map(AlgebraExt::low_bits::<TwoGamma2>)
218 .collect(),
219 )
220 }
221}
222
223impl<K: ArraySize> AlgebraExt for Vector<K> {
224 fn mod_plus_minus<M: Unsigned>(&self) -> Self {
225 Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
226 }
227
228 fn infinity_norm(&self) -> u32 {
229 self.0
230 .iter()
231 .map(AlgebraExt::infinity_norm)
232 .max()
233 .expect("should have a maximum")
234 }
235
236 fn power2round(&self) -> (Self, Self) {
237 let mut r1 = Self::default();
238 let mut r0 = Self::default();
239
240 for (i, x) in self.0.iter().enumerate() {
241 (r1.0[i], r0.0[i]) = x.power2round();
242 }
243
244 (r1, r0)
245 }
246
247 fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
248 Self(
249 self.0
250 .iter()
251 .map(AlgebraExt::high_bits::<TwoGamma2>)
252 .collect(),
253 )
254 }
255
256 fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
257 Self(
258 self.0
259 .iter()
260 .map(AlgebraExt::low_bits::<TwoGamma2>)
261 .collect(),
262 )
263 }
264}
265
266#[cfg(test)]
267#[allow(clippy::integer_division_remainder_used, reason = "tests")]
268mod test {
269 use super::*;
270
271 use crate::{MlDsa65, ParameterSet};
272
273 type Mod = <MlDsa65 as ParameterSet>::TwoGamma2;
274 const MOD: u32 = Mod::U32;
275 const MOD_ELEM: Elem = Elem::new(MOD);
276
277 #[test]
278 fn mod_plus_minus() {
279 for x in 0..MOD {
280 let x = Elem::new(x);
282 let x0 = x.mod_plus_minus::<Mod>();
283
284 let positive_bound = x0.0 <= MOD / 2;
286 let negative_bound = x0.0 > BaseField::Q - MOD / 2;
287 assert!(positive_bound || negative_bound);
288
289 let xn = x + MOD_ELEM;
293 let x0n = x0 + MOD_ELEM;
294 assert_eq!(xn.0 % MOD, x0n.0 % MOD);
295 }
296 }
297
298 #[test]
299 fn decompose() {
300 for x in 0..MOD {
301 let x = Elem::new(x);
302 let (x1, x0) = x.decompose::<Mod>();
303
304 let positive_bound = x0.0 <= MOD / 2;
307 let negative_bound = x0.0 >= BaseField::Q - MOD / 2;
308 assert!(positive_bound || negative_bound);
309
310 let xx = (MOD * x1.0 + x0.0) % BaseField::Q;
312 assert_eq!(xx, x.0);
313 }
314 }
315
316 #[test]
317 fn barrett_reduce_boundary() {
318 let m_minus_1 = Mod::U32 - 1;
319 assert_eq!(Mod::reduce(m_minus_1), m_minus_1);
320 assert_eq!(Mod::reduce(Mod::U32), 0);
321 assert_eq!(Mod::reduce(Mod::U32 + 1), 1);
322 assert_eq!(Mod::reduce(2 * Mod::U32 - 1), m_minus_1);
323 assert_eq!(Mod::reduce(2 * Mod::U32), 0);
324 }
325
326 #[test]
327 fn constant_time_div_accuracy() {
328 for x in 0..1000 {
329 assert_eq!(Mod::ct_div(x), x / Mod::U32);
330 }
331 for x in (BaseField::Q - 1000)..BaseField::Q {
332 assert_eq!(Mod::ct_div(x), x / Mod::U32);
333 }
334 }
335
336 #[test]
337 fn decompose_edge_case() {
338 let q_minus_1 = Elem::new(BaseField::Q - 1);
339 let (r1, r0) = q_minus_1.decompose::<Mod>();
340 let reconstructed = (MOD * r1.0 + r0.0) % BaseField::Q;
341 assert_eq!(reconstructed, q_minus_1.0);
342 }
343
344 #[test]
345 fn high_low_bits_consistency() {
346 for x in [0, 1, MOD / 2, MOD - 1, MOD, MOD + 1, BaseField::Q - 1] {
347 let elem = Elem::new(x);
348 let (decomp_high, decomp_low) = elem.decompose::<Mod>();
349 assert_eq!(elem.high_bits::<Mod>(), decomp_high);
350 assert_eq!(elem.low_bits::<Mod>(), decomp_low);
351 }
352 }
353}