rav1e/util/
kmeans.rs

1// Copyright (c) 2022-2023, The rav1e contributors. All rights reserved
2//
3// This source code is subject to the terms of the BSD 2 Clause License and
4// the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
5// was not distributed with this source code in the LICENSE file, you can
6// obtain it at www.aomedia.org/license/software. If the Alliance for Open
7// Media Patent License 1.0 was not distributed with this source code in the
8// PATENTS file, you can obtain it at www.aomedia.org/license/patent.
9
10/// Find k-means for a sorted slice of integers that can be summed in `i64`.
11pub fn kmeans<T, const K: usize>(data: &[T]) -> [T; K]
12where
13  T: Copy,
14  T: Into<i64>,
15  T: PartialEq,
16  T: PartialOrd,
17  i64: TryInto<T>,
18  <i64 as std::convert::TryInto<T>>::Error: std::fmt::Debug,
19{
20  let mut low = [0; K];
21  for (i, val) in low.iter_mut().enumerate() {
22    *val = (i * (data.len() - 1)) / (K - 1);
23  }
24  let mut means = low.map(|i| unsafe { *data.get_unchecked(i) });
25  let mut high = low;
26  let mut sum = [0i64; K];
27  high[K - 1] = data.len();
28  sum[K - 1] = means[K - 1].into();
29
30  // Constrain complexity to O(n log n)
31  let limit = 2 * (usize::BITS - data.len().leading_zeros());
32  for _ in 0..limit {
33    for (i, (threshold, (low, high))) in (means.iter().skip(1).zip(&means))
34      .map(|(&c1, &c2)| unsafe {
35        ((c1.into() + c2.into() + 1) >> 1).try_into().unwrap_unchecked()
36      })
37      .zip(low.iter_mut().skip(1).zip(&mut high))
38      .enumerate()
39    {
40      unsafe {
41        scan(high, low, sum.get_unchecked_mut(i..=i + 1), data, threshold);
42      }
43    }
44    let mut changed = false;
45    for (((m, sum), high), low) in
46      means.iter_mut().zip(&sum).zip(&high).zip(&low)
47    {
48      let count = (high - low) as i64;
49      if count == 0 {
50        continue;
51      }
52      let new_mean = unsafe {
53        ((sum + (count >> 1)).saturating_div(count))
54          .try_into()
55          .unwrap_unchecked()
56      };
57      changed |= *m != new_mean;
58      *m = new_mean;
59    }
60    if !changed {
61      break;
62    }
63  }
64
65  means
66}
67
68#[inline(never)]
69unsafe fn scan<T>(
70  high: &mut usize, low: &mut usize, sum: &mut [i64], data: &[T], t: T,
71) where
72  T: Copy,
73  T: Into<i64>,
74  T: PartialEq,
75  T: PartialOrd,
76{
77  let mut n = *high;
78  let mut s = *sum.get_unchecked(0);
79  for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d > t) {
80    s -= d.into();
81    n -= 1;
82  }
83  for &d in data.get_unchecked(n..).iter().take_while(|&d| *d <= t) {
84    s += d.into();
85    n += 1;
86  }
87  *high = n;
88  *sum.get_unchecked_mut(0) = s;
89
90  let mut n = *low;
91  let mut s = *sum.get_unchecked(1);
92  for &d in data.get_unchecked(n..).iter().take_while(|&d| *d < t) {
93    s -= d.into();
94    n += 1;
95  }
96  for &d in data.get_unchecked(..n).iter().rev().take_while(|&d| *d >= t) {
97    s += d.into();
98    n -= 1;
99  }
100  *low = n;
101  *sum.get_unchecked_mut(1) = s;
102}
103
104#[cfg(test)]
105mod test {
106  use super::*;
107
108  #[test]
109  fn three_means() {
110    let mut data = [1, 2, 3, 10, 11, 12, 20, 21, 22];
111    data.sort_unstable();
112    let centroids = kmeans(&data);
113    assert_eq!(centroids, [2, 11, 21]);
114  }
115
116  #[test]
117  fn four_means() {
118    let mut data = [30, 31, 32, 1, 2, 3, 10, 11, 12, 20, 21, 22];
119    data.sort_unstable();
120    let centroids = kmeans(&data);
121    assert_eq!(centroids, [2, 11, 21, 31]);
122  }
123}