1use super::coin_flipper::CoinFlipper;
12#[allow(unused)]
13use super::IndexedRandom;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18pub trait IteratorRandom: Iterator + Sized {
35    fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67    where
68        R: Rng + ?Sized,
69    {
70        let (mut lower, mut upper) = self.size_hint();
71        let mut result = None;
72
73        if upper == Some(lower) {
77            return match lower {
78                0 => None,
79                1 => self.next(),
80                _ => self.nth(rng.random_range(..lower)),
81            };
82        }
83
84        let mut coin_flipper = CoinFlipper::new(rng);
85        let mut consumed = 0;
86
87        loop {
89            if lower > 1 {
90                let ix = coin_flipper.rng.random_range(..lower + consumed);
91                let skip = if ix < lower {
92                    result = self.nth(ix);
93                    lower - (ix + 1)
94                } else {
95                    lower
96                };
97                if upper == Some(lower) {
98                    return result;
99                }
100                consumed += lower;
101                if skip > 0 {
102                    self.nth(skip - 1);
103                }
104            } else {
105                let elem = self.next();
106                if elem.is_none() {
107                    return result;
108                }
109                consumed += 1;
110                if coin_flipper.random_ratio_one_over(consumed) {
111                    result = elem;
112                }
113            }
114
115            let hint = self.size_hint();
116            lower = hint.0;
117            upper = hint.1;
118        }
119    }
120
121    #[allow(unknown_lints)]
144    #[allow(clippy::double_ended_iterator_last)]
145    fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
146    where
147        R: Rng + ?Sized,
148    {
149        let mut consumed = 0;
150        let mut result = None;
151        let mut coin_flipper = CoinFlipper::new(rng);
152
153        loop {
154            let mut next = 0;
159
160            let (lower, _) = self.size_hint();
161            if lower >= 2 {
162                let highest_selected = (0..lower)
163                    .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
164                    .last();
165
166                consumed += lower;
167                next = lower;
168
169                if let Some(ix) = highest_selected {
170                    result = self.nth(ix);
171                    next -= ix + 1;
172                    debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
173                }
174            }
175
176            let elem = self.nth(next);
177            if elem.is_none() {
178                return result;
179            }
180
181            if coin_flipper.random_ratio_one_over(consumed + 1) {
182                result = elem;
183            }
184            consumed += 1;
185        }
186    }
187
188    fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
204    where
205        R: Rng + ?Sized,
206    {
207        let amount = buf.len();
208        let mut len = 0;
209        while len < amount {
210            if let Some(elem) = self.next() {
211                buf[len] = elem;
212                len += 1;
213            } else {
214                return len;
216            }
217        }
218
219        for (i, elem) in self.enumerate() {
221            let k = rng.random_range(..i + 1 + amount);
222            if let Some(slot) = buf.get_mut(k) {
223                *slot = elem;
224            }
225        }
226        len
227    }
228
229    #[cfg(feature = "alloc")]
244    fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
245    where
246        R: Rng + ?Sized,
247    {
248        let mut reservoir = Vec::with_capacity(amount);
249        reservoir.extend(self.by_ref().take(amount));
250
251        if reservoir.len() == amount {
256            for (i, elem) in self.enumerate() {
257                let k = rng.random_range(..i + 1 + amount);
258                if let Some(slot) = reservoir.get_mut(k) {
259                    *slot = elem;
260                }
261            }
262        } else {
263            reservoir.shrink_to_fit();
266        }
267        reservoir
268    }
269}
270
271impl<I> IteratorRandom for I where I: Iterator + Sized {}
272
273#[cfg(test)]
274mod test {
275    use super::*;
276    #[cfg(all(feature = "alloc", not(feature = "std")))]
277    use alloc::vec::Vec;
278
279    #[derive(Clone)]
280    struct UnhintedIterator<I: Iterator + Clone> {
281        iter: I,
282    }
283    impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
284        type Item = I::Item;
285
286        fn next(&mut self) -> Option<Self::Item> {
287            self.iter.next()
288        }
289    }
290
291    #[derive(Clone)]
292    struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
293        iter: I,
294        chunk_remaining: usize,
295        chunk_size: usize,
296        hint_total_size: bool,
297    }
298    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
299        type Item = I::Item;
300
301        fn next(&mut self) -> Option<Self::Item> {
302            if self.chunk_remaining == 0 {
303                self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
304            }
305            self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
306
307            self.iter.next()
308        }
309
310        fn size_hint(&self) -> (usize, Option<usize>) {
311            (
312                self.chunk_remaining,
313                if self.hint_total_size {
314                    Some(self.iter.len())
315                } else {
316                    None
317                },
318            )
319        }
320    }
321
322    #[derive(Clone)]
323    struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
324        iter: I,
325        window_size: usize,
326        hint_total_size: bool,
327    }
328    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
329        type Item = I::Item;
330
331        fn next(&mut self) -> Option<Self::Item> {
332            self.iter.next()
333        }
334
335        fn size_hint(&self) -> (usize, Option<usize>) {
336            (
337                core::cmp::min(self.iter.len(), self.window_size),
338                if self.hint_total_size {
339                    Some(self.iter.len())
340                } else {
341                    None
342                },
343            )
344        }
345    }
346
347    #[test]
348    #[cfg_attr(miri, ignore)] fn test_iterator_choose() {
350        let r = &mut crate::test::rng(109);
351        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
352            let mut chosen = [0i32; 9];
353            for _ in 0..1000 {
354                let picked = iter.clone().choose(r).unwrap();
355                chosen[picked] += 1;
356            }
357            for count in chosen.iter() {
358                assert!(
362                    72 < *count && *count < 154,
363                    "count not close to 1000/9: {}",
364                    count
365                );
366            }
367        }
368
369        test_iter(r, 0..9);
370        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
371        #[cfg(feature = "alloc")]
372        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
373        test_iter(r, UnhintedIterator { iter: 0..9 });
374        test_iter(
375            r,
376            ChunkHintedIterator {
377                iter: 0..9,
378                chunk_size: 4,
379                chunk_remaining: 4,
380                hint_total_size: false,
381            },
382        );
383        test_iter(
384            r,
385            ChunkHintedIterator {
386                iter: 0..9,
387                chunk_size: 4,
388                chunk_remaining: 4,
389                hint_total_size: true,
390            },
391        );
392        test_iter(
393            r,
394            WindowHintedIterator {
395                iter: 0..9,
396                window_size: 2,
397                hint_total_size: false,
398            },
399        );
400        test_iter(
401            r,
402            WindowHintedIterator {
403                iter: 0..9,
404                window_size: 2,
405                hint_total_size: true,
406            },
407        );
408
409        assert_eq!((0..0).choose(r), None);
410        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
411    }
412
413    #[test]
414    #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable() {
416        let r = &mut crate::test::rng(109);
417        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
418            let mut chosen = [0i32; 9];
419            for _ in 0..1000 {
420                let picked = iter.clone().choose_stable(r).unwrap();
421                chosen[picked] += 1;
422            }
423            for count in chosen.iter() {
424                assert!(
428                    72 < *count && *count < 154,
429                    "count not close to 1000/9: {}",
430                    count
431                );
432            }
433        }
434
435        test_iter(r, 0..9);
436        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
437        #[cfg(feature = "alloc")]
438        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
439        test_iter(r, UnhintedIterator { iter: 0..9 });
440        test_iter(
441            r,
442            ChunkHintedIterator {
443                iter: 0..9,
444                chunk_size: 4,
445                chunk_remaining: 4,
446                hint_total_size: false,
447            },
448        );
449        test_iter(
450            r,
451            ChunkHintedIterator {
452                iter: 0..9,
453                chunk_size: 4,
454                chunk_remaining: 4,
455                hint_total_size: true,
456            },
457        );
458        test_iter(
459            r,
460            WindowHintedIterator {
461                iter: 0..9,
462                window_size: 2,
463                hint_total_size: false,
464            },
465        );
466        test_iter(
467            r,
468            WindowHintedIterator {
469                iter: 0..9,
470                window_size: 2,
471                hint_total_size: true,
472            },
473        );
474
475        assert_eq!((0..0).choose(r), None);
476        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
477    }
478
479    #[test]
480    #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable_stability() {
482        fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
483            let r = &mut crate::test::rng(109);
484            let mut chosen = [0i32; 9];
485            for _ in 0..1000 {
486                let picked = iter.clone().choose_stable(r).unwrap();
487                chosen[picked] += 1;
488            }
489            chosen
490        }
491
492        let reference = test_iter(0..9);
493        assert_eq!(
494            test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
495            reference
496        );
497
498        #[cfg(feature = "alloc")]
499        assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
500        assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
501        assert_eq!(
502            test_iter(ChunkHintedIterator {
503                iter: 0..9,
504                chunk_size: 4,
505                chunk_remaining: 4,
506                hint_total_size: false,
507            }),
508            reference
509        );
510        assert_eq!(
511            test_iter(ChunkHintedIterator {
512                iter: 0..9,
513                chunk_size: 4,
514                chunk_remaining: 4,
515                hint_total_size: true,
516            }),
517            reference
518        );
519        assert_eq!(
520            test_iter(WindowHintedIterator {
521                iter: 0..9,
522                window_size: 2,
523                hint_total_size: false,
524            }),
525            reference
526        );
527        assert_eq!(
528            test_iter(WindowHintedIterator {
529                iter: 0..9,
530                window_size: 2,
531                hint_total_size: true,
532            }),
533            reference
534        );
535    }
536
537    #[test]
538    #[cfg(feature = "alloc")]
539    fn test_sample_iter() {
540        let min_val = 1;
541        let max_val = 100;
542
543        let mut r = crate::test::rng(401);
544        let vals = (min_val..max_val).collect::<Vec<i32>>();
545        let small_sample = vals.iter().choose_multiple(&mut r, 5);
546        let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
547
548        assert_eq!(small_sample.len(), 5);
549        assert_eq!(large_sample.len(), vals.len());
550        assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
552
553        assert!(small_sample
554            .iter()
555            .all(|e| { **e >= min_val && **e <= max_val }));
556    }
557
558    #[test]
559    fn value_stability_choose() {
560        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
561            let mut rng = crate::test::rng(411);
562            iter.choose(&mut rng)
563        }
564
565        assert_eq!(choose([].iter().cloned()), None);
566        assert_eq!(choose(0..100), Some(33));
567        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
568        assert_eq!(
569            choose(ChunkHintedIterator {
570                iter: 0..100,
571                chunk_size: 32,
572                chunk_remaining: 32,
573                hint_total_size: false,
574            }),
575            Some(91)
576        );
577        assert_eq!(
578            choose(ChunkHintedIterator {
579                iter: 0..100,
580                chunk_size: 32,
581                chunk_remaining: 32,
582                hint_total_size: true,
583            }),
584            Some(91)
585        );
586        assert_eq!(
587            choose(WindowHintedIterator {
588                iter: 0..100,
589                window_size: 32,
590                hint_total_size: false,
591            }),
592            Some(34)
593        );
594        assert_eq!(
595            choose(WindowHintedIterator {
596                iter: 0..100,
597                window_size: 32,
598                hint_total_size: true,
599            }),
600            Some(34)
601        );
602    }
603
604    #[test]
605    fn value_stability_choose_stable() {
606        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
607            let mut rng = crate::test::rng(411);
608            iter.choose_stable(&mut rng)
609        }
610
611        assert_eq!(choose([].iter().cloned()), None);
612        assert_eq!(choose(0..100), Some(27));
613        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
614        assert_eq!(
615            choose(ChunkHintedIterator {
616                iter: 0..100,
617                chunk_size: 32,
618                chunk_remaining: 32,
619                hint_total_size: false,
620            }),
621            Some(27)
622        );
623        assert_eq!(
624            choose(ChunkHintedIterator {
625                iter: 0..100,
626                chunk_size: 32,
627                chunk_remaining: 32,
628                hint_total_size: true,
629            }),
630            Some(27)
631        );
632        assert_eq!(
633            choose(WindowHintedIterator {
634                iter: 0..100,
635                window_size: 32,
636                hint_total_size: false,
637            }),
638            Some(27)
639        );
640        assert_eq!(
641            choose(WindowHintedIterator {
642                iter: 0..100,
643                window_size: 32,
644                hint_total_size: true,
645            }),
646            Some(27)
647        );
648    }
649
650    #[test]
651    fn value_stability_choose_multiple() {
652        fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
653            let mut rng = crate::test::rng(412);
654            let mut buf = [0u32; 8];
655            assert_eq!(
656                iter.clone().choose_multiple_fill(&mut rng, &mut buf),
657                v.len()
658            );
659            assert_eq!(&buf[0..v.len()], v);
660
661            #[cfg(feature = "alloc")]
662            {
663                let mut rng = crate::test::rng(412);
664                assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
665            }
666        }
667
668        do_test(0..4, &[0, 1, 2, 3]);
669        do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
670        do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
671    }
672}