emath/
ordered_float.rs

1//! Total order on floating point types.
2//! Can be used for sorting, min/max computation, and other collection algorithms.
3
4use std::cmp::Ordering;
5use std::hash::{Hash, Hasher};
6
7/// Wraps a floating-point value to add total order and hash.
8/// Possible types for `T` are `f32` and `f64`.
9///
10/// All NaNs are considered equal to each other.
11/// The size of zero is ignored.
12///
13/// See also [`Float`].
14#[derive(Clone, Copy)]
15pub struct OrderedFloat<T>(pub T);
16
17impl<T: Float + Copy> OrderedFloat<T> {
18    #[inline]
19    pub fn into_inner(self) -> T {
20        self.0
21    }
22}
23
24impl<T: std::fmt::Debug> std::fmt::Debug for OrderedFloat<T> {
25    #[inline]
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        self.0.fmt(f)
28    }
29}
30
31impl<T: Float> Eq for OrderedFloat<T> {}
32
33impl<T: Float> PartialEq<Self> for OrderedFloat<T> {
34    #[inline]
35    fn eq(&self, other: &Self) -> bool {
36        // NaNs are considered equal (equivalent) when it comes to ordering
37        if self.0.is_nan() {
38            other.0.is_nan()
39        } else {
40            self.0 == other.0
41        }
42    }
43}
44
45impl<T: Float> PartialOrd<Self> for OrderedFloat<T> {
46    #[inline]
47    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48        Some(self.cmp(other))
49    }
50}
51
52impl<T: Float> Ord for OrderedFloat<T> {
53    #[inline]
54    fn cmp(&self, other: &Self) -> Ordering {
55        match self.0.partial_cmp(&other.0) {
56            Some(ord) => ord,
57            None => self.0.is_nan().cmp(&other.0.is_nan()),
58        }
59    }
60}
61
62impl<T: Float> Hash for OrderedFloat<T> {
63    fn hash<H: Hasher>(&self, state: &mut H) {
64        self.0.hash(state);
65    }
66}
67
68impl<T> From<T> for OrderedFloat<T> {
69    #[inline]
70    fn from(val: T) -> Self {
71        Self(val)
72    }
73}
74
75// ----------------------------------------------------------------------------
76
77/// Extension trait to provide `ord()` method.
78///
79/// Example with `f64`:
80/// ```
81/// use emath::Float as _;
82///
83/// let array = [1.0, 2.5, 2.0];
84/// let max = array.iter().max_by_key(|val| val.ord());
85///
86/// assert_eq!(max, Some(&2.5));
87/// ```
88pub trait Float: PartialOrd + PartialEq + private::FloatImpl {
89    /// Type to provide total order, useful as key in sorted contexts.
90    fn ord(self) -> OrderedFloat<Self>
91    where
92        Self: Sized;
93}
94
95impl Float for f32 {
96    #[inline]
97    fn ord(self) -> OrderedFloat<Self> {
98        OrderedFloat(self)
99    }
100}
101
102impl Float for f64 {
103    #[inline]
104    fn ord(self) -> OrderedFloat<Self> {
105        OrderedFloat(self)
106    }
107}
108
109// Keep this trait in private module, to avoid exposing its methods as extensions in user code
110mod private {
111    use super::{Hash as _, Hasher};
112
113    pub trait FloatImpl {
114        fn is_nan(&self) -> bool;
115
116        fn hash<H: Hasher>(&self, state: &mut H);
117    }
118
119    impl FloatImpl for f32 {
120        #[inline]
121        fn is_nan(&self) -> bool {
122            Self::is_nan(*self)
123        }
124
125        #[inline]
126        fn hash<H: Hasher>(&self, state: &mut H) {
127            let bits = if self.is_nan() {
128                // "Canonical" NaN.
129                0x7fc00000
130            } else {
131                // A trick taken from the `ordered-float` crate: -0.0 + 0.0 == +0.0.
132                // https://github.com/reem/rust-ordered-float/blob/1841f0541ea0e56779cbac03de2705149e020675/src/lib.rs#L2178-L2181
133                (self + 0.0).to_bits()
134            };
135            bits.hash(state);
136        }
137    }
138
139    impl FloatImpl for f64 {
140        #[inline]
141        fn is_nan(&self) -> bool {
142            Self::is_nan(*self)
143        }
144
145        #[inline]
146        fn hash<H: Hasher>(&self, state: &mut H) {
147            let bits = if self.is_nan() {
148                // "Canonical" NaN.
149                0x7ff8000000000000
150            } else {
151                (self + 0.0).to_bits()
152            };
153            bits.hash(state);
154        }
155    }
156}