Skip to main content

quick_xml/
utils.rs

1use std::borrow::{Borrow, Cow};
2use std::fmt::{self, Debug, Formatter};
3use std::io;
4use std::iter::FusedIterator;
5use std::ops::Deref;
6
7#[cfg(feature = "async-tokio")]
8use std::{
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13#[cfg(feature = "serialize")]
14use serde::de::{Deserialize, Deserializer, Error, Visitor};
15#[cfg(feature = "serialize")]
16use serde::ser::{Serialize, Serializer};
17
18#[allow(clippy::ptr_arg)]
19pub fn write_cow_string(f: &mut Formatter, cow_string: &Cow<[u8]>) -> fmt::Result {
20    match cow_string {
21        Cow::Owned(s) => {
22            write!(f, "Owned(")?;
23            write_byte_string(f, s)?;
24        }
25        Cow::Borrowed(s) => {
26            write!(f, "Borrowed(")?;
27            write_byte_string(f, s)?;
28        }
29    }
30    write!(f, ")")
31}
32
33pub fn write_byte_string(f: &mut Formatter, byte_string: &[u8]) -> fmt::Result {
34    write!(f, "\"")?;
35    for b in byte_string {
36        match *b {
37            32..=33 | 35..=126 => write!(f, "{}", *b as char)?,
38            34 => write!(f, "\\\"")?,
39            _ => write!(f, "{:#02X}", b)?,
40        }
41    }
42    write!(f, "\"")?;
43    Ok(())
44}
45
46////////////////////////////////////////////////////////////////////////////////////////////////////
47
48/// A version of [`Cow`] that can borrow from two different buffers, one of them
49/// is a deserializer input.
50///
51/// # Lifetimes
52///
53/// - `'i`: lifetime of the data that deserializer borrow from the parsed input
54/// - `'s`: lifetime of the data that owned by a deserializer
55pub enum CowRef<'i, 's, B>
56where
57    B: ToOwned + ?Sized,
58{
59    /// An input borrowed from the parsed data
60    Input(&'i B),
61    /// An input borrowed from the buffer owned by another deserializer
62    Slice(&'s B),
63    /// An input taken from an external deserializer, owned by that deserializer
64    Owned(<B as ToOwned>::Owned),
65}
66impl<'i, 's, B> Deref for CowRef<'i, 's, B>
67where
68    B: ToOwned + ?Sized,
69    B::Owned: Borrow<B>,
70{
71    type Target = B;
72
73    fn deref(&self) -> &B {
74        match *self {
75            Self::Input(borrowed) => borrowed,
76            Self::Slice(borrowed) => borrowed,
77            Self::Owned(ref owned) => owned.borrow(),
78        }
79    }
80}
81
82impl<'i, 's, B> Debug for CowRef<'i, 's, B>
83where
84    B: ToOwned + ?Sized + Debug,
85    B::Owned: Debug,
86{
87    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
88        match *self {
89            Self::Input(borrowed) => Debug::fmt(borrowed, f),
90            Self::Slice(borrowed) => Debug::fmt(borrowed, f),
91            Self::Owned(ref owned) => Debug::fmt(owned, f),
92        }
93    }
94}
95
96impl<'i, 's> CowRef<'i, 's, str> {
97    /// Supply to the visitor a borrowed string, a string slice, or an owned
98    /// string depending on the kind of input. Unlike [`Self::deserialize_all`],
99    /// only part of [`Self::Owned`] string will be passed to the visitor.
100    ///
101    /// Calls
102    /// - `visitor.visit_borrowed_str` if data borrowed from the input
103    /// - `visitor.visit_str` if data borrowed from another source
104    /// - `visitor.visit_string` if data owned by this type
105    #[cfg(feature = "serialize")]
106    pub fn deserialize_str<V, E>(self, visitor: V) -> Result<V::Value, E>
107    where
108        V: Visitor<'i>,
109        E: Error,
110    {
111        match self {
112            Self::Input(s) => visitor.visit_borrowed_str(s),
113            Self::Slice(s) => visitor.visit_str(s),
114            Self::Owned(s) => visitor.visit_string(s),
115        }
116    }
117
118    /// Calls [`Visitor::visit_bool`] with `true` or `false` if text contains
119    /// [valid] boolean representation, otherwise calls [`Self::deserialize_str`].
120    ///
121    /// The valid boolean representations are only `"true"`, `"false"`, `"1"`, and `"0"`.
122    ///
123    /// [valid]: https://www.w3.org/TR/xmlschema11-2/#boolean
124    #[cfg(feature = "serialize")]
125    pub fn deserialize_bool<V, E>(self, visitor: V) -> Result<V::Value, E>
126    where
127        V: Visitor<'i>,
128        E: Error,
129    {
130        match self.as_ref() {
131            "1" | "true" => visitor.visit_bool(true),
132            "0" | "false" => visitor.visit_bool(false),
133            _ => self.deserialize_str(visitor),
134        }
135    }
136}
137
138////////////////////////////////////////////////////////////////////////////////////////////////////
139
140/// Wrapper around `Vec<u8>` that has a human-readable debug representation:
141/// printable ASCII symbols output as is, all other output in HEX notation.
142///
143/// Also, when [`serialize`] feature is on, this type deserialized using
144/// [`deserialize_byte_buf`](serde::Deserializer::deserialize_byte_buf) instead
145/// of vector's generic [`deserialize_seq`](serde::Deserializer::deserialize_seq)
146///
147/// [`serialize`]: ../index.html#serialize
148#[derive(PartialEq, Eq)]
149pub struct ByteBuf(pub Vec<u8>);
150
151impl Debug for ByteBuf {
152    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
153        write_byte_string(f, &self.0)
154    }
155}
156
157#[cfg(feature = "serialize")]
158impl<'de> Deserialize<'de> for ByteBuf {
159    fn deserialize<D>(d: D) -> Result<Self, D::Error>
160    where
161        D: Deserializer<'de>,
162    {
163        struct ValueVisitor;
164
165        impl<'de> Visitor<'de> for ValueVisitor {
166            type Value = ByteBuf;
167
168            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
169                f.write_str("byte data")
170            }
171
172            fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
173                Ok(ByteBuf(v.to_vec()))
174            }
175
176            fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
177                Ok(ByteBuf(v))
178            }
179        }
180
181        d.deserialize_byte_buf(ValueVisitor)
182    }
183}
184
185#[cfg(feature = "serialize")]
186impl Serialize for ByteBuf {
187    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
188    where
189        S: Serializer,
190    {
191        serializer.serialize_bytes(&self.0)
192    }
193}
194
195////////////////////////////////////////////////////////////////////////////////////////////////////
196
197/// Wrapper around `&[u8]` that has a human-readable debug representation:
198/// printable ASCII symbols output as is, all other output in HEX notation.
199///
200/// Also, when [`serialize`] feature is on, this type deserialized using
201/// [`deserialize_bytes`](serde::Deserializer::deserialize_bytes) instead
202/// of vector's generic [`deserialize_seq`](serde::Deserializer::deserialize_seq)
203///
204/// [`serialize`]: ../index.html#serialize
205#[derive(PartialEq, Eq)]
206pub struct Bytes<'de>(pub &'de [u8]);
207
208impl<'de> Debug for Bytes<'de> {
209    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
210        write_byte_string(f, self.0)
211    }
212}
213
214#[cfg(feature = "serialize")]
215impl<'de> Deserialize<'de> for Bytes<'de> {
216    fn deserialize<D>(d: D) -> Result<Self, D::Error>
217    where
218        D: Deserializer<'de>,
219    {
220        struct ValueVisitor;
221
222        impl<'de> Visitor<'de> for ValueVisitor {
223            type Value = Bytes<'de>;
224
225            fn expecting(&self, f: &mut Formatter) -> fmt::Result {
226                f.write_str("borrowed bytes")
227            }
228
229            fn visit_borrowed_bytes<E: Error>(self, v: &'de [u8]) -> Result<Self::Value, E> {
230                Ok(Bytes(v))
231            }
232        }
233
234        d.deserialize_bytes(ValueVisitor)
235    }
236}
237
238#[cfg(feature = "serialize")]
239impl<'de> Serialize for Bytes<'de> {
240    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
241    where
242        S: Serializer,
243    {
244        serializer.serialize_bytes(self.0)
245    }
246}
247
248////////////////////////////////////////////////////////////////////////////////////////////////////
249
250/// A simple producer of infinite stream of bytes, useful in tests.
251///
252/// Will repeat `chunk` field indefinitely.
253pub struct Fountain<'a> {
254    /// That piece of data repeated infinitely...
255    pub chunk: &'a [u8],
256    /// Part of `chunk` that was consumed by BufRead impl
257    pub consumed: usize,
258    /// The overall count of read bytes
259    pub overall_read: u64,
260}
261
262impl<'a> io::Read for Fountain<'a> {
263    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
264        let available = &self.chunk[self.consumed..];
265        let len = buf.len().min(available.len());
266        let (portion, _) = available.split_at(len);
267
268        buf.copy_from_slice(portion);
269        Ok(len)
270    }
271}
272
273impl<'a> io::BufRead for Fountain<'a> {
274    #[inline]
275    fn fill_buf(&mut self) -> io::Result<&[u8]> {
276        Ok(&self.chunk[self.consumed..])
277    }
278
279    fn consume(&mut self, amt: usize) {
280        self.consumed += amt;
281        if self.consumed == self.chunk.len() {
282            self.consumed = 0;
283        }
284        self.overall_read += amt as u64;
285    }
286}
287
288#[cfg(feature = "async-tokio")]
289impl<'a> tokio::io::AsyncRead for Fountain<'a> {
290    fn poll_read(
291        self: Pin<&mut Self>,
292        _cx: &mut Context<'_>,
293        buf: &mut tokio::io::ReadBuf<'_>,
294    ) -> Poll<io::Result<()>> {
295        let available = &self.chunk[self.consumed..];
296        let len = buf.remaining().min(available.len());
297        let (portion, _) = available.split_at(len);
298
299        buf.put_slice(portion);
300        Poll::Ready(Ok(()))
301    }
302}
303
304#[cfg(feature = "async-tokio")]
305impl<'a> tokio::io::AsyncBufRead for Fountain<'a> {
306    #[inline]
307    fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
308        Poll::Ready(io::BufRead::fill_buf(self.get_mut()))
309    }
310
311    #[inline]
312    fn consume(self: Pin<&mut Self>, amt: usize) {
313        io::BufRead::consume(self.get_mut(), amt);
314    }
315}
316
317////////////////////////////////////////////////////////////////////////////////////////////////////
318
319/// A function to check whether the byte is a whitespace (blank, new line, carriage return or tab).
320#[inline]
321pub const fn is_whitespace(b: u8) -> bool {
322    matches!(b, b' ' | b'\r' | b'\n' | b'\t')
323}
324
325/// Calculates name from an element-like content. Name is the first word in `content`,
326/// where word boundaries is XML whitespace characters.
327///
328/// 'Whitespace' refers to the definition used by [`is_whitespace`].
329#[inline]
330pub const fn name_len(mut bytes: &[u8]) -> usize {
331    // Note: A pattern matching based approach (instead of indexing) allows
332    // making the function const.
333    let mut len = 0;
334    while let [first, rest @ ..] = bytes {
335        if is_whitespace(*first) {
336            break;
337        }
338        len += 1;
339        bytes = rest;
340    }
341    len
342}
343
344/// Returns a byte slice with leading XML whitespace bytes removed.
345///
346/// 'Whitespace' refers to the definition used by [`is_whitespace`].
347#[inline]
348pub const fn trim_xml_start(mut bytes: &[u8]) -> &[u8] {
349    // Note: A pattern matching based approach (instead of indexing) allows
350    // making the function const.
351    while let [first, rest @ ..] = bytes {
352        if is_whitespace(*first) {
353            bytes = rest;
354        } else {
355            break;
356        }
357    }
358    bytes
359}
360
361/// Returns a byte slice with trailing XML whitespace bytes removed.
362///
363/// 'Whitespace' refers to the definition used by [`is_whitespace`].
364#[inline]
365pub const fn trim_xml_end(mut bytes: &[u8]) -> &[u8] {
366    // Note: A pattern matching based approach (instead of indexing) allows
367    // making the function const.
368    while let [rest @ .., last] = bytes {
369        if is_whitespace(*last) {
370            bytes = rest;
371        } else {
372            break;
373        }
374    }
375    bytes
376}
377
378/// Returns a string slice with XML whitespace characters removed from both sides.
379///
380/// 'Whitespace' refers to the definition used by [`is_whitespace`].
381#[inline]
382pub fn trim_xml_spaces(text: &str) -> &str {
383    let bytes = trim_xml_end(trim_xml_start(text.as_bytes()));
384    match core::str::from_utf8(bytes) {
385        Ok(s) => s,
386        // SAFETY: Removing XML space characters (subset of ASCII) from a `&str` does not invalidate UTF-8.
387        _ => unreachable!(),
388    }
389}
390
391////////////////////////////////////////////////////////////////////////////////////////////////////
392
393/// Splits string into pieces which can be part of a single `CDATA` section.
394///
395/// Because CDATA cannot contain the `]]>` sequence, split the string between
396/// `]]` and `>`.
397#[derive(Debug, Clone)]
398pub(crate) struct CDataIterator<'a> {
399    /// The unprocessed data which should be emitted as `BytesCData` events.
400    /// At each iteration, the processed data is cut from this slice.
401    unprocessed: &'a str,
402    finished: bool,
403}
404
405impl<'a> CDataIterator<'a> {
406    pub fn new(value: &'a str) -> Self {
407        Self {
408            unprocessed: value,
409            finished: false,
410        }
411    }
412}
413
414impl<'a> Iterator for CDataIterator<'a> {
415    type Item = &'a str;
416
417    fn next(&mut self) -> Option<&'a str> {
418        if self.finished {
419            return None;
420        }
421
422        for gt in memchr::memchr_iter(b'>', self.unprocessed.as_bytes()) {
423            let (slice, rest) = self.unprocessed.split_at(gt);
424            if slice.ends_with("]]") {
425                self.unprocessed = rest;
426                return Some(slice);
427            }
428        }
429
430        self.finished = true;
431        Some(self.unprocessed)
432    }
433}
434
435impl FusedIterator for CDataIterator<'_> {}
436
437////////////////////////////////////////////////////////////////////////////////////////////////////
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use pretty_assertions::assert_eq;
443
444    #[test]
445    fn write_byte_string0() {
446        let bytes = ByteBuf(vec![10, 32, 32, 32, 32, 32, 32, 32, 32]);
447        assert_eq!(format!("{:?}", bytes), "\"0xA        \"");
448    }
449
450    #[test]
451    fn write_byte_string1() {
452        let bytes = ByteBuf(vec![
453            104, 116, 116, 112, 58, 47, 47, 119, 119, 119, 46, 119, 51, 46, 111, 114, 103, 47, 50,
454            48, 48, 50, 47, 48, 55, 47, 111, 119, 108, 35,
455        ]);
456        assert_eq!(
457            format!("{:?}", bytes),
458            r##""http://www.w3.org/2002/07/owl#""##
459        );
460    }
461
462    #[test]
463    fn write_byte_string3() {
464        let bytes = ByteBuf(vec![
465            67, 108, 97, 115, 115, 32, 73, 82, 73, 61, 34, 35, 66, 34,
466        ]);
467        assert_eq!(format!("{:?}", bytes), r##""Class IRI=\"#B\"""##);
468    }
469
470    #[test]
471    fn name_len() {
472        assert_eq!(super::name_len(b""), 0);
473        assert_eq!(super::name_len(b" abc"), 0);
474        assert_eq!(super::name_len(b" \t\r\n"), 0);
475
476        assert_eq!(super::name_len(b"abc"), 3);
477        assert_eq!(super::name_len(b"abc "), 3);
478
479        assert_eq!(super::name_len(b"a bc"), 1);
480        assert_eq!(super::name_len(b"ab\tc"), 2);
481        assert_eq!(super::name_len(b"ab\rc"), 2);
482        assert_eq!(super::name_len(b"ab\nc"), 2);
483    }
484
485    #[test]
486    fn trim_xml_start() {
487        assert_eq!(Bytes(super::trim_xml_start(b"")), Bytes(b""));
488        assert_eq!(Bytes(super::trim_xml_start(b"abc")), Bytes(b"abc"));
489        assert_eq!(
490            Bytes(super::trim_xml_start(b"\r\n\t ab \t\r\nc \t\r\n")),
491            Bytes(b"ab \t\r\nc \t\r\n")
492        );
493    }
494
495    #[test]
496    fn trim_xml_end() {
497        assert_eq!(Bytes(super::trim_xml_end(b"")), Bytes(b""));
498        assert_eq!(Bytes(super::trim_xml_end(b"abc")), Bytes(b"abc"));
499        assert_eq!(
500            Bytes(super::trim_xml_end(b"\r\n\t ab \t\r\nc \t\r\n")),
501            Bytes(b"\r\n\t ab \t\r\nc")
502        );
503    }
504}