1use crate::util::int::Usize;
2
3#[derive(Clone, Copy)]
10pub(crate) struct ByteClasses([u8; 256]);
11
12impl ByteClasses {
13    pub(crate) fn empty() -> ByteClasses {
16        ByteClasses([0; 256])
17    }
18
19    pub(crate) fn singletons() -> ByteClasses {
22        let mut classes = ByteClasses::empty();
23        for b in 0..=255 {
24            classes.set(b, b);
25        }
26        classes
27    }
28
29    #[inline]
31    pub(crate) fn set(&mut self, byte: u8, class: u8) {
32        self.0[usize::from(byte)] = class;
33    }
34
35    #[inline]
37    pub(crate) fn get(&self, byte: u8) -> u8 {
38        self.0[usize::from(byte)]
39    }
40
41    #[inline]
45    pub(crate) fn alphabet_len(&self) -> usize {
46        usize::from(self.0[255]) + 1
49    }
50
51    pub(crate) fn stride2(&self) -> usize {
60        let zeros = self.alphabet_len().next_power_of_two().trailing_zeros();
61        usize::try_from(zeros).unwrap()
62    }
63
64    pub(crate) fn stride(&self) -> usize {
68        1 << self.stride2()
69    }
70
71    #[inline]
75    pub(crate) fn is_singleton(&self) -> bool {
76        self.alphabet_len() == 256
77    }
78
79    pub(crate) fn iter(&self) -> ByteClassIter {
81        ByteClassIter { it: 0..self.alphabet_len() }
82    }
83
84    pub(crate) fn elements(&self, class: u8) -> ByteClassElements {
86        ByteClassElements { classes: self, class, bytes: 0..=255 }
87    }
88
89    fn element_ranges(&self, class: u8) -> ByteClassElementRanges {
94        ByteClassElementRanges { elements: self.elements(class), range: None }
95    }
96}
97
98impl core::fmt::Debug for ByteClasses {
99    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100        if self.is_singleton() {
101            write!(f, "ByteClasses(<one-class-per-byte>)")
102        } else {
103            write!(f, "ByteClasses(")?;
104            for (i, class) in self.iter().enumerate() {
105                if i > 0 {
106                    write!(f, ", ")?;
107                }
108                write!(f, "{:?} => [", class)?;
109                for (start, end) in self.element_ranges(class) {
110                    if start == end {
111                        write!(f, "{:?}", start)?;
112                    } else {
113                        write!(f, "{:?}-{:?}", start, end)?;
114                    }
115                }
116                write!(f, "]")?;
117            }
118            write!(f, ")")
119        }
120    }
121}
122
123#[derive(Debug)]
125pub(crate) struct ByteClassIter {
126    it: core::ops::Range<usize>,
127}
128
129impl Iterator for ByteClassIter {
130    type Item = u8;
131
132    fn next(&mut self) -> Option<u8> {
133        self.it.next().map(|class| class.as_u8())
134    }
135}
136
137#[derive(Debug)]
139pub(crate) struct ByteClassElements<'a> {
140    classes: &'a ByteClasses,
141    class: u8,
142    bytes: core::ops::RangeInclusive<u8>,
143}
144
145impl<'a> Iterator for ByteClassElements<'a> {
146    type Item = u8;
147
148    fn next(&mut self) -> Option<u8> {
149        while let Some(byte) = self.bytes.next() {
150            if self.class == self.classes.get(byte) {
151                return Some(byte);
152            }
153        }
154        None
155    }
156}
157
158#[derive(Debug)]
161pub(crate) struct ByteClassElementRanges<'a> {
162    elements: ByteClassElements<'a>,
163    range: Option<(u8, u8)>,
164}
165
166impl<'a> Iterator for ByteClassElementRanges<'a> {
167    type Item = (u8, u8);
168
169    fn next(&mut self) -> Option<(u8, u8)> {
170        loop {
171            let element = match self.elements.next() {
172                None => return self.range.take(),
173                Some(element) => element,
174            };
175            match self.range.take() {
176                None => {
177                    self.range = Some((element, element));
178                }
179                Some((start, end)) => {
180                    if usize::from(end) + 1 != usize::from(element) {
181                        self.range = Some((element, element));
182                        return Some((start, end));
183                    }
184                    self.range = Some((start, element));
185                }
186            }
187        }
188    }
189}
190
191#[derive(Clone, Debug)]
207pub(crate) struct ByteClassSet(ByteSet);
208
209impl Default for ByteClassSet {
210    fn default() -> ByteClassSet {
211        ByteClassSet::empty()
212    }
213}
214
215impl ByteClassSet {
216    pub(crate) fn empty() -> Self {
219        ByteClassSet(ByteSet::empty())
220    }
221
222    pub(crate) fn set_range(&mut self, start: u8, end: u8) {
225        debug_assert!(start <= end);
226        if start > 0 {
227            self.0.add(start - 1);
228        }
229        self.0.add(end);
230    }
231
232    pub(crate) fn byte_classes(&self) -> ByteClasses {
236        let mut classes = ByteClasses::empty();
237        let mut class = 0u8;
238        let mut b = 0u8;
239        loop {
240            classes.set(b, class);
241            if b == 255 {
242                break;
243            }
244            if self.0.contains(b) {
245                class = class.checked_add(1).unwrap();
246            }
247            b = b.checked_add(1).unwrap();
248        }
249        classes
250    }
251}
252
253#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
255pub(crate) struct ByteSet {
256    bits: BitSet,
257}
258
259#[derive(Clone, Copy, Default, Eq, PartialEq)]
262struct BitSet([u128; 2]);
263
264impl ByteSet {
265    pub(crate) fn empty() -> ByteSet {
267        ByteSet { bits: BitSet([0; 2]) }
268    }
269
270    pub(crate) fn add(&mut self, byte: u8) {
274        let bucket = byte / 128;
275        let bit = byte % 128;
276        self.bits.0[usize::from(bucket)] |= 1 << bit;
277    }
278
279    pub(crate) fn contains(&self, byte: u8) -> bool {
281        let bucket = byte / 128;
282        let bit = byte % 128;
283        self.bits.0[usize::from(bucket)] & (1 << bit) > 0
284    }
285}
286
287impl core::fmt::Debug for BitSet {
288    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
289        let mut fmtd = f.debug_set();
290        for b in 0u8..=255 {
291            if (ByteSet { bits: *self }).contains(b) {
292                fmtd.entry(&b);
293            }
294        }
295        fmtd.finish()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use alloc::{vec, vec::Vec};
302
303    use super::*;
304
305    #[test]
306    fn byte_classes() {
307        let mut set = ByteClassSet::empty();
308        set.set_range(b'a', b'z');
309
310        let classes = set.byte_classes();
311        assert_eq!(classes.get(0), 0);
312        assert_eq!(classes.get(1), 0);
313        assert_eq!(classes.get(2), 0);
314        assert_eq!(classes.get(b'a' - 1), 0);
315        assert_eq!(classes.get(b'a'), 1);
316        assert_eq!(classes.get(b'm'), 1);
317        assert_eq!(classes.get(b'z'), 1);
318        assert_eq!(classes.get(b'z' + 1), 2);
319        assert_eq!(classes.get(254), 2);
320        assert_eq!(classes.get(255), 2);
321
322        let mut set = ByteClassSet::empty();
323        set.set_range(0, 2);
324        set.set_range(4, 6);
325        let classes = set.byte_classes();
326        assert_eq!(classes.get(0), 0);
327        assert_eq!(classes.get(1), 0);
328        assert_eq!(classes.get(2), 0);
329        assert_eq!(classes.get(3), 1);
330        assert_eq!(classes.get(4), 2);
331        assert_eq!(classes.get(5), 2);
332        assert_eq!(classes.get(6), 2);
333        assert_eq!(classes.get(7), 3);
334        assert_eq!(classes.get(255), 3);
335    }
336
337    #[test]
338    fn full_byte_classes() {
339        let mut set = ByteClassSet::empty();
340        for b in 0u8..=255 {
341            set.set_range(b, b);
342        }
343        assert_eq!(set.byte_classes().alphabet_len(), 256);
344    }
345
346    #[test]
347    fn elements_typical() {
348        let mut set = ByteClassSet::empty();
349        set.set_range(b'b', b'd');
350        set.set_range(b'g', b'm');
351        set.set_range(b'z', b'z');
352        let classes = set.byte_classes();
353        assert_eq!(classes.alphabet_len(), 7);
361
362        let elements = classes.elements(0).collect::<Vec<_>>();
363        assert_eq!(elements.len(), 98);
364        assert_eq!(elements[0], b'\x00');
365        assert_eq!(elements[97], b'a');
366
367        let elements = classes.elements(1).collect::<Vec<_>>();
368        assert_eq!(elements, vec![b'b', b'c', b'd'],);
369
370        let elements = classes.elements(2).collect::<Vec<_>>();
371        assert_eq!(elements, vec![b'e', b'f'],);
372
373        let elements = classes.elements(3).collect::<Vec<_>>();
374        assert_eq!(elements, vec![b'g', b'h', b'i', b'j', b'k', b'l', b'm',],);
375
376        let elements = classes.elements(4).collect::<Vec<_>>();
377        assert_eq!(elements.len(), 12);
378        assert_eq!(elements[0], b'n');
379        assert_eq!(elements[11], b'y');
380
381        let elements = classes.elements(5).collect::<Vec<_>>();
382        assert_eq!(elements, vec![b'z']);
383
384        let elements = classes.elements(6).collect::<Vec<_>>();
385        assert_eq!(elements.len(), 133);
386        assert_eq!(elements[0], b'\x7B');
387        assert_eq!(elements[132], b'\xFF');
388    }
389
390    #[test]
391    fn elements_singletons() {
392        let classes = ByteClasses::singletons();
393        assert_eq!(classes.alphabet_len(), 256);
394
395        let elements = classes.elements(b'a').collect::<Vec<_>>();
396        assert_eq!(elements, vec![b'a']);
397    }
398
399    #[test]
400    fn elements_empty() {
401        let classes = ByteClasses::empty();
402        assert_eq!(classes.alphabet_len(), 1);
403
404        let elements = classes.elements(0).collect::<Vec<_>>();
405        assert_eq!(elements.len(), 256);
406        assert_eq!(elements[0], b'\x00');
407        assert_eq!(elements[255], b'\xFF');
408    }
409}