1pub use crate::module_lattice::algebra::Field;
2pub use crate::module_lattice::util::Truncate;
3use hybrid_array::{
4 ArraySize,
5 typenum::{Shleft, U1, U13, Unsigned},
6};
7
8use crate::define_field;
9use crate::module_lattice::algebra;
10
11define_field!(BaseField, u32, u64, u128, 8_380_417);
12
13pub type Int = <BaseField as Field>::Int;
14
15pub type Elem = algebra::Elem<BaseField>;
16pub type Polynomial = algebra::Polynomial<BaseField>;
17pub type Vector<K> = algebra::Vector<BaseField, K>;
18pub type NttPolynomial = algebra::NttPolynomial<BaseField>;
19pub type NttVector<K> = algebra::NttVector<BaseField, K>;
20pub type NttMatrix<K, L> = algebra::NttMatrix<BaseField, K, L>;
21
22pub trait BarrettReduce: Unsigned {
26 const SHIFT: usize;
27 const MULTIPLIER: u64;
28
29 fn reduce(x: u32) -> u32 {
30 let m = Self::U64;
31 let x: u64 = x.into();
32 let quotient = (x * Self::MULTIPLIER) >> Self::SHIFT;
33 let remainder = x - quotient * m;
34
35 if remainder < m {
36 Truncate::truncate(remainder)
37 } else {
38 Truncate::truncate(remainder - m)
39 }
40 }
41}
42
43impl<M> BarrettReduce for M
44where
45 M: Unsigned,
46{
47 #[allow(clippy::as_conversions)]
48 const SHIFT: usize = 2 * (M::U64.ilog2() + 1) as usize;
49 #[allow(clippy::integer_division_remainder_used)]
50 const MULTIPLIER: u64 = (1 << Self::SHIFT) / M::U64;
51}
52
53pub trait Decompose {
54 fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem);
55}
56
57impl Decompose for Elem {
58 fn decompose<TwoGamma2: Unsigned>(self) -> (Elem, Elem) {
60 let r_plus = self.clone();
61 let r0 = r_plus.mod_plus_minus::<TwoGamma2>();
62
63 if r_plus - r0 == Elem::new(BaseField::Q - 1) {
64 (Elem::new(0), r0 - Elem::new(1))
65 } else {
66 let mut r1 = r_plus - r0;
67 r1.0 /= TwoGamma2::U32;
68 (r1, r0)
69 }
70 }
71}
72
73#[allow(clippy::module_name_repetitions)] pub trait AlgebraExt: Sized {
75 fn mod_plus_minus<M: Unsigned>(&self) -> Self;
76 fn infinity_norm(&self) -> Int;
77 fn power2round(&self) -> (Self, Self);
78 fn high_bits<TwoGamma2: Unsigned>(&self) -> Self;
79 fn low_bits<TwoGamma2: Unsigned>(&self) -> Self;
80}
81
82impl AlgebraExt for Elem {
83 fn mod_plus_minus<M: Unsigned>(&self) -> Self {
84 let raw_mod = Elem::new(M::reduce(self.0));
85 if raw_mod.0 <= M::U32 >> 1 {
86 raw_mod
87 } else {
88 raw_mod - Elem::new(M::U32)
89 }
90 }
91
92 fn infinity_norm(&self) -> u32 {
101 if self.0 <= BaseField::Q >> 1 {
102 self.0
103 } else {
104 BaseField::Q - self.0
105 }
106 }
107
108 fn power2round(&self) -> (Self, Self) {
115 type D = U13;
116 type Pow2D = Shleft<U1, D>;
117
118 let r_plus = self.clone();
119 let r0 = r_plus.mod_plus_minus::<Pow2D>();
120 let r1 = Elem::new((r_plus - r0).0 >> D::USIZE);
121
122 (r1, r0)
123 }
124
125 fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
127 self.decompose::<TwoGamma2>().0
128 }
129
130 fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
132 self.decompose::<TwoGamma2>().1
133 }
134}
135
136impl AlgebraExt for Polynomial {
137 fn mod_plus_minus<M: Unsigned>(&self) -> Self {
138 Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
139 }
140
141 fn infinity_norm(&self) -> u32 {
142 self.0.iter().map(AlgebraExt::infinity_norm).max().unwrap()
143 }
144
145 fn power2round(&self) -> (Self, Self) {
146 let mut r1 = Self::default();
147 let mut r0 = Self::default();
148
149 for (i, x) in self.0.iter().enumerate() {
150 (r1.0[i], r0.0[i]) = x.power2round();
151 }
152
153 (r1, r0)
154 }
155
156 fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
157 Self(
158 self.0
159 .iter()
160 .map(AlgebraExt::high_bits::<TwoGamma2>)
161 .collect(),
162 )
163 }
164
165 fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
166 Self(
167 self.0
168 .iter()
169 .map(AlgebraExt::low_bits::<TwoGamma2>)
170 .collect(),
171 )
172 }
173}
174
175impl<K: ArraySize> AlgebraExt for Vector<K> {
176 fn mod_plus_minus<M: Unsigned>(&self) -> Self {
177 Self(self.0.iter().map(AlgebraExt::mod_plus_minus::<M>).collect())
178 }
179
180 fn infinity_norm(&self) -> u32 {
181 self.0.iter().map(AlgebraExt::infinity_norm).max().unwrap()
182 }
183
184 fn power2round(&self) -> (Self, Self) {
185 let mut r1 = Self::default();
186 let mut r0 = Self::default();
187
188 for (i, x) in self.0.iter().enumerate() {
189 (r1.0[i], r0.0[i]) = x.power2round();
190 }
191
192 (r1, r0)
193 }
194
195 fn high_bits<TwoGamma2: Unsigned>(&self) -> Self {
196 Self(
197 self.0
198 .iter()
199 .map(AlgebraExt::high_bits::<TwoGamma2>)
200 .collect(),
201 )
202 }
203
204 fn low_bits<TwoGamma2: Unsigned>(&self) -> Self {
205 Self(
206 self.0
207 .iter()
208 .map(AlgebraExt::low_bits::<TwoGamma2>)
209 .collect(),
210 )
211 }
212}
213
214#[cfg(test)]
215mod test {
216 use super::*;
217
218 use crate::{MlDsa65, ParameterSet};
219
220 type Mod = <MlDsa65 as ParameterSet>::TwoGamma2;
221 const MOD: u32 = Mod::U32;
222 const MOD_ELEM: Elem = Elem::new(MOD);
223
224 #[test]
225 fn mod_plus_minus() {
226 for x in 0..MOD {
227 let x = Elem::new(x);
229 let x0 = x.mod_plus_minus::<Mod>();
230
231 let positive_bound = x0.0 <= MOD / 2;
233 let negative_bound = x0.0 > BaseField::Q - MOD / 2;
234 assert!(positive_bound || negative_bound);
235
236 let xn = x + MOD_ELEM;
240 let x0n = x0 + MOD_ELEM;
241 assert_eq!(xn.0 % MOD, x0n.0 % MOD);
242 }
243 }
244
245 #[test]
246 fn decompose() {
247 for x in 0..MOD {
248 let x = Elem::new(x);
249 let (x1, x0) = x.decompose::<Mod>();
250
251 let positive_bound = x0.0 <= MOD / 2;
254 let negative_bound = x0.0 >= BaseField::Q - MOD / 2;
255 assert!(positive_bound || negative_bound);
256
257 let xx = (MOD * x1.0 + x0.0) % BaseField::Q;
259 assert_eq!(xx, x.0);
260 }
261 }
262}