Skip to main content

crypto_bigint/uint/
rand.rs

1//! Random number generator support
2
3use super::Uint;
4use crate::{
5    CtLt, Encoding, Limb, NonZero, Random, RandomBits, RandomBitsError, RandomMod, bitlen,
6};
7use rand_core::{Rng, TryRng};
8
9impl<const LIMBS: usize> Random for Uint<LIMBS> {
10    fn try_random_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
11        let mut limbs = [Limb::ZERO; LIMBS];
12
13        for limb in &mut limbs {
14            *limb = Limb::try_random_from_rng(rng)?;
15        }
16
17        Ok(limbs.into())
18    }
19}
20
21/// Fill the given limbs slice with random bits.
22///
23/// NOTE: Assumes that the limbs in the given slice are zeroed!
24///
25/// When combined with a platform-independent "4-byte sequential" `rng`, this function is
26/// platform-independent. We consider an RNG "`X`-byte sequential" whenever
27/// `rng.fill_bytes(&mut bytes[..i]); rng.fill_bytes(&mut bytes[i..])` constructs the same `bytes`,
28/// as long as `i` is a multiple of `X`.
29/// Note that the `TryRng` trait does _not_ require this behaviour from `rng`.
30#[allow(clippy::integer_division_remainder_used, reason = "public parameter")]
31pub(crate) fn random_bits_core<T, R: TryRng + ?Sized>(
32    rng: &mut R,
33    x: &mut T,
34    n_bits: u32,
35) -> Result<(), R::Error>
36where
37    T: Encoding,
38{
39    if n_bits == 0 {
40        return Ok(());
41    }
42
43    let n_bytes = bitlen::to_bytes(n_bits);
44    let hi_mask = u8::MAX >> ((u8::BITS - (n_bits % u8::BITS)) % u8::BITS);
45
46    let mut buffer = x.to_le_bytes();
47    let slice = buffer.as_mut();
48    rng.try_fill_bytes(&mut slice[..n_bytes])?;
49    slice[n_bytes - 1] &= hi_mask;
50    *x = T::from_le_bytes(buffer);
51
52    Ok(())
53}
54
55impl<const LIMBS: usize> RandomBits for Uint<LIMBS> {
56    fn try_random_bits<R: TryRng + ?Sized>(
57        rng: &mut R,
58        bit_length: u32,
59    ) -> Result<Self, RandomBitsError<R::Error>> {
60        Self::try_random_bits_with_precision(rng, bit_length, Self::BITS)
61    }
62
63    fn try_random_bits_with_precision<R: TryRng + ?Sized>(
64        rng: &mut R,
65        bit_length: u32,
66        bits_precision: u32,
67    ) -> Result<Self, RandomBitsError<R::Error>> {
68        if bits_precision != Self::BITS {
69            return Err(RandomBitsError::BitsPrecisionMismatch {
70                bits_precision,
71                integer_bits: Self::BITS,
72            });
73        }
74        if bit_length > Self::BITS {
75            return Err(RandomBitsError::BitLengthTooLarge {
76                bit_length,
77                bits_precision,
78            });
79        }
80        let mut x = Self::ZERO;
81        random_bits_core(rng, &mut x, bit_length).map_err(RandomBitsError::RandCore)?;
82        Ok(x)
83    }
84}
85
86impl<const LIMBS: usize> RandomMod for Uint<LIMBS> {
87    fn random_mod_vartime<R: Rng + ?Sized>(rng: &mut R, modulus: &NonZero<Self>) -> Self {
88        let mut x = Self::ZERO;
89        let Ok(()) = random_mod_vartime_core(rng, &mut x, modulus, modulus.bits_vartime());
90        x
91    }
92
93    fn try_random_mod_vartime<R: TryRng + ?Sized>(
94        rng: &mut R,
95        modulus: &NonZero<Self>,
96    ) -> Result<Self, R::Error> {
97        let mut x = Self::ZERO;
98        random_mod_vartime_core(rng, &mut x, modulus, modulus.bits_vartime())?;
99        Ok(x)
100    }
101}
102
103/// Generic implementation of `random_mod_vartime` which can be shared with `BoxedUint`.
104// TODO(tarcieri): obtain `n_bits` via a trait like `Integer`
105pub(super) fn random_mod_vartime_core<T, R: TryRng + ?Sized>(
106    rng: &mut R,
107    x: &mut T,
108    modulus: &NonZero<T>,
109    n_bits: u32,
110) -> Result<(), R::Error>
111where
112    T: Encoding + CtLt,
113{
114    loop {
115        random_bits_core(rng, x, n_bits)?;
116        if x.ct_lt(modulus).into() {
117            return Ok(());
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use crate::uint::rand::random_bits_core;
125    use crate::{Limb, NonZero, Random, RandomBits, RandomMod, U256, U1024, Uint};
126    use chacha20::ChaCha8Rng;
127    use rand_core::{Rng, SeedableRng};
128
129    const RANDOM_OUTPUT: U1024 = Uint::from_be_hex(concat![
130        "A484C4C693EECC47C3B919AE0D16DF2259CD1A8A9B8EA8E0862878227D4B40A3",
131        "C54302F2EB1E2F69E17653A37F1BCC44277FA208E6B31E08CDC4A23A7E88E660",
132        "EF781C7DD2D368BAD438539D6A2E923C8CAE14CB947EB0BDE10D666732024679",
133        "0F6760A48F9B887CB2FB0D3281E2A6E67746A55FBAD8C037B585F767A79A3B6C"
134    ]);
135
136    /// Construct a 4-sequential `rng`, i.e., an `rng` such that
137    /// `rng.fill_bytes(&mut buffer[..x]); rng.fill_bytes(&mut buffer[x..])` will construct the
138    /// same `buffer`, for `x` any in `0..buffer.len()` that is `0 mod 4`.
139    fn get_four_sequential_rng() -> ChaCha8Rng {
140        ChaCha8Rng::seed_from_u64(0)
141    }
142
143    /// Make sure the random value constructed is consistent across platforms
144    #[test]
145    fn random_platform_independence() {
146        let mut rng = get_four_sequential_rng();
147        assert_eq!(U1024::random_from_rng(&mut rng), RANDOM_OUTPUT);
148    }
149
150    #[test]
151    fn random_mod_vartime() {
152        let mut rng = ChaCha8Rng::seed_from_u64(1);
153
154        // Ensure `random_mod_vartime` runs in a reasonable amount of time
155        let modulus = NonZero::new(U256::from(42u8)).unwrap();
156        let res = U256::random_mod_vartime(&mut rng, &modulus);
157
158        // Check that the value is in range
159        assert!(res < U256::from(42u8));
160
161        // Ensure `random_mod_vartime` runs in a reasonable amount of time
162        // when the modulus is larger than 1 limb
163        let modulus = NonZero::new(U256::from(0x10000000000000001u128)).unwrap();
164        let res = U256::random_mod_vartime(&mut rng, &modulus);
165
166        // Check that the value is in range
167        assert!(res < U256::from(0x10000000000000001u128));
168    }
169
170    #[test]
171    fn random_bits() {
172        let mut rng = ChaCha8Rng::seed_from_u64(1);
173
174        let lower_bound = 16;
175
176        // Full length of the integer
177        let bit_length = U256::BITS;
178        for _ in 0..10 {
179            let res = U256::random_bits(&mut rng, bit_length);
180            assert!(res > (U256::ONE << (bit_length - lower_bound)));
181        }
182
183        // A multiple of limb size
184        let bit_length = U256::BITS - Limb::BITS;
185        for _ in 0..10 {
186            let res = U256::random_bits(&mut rng, bit_length);
187            assert!(res > (U256::ONE << (bit_length - lower_bound)));
188            assert!(res < (U256::ONE << bit_length));
189        }
190
191        // A multiple of 8
192        let bit_length = U256::BITS - Limb::BITS - 8;
193        for _ in 0..10 {
194            let res = U256::random_bits(&mut rng, bit_length);
195            assert!(res > (U256::ONE << (bit_length - lower_bound)));
196            assert!(res < (U256::ONE << bit_length));
197        }
198
199        // Not a multiple of 8
200        let bit_length = U256::BITS - Limb::BITS - 8 - 3;
201        for _ in 0..10 {
202            let res = U256::random_bits(&mut rng, bit_length);
203            assert!(res > (U256::ONE << (bit_length - lower_bound)));
204            assert!(res < (U256::ONE << bit_length));
205        }
206
207        // One incomplete limb
208        let bit_length = 7;
209        for _ in 0..10 {
210            let res = U256::random_bits(&mut rng, bit_length);
211            assert!(res < (U256::ONE << bit_length));
212        }
213
214        // Zero bits
215        let bit_length = 0;
216        for _ in 0..10 {
217            let res = U256::random_bits(&mut rng, bit_length);
218            assert_eq!(res, U256::ZERO);
219        }
220    }
221
222    /// Make sure the `random_bits` output is consistent across platforms
223    #[test]
224    fn random_bits_platform_independence() {
225        let mut rng = get_four_sequential_rng();
226
227        let bit_length = 989;
228        let mut val = U1024::ZERO;
229        random_bits_core(&mut rng, &mut val, bit_length).expect("safe");
230
231        assert_eq!(
232            val,
233            RANDOM_OUTPUT.bitand(&U1024::ONE.shl(bit_length).wrapping_sub(&Uint::ONE))
234        );
235
236        // Test that the RNG is in the same state
237        let mut state = [0u8; 16];
238        rng.fill_bytes(&mut state);
239
240        assert_eq!(
241            state,
242            [
243                198, 196, 132, 164, 240, 211, 223, 12, 36, 189, 139, 48, 94, 1, 123, 253
244            ]
245        );
246    }
247
248    /// Make sure `random_mod_vartime` output is consistent across platforms
249    #[test]
250    fn random_mod_vartime_platform_independence() {
251        let mut rng = get_four_sequential_rng();
252
253        let modulus = NonZero::new(U256::from_u32(8192)).unwrap();
254        let mut vals = [U256::ZERO; 5];
255        for val in &mut vals {
256            *val = U256::random_mod_vartime(&mut rng, &modulus);
257        }
258        let expected = [55, 3378, 2172, 1657, 5323];
259        for (want, got) in expected.into_iter().zip(vals.into_iter()) {
260            // assert_eq!(got.as_words()[0], want);
261            assert_eq!(got, U256::from_u32(want));
262        }
263
264        let modulus =
265            NonZero::new(U256::ZERO.wrapping_sub(&U256::from_u64(rng.next_u64()))).unwrap();
266        let val = U256::random_mod_vartime(&mut rng, &modulus);
267        assert_eq!(
268            val,
269            U256::from_be_hex("E17653A37F1BCC44277FA208E6B31E08CDC4A23A7E88E660EF781C7DD2D368BA")
270        );
271
272        let mut state = [0u8; 16];
273        rng.fill_bytes(&mut state);
274
275        assert_eq!(
276            state,
277            [
278                105, 47, 30, 235, 242, 2, 67, 197, 163, 64, 75, 125, 34, 120, 40, 134,
279            ],
280        );
281    }
282
283    /// Test that random bytes are sampled consecutively.
284    #[test]
285    fn random_bits_4_bytes_sequential() {
286        // Test for multiples of 4 bytes, i.e., multiples of 32 bits.
287        let bit_lengths = [0, 32, 64, 128, 192, 992];
288
289        for bit_length in bit_lengths {
290            let mut rng = get_four_sequential_rng();
291            let mut first = U1024::ZERO;
292            let mut second = U1024::ZERO;
293            random_bits_core(&mut rng, &mut first, bit_length).expect("safe");
294            random_bits_core(&mut rng, &mut second, U1024::BITS - bit_length).expect("safe");
295            assert_eq!(second.shl(bit_length).bitor(&first), RANDOM_OUTPUT);
296        }
297    }
298}