1use core::{fmt, iter, ops::Range};
35
36use smallvec::SmallVec;
37
38mod buffer;
39mod texture;
40
41pub(crate) use buffer::{BufferInitTracker, BufferInitTrackerAction};
42pub(crate) use texture::{
43    has_copy_partial_init_tracker_coverage, TextureInitRange, TextureInitTracker,
44    TextureInitTrackerAction,
45};
46
47#[derive(Debug, Clone, Copy)]
48pub(crate) enum MemoryInitKind {
49    ImplicitlyInitialized,
52    NeedsInitializedMemory,
55}
56
57type UninitializedRangeVec<Idx> = SmallVec<[Range<Idx>; 1]>;
60
61#[derive(Debug, Clone)]
63pub(crate) struct InitTracker<Idx: Ord + Copy + Default> {
64    uninitialized_ranges: UninitializedRangeVec<Idx>,
67}
68
69pub(crate) struct UninitializedIter<'a, Idx: fmt::Debug + Ord + Copy> {
70    uninitialized_ranges: &'a UninitializedRangeVec<Idx>,
71    drain_range: Range<Idx>,
72    next_index: usize,
73}
74
75impl<'a, Idx> Iterator for UninitializedIter<'a, Idx>
76where
77    Idx: fmt::Debug + Ord + Copy,
78{
79    type Item = Range<Idx>;
80
81    fn next(&mut self) -> Option<Self::Item> {
82        self.uninitialized_ranges
83            .get(self.next_index)
84            .and_then(|range| {
85                if range.start < self.drain_range.end {
86                    self.next_index += 1;
87                    Some(
88                        range.start.max(self.drain_range.start)
89                            ..range.end.min(self.drain_range.end),
90                    )
91                } else {
92                    None
93                }
94            })
95    }
96}
97
98pub(crate) struct InitTrackerDrain<'a, Idx: fmt::Debug + Ord + Copy> {
99    uninitialized_ranges: &'a mut UninitializedRangeVec<Idx>,
100    drain_range: Range<Idx>,
101    first_index: usize,
102    next_index: usize,
103}
104
105impl<'a, Idx> Iterator for InitTrackerDrain<'a, Idx>
106where
107    Idx: fmt::Debug + Ord + Copy,
108{
109    type Item = Range<Idx>;
110
111    fn next(&mut self) -> Option<Self::Item> {
112        if let Some(r) = self
113            .uninitialized_ranges
114            .get(self.next_index)
115            .and_then(|range| {
116                if range.start < self.drain_range.end {
117                    Some(range.clone())
118                } else {
119                    None
120                }
121            })
122        {
123            self.next_index += 1;
124            Some(r.start.max(self.drain_range.start)..r.end.min(self.drain_range.end))
125        } else {
126            let num_affected = self.next_index - self.first_index;
127            if num_affected == 0 {
128                return None;
129            }
130            let first_range = &mut self.uninitialized_ranges[self.first_index];
131
132            if num_affected == 1
134                && first_range.start < self.drain_range.start
135                && first_range.end > self.drain_range.end
136            {
137                let old_start = first_range.start;
138                first_range.start = self.drain_range.end;
139                self.uninitialized_ranges
140                    .insert(self.first_index, old_start..self.drain_range.start);
141            }
142            else {
144                let remove_start = if first_range.start >= self.drain_range.start {
145                    self.first_index
146                } else {
147                    first_range.end = self.drain_range.start;
148                    self.first_index + 1
149                };
150
151                let last_range = &mut self.uninitialized_ranges[self.next_index - 1];
152                let remove_end = if last_range.end <= self.drain_range.end {
153                    self.next_index
154                } else {
155                    last_range.start = self.drain_range.end;
156                    self.next_index - 1
157                };
158
159                self.uninitialized_ranges.drain(remove_start..remove_end);
160            }
161
162            None
163        }
164    }
165}
166
167impl<'a, Idx> Drop for InitTrackerDrain<'a, Idx>
168where
169    Idx: fmt::Debug + Ord + Copy,
170{
171    fn drop(&mut self) {
172        if self.next_index <= self.first_index {
173            for _ in self {}
174        }
175    }
176}
177
178impl<Idx> InitTracker<Idx>
179where
180    Idx: fmt::Debug + Ord + Copy + Default,
181{
182    pub(crate) fn new(size: Idx) -> Self {
183        Self {
184            uninitialized_ranges: iter::once(Idx::default()..size).collect(),
185        }
186    }
187
188    pub(crate) fn check(&self, query_range: Range<Idx>) -> Option<Range<Idx>> {
197        let index = self
198            .uninitialized_ranges
199            .partition_point(|r| r.end <= query_range.start);
200        self.uninitialized_ranges
201            .get(index)
202            .and_then(|start_range| {
203                if start_range.start < query_range.end {
204                    let start = start_range.start.max(query_range.start);
205                    match self.uninitialized_ranges.get(index + 1) {
206                        Some(next_range) => {
207                            if next_range.start < query_range.end {
208                                Some(start..query_range.end)
211                            } else {
212                                Some(start..start_range.end.min(query_range.end))
213                            }
214                        }
215                        None => Some(start..start_range.end.min(query_range.end)),
216                    }
217                } else {
218                    None
219                }
220            })
221    }
222
223    pub(crate) fn uninitialized(&mut self, drain_range: Range<Idx>) -> UninitializedIter<Idx> {
225        let index = self
226            .uninitialized_ranges
227            .partition_point(|r| r.end <= drain_range.start);
228        UninitializedIter {
229            drain_range,
230            uninitialized_ranges: &self.uninitialized_ranges,
231            next_index: index,
232        }
233    }
234
235    pub(crate) fn drain(&mut self, drain_range: Range<Idx>) -> InitTrackerDrain<Idx> {
237        let index = self
238            .uninitialized_ranges
239            .partition_point(|r| r.end <= drain_range.start);
240        InitTrackerDrain {
241            drain_range,
242            uninitialized_ranges: &mut self.uninitialized_ranges,
243            first_index: index,
244            next_index: index,
245        }
246    }
247}
248
249impl InitTracker<u32> {
250    #[allow(dead_code)]
252    pub(crate) fn discard(&mut self, pos: u32) {
253        let r_idx = self.uninitialized_ranges.partition_point(|r| r.end < pos);
255        if let Some(r) = self.uninitialized_ranges.get(r_idx) {
256            if r.end == pos {
258                if let Some(right) = self.uninitialized_ranges.get(r_idx + 1) {
260                    if right.start == pos + 1 {
261                        self.uninitialized_ranges[r_idx] = r.start..right.end;
262                        self.uninitialized_ranges.remove(r_idx + 1);
263                        return;
264                    }
265                }
266                self.uninitialized_ranges[r_idx] = r.start..(pos + 1);
267            } else if r.start > pos {
268                if r.start == pos + 1 {
270                    self.uninitialized_ranges[r_idx] = pos..r.end;
271                } else {
272                    self.uninitialized_ranges.push(pos..(pos + 1));
274                }
275            }
276        } else {
277            self.uninitialized_ranges.push(pos..(pos + 1));
278        }
279    }
280}
281
282#[cfg(test)]
283mod test {
284    use alloc::{vec, vec::Vec};
285    use core::ops::Range;
286
287    type Tracker = super::InitTracker<u32>;
288
289    #[test]
290    fn check_for_newly_created_tracker() {
291        let tracker = Tracker::new(10);
292        assert_eq!(tracker.check(0..10), Some(0..10));
293        assert_eq!(tracker.check(0..3), Some(0..3));
294        assert_eq!(tracker.check(3..4), Some(3..4));
295        assert_eq!(tracker.check(4..10), Some(4..10));
296    }
297
298    #[test]
299    fn check_for_drained_tracker() {
300        let mut tracker = Tracker::new(10);
301        tracker.drain(0..10);
302        assert_eq!(tracker.check(0..10), None);
303        assert_eq!(tracker.check(0..3), None);
304        assert_eq!(tracker.check(3..4), None);
305        assert_eq!(tracker.check(4..10), None);
306    }
307
308    #[test]
309    fn check_for_partially_filled_tracker() {
310        let mut tracker = Tracker::new(25);
311        tracker.drain(0..5);
313        tracker.drain(10..15);
314        tracker.drain(20..25);
315
316        assert_eq!(tracker.check(0..25), Some(5..25)); assert_eq!(tracker.check(0..5), None); assert_eq!(tracker.check(3..8), Some(5..8)); assert_eq!(tracker.check(3..17), Some(5..17)); assert_eq!(tracker.check(8..22), Some(8..22));
324        assert_eq!(tracker.check(17..22), Some(17..20));
326        assert_eq!(tracker.check(20..25), None);
328    }
329
330    #[test]
331    fn drain_already_drained() {
332        let mut tracker = Tracker::new(30);
333        tracker.drain(10..20);
334
335        tracker.drain(5..15); tracker.drain(15..25); tracker.drain(0..30); tracker.drain(0..30);
342
343        assert_eq!(tracker.check(0..30), None);
344    }
345
346    #[test]
347    fn drain_never_returns_ranges_twice_for_same_range() {
348        let mut tracker = Tracker::new(19);
349        assert_eq!(tracker.drain(0..19).count(), 1);
350        assert_eq!(tracker.drain(0..19).count(), 0);
351
352        let mut tracker = Tracker::new(17);
353        assert_eq!(tracker.drain(5..8).count(), 1);
354        assert_eq!(tracker.drain(5..8).count(), 0);
355        assert_eq!(tracker.drain(1..3).count(), 1);
356        assert_eq!(tracker.drain(1..3).count(), 0);
357        assert_eq!(tracker.drain(7..13).count(), 1);
358        assert_eq!(tracker.drain(7..13).count(), 0);
359    }
360
361    #[test]
362    fn drain_splits_ranges_correctly() {
363        let mut tracker = Tracker::new(1337);
364        assert_eq!(
365            tracker.drain(21..42).collect::<Vec<Range<u32>>>(),
366            vec![21..42]
367        );
368        assert_eq!(
369            tracker.drain(900..1000).collect::<Vec<Range<u32>>>(),
370            vec![900..1000]
371        );
372
373        assert_eq!(
375            tracker.drain(5..1003).collect::<Vec<Range<u32>>>(),
376            vec![5..21, 42..900, 1000..1003]
377        );
378        assert_eq!(
379            tracker.drain(0..1337).collect::<Vec<Range<u32>>>(),
380            vec![0..5, 1003..1337]
381        );
382    }
383
384    #[test]
385    fn discard_adds_range_on_cleared() {
386        let mut tracker = Tracker::new(10);
387        tracker.drain(0..10);
388        tracker.discard(0);
389        tracker.discard(5);
390        tracker.discard(9);
391        assert_eq!(tracker.check(0..1), Some(0..1));
392        assert_eq!(tracker.check(1..5), None);
393        assert_eq!(tracker.check(5..6), Some(5..6));
394        assert_eq!(tracker.check(6..9), None);
395        assert_eq!(tracker.check(9..10), Some(9..10));
396    }
397
398    #[test]
399    fn discard_does_nothing_on_uncleared() {
400        let mut tracker = Tracker::new(10);
401        tracker.discard(0);
402        tracker.discard(5);
403        tracker.discard(9);
404        assert_eq!(tracker.uninitialized_ranges.len(), 1);
405        assert_eq!(tracker.uninitialized_ranges[0], 0..10);
406    }
407
408    #[test]
409    fn discard_extends_ranges() {
410        let mut tracker = Tracker::new(10);
411        tracker.drain(3..7);
412        tracker.discard(2);
413        tracker.discard(7);
414        assert_eq!(tracker.uninitialized_ranges.len(), 2);
415        assert_eq!(tracker.uninitialized_ranges[0], 0..3);
416        assert_eq!(tracker.uninitialized_ranges[1], 7..10);
417    }
418
419    #[test]
420    fn discard_merges_ranges() {
421        let mut tracker = Tracker::new(10);
422        tracker.drain(3..4);
423        tracker.discard(3);
424        assert_eq!(tracker.uninitialized_ranges.len(), 1);
425        assert_eq!(tracker.uninitialized_ranges[0], 0..10);
426    }
427}