Skip to main content

crypto_bigint/
primitives.rs

1use crate::{Choice, WideWord, Word};
2
3/// Computes `lhs + rhs + carry`, returning the result along with the new carry (0, 1, or 2).
4#[inline(always)]
5#[allow(clippy::cast_possible_truncation)]
6pub(crate) const fn carrying_add(lhs: Word, rhs: Word, carry: Word) -> (Word, Word) {
7    // We could use `Word::overflowing_add()` here analogous to `overflowing_add()`,
8    // but this version seems to produce a slightly better assembly.
9    let a = lhs as WideWord;
10    let b = rhs as WideWord;
11    let carry = carry as WideWord;
12    let ret = a + b + carry;
13    (ret as Word, (ret >> Word::BITS) as Word)
14}
15
16/// Computes `lhs + rhs`, returning the result along with the carry (0 or 1).
17#[inline(always)]
18pub(crate) const fn overflowing_add(lhs: Word, rhs: Word) -> (Word, Word) {
19    let (res, carry) = lhs.overflowing_add(rhs);
20    (res, carry as Word)
21}
22
23/// Computes `lhs - (rhs + borrow)`, returning the result along with the new borrow.
24#[inline(always)]
25pub(crate) const fn borrowing_sub(lhs: Word, rhs: Word, borrow: Word) -> (Word, Word) {
26    // XXX we cannot use WideWord casts here: https://github.com/rust-lang/rust/issues/149522
27    // rustc 1.87 through 1.91 incorrectly optimize some WideWord bit arithmetic.
28    let (ret, b2) = lhs.overflowing_sub(borrow >> (Word::BITS - 1));
29    let (ret, b1) = ret.overflowing_sub(rhs);
30    (ret, Word::MIN.wrapping_sub((b1 | b2) as Word))
31}
32
33/// Computes `lhs * rhs`, returning the low and the high words of the result.
34#[inline(always)]
35#[allow(clippy::cast_possible_truncation)]
36pub(crate) const fn widening_mul(lhs: Word, rhs: Word) -> (Word, Word) {
37    let a = lhs as WideWord;
38    let b = rhs as WideWord;
39    let ret = a * b;
40    (ret as Word, (ret >> Word::BITS) as Word)
41}
42
43/// Computes `(lhs * rhs) + addend + carry`, returning the result along with the new carry.
44#[inline(always)]
45#[allow(clippy::cast_possible_truncation)]
46pub(crate) const fn carrying_mul_add(
47    lhs: Word,
48    rhs: Word,
49    addend: Word,
50    carry: Word,
51) -> (Word, Word) {
52    let lhs = lhs as WideWord;
53    let rhs = rhs as WideWord;
54    let addend = addend as WideWord;
55    let carry = carry as WideWord;
56
57    // Cannot overflow:
58    // lhs      * rhs      + addend   + carry
59    // (2^64-1) * (2^64-1) + (2^64-1) + (2^64-1) =
60    // 2^128 - 2^65 + 1 + 2^64 - 1 + 2^64 - 1 =
61    // 2^128 - 2^65 + 2*2^64 - 1 =
62    // 2^128 - 1 = u128::MAX
63    let ret = ((lhs * rhs) + addend) + carry;
64    (ret as Word, (ret >> Word::BITS) as Word)
65}
66
67/// `const fn` equivalent of `u32::max(a, b)`.
68#[inline]
69pub(crate) const fn u32_max(a: u32, b: u32) -> u32 {
70    Choice::from_u32_lt(a, b).select_u32(a, b)
71}
72
73/// `const` equivalent of `u32::min(a, b)`.
74#[inline]
75pub(crate) const fn u32_min(a: u32, b: u32) -> u32 {
76    Choice::from_u32_lt(a, b).select_u32(b, a)
77}
78
79/// Remainder calculation, constant time for a given divisor `d`.
80/// Based on "Faster Remainder by Direct Computation: Applications to Compilers and Software Libraries"
81/// by Daniel Lemire, Owen Kaser, and Nathan Kurz., Fig. 1.
82#[inline(never)]
83#[allow(clippy::cast_possible_truncation, reason = "needs triage")]
84#[allow(clippy::integer_division_remainder_used, reason = "needs triage")]
85pub(crate) const fn u32_rem(n: u32, d: u32) -> u32 {
86    assert!(d > 0, "divisor must be nonzero");
87    let c = u64::MAX / (d as u64) + 1;
88    (((c.wrapping_mul(n as u64) as u128) * d as u128) >> 64) as u32
89}
90
91/// Compute the number of bits needed to represent `n`.
92#[inline(always)]
93pub(crate) const fn u32_bits(n: u32) -> u32 {
94    u32::BITS - n.leading_zeros()
95}
96
97/// Return a `Choice` representing whether `a < b`.
98#[allow(clippy::cast_possible_truncation)]
99#[cfg(target_pointer_width = "32")]
100#[inline]
101pub(crate) const fn usize_lt(a: usize, b: usize) -> Choice {
102    Choice::from_u32_lt(a as u32, b as u32)
103}
104
105/// Return a `Choice` representing whether `a < b`.
106#[allow(clippy::cast_possible_truncation)]
107#[cfg(target_pointer_width = "64")]
108#[inline]
109pub(crate) const fn usize_lt(a: usize, b: usize) -> Choice {
110    Choice::from_u64_lt(a as u64, b as u64)
111}
112
113cpubits::cpubits! {
114    32 => {
115        /// Returns the multiplicative inverse of the argument modulo 2^32.
116        ///
117        /// For correct results, the input `value` must be odd.
118        #[must_use]
119        pub(crate) const fn u32_invert_odd(value: u32) -> u32 {
120            debug_assert!(value & 1 == 1, "value must be odd");
121            let x = value.wrapping_mul(3) ^ 2;
122            let y = 1u32.wrapping_sub(x.wrapping_mul(value));
123            let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
124            let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
125            x.wrapping_mul(y.wrapping_add(1))
126        }
127    }
128}
129
130/// Returns the multiplicative inverse of the argument modulo 2^64. The implementation is based
131/// on Hurchalla's method for computing the multiplicative inverse modulo a power of two, and
132/// is essentially an optimized Newton iteration.
133///
134/// For correct results, the input `value` must be odd.
135///
136/// For better understanding the implementation, the following paper is recommended:
137/// J. Hurchalla, "An Improved Integer Multiplicative Inverse (modulo 2^w)",
138/// <https://arxiv.org/abs/2204.04342>
139#[must_use]
140pub(crate) const fn u64_invert_odd(value: u64) -> u64 {
141    debug_assert!(value & 1 == 1, "value must be odd");
142    let x = value.wrapping_mul(3) ^ 2;
143    let y = 1u64.wrapping_sub(x.wrapping_mul(value));
144    let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
145    let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
146    let (x, y) = (x.wrapping_mul(y.wrapping_add(1)), y.wrapping_mul(y));
147    x.wrapping_mul(y.wrapping_add(1))
148}
149
150#[cfg(test)]
151mod tests {
152    use super::{u32_max, u32_min, u32_rem, usize_lt};
153    use crate::Word;
154
155    #[test]
156    fn carrying_mul_add_cannot_overflow() {
157        let lhs = Word::MAX;
158        let rhs = Word::MAX;
159        let addend = Word::MAX;
160        let carry_in = Word::MAX;
161        let (result, carry_out) = super::carrying_mul_add(lhs, rhs, addend, carry_in);
162        assert_eq!(result, Word::MAX);
163        assert_eq!(carry_out, Word::MAX);
164    }
165
166    #[test]
167    fn test_u32_const_min() {
168        assert_eq!(u32_min(0, 5), 0);
169        assert_eq!(u32_min(7, 0), 0);
170        assert_eq!(u32_min(7, 5), 5);
171        assert_eq!(u32_min(7, 7), 7);
172    }
173
174    #[test]
175    fn test_u32_const_max() {
176        assert_eq!(u32_max(0, 5), 5);
177        assert_eq!(u32_max(7, 0), 7);
178        assert_eq!(u32_max(7, 5), 7);
179        assert_eq!(u32_max(7, 7), 7);
180    }
181
182    #[test]
183    fn test_u32_const_rem() {
184        assert_eq!(u32_rem(0, 5), 0);
185        assert_eq!(u32_rem(4, 5), 4);
186        assert_eq!(u32_rem(7, 5), 2);
187        assert_eq!(u32_rem(101, 5), 1);
188    }
189
190    #[test]
191    fn test_usize_const_lt() {
192        assert!(usize_lt(0, 5).to_bool_vartime());
193        assert!(!usize_lt(7, 0).to_bool_vartime());
194        assert!(!usize_lt(7, 5).to_bool_vartime());
195        assert!(!usize_lt(7, 7).to_bool_vartime());
196    }
197
198    cpubits::cpubits! {
199        32 => {
200            #[test]
201            fn test_u32_invert_odd() {
202                use super::u32_invert_odd;
203
204                assert_eq!(u32_invert_odd(1), 1);
205                assert_eq!(u32_invert_odd(5).wrapping_mul(5), 1);
206                assert_eq!(u32_invert_odd(u32::MAX).wrapping_mul(u32::MAX), 1);
207            }
208        }
209    }
210}