ml_kem/
encode.rs

1use hybrid_array::{
2    typenum::{Unsigned, U256},
3    Array,
4};
5
6use crate::algebra::{
7    FieldElement, Integer, NttPolynomial, NttVector, Polynomial, PolynomialVector,
8};
9use crate::param::{ArraySize, EncodedPolynomial, EncodingSize, VectorEncodingSize};
10use crate::util::Truncate;
11
12type DecodedValue = Array<FieldElement, U256>;
13
14// Algorithm 4 ByteEncode_d(F)
15//
16// Note: This algorithm performs compression as well as encoding.
17fn byte_encode<D: EncodingSize>(vals: &DecodedValue) -> EncodedPolynomial<D> {
18    let val_step = D::ValueStep::USIZE;
19    let byte_step = D::ByteStep::USIZE;
20
21    let mut bytes = EncodedPolynomial::<D>::default();
22
23    let vc = vals.chunks(val_step);
24    let bc = bytes.chunks_mut(byte_step);
25    for (v, b) in vc.zip(bc) {
26        let mut x = 0u128;
27        for (j, vj) in v.iter().enumerate() {
28            x |= u128::from(vj.0) << (D::USIZE * j);
29        }
30
31        let xb = x.to_le_bytes();
32        b.copy_from_slice(&xb[..byte_step]);
33    }
34
35    bytes
36}
37
38// Algorithm 5 ByteDecode_d(F)
39//
40// Note: This function performs decompression as well as decoding.
41fn byte_decode<D: EncodingSize>(bytes: &EncodedPolynomial<D>) -> DecodedValue {
42    let val_step = D::ValueStep::USIZE;
43    let byte_step = D::ByteStep::USIZE;
44    let mask = (1 << D::USIZE) - 1;
45
46    let mut vals = DecodedValue::default();
47
48    let vc = vals.chunks_mut(val_step);
49    let bc = bytes.chunks(byte_step);
50    for (v, b) in vc.zip(bc) {
51        let mut xb = [0u8; 16];
52        xb[..byte_step].copy_from_slice(b);
53
54        let x = u128::from_le_bytes(xb);
55        for (j, vj) in v.iter_mut().enumerate() {
56            let val: Integer = (x >> (D::USIZE * j)).truncate();
57            vj.0 = val & mask;
58
59            if D::USIZE == 12 {
60                vj.0 %= FieldElement::Q;
61            }
62        }
63    }
64
65    vals
66}
67
68pub trait Encode<D: EncodingSize> {
69    type EncodedSize: ArraySize;
70    fn encode(&self) -> Array<u8, Self::EncodedSize>;
71    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self;
72}
73
74impl<D: EncodingSize> Encode<D> for Polynomial {
75    type EncodedSize = D::EncodedPolynomialSize;
76
77    fn encode(&self) -> Array<u8, Self::EncodedSize> {
78        byte_encode::<D>(&self.0)
79    }
80
81    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
82        Self(byte_decode::<D>(enc))
83    }
84}
85
86impl<D, K> Encode<D> for PolynomialVector<K>
87where
88    K: ArraySize,
89    D: VectorEncodingSize<K>,
90{
91    type EncodedSize = D::EncodedPolynomialVectorSize;
92
93    fn encode(&self) -> Array<u8, Self::EncodedSize> {
94        let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
95        <D as VectorEncodingSize<K>>::flatten(polys)
96    }
97
98    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
99        let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
100        Self(
101            unfold
102                .iter()
103                .map(|&x| <Polynomial as Encode<D>>::decode(x))
104                .collect(),
105        )
106    }
107}
108
109impl<D: EncodingSize> Encode<D> for NttPolynomial {
110    type EncodedSize = D::EncodedPolynomialSize;
111
112    fn encode(&self) -> Array<u8, Self::EncodedSize> {
113        byte_encode::<D>(&self.0)
114    }
115
116    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
117        Self(byte_decode::<D>(enc))
118    }
119}
120
121impl<D, K> Encode<D> for NttVector<K>
122where
123    D: VectorEncodingSize<K>,
124    K: ArraySize,
125{
126    type EncodedSize = D::EncodedPolynomialVectorSize;
127
128    fn encode(&self) -> Array<u8, Self::EncodedSize> {
129        let polys = self.0.iter().map(|x| Encode::<D>::encode(x)).collect();
130        <D as VectorEncodingSize<K>>::flatten(polys)
131    }
132
133    fn decode(enc: &Array<u8, Self::EncodedSize>) -> Self {
134        let unfold = <D as VectorEncodingSize<K>>::unflatten(enc);
135        Self(
136            unfold
137                .iter()
138                .map(|&x| <NttPolynomial as Encode<D>>::decode(x))
139                .collect(),
140        )
141    }
142}
143
144#[cfg(test)]
145pub(crate) mod test {
146    use super::*;
147    use core::fmt::Debug;
148    use core::ops::Rem;
149    use hybrid_array::typenum::{
150        marker_traits::Zero, operator_aliases::Mod, U1, U10, U11, U12, U2, U3, U4, U5, U6, U8,
151    };
152    use rand::Rng;
153
154    use crate::param::EncodedPolynomialVector;
155
156    // A helper trait to construct larger arrays by repeating smaller ones
157    trait Repeat<T: Clone, D: ArraySize> {
158        fn repeat(&self) -> Array<T, D>;
159    }
160
161    impl<T, N, D> Repeat<T, D> for Array<T, N>
162    where
163        N: ArraySize,
164        T: Clone,
165        D: ArraySize + Rem<N>,
166        Mod<D, N>: Zero,
167    {
168        #[allow(clippy::integer_division_remainder_used)]
169        fn repeat(&self) -> Array<T, D> {
170            Array::from_fn(|i| self[i % N::USIZE].clone())
171        }
172    }
173
174    #[allow(clippy::integer_division_remainder_used)]
175    fn byte_codec_test<D>(decoded: DecodedValue, encoded: EncodedPolynomial<D>)
176    where
177        D: EncodingSize,
178    {
179        // Test known answer
180        let actual_encoded = byte_encode::<D>(&decoded);
181        assert_eq!(actual_encoded, encoded);
182
183        let actual_decoded = byte_decode::<D>(&encoded);
184        assert_eq!(actual_decoded, decoded);
185
186        // Test random decode/encode and encode/decode round trips
187        let mut rng = rand::thread_rng();
188        let mut decoded: Array<Integer, U256> = Default::default();
189        rng.fill(decoded.as_mut_slice());
190        let m = match D::USIZE {
191            12 => FieldElement::Q,
192            d => (1 as Integer) << d,
193        };
194        let decoded = decoded.iter().map(|x| FieldElement(x % m)).collect();
195
196        let actual_encoded = byte_encode::<D>(&decoded);
197        let actual_decoded = byte_decode::<D>(&actual_encoded);
198        assert_eq!(actual_decoded, decoded);
199
200        let actual_reencoded = byte_encode::<D>(&decoded);
201        assert_eq!(actual_reencoded, actual_encoded);
202    }
203
204    #[test]
205    fn byte_codec() {
206        // The 1-bit can only represent decoded values equal to 0 or 1.
207        let decoded: DecodedValue = Array::<_, U2>([FieldElement(0), FieldElement(1)]).repeat();
208        let encoded: EncodedPolynomial<U1> = Array([0xaa; 32]);
209        byte_codec_test::<U1>(decoded, encoded);
210
211        // For other codec widths, we use a standard sequence
212        let decoded: DecodedValue = Array::<_, U8>([
213            FieldElement(0),
214            FieldElement(1),
215            FieldElement(2),
216            FieldElement(3),
217            FieldElement(4),
218            FieldElement(5),
219            FieldElement(6),
220            FieldElement(7),
221        ])
222        .repeat();
223
224        let encoded: EncodedPolynomial<U4> = Array::<_, U4>([0x10, 0x32, 0x54, 0x76]).repeat();
225        byte_codec_test::<U4>(decoded, encoded);
226
227        let encoded: EncodedPolynomial<U5> =
228            Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
229        byte_codec_test::<U5>(decoded, encoded);
230
231        let encoded: EncodedPolynomial<U6> =
232            Array::<_, U6>([0x40, 0x20, 0x0c, 0x44, 0x61, 0x1c]).repeat();
233        byte_codec_test::<U6>(decoded, encoded);
234
235        let encoded: EncodedPolynomial<U10> =
236            Array::<_, U10>([0x00, 0x04, 0x20, 0xc0, 0x00, 0x04, 0x14, 0x60, 0xc0, 0x01]).repeat();
237        byte_codec_test::<U10>(decoded, encoded);
238
239        let encoded: EncodedPolynomial<U11> = Array::<_, U11>([
240            0x00, 0x08, 0x80, 0x00, 0x06, 0x40, 0x80, 0x02, 0x18, 0xe0, 0x00,
241        ])
242        .repeat();
243        byte_codec_test::<U11>(decoded, encoded);
244
245        let encoded: EncodedPolynomial<U12> = Array::<_, U12>([
246            0x00, 0x10, 0x00, 0x02, 0x30, 0x00, 0x04, 0x50, 0x00, 0x06, 0x70, 0x00,
247        ])
248        .repeat();
249        byte_codec_test::<U12>(decoded, encoded);
250    }
251
252    #[allow(clippy::integer_division_remainder_used)]
253    #[test]
254    fn byte_codec_12_mod() {
255        // DecodeBytes_12 is required to reduce mod q
256        let encoded: EncodedPolynomial<U12> = Array([0xff; 384]);
257        let decoded: DecodedValue = Array([FieldElement(0xfff % FieldElement::Q); 256]);
258
259        let actual_decoded = byte_decode::<U12>(&encoded);
260        assert_eq!(actual_decoded, decoded);
261    }
262
263    fn vector_codec_known_answer_test<D, T>(decoded: T, encoded: Array<u8, T::EncodedSize>)
264    where
265        D: EncodingSize,
266        T: Encode<D> + PartialEq + Debug,
267    {
268        let actual_encoded = decoded.encode();
269        assert_eq!(actual_encoded, encoded);
270
271        let actual_decoded: T = Encode::decode(&encoded);
272        assert_eq!(actual_decoded, decoded);
273    }
274
275    #[test]
276    fn vector_codec() {
277        let poly = Polynomial(
278            Array::<_, U8>([
279                FieldElement(0),
280                FieldElement(1),
281                FieldElement(2),
282                FieldElement(3),
283                FieldElement(4),
284                FieldElement(5),
285                FieldElement(6),
286                FieldElement(7),
287            ])
288            .repeat(),
289        );
290
291        // The required vector sizes are 2, 3, and 4.
292        let decoded: PolynomialVector<U2> = PolynomialVector(Array([poly, poly]));
293        let encoded: EncodedPolynomialVector<U5, U2> =
294            Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
295        vector_codec_known_answer_test::<U5, PolynomialVector<U2>>(decoded, encoded);
296
297        let decoded: PolynomialVector<U3> = PolynomialVector(Array([poly, poly, poly]));
298        let encoded: EncodedPolynomialVector<U5, U3> =
299            Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
300        vector_codec_known_answer_test::<U5, PolynomialVector<U3>>(decoded, encoded);
301
302        let decoded: PolynomialVector<U4> = PolynomialVector(Array([poly, poly, poly, poly]));
303        let encoded: EncodedPolynomialVector<U5, U4> =
304            Array::<_, U5>([0x20, 0x88, 0x41, 0x8a, 0x39]).repeat();
305        vector_codec_known_answer_test::<U5, PolynomialVector<U4>>(decoded, encoded);
306    }
307}