Skip to main content

ml_kem/
compress.rs

1use crate::algebra::{BaseField, Elem, Int, Polynomial, Vector};
2use array::ArraySize;
3use module_lattice::EncodingSize;
4use module_lattice::{Field, Truncate};
5
6// A convenience trait to allow us to associate some constants with a typenum
7pub(crate) trait CompressionFactor: EncodingSize {
8    const POW2_HALF: u32;
9    const MASK: Int;
10    const DIV_SHIFT: usize;
11    const DIV_MUL: u64;
12}
13
14impl<T> CompressionFactor for T
15where
16    T: EncodingSize,
17{
18    const POW2_HALF: u32 = 1 << (T::USIZE - 1);
19    const MASK: Int = (1 << T::USIZE) - 1;
20    const DIV_SHIFT: usize = 34;
21    #[allow(clippy::integer_division_remainder_used, reason = "constant")]
22    const DIV_MUL: u64 = (1 << T::DIV_SHIFT) / BaseField::QLL;
23}
24
25// Traits for objects that allow compression / decompression
26pub(crate) trait Compress {
27    fn compress<D: CompressionFactor>(&mut self) -> &Self;
28    fn decompress<D: CompressionFactor>(&mut self) -> &Self;
29}
30
31impl Compress for Elem {
32    // Equation 4.5: Compress_d(x) = round((2^d / q) x)
33    //
34    // Here and in decompression, we leverage the following facts:
35    //
36    //   round(a / b) = floor((a + b/2) / b)
37    //   a / q ~= (a * x) >> s where x >> s ~= 1/q
38    fn compress<D: CompressionFactor>(&mut self) -> &Self {
39        const Q_HALF: u64 = (BaseField::QLL + 1) >> 1;
40        let x = u64::from(self.0);
41        let y = (((x << D::USIZE) + Q_HALF) * D::DIV_MUL) >> D::DIV_SHIFT;
42        self.0 = u16::truncate(y) & D::MASK;
43        self
44    }
45
46    // Equation 4.6: Decompress_d(x) = round((q / 2^d) x)
47    fn decompress<D: CompressionFactor>(&mut self) -> &Self {
48        let x = u32::from(self.0);
49        let y = ((x * BaseField::QL) + D::POW2_HALF) >> D::USIZE;
50        self.0 = Truncate::truncate(y);
51        self
52    }
53}
54impl Compress for Polynomial {
55    fn compress<D: CompressionFactor>(&mut self) -> &Self {
56        for x in &mut self.0 {
57            x.compress::<D>();
58        }
59
60        self
61    }
62
63    fn decompress<D: CompressionFactor>(&mut self) -> &Self {
64        for x in &mut self.0 {
65            x.decompress::<D>();
66        }
67
68        self
69    }
70}
71
72impl<K: ArraySize> Compress for Vector<K> {
73    fn compress<D: CompressionFactor>(&mut self) -> &Self {
74        for x in &mut self.0 {
75            x.compress::<D>();
76        }
77
78        self
79    }
80
81    fn decompress<D: CompressionFactor>(&mut self) -> &Self {
82        for x in &mut self.0 {
83            x.decompress::<D>();
84        }
85
86        self
87    }
88}
89
90#[cfg(test)]
91#[allow(clippy::cast_possible_truncation, reason = "tests")]
92#[allow(clippy::integer_division_remainder_used, reason = "tests")]
93pub(crate) mod tests {
94    use super::*;
95    use array::typenum::{U1, U4, U5, U6, U10, U11, U12};
96    use num_rational::Ratio;
97
98    fn rational_compress<D: CompressionFactor>(input: u16) -> u16 {
99        let fraction = Ratio::new(u32::from(input) * (1 << D::USIZE), BaseField::QL);
100        (fraction.round().to_integer() as u16) & D::MASK
101    }
102
103    fn rational_decompress<D: CompressionFactor>(input: u16) -> u16 {
104        let fraction = Ratio::new(u32::from(input) * BaseField::QL, 1 << D::USIZE);
105        fraction.round().to_integer() as u16
106    }
107
108    // Verify against inequality 4.7
109    fn compression_decompression_inequality<D: CompressionFactor>() {
110        const QI32: i32 = BaseField::Q as i32;
111        let error_threshold = i32::from(Ratio::new(BaseField::Q, 1 << D::USIZE).to_integer());
112
113        for x in 0..BaseField::Q {
114            let mut y = Elem::new(x);
115            y.compress::<D>();
116            y.decompress::<D>();
117
118            let mut error = i32::from(y.0) - i32::from(x) + QI32;
119            if error > (QI32 - 1) / 2 {
120                error -= QI32;
121            }
122
123            assert!(
124                error.abs() <= error_threshold,
125                "Inequality failed for x = {x}: error = {}, error_threshold = {error_threshold}, D = {:?}",
126                error.abs(),
127                D::USIZE
128            );
129        }
130    }
131
132    fn decompression_compression_equality<D: CompressionFactor>() {
133        for x in 0..(1 << D::USIZE) {
134            let mut y = Elem::new(x);
135            y.decompress::<D>();
136            y.compress::<D>();
137
138            assert_eq!(y.0, x, "failed for x: {}, D: {}", x, D::USIZE);
139        }
140    }
141
142    fn decompress_KAT<D: CompressionFactor>() {
143        for y in 0..(1 << D::USIZE) {
144            let x_expected = rational_decompress::<D>(y);
145            let mut x_actual = Elem::new(y);
146            x_actual.decompress::<D>();
147
148            assert_eq!(x_expected, x_actual.0);
149        }
150    }
151
152    fn compress_KAT<D: CompressionFactor>() {
153        for x in 0..BaseField::Q {
154            let y_expected = rational_compress::<D>(x);
155            let mut y_actual = Elem::new(x);
156            y_actual.compress::<D>();
157
158            assert_eq!(y_expected, y_actual.0, "for x: {}, D: {}", x, D::USIZE);
159        }
160    }
161
162    fn compress_decompress_properties<D: CompressionFactor>() {
163        compression_decompression_inequality::<D>();
164        decompression_compression_equality::<D>();
165    }
166
167    fn compress_decompress_KATs<D: CompressionFactor>() {
168        decompress_KAT::<D>();
169        compress_KAT::<D>();
170    }
171
172    #[test]
173    fn decompress_compress() {
174        compress_decompress_properties::<U1>();
175        compress_decompress_properties::<U4>();
176        compress_decompress_properties::<U5>();
177        compress_decompress_properties::<U6>();
178        compress_decompress_properties::<U10>();
179        compress_decompress_properties::<U11>();
180        // preservation under decompression first only holds for d < 12
181        compression_decompression_inequality::<U12>();
182
183        compress_decompress_KATs::<U1>();
184        compress_decompress_KATs::<U4>();
185        compress_decompress_KATs::<U5>();
186        compress_decompress_KATs::<U6>();
187        compress_decompress_KATs::<U10>();
188        compress_decompress_KATs::<U11>();
189        compress_decompress_KATs::<U12>();
190    }
191}