1use hybrid_array::{
2 typenum::{Unsigned, U256},
3 Array,
4};
5
6use crate::algebra::{
7 FieldElement, Integer, NttPolynomial, NttVector, Polynomial, PolynomialVector,
8};
9use crate::param::{ArraySize, EncodedPolynomial, EncodingSize, VectorEncodingSize};
10use crate::util::Truncate;
11
12type DecodedValue = Array<FieldElement, U256>;
13
14fn byte_encode<D: EncodingSize>(vals: &DecodedValue) -> EncodedPolynomial<D> {
18 let val_step = D::ValueStep::USIZE;
19 let byte_step = D::ByteStep::USIZE;
20
21 let mut bytes = EncodedPolynomial::<D>::default();
22
23 let vc = vals.chunks(val_step);
24 let bc = bytes.chunks_mut(byte_step);
25 for (v, b) in vc.zip(bc) {
26 let mut x = 0u128;
27 for (j, vj) in v.iter().enumerate() {
28 x |= u128::from(vj.0) << (D::USIZE * j);
29 }
30
31 let xb = x.to_le_bytes();
32 b.copy_from_slice(&xb[..byte_step]);
33 }
34
35 bytes
36}
37
38fn byte_decode<D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> DecodedValue {
42 let val_step = D::ValueStep::USIZE;
43 let byte_step = D::ByteStep::USIZE;
44 let mask = (1 << D::USIZE) - 1;
45
46 let mut vals = DecodedValue::default();
47
48 let vc = vals.chunks_mut(val_step);
49 let bc = bytes.chunks(byte_step);
50 for (v, b) in vc.zip(bc) {
51 let mut xb = [0u8; 16];
52 xb[..byte_step].copy_from_slice(b);
53
54 let x = u128::from_le_bytes(xb);
55 for (j, vj) in v.iter_mut().enumerate() {
56 let val: Integer = (x >> (D::USIZE * j)).truncate();
57 vj.0 = val & mask;
58
59 if D::USIZE == 12 {
60 vj.0 %= FieldElement::Q;
61 }
62 }
63 }
64
65 vals
66}
67
68pub trait Encode<D: EncodingSize> {
69 type EncodedSize: ArraySize;
70 fn encode(&self) -> Array<u8, Self::EncodedSize>;
71 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self;
72}
73
74impl<D: EncodingSize> Encode<D> for Polynomial {
75 type EncodedSize = D::EncodedPolynomialSize;
76
77 fn encode(&self) -> Array<u8, Self::EncodedSize> {
78 byte_encode::<D>(&self.0)
79 }
80
81 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
82 Self(byte_decode::<D>(enc))
83 }
84}
85
86impl<D, K> Encode<D> for PolynomialVector<K>
87where
88 K: ArraySize,
89 D: VectorEncodingSize<K>,
90{
91 type EncodedSize = D::EncodedPolynomialVectorSize;
92
93 fn encode(&self) -> Array<u8, Self::EncodedSize> {
94 let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
95 <D as VectorEncodingSize<K>>::flatten(polys)
96 }
97
98 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
99 let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
100 Self(
101 unfold
102 .iter()
103 .map(|&x| <Polynomial as Encode<D>>::decode(x))
104 .collect(),
105 )
106 }
107}
108
109impl<D: EncodingSize> Encode<D> for NttPolynomial {
110 type EncodedSize = D::EncodedPolynomialSize;
111
112 fn encode(&self) -> Array<u8, Self::EncodedSize> {
113 byte_encode::<D>(&self.0)
114 }
115
116 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
117 Self(byte_decode::<D>(enc))
118 }
119}
120
121impl<D, K> Encode<D> for NttVector<K>
122where
123 D: VectorEncodingSize<K>,
124 K: ArraySize,
125{
126 type EncodedSize = D::EncodedPolynomialVectorSize;
127
128 fn encode(&self) -> Array<u8, Self::EncodedSize> {
129 let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
130 <D as VectorEncodingSize<K>>::flatten(polys)
131 }
132
133 fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
134 let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
135 Self(
136 unfold
137 .iter()
138 .map(|&x| <NttPolynomial as Encode<D>>::decode(x))
139 .collect(),
140 )
141 }
142}
143
144#[cfg(test)]
145pub(crate) mod test {
146 use super::*;
147 use core::fmt::Debug;
148 use core::ops::Rem;
149 use hybrid_array::typenum::{
150 marker_traits::Zero, operator_aliases::Mod, U1, U10, U11, U12, U2, U3, U4, U5, U6, U8,
151 };
152 use rand::Rng;
153
154 use crate::param::EncodedPolynomialVector;
155
156 trait Repeat<T: Clone, D: ArraySize> {
158 fn repeat(&self) -> Array<T, D>;
159 }
160
161 impl<T, N, D> Repeat<T, D> for Array<T, N>
162 where
163 N: ArraySize,
164 T: Clone,
165 D: ArraySize + Rem<N>,
166 Mod<D, N>: Zero,
167 {
168 #[allow(clippy::integer_division_remainder_used)]
169 fn repeat(&self) -> Array<T, D> {
170 Array::from_fn(|i| self[i % N::USIZE].clone())
171 }
172 }
173
174 #[allow(clippy::integer_division_remainder_used)]
175 fn byte_codec_test<D>(decoded: DecodedValue, encoded: EncodedPolynomial<D>)
176 where
177 D: EncodingSize,
178 {
179 let actual_encoded = byte_encode::<D>(&decoded);
181 assert_eq!(actual_encoded, encoded);
182
183 let actual_decoded = byte_decode::<D>(&encoded);
184 assert_eq!(actual_decoded, decoded);
185
186 let mut rng = rand::thread_rng();
188 let mut decoded: Array<Integer, U256> = Default::default();
189 rng.fill(decoded.as_mut_slice());
190 let m = match D::USIZE {
191 12 => FieldElement::Q,
192 d => (1 as Integer) << d,
193 };
194 let decoded = decoded.iter().map(|x| FieldElement(x % m)).collect();
195
196 let actual_encoded = byte_encode::<D>(&decoded);
197 let actual_decoded = byte_decode::<D>(&actual_encoded);
198 assert_eq!(actual_decoded, decoded);
199
200 let actual_reencoded = byte_encode::<D>(&decoded);
201 assert_eq!(actual_reencoded, actual_encoded);
202 }
203
204 #[test]
205 fn byte_codec() {
206 let decoded: DecodedValue = Array::<_, U2>([FieldElement(0), FieldElement(1)]).repeat();
208 let encoded: EncodedPolynomial<U1> = Array([0xaa; 32]);
209 byte_codec_test::<U1>(decoded, encoded);
210
211 let decoded: DecodedValue = Array::<_, U8>([
213 FieldElement(0),
214 FieldElement(1),
215 FieldElement(2),
216 FieldElement(3),
217 FieldElement(4),
218 FieldElement(5),
219 FieldElement(6),
220 FieldElement(7),
221 ])
222 .repeat();
223
224 let encoded: EncodedPolynomial<U4> = Array::<_, U4>([0x10, 0x32, 0x54, 0x76]).repeat();
225 byte_codec_test::<U4>(decoded, encoded);
226
227 let encoded: EncodedPolynomial<U5> =
228 Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
229 byte_codec_test::<U5>(decoded, encoded);
230
231 let encoded: EncodedPolynomial<U6> =
232 Array::<_, U6>([0x40, 0x20, 0x0c, 0x44, 0x61, 0x1c]).repeat();
233 byte_codec_test::<U6>(decoded, encoded);
234
235 let encoded: EncodedPolynomial<U10> =
236 Array::<_, U10>([0x00, 0x04, 0x20, 0xc0, 0x00, 0x04, 0x14, 0x60, 0xc0, 0x01]).repeat();
237 byte_codec_test::<U10>(decoded, encoded);
238
239 let encoded: EncodedPolynomial<U11> = Array::<_, U11>([
240 0x00, 0x08, 0x80, 0x00, 0x06, 0x40, 0x80, 0x02, 0x18, 0xe0, 0x00,
241 ])
242 .repeat();
243 byte_codec_test::<U11>(decoded, encoded);
244
245 let encoded: EncodedPolynomial<U12> = Array::<_, U12>([
246 0x00, 0x10, 0x00, 0x02, 0x30, 0x00, 0x04, 0x50, 0x00, 0x06, 0x70, 0x00,
247 ])
248 .repeat();
249 byte_codec_test::<U12>(decoded, encoded);
250 }
251
252 #[allow(clippy::integer_division_remainder_used)]
253 #[test]
254 fn byte_codec_12_mod() {
255 let encoded: EncodedPolynomial<U12> = Array([0xff; 384]);
257 let decoded: DecodedValue = Array([FieldElement(0xfff % FieldElement::Q); 256]);
258
259 let actual_decoded = byte_decode::<U12>(&encoded);
260 assert_eq!(actual_decoded, decoded);
261 }
262
263 fn vector_codec_known_answer_test<D, T>(decoded: T, encoded: Array<u8, T::EncodedSize>)
264 where
265 D: EncodingSize,
266 T: Encode<D> + PartialEq + Debug,
267 {
268 let actual_encoded = decoded.encode();
269 assert_eq!(actual_encoded, encoded);
270
271 let actual_decoded: T = Encode::decode(&encoded);
272 assert_eq!(actual_decoded, decoded);
273 }
274
275 #[test]
276 fn vector_codec() {
277 let poly = Polynomial(
278 Array::<_, U8>([
279 FieldElement(0),
280 FieldElement(1),
281 FieldElement(2),
282 FieldElement(3),
283 FieldElement(4),
284 FieldElement(5),
285 FieldElement(6),
286 FieldElement(7),
287 ])
288 .repeat(),
289 );
290
291 let decoded: PolynomialVector<U2> = PolynomialVector(Array([poly, poly]));
293 let encoded: EncodedPolynomialVector<U5, U2> =
294 Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
295 vector_codec_known_answer_test::<U5, PolynomialVector<U2>>(decoded, encoded);
296
297 let decoded: PolynomialVector<U3> = PolynomialVector(Array([poly, poly, poly]));
298 let encoded: EncodedPolynomialVector<U5, U3> =
299 Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
300 vector_codec_known_answer_test::<U5, PolynomialVector<U3>>(decoded, encoded);
301
302 let decoded: PolynomialVector<U4> = PolynomialVector(Array([poly, poly, poly, poly]));
303 let encoded: EncodedPolynomialVector<U5, U4> =
304 Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
305 vector_codec_known_answer_test::<U5, PolynomialVector<U4>>(decoded, encoded);
306 }
307}