Skip to main content

tendril/
futf.rs

1// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
4// option. This file may not be copied, modified, or distributed
5// except according to those terms.
6
7use debug_unreachable::debug_unreachable;
8use std::{char, slice};
9
10/// Meaning of a complete or partial UTF-8 codepoint.
11///
12/// Not all checking is performed eagerly. That is, a codepoint `Prefix` or
13/// `Suffix` may in reality have no valid completion.
14#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
15pub enum Meaning {
16    /// We found a whole codepoint.
17    Whole(char),
18
19    /// We found something that isn't a valid Unicode codepoint, but
20    /// it *would* correspond to a UTF-16 leading surrogate code unit,
21    /// i.e. a value in the range `U+D800` - `U+DBFF`.
22    ///
23    /// The argument is the code unit's 10-bit index within that range.
24    ///
25    /// These are found in UTF-8 variants such as CESU-8 and WTF-8.
26    LeadSurrogate(u16),
27
28    /// We found something that isn't a valid Unicode codepoint, but
29    /// it *would* correspond to a UTF-16 trailing surrogate code unit,
30    /// i.e. a value in the range `U+DC00` - `U+DFFF`.
31    ///
32    /// The argument is the code unit's 10-bit index within that range.
33    ///
34    /// These are found in UTF-8 variants such as CESU-8 and WTF-8.
35    TrailSurrogate(u16),
36
37    /// We found only a prefix of a codepoint before the buffer ended.
38    ///
39    /// Includes the number of additional bytes needed.
40    Prefix(usize),
41
42    /// We found only a suffix of a codepoint before running off the
43    /// start of the buffer.
44    ///
45    /// Up to 3 more bytes may be needed.
46    Suffix,
47}
48
49/// Represents a complete or partial UTF-8 codepoint.
50#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
51pub struct Codepoint<'a> {
52    /// The bytes that make up the partial or full codepoint.
53    ///
54    /// For a `Suffix` this depends on `idx`. We don't scan forward
55    /// for additional continuation bytes after the reverse scan
56    /// failed to locate a multibyte sequence start.
57    pub bytes: &'a [u8],
58
59    /// Start of the codepoint in the buffer, expressed as an offset
60    /// back from `idx`.
61    pub rewind: usize,
62
63    /// Meaning of the partial or full codepoint.
64    pub meaning: Meaning,
65}
66
67#[derive(Debug, PartialEq, Eq)]
68enum Byte {
69    Ascii,
70    Start(usize),
71    Cont,
72}
73
74impl Byte {
75    #[inline(always)]
76    fn classify(x: u8) -> Option<Byte> {
77        match x & 0xC0 {
78            0xC0 => match x {
79                x if x & 0b11111_000 == 0b11110_000 => Some(Byte::Start(4)),
80                x if x & 0b1111_0000 == 0b1110_0000 => Some(Byte::Start(3)),
81                x if x & 0b111_00000 == 0b110_00000 => Some(Byte::Start(2)),
82                _ => None,
83            },
84            0x80 => Some(Byte::Cont),
85            _ => Some(Byte::Ascii),
86        }
87    }
88}
89
90#[inline(always)]
91fn all_cont(buf: &[u8]) -> bool {
92    buf.iter()
93        .all(|&b| matches!(Byte::classify(b), Some(Byte::Cont)))
94}
95
96// NOTE: Assumes the buffer is a syntactically valid multi-byte UTF-8 sequence:
97// a starting byte followed by the correct number of continuation bytes.
98#[inline(always)]
99unsafe fn decode(buf: &[u8]) -> Option<Meaning> {
100    debug_assert!(buf.len() >= 2);
101    debug_assert!(buf.len() <= 4);
102    let n;
103    match buf.len() {
104        2 => {
105            n = ((*buf.get_unchecked(0) & 0b11111) as u32) << 6
106                | ((*buf.get_unchecked(1) & 0x3F) as u32);
107            if n < 0x80 {
108                return None;
109            } // Overlong
110        },
111        3 => {
112            n = ((*buf.get_unchecked(0) & 0b1111) as u32) << 12
113                | ((*buf.get_unchecked(1) & 0x3F) as u32) << 6
114                | ((*buf.get_unchecked(2) & 0x3F) as u32);
115            match n {
116                0x0000..=0x07FF => return None, // Overlong
117                0xD800..=0xDBFF => return Some(Meaning::LeadSurrogate(n as u16 - 0xD800)),
118                0xDC00..=0xDFFF => return Some(Meaning::TrailSurrogate(n as u16 - 0xDC00)),
119                _ => {},
120            }
121        },
122        4 => {
123            n = ((*buf.get_unchecked(0) & 0b111) as u32) << 18
124                | ((*buf.get_unchecked(1) & 0x3F) as u32) << 12
125                | ((*buf.get_unchecked(2) & 0x3F) as u32) << 6
126                | ((*buf.get_unchecked(3) & 0x3F) as u32);
127            if n < 0x1_0000 {
128                return None;
129            } // Overlong
130        },
131        _ => debug_unreachable!(),
132    }
133
134    char::from_u32(n).map(Meaning::Whole)
135}
136
137#[inline(always)]
138unsafe fn unsafe_slice(buf: &[u8], start: usize, new_len: usize) -> &[u8] {
139    debug_assert!(start <= buf.len());
140    debug_assert!(new_len <= (buf.len() - start));
141    slice::from_raw_parts(buf.as_ptr().add(start), new_len)
142}
143
144/// Describes the UTF-8 codepoint containing the byte at index `idx` within
145/// `buf`.
146///
147/// Returns `None` if `idx` is out of range, or if `buf` contains invalid UTF-8
148/// in the vicinity of `idx`.
149#[inline]
150pub fn classify<'a>(buf: &'a [u8], idx: usize) -> Option<Codepoint<'a>> {
151    if idx >= buf.len() {
152        return None;
153    }
154
155    unsafe {
156        let x = *buf.get_unchecked(idx);
157        match Byte::classify(x)? {
158            Byte::Ascii => Some(Codepoint {
159                bytes: unsafe_slice(buf, idx, 1),
160                rewind: 0,
161                meaning: Meaning::Whole(x as char),
162            }),
163            Byte::Start(n) => {
164                let avail = buf.len() - idx;
165                if avail >= n {
166                    let bytes = unsafe_slice(buf, idx, n);
167                    if !all_cont(unsafe_slice(bytes, 1, n - 1)) {
168                        return None;
169                    }
170                    let meaning = decode(bytes)?;
171                    Some(Codepoint {
172                        bytes,
173                        rewind: 0,
174                        meaning,
175                    })
176                } else {
177                    Some(Codepoint {
178                        bytes: unsafe_slice(buf, idx, avail),
179                        rewind: 0,
180                        meaning: Meaning::Prefix(n - avail),
181                    })
182                }
183            },
184            Byte::Cont => {
185                let mut start = idx;
186                let mut checked = 0;
187                loop {
188                    if start == 0 {
189                        // Whoops, fell off the beginning.
190                        return Some(Codepoint {
191                            bytes: unsafe_slice(buf, 0, idx + 1),
192                            rewind: idx,
193                            meaning: Meaning::Suffix,
194                        });
195                    }
196
197                    start -= 1;
198                    checked += 1;
199                    match Byte::classify(*buf.get_unchecked(start))? {
200                        Byte::Cont => (),
201                        Byte::Start(n) => {
202                            let avail = buf.len() - start;
203                            if avail >= n {
204                                let bytes = unsafe_slice(buf, start, n);
205                                if checked < n {
206                                    if !all_cont(unsafe_slice(bytes, checked, n - checked)) {
207                                        return None;
208                                    }
209                                }
210                                let meaning = decode(bytes)?;
211                                return Some(Codepoint {
212                                    bytes,
213                                    rewind: idx - start,
214                                    meaning,
215                                });
216                            } else {
217                                return Some(Codepoint {
218                                    bytes: unsafe_slice(buf, start, avail),
219                                    rewind: idx - start,
220                                    meaning: Meaning::Prefix(n - avail),
221                                });
222                            }
223                        },
224                        _ => return None,
225                    }
226
227                    if idx - start >= 3 {
228                        // We looked at 3 bytes before a continuation byte
229                        // and didn't find a start byte.
230                        return None;
231                    }
232                }
233            },
234        }
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::{all_cont, classify, decode, Byte, Meaning};
241    use std::borrow::ToOwned;
242    use std::io::Write;
243
244    #[test]
245    fn classify_all_bytes() {
246        for n in 0x00..0x80 {
247            assert_eq!(Byte::classify(n), Some(Byte::Ascii));
248        }
249        for n in 0x80..0xC0 {
250            assert_eq!(Byte::classify(n), Some(Byte::Cont));
251        }
252        for n in 0xC0..0xE0 {
253            assert_eq!(Byte::classify(n), Some(Byte::Start(2)));
254        }
255        for n in 0xE0..0xF0 {
256            assert_eq!(Byte::classify(n), Some(Byte::Start(3)));
257        }
258        for n in 0xF0..0xF8 {
259            assert_eq!(Byte::classify(n), Some(Byte::Start(4)));
260        }
261        for n in 0xF8..0xFF {
262            assert_eq!(Byte::classify(n), None);
263        }
264        assert_eq!(Byte::classify(0xFF), None);
265    }
266
267    #[test]
268    fn test_all_cont() {
269        assert!(all_cont(b""));
270        assert!(all_cont(b"\x80"));
271        assert!(all_cont(b"\xBF"));
272        assert!(all_cont(b"\x80\xBF\x80\xBF"));
273
274        assert!(!all_cont(b"z"));
275        assert!(!all_cont(b"\xC0\xBF"));
276        assert!(!all_cont(b"\xFF"));
277        assert!(!all_cont(b"\x80\xBFz\x80\xBF"));
278        assert!(!all_cont(b"\x80\xBF\xC0\x80\xBF"));
279        assert!(!all_cont(b"\x80\xBF\xFF\x80\xBF"));
280        assert!(!all_cont(b"\x80\xBF\x80\xBFz"));
281        assert!(!all_cont(b"\x80\xBF\x80\xBF\xC0"));
282        assert!(!all_cont(b"z\x80\xBF\x80\xBF"));
283        assert!(!all_cont(b"\xC0\x80\xBF\x80\xBF"));
284    }
285
286    #[test]
287    fn test_decode() {
288        unsafe {
289            assert_eq!(Some(Meaning::Whole('ő')), decode(b"\xC5\x91"));
290            assert_eq!(Some(Meaning::Whole('\u{a66e}')), decode(b"\xEA\x99\xAE"));
291            assert_eq!(
292                Some(Meaning::Whole('\u{1f4a9}')),
293                decode(b"\xF0\x9F\x92\xA9")
294            );
295            assert_eq!(
296                Some(Meaning::Whole('\u{10ffff}')),
297                decode(b"\xF4\x8F\xBF\xBF")
298            );
299
300            assert_eq!(
301                Some(Meaning::LeadSurrogate(0x0000)),
302                decode(b"\xED\xA0\x80")
303            );
304            assert_eq!(
305                Some(Meaning::LeadSurrogate(0x0001)),
306                decode(b"\xED\xA0\x81")
307            );
308            assert_eq!(
309                Some(Meaning::LeadSurrogate(0x03FE)),
310                decode(b"\xED\xAF\xBE")
311            );
312            assert_eq!(
313                Some(Meaning::LeadSurrogate(0x03FF)),
314                decode(b"\xED\xAF\xBF")
315            );
316
317            assert_eq!(
318                Some(Meaning::TrailSurrogate(0x0000)),
319                decode(b"\xED\xB0\x80")
320            );
321            assert_eq!(
322                Some(Meaning::TrailSurrogate(0x0001)),
323                decode(b"\xED\xB0\x81")
324            );
325            assert_eq!(
326                Some(Meaning::TrailSurrogate(0x03FE)),
327                decode(b"\xED\xBF\xBE")
328            );
329            assert_eq!(
330                Some(Meaning::TrailSurrogate(0x03FF)),
331                decode(b"\xED\xBF\xBF")
332            );
333
334            // The last 4-byte UTF-8 sequence. This would be U+1FFFFF, which is out of
335            // range.
336            assert_eq!(None, decode(b"\xF7\xBF\xBF\xBF"));
337
338            // First otherwise-valid sequence (would be U+110000) that is out of range
339            assert_eq!(None, decode(b"\xF4\x90\x80\x80"));
340
341            // Overlong sequences
342            assert_eq!(None, decode(b"\xC0\x80"));
343            assert_eq!(None, decode(b"\xC1\xBF"));
344            assert_eq!(None, decode(b"\xE0\x80\x80"));
345            assert_eq!(None, decode(b"\xE0\x9F\xBF"));
346            assert_eq!(None, decode(b"\xF0\x80\x80\x80"));
347            assert_eq!(None, decode(b"\xF0\x8F\xBF\xBF"));
348
349            // For not-overlong sequence for each sequence length
350            assert_eq!(Some(Meaning::Whole('\u{80}')), decode(b"\xC2\x80"));
351            assert_eq!(Some(Meaning::Whole('\u{800}')), decode(b"\xE0\xA0\x80"));
352            assert_eq!(
353                Some(Meaning::Whole('\u{10000}')),
354                decode(b"\xF0\x90\x80\x80")
355            );
356        }
357    }
358
359    static JUNK: &[u8] = b"\
360        \xf8\x0d\x07\x25\xa6\x7b\x95\xeb\x47\x01\x7f\xee\
361        \x3b\x00\x60\x57\x1d\x9e\x5d\x0a\x0b\x0a\x7c\x75\
362        \x13\xa1\x82\x46\x27\x34\xe9\x52\x61\x0d\xec\x10\
363        \x54\x49\x6e\x54\xdf\x7b\xe1\x31\x8c\x06\x21\x83\
364        \x0f\xb5\x1f\x4c\x6a\x71\x52\x42\x74\xe7\x7b\x50\
365        \x59\x1f\x6a\xd4\xff\x06\x92\x33\xc4\x34\x97\xff\
366        \xcc\xb5\xc4\x00\x7b\xc3\x4a\x7f\x7e\x63\x96\x58\
367        \x51\x63\x21\x54\x53\x2f\x03\x8a\x7d\x41\x79\x98\
368        \x5b\xcb\xb8\x94\x6b\x73\xf3\x0c\x5a\xd7\xc4\x12\
369        \x7a\x2b\x9a\x2e\x67\x62\x2a\x00\x45\x2c\xfe\x7d\
370        \x8d\xd6\x51\x4e\x59\x36\x72\x1b\xae\xaa\x06\xe8\
371        \x71\x1b\x85\xd3\x35\xb5\xbe\x9e\x16\x96\x72\xd8\
372        \x1a\x48\xba\x4d\x55\x4f\x1b\xa2\x77\xfa\x8f\x71\
373        \x58\x7d\x03\x93\xa2\x3a\x76\x51\xda\x48\xe2\x3f\
374        \xeb\x8d\xda\x89\xae\xf7\xbd\x3d\xb6\x37\x97\xca\
375        \x99\xcc\x4a\x8d\x62\x89\x97\xe3\xc0\xd1\x8d\xc1\
376        \x26\x11\xbb\x8d\x53\x61\x4f\x76\x03\x00\x30\xd3\
377        \x5f\x86\x19\x52\x9c\x3e\x99\x8c\xb7\x21\x48\x1c\
378        \x85\xae\xad\xd5\x74\x00\x6c\x3e\xd0\x17\xff\x76\
379        \x5c\x32\xc3\xfb\x24\x99\xd4\x4c\xa4\x1f\x66\x46\
380        \xe7\x2d\x44\x56\x7d\x14\xd9\x76\x91\x37\x2f\xb7\
381        \xcc\x1b\xd3\xc2";
382
383    #[test]
384    fn classify_whole() {
385        assert_eq!(JUNK.len(), 256);
386
387        for &c in &[
388            '\0',
389            '\x01',
390            'o',
391            'z',
392            'ő',
393            '\u{2764}',
394            '\u{a66e}',
395            '\u{1f4a9}',
396            '\u{1f685}',
397        ] {
398            for idx in 0..JUNK.len() - 3 {
399                let mut buf = JUNK.to_owned();
400                let ch = format!("{}", c).into_bytes();
401                (&mut buf[idx..]).write_all(&ch).unwrap();
402
403                for j in 0..ch.len() {
404                    let class = classify(&buf, idx + j).unwrap();
405                    assert_eq!(class.bytes, &*ch);
406                    assert_eq!(class.rewind, j);
407                    assert_eq!(class.meaning, Meaning::Whole(c));
408                }
409            }
410        }
411    }
412
413    #[test]
414    fn classify_surrogates() {
415        for &(s, b) in &[
416            (Meaning::LeadSurrogate(0x0000), b"\xED\xA0\x80"),
417            (Meaning::LeadSurrogate(0x0001), b"\xED\xA0\x81"),
418            (Meaning::LeadSurrogate(0x03FE), b"\xED\xAF\xBE"),
419            (Meaning::LeadSurrogate(0x03FF), b"\xED\xAF\xBF"),
420            (Meaning::TrailSurrogate(0x0000), b"\xED\xB0\x80"),
421            (Meaning::TrailSurrogate(0x0001), b"\xED\xB0\x81"),
422            (Meaning::TrailSurrogate(0x03FE), b"\xED\xBF\xBE"),
423            (Meaning::TrailSurrogate(0x03FF), b"\xED\xBF\xBF"),
424        ] {
425            for idx in 0..JUNK.len() - 2 {
426                let mut buf = JUNK.to_owned();
427                (&mut buf[idx..]).write_all(b).unwrap();
428
429                let class = classify(&buf, idx).unwrap();
430                assert_eq!(class.bytes, b);
431                assert_eq!(class.rewind, 0);
432                assert_eq!(class.meaning, s);
433            }
434        }
435    }
436
437    #[test]
438    fn classify_prefix_suffix() {
439        for &c in &['ő', '\u{a66e}', '\u{1f4a9}'] {
440            let ch = format!("{}", c).into_bytes();
441            for pfx in 1..ch.len() - 1 {
442                let mut buf = JUNK.to_owned();
443                let buflen = buf.len();
444                (&mut buf[buflen - pfx..buflen])
445                    .write_all(&ch[..pfx])
446                    .unwrap();
447                for j in 0..pfx {
448                    let idx = buflen - 1 - j;
449                    let class = classify(&buf, idx).unwrap();
450                    assert_eq!(class.bytes, &ch[..pfx]);
451                    assert_eq!(class.rewind, pfx - 1 - j);
452                    assert_eq!(class.meaning, Meaning::Prefix(ch.len() - pfx));
453                }
454            }
455            for sfx in 1..ch.len() - 1 {
456                let ch_bytes = &ch[ch.len() - sfx..];
457                let mut buf = JUNK.to_owned();
458                (&mut *buf).write_all(ch_bytes).unwrap();
459                for j in 0..sfx {
460                    let class = classify(&buf, j).unwrap();
461                    assert!(ch_bytes.starts_with(class.bytes));
462                    assert_eq!(class.rewind, j);
463                    assert_eq!(class.meaning, Meaning::Suffix);
464                }
465            }
466        }
467    }
468
469    #[test]
470    fn out_of_bounds() {
471        assert!(classify(b"", 0).is_none());
472        assert!(classify(b"", 7).is_none());
473        assert!(classify(b"aaaaaaa", 7).is_none());
474    }
475
476    #[test]
477    fn malformed() {
478        assert_eq!(None, classify(b"\xFF", 0));
479        assert_eq!(None, classify(b"\xC5\xC5", 0));
480        assert_eq!(None, classify(b"x\x91", 1));
481        assert_eq!(None, classify(b"\x91\x91\x91\x91", 3));
482        assert_eq!(None, classify(b"\x91\x91\x91\x91\x91", 4));
483        assert_eq!(None, classify(b"\xEA\x91\xFF", 1));
484        assert_eq!(None, classify(b"\xF0\x90\x90\xF0", 0));
485        assert_eq!(None, classify(b"\xF0\x90\x90\xF0", 1));
486        assert_eq!(None, classify(b"\xF0\x90\x90\xF0", 2));
487
488        for i in 0..4 {
489            // out of range: U+110000
490            assert_eq!(None, classify(b"\xF4\x90\x80\x80", i));
491
492            // out of range: U+1FFFFF
493            assert_eq!(None, classify(b"\xF7\xBF\xBF\xBF", i));
494
495            // Overlong sequences
496            assert_eq!(None, classify(b"\xC0\x80", i));
497            assert_eq!(None, classify(b"\xC1\xBF", i));
498            assert_eq!(None, classify(b"\xE0\x80\x80", i));
499            assert_eq!(None, classify(b"\xE0\x9F\xBF", i));
500            assert_eq!(None, classify(b"\xF0\x80\x80\x80", i));
501            assert_eq!(None, classify(b"\xF0\x8F\xBF\xBF", i));
502        }
503    }
504}