ml_kem/
compress.rs

1use crate::algebra::{FieldElement, Integer, Polynomial, PolynomialVector};
2use crate::param::{ArraySize, EncodingSize};
3use crate::util::Truncate;
4
5// A convenience trait to allow us to associate some constants with a typenum
6pub trait CompressionFactor: EncodingSize {
7    const POW2_HALF: u32;
8    const MASK: Integer;
9    const DIV_SHIFT: usize;
10    const DIV_MUL: u64;
11}
12
13impl<T> CompressionFactor for T
14where
15    T: EncodingSize,
16{
17    const POW2_HALF: u32 = 1 << (T::USIZE - 1);
18    const MASK: Integer = ((1 as Integer) << T::USIZE) - 1;
19    const DIV_SHIFT: usize = 34;
20    #[allow(clippy::integer_division_remainder_used)]
21    const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / FieldElement::Q64;
22}
23
24// Traits for objects that allow compression / decompression
25pub trait Compress {
26    fn compress<D: CompressionFactor>(&mut self) -> &Self;
27    fn decompress<D: CompressionFactor>(&mut self) -> &Self;
28}
29
30impl Compress for FieldElement {
31    // Equation 4.5: Compress_d(x) = round((2^d / q) x)
32    //
33    // Here and in decompression, we leverage the following facts:
34    //
35    //   round(a / b) = floor((a + b/2) / b)
36    //   a / q ~= (a * x) >> s where x >> s ~= 1/q
37    fn compress<D: CompressionFactor>(&mut self) -> &Self {
38        const Q_HALF: u64 = (FieldElement::Q64 + 1) >> 1;
39        let x = u64::from(self.0);
40        let y = ((((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT).truncate();
41        self.0 = y.truncate() & D::MASK;
42        self
43    }
44
45    // Equation 4.6: Decompress_d(x) = round((q / 2^d) x)
46    fn decompress<D: CompressionFactor>(&mut self) -> &Self {
47        let x = u32::from(self.0);
48        let y = ((x * FieldElement::Q32) + D::POW2_HALF) >> D::USIZE;
49        self.0 = y.truncate();
50        self
51    }
52}
53impl Compress for Polynomial {
54    fn compress<D: CompressionFactor>(&mut self) -> &Self {
55        for x in &mut self.0 {
56            x.compress::<D>();
57        }
58
59        self
60    }
61
62    fn decompress<D: CompressionFactor>(&mut self) -> &Self {
63        for x in &mut self.0 {
64            x.decompress::<D>();
65        }
66
67        self
68    }
69}
70
71impl<K: ArraySize> Compress for PolynomialVector<K> {
72    fn compress<D: CompressionFactor>(&mut self) -> &Self {
73        for x in &mut self.0 {
74            x.compress::<D>();
75        }
76
77        self
78    }
79
80    fn decompress<D: CompressionFactor>(&mut self) -> &Self {
81        for x in &mut self.0 {
82            x.decompress::<D>();
83        }
84
85        self
86    }
87}
88
89#[cfg(test)]
90pub(crate) mod test {
91    use super::*;
92    use hybrid_array::typenum::{U1, U10, U11, U12, U4, U5, U6};
93    use num_rational::Ratio;
94
95    fn rational_compress<D: CompressionFactor>(input: u16) -> u16 {
96        let fraction = Ratio::new(u32::from(input) * (1 << D::USIZE), FieldElement::Q32);
97        (fraction.round().to_integer() as u16) & D::MASK
98    }
99
100    fn rational_decompress<D: CompressionFactor>(input: u16) -> u16 {
101        let fraction = Ratio::new(u32::from(input) * FieldElement::Q32, 1 << D::USIZE);
102        fraction.round().to_integer() as u16
103    }
104
105    // Verify against inequality 4.7
106    #[allow(clippy::integer_division_remainder_used)]
107    fn compression_decompression_inequality<D: CompressionFactor>() {
108        const QI32: i32 = FieldElement::Q as i32;
109        let error_threshold = Ratio::new(FieldElement::Q, 1 << D::USIZE).to_integer() as i32;
110
111        for x in 0..FieldElement::Q {
112            let mut y = FieldElement(x);
113            y.compress::<D>();
114            y.decompress::<D>();
115
116            let mut error = i32::from(y.0) - i32::from(x) + QI32;
117            if error > (QI32 - 1) / 2 {
118                error -= QI32;
119            }
120
121            assert!(
122                error.abs() <= error_threshold,
123                "Inequality failed for x = {x}: error = {}, error_threshold = {error_threshold}, D = {:?}",
124                error.abs(),
125                D::USIZE
126            );
127        }
128    }
129
130    fn decompression_compression_equality<D: CompressionFactor>() {
131        for x in 0..(1 << D::USIZE) {
132            let mut y = FieldElement(x);
133            y.decompress::<D>();
134            y.compress::<D>();
135
136            assert_eq!(y.0, x, "failed for x: {}, D: {}", x, D::USIZE);
137        }
138    }
139
140    fn decompress_KAT<D: CompressionFactor>() {
141        for y in 0..(1 << D::USIZE) {
142            let x_expected = rational_decompress::<D>(y);
143            let mut x_actual = FieldElement(y);
144            x_actual.decompress::<D>();
145
146            assert_eq!(x_expected, x_actual.0);
147        }
148    }
149
150    fn compress_KAT<D: CompressionFactor>() {
151        for x in 0..FieldElement::Q {
152            let y_expected = rational_compress::<D>(x);
153            let mut y_actual = FieldElement(x);
154            y_actual.compress::<D>();
155
156            assert_eq!(y_expected, y_actual.0, "for x: {}, D: {}", x, D::USIZE);
157        }
158    }
159
160    fn compress_decompress_properties<D: CompressionFactor>() {
161        compression_decompression_inequality::<D>();
162        decompression_compression_equality::<D>();
163    }
164
165    fn compress_decompress_KATs<D: CompressionFactor>() {
166        decompress_KAT::<D>();
167        compress_KAT::<D>();
168    }
169
170    #[test]
171    fn decompress_compress() {
172        compress_decompress_properties::<U1>();
173        compress_decompress_properties::<U4>();
174        compress_decompress_properties::<U5>();
175        compress_decompress_properties::<U6>();
176        compress_decompress_properties::<U10>();
177        compress_decompress_properties::<U11>();
178        // preservation under decompression first only holds for d < 12
179        compression_decompression_inequality::<U12>();
180
181        compress_decompress_KATs::<U1>();
182        compress_decompress_KATs::<U4>();
183        compress_decompress_KATs::<U5>();
184        compress_decompress_KATs::<U6>();
185        compress_decompress_KATs::<U10>();
186        compress_decompress_KATs::<U11>();
187        compress_decompress_KATs::<U12>();
188    }
189}