tungstenite/
utf8.rs

1use std::{cmp, error::Error, fmt, str};
2
3#[derive(Debug, Copy, Clone)]
4pub(crate) enum DecodeError<'a> {
5    /// In lossy decoding insert `valid_prefix`, then `"\u{FFFD}"`,
6    /// then call `decode()` again with `remaining_input`.
7    Invalid { valid_prefix: &'a str, invalid_sequence: &'a [u8], remaining_input: &'a [u8] },
8
9    /// Call the `incomplete_suffix.try_complete` method with more input when available.
10    /// If no more input is available, this is an invalid byte sequence.
11    Incomplete { valid_prefix: &'a str, incomplete_suffix: Incomplete },
12}
13
14impl<'a> fmt::Display for DecodeError<'a> {
15    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
16        match *self {
17            DecodeError::Invalid { valid_prefix, invalid_sequence, remaining_input } => write!(
18                f,
19                "found invalid byte sequence {invalid_sequence:02x?} after \
20                 {valid_byte_count} valid bytes, followed by {unprocessed_byte_count} more \
21                 unprocessed bytes",
22                invalid_sequence = invalid_sequence,
23                valid_byte_count = valid_prefix.len(),
24                unprocessed_byte_count = remaining_input.len()
25            ),
26            DecodeError::Incomplete { valid_prefix, incomplete_suffix } => write!(
27                f,
28                "found incomplete byte sequence {incomplete_suffix:02x?} after \
29                 {valid_byte_count} bytes",
30                incomplete_suffix = incomplete_suffix,
31                valid_byte_count = valid_prefix.len()
32            ),
33        }
34    }
35}
36
37impl<'a> Error for DecodeError<'a> {}
38
39#[derive(Debug, Copy, Clone)]
40pub(crate) struct Incomplete {
41    pub(crate) buffer: [u8; 4],
42    pub(crate) buffer_len: u8,
43}
44
45#[derive(Debug, Copy, Clone)]
46pub(crate) struct Completed<'buffer, 'input> {
47    pub(crate) result: Result<&'buffer str, &'buffer [u8]>,
48    pub(crate) remaining_input: &'input [u8],
49}
50
51pub(crate) fn decode(input: &'_ [u8]) -> Result<&'_ str, DecodeError<'_>> {
52    let error = match str::from_utf8(input) {
53        Ok(valid) => return Ok(valid),
54        Err(error) => error,
55    };
56
57    // FIXME: separate function from here to guide inlining?
58    let (valid, after_valid) = input.split_at(error.valid_up_to());
59    let valid = unsafe { str::from_utf8_unchecked(valid) };
60
61    match error.error_len() {
62        Some(invalid_sequence_length) => {
63            let (invalid, rest) = after_valid.split_at(invalid_sequence_length);
64            Err(DecodeError::Invalid {
65                valid_prefix: valid,
66                invalid_sequence: invalid,
67                remaining_input: rest,
68            })
69        }
70        None => Err(DecodeError::Incomplete {
71            valid_prefix: valid,
72            incomplete_suffix: Incomplete::new(after_valid),
73        }),
74    }
75}
76
77impl Incomplete {
78    pub(crate) fn new(bytes: &[u8]) -> Self {
79        let mut buffer = [0, 0, 0, 0];
80        let len = bytes.len();
81        buffer[..len].copy_from_slice(bytes);
82        Incomplete { buffer, buffer_len: len as u8 }
83    }
84
85    /// * `None`: still incomplete, call `try_complete` again with more input.
86    ///   If no more input is available, this is invalid byte sequence.
87    /// * `Some(completed)`: We’re done with this `Incomplete`,
88    ///   with either a valid chunk on invalid byte sequence in `completed.result`.
89    ///   To keep decoding, pass `completed.remaining_input` to `decode()`.
90    pub(crate) fn try_complete<'input>(
91        &mut self,
92        input: &'input [u8],
93    ) -> Option<Completed<'_, 'input>> {
94        let (consumed, opt_result) = self.try_complete_offsets(input);
95        let result = opt_result?;
96        let remaining_input = &input[consumed..];
97        let result_bytes = self.take_buffer();
98        let result = match result {
99            Ok(()) => Ok(unsafe { str::from_utf8_unchecked(result_bytes) }),
100            Err(()) => Err(result_bytes),
101        };
102        Some(Completed { result, remaining_input })
103    }
104
105    fn take_buffer(&mut self) -> &[u8] {
106        let len = self.buffer_len as usize;
107        self.buffer_len = 0;
108        &self.buffer[..len]
109    }
110
111    /// (consumed_from_input, None): not enough input
112    /// (consumed_from_input, Some(Err(()))): error bytes in buffer
113    /// (consumed_from_input, Some(Ok(()))): UTF-8 string in buffer
114    fn try_complete_offsets(&mut self, input: &[u8]) -> (usize, Option<Result<(), ()>>) {
115        let initial_buffer_len = self.buffer_len as usize;
116        let copied_from_input;
117        {
118            let unwritten = &mut self.buffer[initial_buffer_len..];
119            copied_from_input = cmp::min(unwritten.len(), input.len());
120            unwritten[..copied_from_input].copy_from_slice(&input[..copied_from_input]);
121        }
122        let spliced = &self.buffer[..initial_buffer_len + copied_from_input];
123        match str::from_utf8(spliced) {
124            Ok(_) => {
125                self.buffer_len = spliced.len() as u8;
126                (copied_from_input, Some(Ok(())))
127            }
128            Err(error) => {
129                let valid_up_to = error.valid_up_to();
130                if valid_up_to > 0 {
131                    let consumed = valid_up_to.checked_sub(initial_buffer_len).unwrap();
132                    self.buffer_len = valid_up_to as u8;
133                    (consumed, Some(Ok(())))
134                } else {
135                    match error.error_len() {
136                        Some(invalid_sequence_length) => {
137                            let consumed =
138                                invalid_sequence_length.checked_sub(initial_buffer_len).unwrap();
139                            self.buffer_len = invalid_sequence_length as u8;
140                            (consumed, Some(Err(())))
141                        }
142                        None => {
143                            self.buffer_len = spliced.len() as u8;
144                            (copied_from_input, None)
145                        }
146                    }
147                }
148            }
149        }
150    }
151}