muldiv/
lib.rs

1// Copyright (C) 2016,2017 Sebastian Dröge <[email protected]>
2//
3// Licensed under the MIT license, see the LICENSE file or <http://opensource.org/licenses/MIT>
4
5#![no_std]
6
7//! Provides a trait for numeric types to perform combined multiplication and division with
8//! overflow protection.
9//!
10//! The [`MulDiv`] trait provides functions for performing combined multiplication and division for
11//! numeric types and comes with implementations for all the primitive integer types. Three
12//! variants with different rounding characteristics are provided: [`mul_div_floor()`],
13//! [`mul_div_round()`] and [`mul_div_ceil()`].
14//!
15//! ## Example
16//!
17//! ```rust
18//! extern crate muldiv;
19//! use muldiv::MulDiv;
20//! # fn main() {
21//! // Calculates 127 * 23 / 42 rounded down
22//! let x = 127u8.mul_div_floor(23, 42);
23//! assert_eq!(x, Some(69));
24//! # }
25//! ```
26//! [`MulDiv`]: trait.MulDiv.html
27//! [`mul_div_floor()`]: trait.MulDiv.html#tymethod.mul_div_floor
28//! [`mul_div_round()`]: trait.MulDiv.html#tymethod.mul_div_round
29//! [`mul_div_ceil()`]: trait.MulDiv.html#tymethod.mul_div_ceil
30
31use core::u16;
32use core::u32;
33use core::u64;
34use core::u8;
35
36use core::i16;
37use core::i32;
38use core::i64;
39use core::i8;
40
41/// Trait for calculating `val * num / denom` with different rounding modes and overflow
42/// protection.
43///
44/// Implementations of this trait have to ensure that even if the result of the multiplication does
45/// not fit into the type, as long as it would fit after the division the correct result has to be
46/// returned instead of `None`. `None` only should be returned if the overall result does not fit
47/// into the type.
48///
49/// This specifically means that e.g. the `u64` implementation must, depending on the arguments, be
50/// able to do 128 bit integer multiplication.
51pub trait MulDiv<RHS = Self> {
52    /// Output type for the methods of this trait.
53    type Output;
54
55    /// Calculates `floor(val * num / denom)`, i.e. the largest integer less than or equal to the
56    /// result of the division.
57    ///
58    /// ## Example
59    ///
60    /// ```rust
61    /// extern crate muldiv;
62    /// use muldiv::MulDiv;
63    ///
64    /// # fn main() {
65    /// let x = 3i8.mul_div_floor(4, 2);
66    /// assert_eq!(x, Some(6));
67    ///
68    /// let x = 5i8.mul_div_floor(2, 3);
69    /// assert_eq!(x, Some(3));
70    ///
71    /// let x = (-5i8).mul_div_floor(2, 3);
72    /// assert_eq!(x, Some(-4));
73    ///
74    /// let x = 3i8.mul_div_floor(3, 2);
75    /// assert_eq!(x, Some(4));
76    ///
77    /// let x = (-3i8).mul_div_floor(3, 2);
78    /// assert_eq!(x, Some(-5));
79    ///
80    /// let x = 127i8.mul_div_floor(4, 3);
81    /// assert_eq!(x, None);
82    /// # }
83    /// ```
84    fn mul_div_floor(self, num: RHS, denom: RHS) -> Option<Self::Output>;
85
86    /// Calculates `round(val * num / denom)`, i.e. the closest integer to the result of the
87    /// division. If both surrounding integers are the same distance (`x.5`), the one with the bigger
88    /// absolute value is returned (round away from 0.0).
89    ///
90    /// ## Example
91    ///
92    /// ```rust
93    /// extern crate muldiv;
94    /// use muldiv::MulDiv;
95    ///
96    /// # fn main() {
97    /// let x = 3i8.mul_div_round(4, 2);
98    /// assert_eq!(x, Some(6));
99    ///
100    /// let x = 5i8.mul_div_round(2, 3);
101    /// assert_eq!(x, Some(3));
102    ///
103    /// let x = (-5i8).mul_div_round(2, 3);
104    /// assert_eq!(x, Some(-3));
105    ///
106    /// let x = 3i8.mul_div_round(3, 2);
107    /// assert_eq!(x, Some(5));
108    ///
109    /// let x = (-3i8).mul_div_round(3, 2);
110    /// assert_eq!(x, Some(-5));
111    ///
112    /// let x = 127i8.mul_div_round(4, 3);
113    /// assert_eq!(x, None);
114    /// # }
115    /// ```
116    fn mul_div_round(self, num: RHS, denom: RHS) -> Option<Self::Output>;
117
118    /// Calculates `ceil(val * num / denom)`, i.e. the the smallest integer greater than or equal to
119    /// the result of the division.
120    ///
121    /// ## Example
122    ///
123    /// ```rust
124    /// extern crate muldiv;
125    /// use muldiv::MulDiv;
126    ///
127    /// # fn main() {
128    /// let x = 3i8.mul_div_ceil(4, 2);
129    /// assert_eq!(x, Some(6));
130    ///
131    /// let x = 5i8.mul_div_ceil(2, 3);
132    /// assert_eq!(x, Some(4));
133    ///
134    /// let x = (-5i8).mul_div_ceil(2, 3);
135    /// assert_eq!(x, Some(-3));
136    ///
137    /// let x = 3i8.mul_div_ceil(3, 2);
138    /// assert_eq!(x, Some(5));
139    ///
140    /// let x = (-3i8).mul_div_ceil(3, 2);
141    /// assert_eq!(x, Some(-4));
142    ///
143    /// let x = (127i8).mul_div_ceil(4, 3);
144    /// assert_eq!(x, None);
145    /// # }
146    /// ```
147    fn mul_div_ceil(self, num: RHS, denom: RHS) -> Option<Self::Output>;
148}
149
150macro_rules! mul_div_impl_unsigned {
151    ($t:ident, $u:ident) => {
152        impl MulDiv for $t {
153            type Output = $t;
154
155            fn mul_div_floor(self, num: $t, denom: $t) -> Option<$t> {
156                assert_ne!(denom, 0);
157                let r = ((self as $u) * (num as $u)) / (denom as $u);
158                if r > $t::MAX as $u {
159                    None
160                } else {
161                    Some(r as $t)
162                }
163            }
164
165            fn mul_div_round(self, num: $t, denom: $t) -> Option<$t> {
166                assert_ne!(denom, 0);
167                let r = ((self as $u) * (num as $u) + ((denom >> 1) as $u)) / (denom as $u);
168                if r > $t::MAX as $u {
169                    None
170                } else {
171                    Some(r as $t)
172                }
173            }
174
175            fn mul_div_ceil(self, num: $t, denom: $t) -> Option<$t> {
176                assert_ne!(denom, 0);
177                let r = ((self as $u) * (num as $u) + ((denom - 1) as $u)) / (denom as $u);
178                if r > $t::MAX as $u {
179                    None
180                } else {
181                    Some(r as $t)
182                }
183            }
184        }
185    };
186}
187
188#[cfg(test)]
189macro_rules! mul_div_impl_unsigned_tests {
190    ($t:ident, $u:ident) => {
191        use super::*;
192
193        use quickcheck::{quickcheck, Arbitrary, Gen};
194
195        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
196        struct NonZero($t);
197
198        impl Arbitrary for NonZero {
199            fn arbitrary(g: &mut Gen) -> Self {
200                loop {
201                    let v = $t::arbitrary(g);
202                    if v != 0 {
203                        return NonZero(v);
204                    }
205                }
206            }
207        }
208
209        quickcheck! {
210            fn scale_floor(val: $t, num: $t, den: NonZero) -> bool {
211                let res = val.mul_div_floor(num, den.0);
212
213                let expected = ((val as $u) * (num as $u)) / (den.0 as $u);
214
215                if expected > $t::MAX as $u {
216                    res.is_none()
217                } else {
218                    res == Some(expected as $t)
219                }
220            }
221        }
222
223        quickcheck! {
224            fn scale_round(val: $t, num: $t, den: NonZero) -> bool {
225                let res = val.mul_div_round(num, den.0);
226
227                let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
228                let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
229
230                if expected_rem >= ((den.0 as $u) + 1) >> 1 {
231                    expected += 1
232                }
233
234                if expected > $t::MAX as $u {
235                    res.is_none()
236                } else {
237                    res == Some(expected as $t)
238                }
239            }
240        }
241
242        quickcheck! {
243            fn scale_ceil(val: $t, num: $t, den: NonZero) -> bool {
244                let res = val.mul_div_ceil(num, den.0);
245
246                let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
247                let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
248
249                if expected_rem != 0 {
250                    expected += 1
251                }
252
253                if expected > $t::MAX as $u {
254                    res.is_none()
255                } else {
256                    res == Some(expected as $t)
257                }
258            }
259        }
260    };
261}
262
263mul_div_impl_unsigned!(u64, u128);
264mul_div_impl_unsigned!(u32, u64);
265mul_div_impl_unsigned!(u16, u32);
266mul_div_impl_unsigned!(u8, u16);
267
268// FIXME: https://github.com/rust-lang/rust/issues/12249
269#[cfg(test)]
270mod muldiv_u64_tests {
271    mul_div_impl_unsigned_tests!(u64, u128);
272}
273
274#[cfg(test)]
275mod muldiv_u32_tests {
276    mul_div_impl_unsigned_tests!(u32, u64);
277}
278
279#[cfg(test)]
280mod muldiv_u16_tests {
281    mul_div_impl_unsigned_tests!(u16, u32);
282}
283
284#[cfg(test)]
285mod muldiv_u8_tests {
286    mul_div_impl_unsigned_tests!(u8, u16);
287}
288
289macro_rules! mul_div_impl_signed {
290    ($t:ident, $u:ident, $v:ident, $b:expr) => {
291        impl MulDiv for $t {
292            type Output = $t;
293
294            fn mul_div_floor(self, num: $t, denom: $t) -> Option<$t> {
295                assert_ne!(denom, 0);
296
297                let sgn = self.signum() * num.signum() * denom.signum();
298
299                let min_val: $u = 1 << ($b - 1);
300                let abs = |x: $t| if x != $t::MIN { x.abs() as $u } else { min_val };
301
302                let self_u = abs(self);
303                let num_u = abs(num);
304                let denom_u = abs(denom);
305
306                if sgn < 0 {
307                    self_u.mul_div_ceil(num_u, denom_u)
308                } else {
309                    self_u.mul_div_floor(num_u, denom_u)
310                }
311                .and_then(|r| {
312                    if r <= $t::MAX as $u {
313                        Some(sgn * (r as $t))
314                    } else if sgn < 0 && r == min_val {
315                        Some($t::MIN)
316                    } else {
317                        None
318                    }
319                })
320            }
321
322            fn mul_div_round(self, num: $t, denom: $t) -> Option<$t> {
323                assert_ne!(denom, 0);
324
325                let sgn = self.signum() * num.signum() * denom.signum();
326
327                let min_val: $u = 1 << ($b - 1);
328                let abs = |x: $t| if x != $t::MIN { x.abs() as $u } else { min_val };
329
330                let self_u = abs(self);
331                let num_u = abs(num);
332                let denom_u = abs(denom);
333
334                if sgn < 0 {
335                    let r =
336                        ((self_u as $v) * (num_u as $v) + ((denom_u >> 1) as $v)) / (denom_u as $v);
337                    if r > $u::MAX as $v {
338                        None
339                    } else {
340                        Some(r as $u)
341                    }
342                } else {
343                    self_u.mul_div_round(num_u, denom_u)
344                }
345                .and_then(|r| {
346                    if r <= $t::MAX as $u {
347                        Some(sgn * (r as $t))
348                    } else if sgn < 0 && r == min_val {
349                        Some($t::MIN)
350                    } else {
351                        None
352                    }
353                })
354            }
355
356            fn mul_div_ceil(self, num: $t, denom: $t) -> Option<$t> {
357                assert_ne!(denom, 0);
358
359                let sgn = self.signum() * num.signum() * denom.signum();
360
361                let min_val: $u = 1 << ($b - 1);
362                let abs = |x: $t| if x != $t::MIN { x.abs() as $u } else { min_val };
363
364                let self_u = abs(self);
365                let num_u = abs(num);
366                let denom_u = abs(denom);
367
368                if sgn < 0 {
369                    self_u.mul_div_floor(num_u, denom_u)
370                } else {
371                    self_u.mul_div_ceil(num_u, denom_u)
372                }
373                .and_then(|r| {
374                    if r <= $t::MAX as $u {
375                        Some(sgn * (r as $t))
376                    } else if sgn < 0 && r == min_val {
377                        Some($t::MIN)
378                    } else {
379                        None
380                    }
381                })
382            }
383        }
384    };
385}
386
387mul_div_impl_signed!(i64, u64, u128, 64);
388mul_div_impl_signed!(i32, u32, u64, 32);
389mul_div_impl_signed!(i16, u16, u32, 16);
390mul_div_impl_signed!(i8, u8, u16, 8);
391
392#[cfg(test)]
393macro_rules! mul_div_impl_signed_tests {
394    ($t:ident, $u:ident) => {
395        use super::*;
396
397        use quickcheck::{quickcheck, Arbitrary, Gen};
398
399        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
400        struct NonZero($t);
401
402        impl Arbitrary for NonZero {
403            fn arbitrary(g: &mut Gen) -> Self {
404                loop {
405                    let v = $t::arbitrary(g);
406                    if v != 0 {
407                        return NonZero(v);
408                    }
409                }
410            }
411        }
412
413        quickcheck! {
414            fn scale_floor(val: $t, num: $t, den: NonZero) -> bool {
415                let res = val.mul_div_floor(num, den.0);
416
417                let sgn = val.signum() * num.signum() * den.0.signum();
418                let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
419                let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
420
421                if sgn < 0 && expected_rem.abs() != 0 {
422                    expected -= 1
423                }
424
425                if expected > $t::MAX as $u || expected < $t::MIN as $u {
426                    res.is_none()
427                } else {
428                    res == Some(expected as $t)
429                }
430            }
431        }
432
433        quickcheck! {
434            fn scale_round(val: $t, num: $t, den: NonZero) -> bool {
435                let res = val.mul_div_round(num, den.0);
436
437                let sgn = val.signum() * num.signum() * den.0.signum();
438                let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
439                let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
440
441                if sgn < 0 && expected_rem.abs() >= ((den.0 as $u).abs() + 1) >> 1 {
442                    expected -= 1
443                } else if sgn > 0 && expected_rem.abs() >= ((den.0 as $u).abs() + 1) >> 1 {
444                    expected += 1
445                }
446
447                if expected > $t::MAX as $u || expected < $t::MIN as $u {
448                    res.is_none()
449                } else {
450                    res == Some(expected as $t)
451                }
452            }
453        }
454
455        quickcheck! {
456            fn scale_ceil(val: $t, num: $t, den: NonZero) -> bool {
457                let res = val.mul_div_ceil(num, den.0);
458
459                let sgn = val.signum() * num.signum() * den.0.signum();
460                let mut expected = ((val as $u) * (num as $u)) / (den.0 as $u);
461                let expected_rem = ((val as $u) * (num as $u)) % (den.0 as $u);
462
463                if sgn > 0 && expected_rem.abs() != 0 {
464                    expected += 1
465                }
466
467                if expected > $t::MAX as $u || expected < $t::MIN as $u {
468                    res.is_none()
469                } else {
470                    res == Some(expected as $t)
471                }
472            }
473        }
474    };
475}
476
477// FIXME: https://github.com/rust-lang/rust/issues/12249
478#[cfg(test)]
479mod muldiv_i64_tests {
480    mul_div_impl_signed_tests!(i64, i128);
481}
482
483#[cfg(test)]
484mod muldiv_i32_tests {
485    mul_div_impl_signed_tests!(i32, i64);
486}
487
488#[cfg(test)]
489mod muldiv_i16_tests {
490    mul_div_impl_signed_tests!(i16, i32);
491}
492
493#[cfg(test)]
494mod muldiv_i8_tests {
495    mul_div_impl_signed_tests!(i8, i16);
496}