libm/math/support/int_traits/narrowing_div.rs
1/* SPDX-License-Identifier: MIT OR Apache-2.0 */
2use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};
3
4/// Trait for unsigned division of a double-wide integer
5/// when the quotient doesn't overflow.
6///
7/// This is the inverse of widening multiplication:
8/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
9/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
10pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
11 /// Computes `(self / n, self % n))`
12 ///
13 /// # Safety
14 /// The caller must ensure that `self.hi() < n`, or equivalently,
15 /// that the quotient does not overflow.
16 unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H);
17
18 /// Returns `Some((self / n, self % n))` when `self.hi() < n`.
19 fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> {
20 if self.hi() < n {
21 Some(unsafe { self.unchecked_narrowing_div_rem(n) })
22 } else {
23 None
24 }
25 }
26}
27
28// For primitive types we can just use the standard
29// division operators in the double-wide type.
30macro_rules! impl_narrowing_div_primitive {
31 ($D:ident) => {
32 impl NarrowingDiv for $D {
33 unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
34 if self.hi() >= n {
35 unsafe { core::hint::unreachable_unchecked() }
36 }
37 ((self / n.widen()).cast(), (self % n.widen()).cast())
38 }
39 }
40 };
41}
42
43// Extend division from `u2N / uN` to `u4N / u2N`
44// This is not the most efficient algorithm, but it is
45// relatively simple.
46macro_rules! impl_narrowing_div_recurse {
47 ($D:ident) => {
48 impl NarrowingDiv for $D {
49 unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
50 if self.hi() >= n {
51 unsafe { core::hint::unreachable_unchecked() }
52 }
53
54 // Normalize the divisor by shifting the most significant one
55 // to the leading position. `n != 0` is implied by `self.hi() < n`
56 let lz = n.leading_zeros();
57 let a = self << lz;
58 let b = n << lz;
59
60 let ah = a.hi();
61 let (a0, a1) = a.lo().lo_hi();
62 // SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift.
63 // SAFETY: `ah < b` follows from `self.hi() < n`
64 let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
65 // SAFETY: `r < b` is given as the postcondition of the previous call
66 let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };
67
68 // Undo the earlier normalization for the remainder
69 (Self::H::from_lo_hi(q0, q1), r >> lz)
70 }
71 }
72 };
73}
74
75impl_narrowing_div_primitive!(u16);
76impl_narrowing_div_primitive!(u32);
77impl_narrowing_div_primitive!(u64);
78impl_narrowing_div_primitive!(u128);
79impl_narrowing_div_recurse!(u256);
80
81/// Implement `u3N / u2N`-division on top of `u2N / uN`-division.
82///
83/// Returns the quotient and remainder of `(a * R + a0) / n`,
84/// where `R = (1 << U::BITS)` is the digit size.
85///
86/// # Safety
87/// Requires that `n.leading_zeros() == 0` and `a < n`.
88unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D)
89where
90 U: HInt,
91 U::D: Int + NarrowingDiv,
92{
93 if n.leading_zeros() > 0 || a >= n {
94 unsafe { core::hint::unreachable_unchecked() }
95 }
96
97 // n = n1R + n0
98 let (n0, n1) = n.lo_hi();
99 // a = a2R + a1
100 let (a1, a2) = a.lo_hi();
101
102 let mut q;
103 let mut r;
104 let mut wrap;
105 // `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible
106 if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) {
107 q = q0;
108 // a = qn1 + r1, where 0 <= r1 < n1
109
110 // Include the remainder with the low bits:
111 // r = a0 + r1R
112 r = U::D::from_lo_hi(a0, r1);
113
114 // Subtract the contribution of the divisor low bits with the estimated quotient
115 let d = q.widen_mul(n0);
116 (r, wrap) = r.overflowing_sub(d);
117
118 // Since `q` is the quotient of dividing with a slightly smaller divisor,
119 // it may be an overapproximation, but is never too small, and similarly,
120 // `r` is now either the correct remainder ...
121 if !wrap {
122 return (q, r);
123 }
124 // ... or the remainder went "negative" (by as much as `d = qn0 < RR`)
125 // and we have to adjust.
126 q -= U::ONE;
127 } else {
128 debug_assert!(a2 == n1 && a1 < n0);
129 // Otherwise, `a2 == n1`, and the estimated quotient would be
130 // `R + (a1 % n1)`, but the correct quotient can't overflow.
131 // We'll start from `q = R = (1 << U::BITS)`,
132 // so `r = aR + a0 - qn = (a - n)R + a0`
133 r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0));
134 // Since `a < n`, the first decrement is always needed:
135 q = U::MAX; /* R - 1 */
136 }
137
138 (r, wrap) = r.overflowing_add(n);
139 if wrap {
140 return (q, r);
141 }
142
143 // If the remainder still didn't wrap, we need another step.
144 q -= U::ONE;
145 (r, wrap) = r.overflowing_add(n);
146 // Since `n >= RR/2`, at least one of the two `r += n` must have wrapped.
147 debug_assert!(wrap, "estimated quotient should be off by at most two");
148 (q, r)
149}
150
151#[cfg(test)]
152mod test {
153 use super::{HInt, NarrowingDiv};
154
155 #[test]
156 fn inverse_mul() {
157 for x in 0..=u8::MAX {
158 for y in 1..=u8::MAX {
159 let xy = x.widen_mul(y);
160 assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
161 assert_eq!(
162 (xy + (y - 1) as u16).checked_narrowing_div_rem(y),
163 Some((x, y - 1))
164 );
165 if y > 1 {
166 assert_eq!((xy + 1).checked_narrowing_div_rem(y), Some((x, 1)));
167 assert_eq!(
168 (xy + (y - 2) as u16).checked_narrowing_div_rem(y),
169 Some((x, y - 2))
170 );
171 }
172 }
173 }
174 }
175}