regex_automata/util/
wire.rs

1/*!
2Types and routines that support the wire format of finite automata.
3
4Currently, this module just exports a few error types and some small helpers
5for deserializing [dense DFAs](crate::dfa::dense::DFA) using correct alignment.
6*/
7
8/*
9A collection of helper functions, types and traits for serializing automata.
10
11This crate defines its own bespoke serialization mechanism for some structures
12provided in the public API, namely, DFAs. A bespoke mechanism was developed
13primarily because structures like automata demand a specific binary format.
14Attempting to encode their rich structure in an existing serialization
15format is just not feasible. Moreover, the format for each structure is
16generally designed such that deserialization is cheap. More specifically, that
17deserialization can be done in constant time. (The idea being that you can
18embed it into your binary or mmap it, and then use it immediately.)
19
20In order to achieve this, the dense and sparse DFAs in this crate use an
21in-memory representation that very closely corresponds to its binary serialized
22form. This pervades and complicates everything, and in some cases, requires
23dealing with alignment and reasoning about safety.
24
25This technique does have major advantages. In particular, it permits doing
26the potentially costly work of compiling a finite state machine in an offline
27manner, and then loading it at runtime not only without having to re-compile
28the regex, but even without the code required to do the compilation. This, for
29example, permits one to use a pre-compiled DFA not only in environments without
30Rust's standard library, but also in environments without a heap.
31
32In the code below, whenever we insert some kind of padding, it's to enforce a
334-byte alignment, unless otherwise noted. Namely, u32 is the only state ID type
34supported. (In a previous version of this library, DFAs were generic over the
35state ID representation.)
36
37Also, serialization generally requires the caller to specify endianness,
38where as deserialization always assumes native endianness (otherwise cheap
39deserialization would be impossible). This implies that serializing a structure
40generally requires serializing both its big-endian and little-endian variants,
41and then loading the correct one based on the target's endianness.
42*/
43
44use core::{cmp, mem::size_of};
45
46#[cfg(feature = "alloc")]
47use alloc::{vec, vec::Vec};
48
49use crate::util::{
50    int::Pointer,
51    primitives::{PatternID, PatternIDError, StateID, StateIDError},
52};
53
54/// A hack to align a smaller type `B` with a bigger type `T`.
55///
56/// The usual use of this is with `B = [u8]` and `T = u32`. That is,
57/// it permits aligning a sequence of bytes on a 4-byte boundary. This
58/// is useful in contexts where one wants to embed a serialized [dense
59/// DFA](crate::dfa::dense::DFA) into a Rust a program while guaranteeing the
60/// alignment required for the DFA.
61///
62/// See [`dense::DFA::from_bytes`](crate::dfa::dense::DFA::from_bytes) for an
63/// example of how to use this type.
64#[repr(C)]
65#[derive(Debug)]
66pub struct AlignAs<B: ?Sized, T> {
67    /// A zero-sized field indicating the alignment we want.
68    pub _align: [T; 0],
69    /// A possibly non-sized field containing a sequence of bytes.
70    pub bytes: B,
71}
72
73/// An error that occurs when serializing an object from this crate.
74///
75/// Serialization, as used in this crate, universally refers to the process
76/// of transforming a structure (like a DFA) into a custom binary format
77/// represented by `&[u8]`. To this end, serialization is generally infallible.
78/// However, it can fail when caller provided buffer sizes are too small. When
79/// that occurs, a serialization error is reported.
80///
81/// A `SerializeError` provides no introspection capabilities. Its only
82/// supported operation is conversion to a human readable error message.
83///
84/// This error type implements the `std::error::Error` trait only when the
85/// `std` feature is enabled. Otherwise, this type is defined in all
86/// configurations.
87#[derive(Debug)]
88pub struct SerializeError {
89    /// The name of the thing that a buffer is too small for.
90    ///
91    /// Currently, the only kind of serialization error is one that is
92    /// committed by a caller: providing a destination buffer that is too
93    /// small to fit the serialized object. This makes sense conceptually,
94    /// since every valid inhabitant of a type should be serializable.
95    ///
96    /// This is somewhat exposed in the public API of this crate. For example,
97    /// the `to_bytes_{big,little}_endian` APIs return a `Vec<u8>` and are
98    /// guaranteed to never panic or error. This is only possible because the
99    /// implementation guarantees that it will allocate a `Vec<u8>` that is
100    /// big enough.
101    ///
102    /// In summary, if a new serialization error kind needs to be added, then
103    /// it will need careful consideration.
104    what: &'static str,
105}
106
107impl SerializeError {
108    pub(crate) fn buffer_too_small(what: &'static str) -> SerializeError {
109        SerializeError { what }
110    }
111}
112
113impl core::fmt::Display for SerializeError {
114    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
115        write!(f, "destination buffer is too small to write {}", self.what)
116    }
117}
118
119#[cfg(feature = "std")]
120impl std::error::Error for SerializeError {}
121
122/// An error that occurs when deserializing an object defined in this crate.
123///
124/// Serialization, as used in this crate, universally refers to the process
125/// of transforming a structure (like a DFA) into a custom binary format
126/// represented by `&[u8]`. Deserialization, then, refers to the process of
127/// cheaply converting this binary format back to the object's in-memory
128/// representation as defined in this crate. To the extent possible,
129/// deserialization will report this error whenever this process fails.
130///
131/// A `DeserializeError` provides no introspection capabilities. Its only
132/// supported operation is conversion to a human readable error message.
133///
134/// This error type implements the `std::error::Error` trait only when the
135/// `std` feature is enabled. Otherwise, this type is defined in all
136/// configurations.
137#[derive(Debug)]
138pub struct DeserializeError(DeserializeErrorKind);
139
140#[derive(Debug)]
141enum DeserializeErrorKind {
142    Generic { msg: &'static str },
143    BufferTooSmall { what: &'static str },
144    InvalidUsize { what: &'static str },
145    VersionMismatch { expected: u32, found: u32 },
146    EndianMismatch { expected: u32, found: u32 },
147    AlignmentMismatch { alignment: usize, address: usize },
148    LabelMismatch { expected: &'static str },
149    ArithmeticOverflow { what: &'static str },
150    PatternID { err: PatternIDError, what: &'static str },
151    StateID { err: StateIDError, what: &'static str },
152}
153
154impl DeserializeError {
155    pub(crate) fn generic(msg: &'static str) -> DeserializeError {
156        DeserializeError(DeserializeErrorKind::Generic { msg })
157    }
158
159    pub(crate) fn buffer_too_small(what: &'static str) -> DeserializeError {
160        DeserializeError(DeserializeErrorKind::BufferTooSmall { what })
161    }
162
163    fn invalid_usize(what: &'static str) -> DeserializeError {
164        DeserializeError(DeserializeErrorKind::InvalidUsize { what })
165    }
166
167    fn version_mismatch(expected: u32, found: u32) -> DeserializeError {
168        DeserializeError(DeserializeErrorKind::VersionMismatch {
169            expected,
170            found,
171        })
172    }
173
174    fn endian_mismatch(expected: u32, found: u32) -> DeserializeError {
175        DeserializeError(DeserializeErrorKind::EndianMismatch {
176            expected,
177            found,
178        })
179    }
180
181    fn alignment_mismatch(
182        alignment: usize,
183        address: usize,
184    ) -> DeserializeError {
185        DeserializeError(DeserializeErrorKind::AlignmentMismatch {
186            alignment,
187            address,
188        })
189    }
190
191    fn label_mismatch(expected: &'static str) -> DeserializeError {
192        DeserializeError(DeserializeErrorKind::LabelMismatch { expected })
193    }
194
195    fn arithmetic_overflow(what: &'static str) -> DeserializeError {
196        DeserializeError(DeserializeErrorKind::ArithmeticOverflow { what })
197    }
198
199    fn pattern_id_error(
200        err: PatternIDError,
201        what: &'static str,
202    ) -> DeserializeError {
203        DeserializeError(DeserializeErrorKind::PatternID { err, what })
204    }
205
206    pub(crate) fn state_id_error(
207        err: StateIDError,
208        what: &'static str,
209    ) -> DeserializeError {
210        DeserializeError(DeserializeErrorKind::StateID { err, what })
211    }
212}
213
214#[cfg(feature = "std")]
215impl std::error::Error for DeserializeError {}
216
217impl core::fmt::Display for DeserializeError {
218    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
219        use self::DeserializeErrorKind::*;
220
221        match self.0 {
222            Generic { msg } => write!(f, "{}", msg),
223            BufferTooSmall { what } => {
224                write!(f, "buffer is too small to read {}", what)
225            }
226            InvalidUsize { what } => {
227                write!(f, "{} is too big to fit in a usize", what)
228            }
229            VersionMismatch { expected, found } => write!(
230                f,
231                "unsupported version: \
232                 expected version {} but found version {}",
233                expected, found,
234            ),
235            EndianMismatch { expected, found } => write!(
236                f,
237                "endianness mismatch: expected 0x{:X} but got 0x{:X}. \
238                 (Are you trying to load an object serialized with a \
239                 different endianness?)",
240                expected, found,
241            ),
242            AlignmentMismatch { alignment, address } => write!(
243                f,
244                "alignment mismatch: slice starts at address \
245                 0x{:X}, which is not aligned to a {} byte boundary",
246                address, alignment,
247            ),
248            LabelMismatch { expected } => write!(
249                f,
250                "label mismatch: start of serialized object should \
251                 contain a NUL terminated {:?} label, but a different \
252                 label was found",
253                expected,
254            ),
255            ArithmeticOverflow { what } => {
256                write!(f, "arithmetic overflow for {}", what)
257            }
258            PatternID { ref err, what } => {
259                write!(f, "failed to read pattern ID for {}: {}", what, err)
260            }
261            StateID { ref err, what } => {
262                write!(f, "failed to read state ID for {}: {}", what, err)
263            }
264        }
265    }
266}
267
268/// Safely converts a `&[u32]` to `&[StateID]` with zero cost.
269#[cfg_attr(feature = "perf-inline", inline(always))]
270pub(crate) fn u32s_to_state_ids(slice: &[u32]) -> &[StateID] {
271    // SAFETY: This is safe because StateID is defined to have the same memory
272    // representation as a u32 (it is repr(transparent)). While not every u32
273    // is a "valid" StateID, callers are not permitted to rely on the validity
274    // of StateIDs for memory safety. It can only lead to logical errors. (This
275    // is why StateID::new_unchecked is safe.)
276    unsafe {
277        core::slice::from_raw_parts(
278            slice.as_ptr().cast::<StateID>(),
279            slice.len(),
280        )
281    }
282}
283
284/// Safely converts a `&mut [u32]` to `&mut [StateID]` with zero cost.
285pub(crate) fn u32s_to_state_ids_mut(slice: &mut [u32]) -> &mut [StateID] {
286    // SAFETY: This is safe because StateID is defined to have the same memory
287    // representation as a u32 (it is repr(transparent)). While not every u32
288    // is a "valid" StateID, callers are not permitted to rely on the validity
289    // of StateIDs for memory safety. It can only lead to logical errors. (This
290    // is why StateID::new_unchecked is safe.)
291    unsafe {
292        core::slice::from_raw_parts_mut(
293            slice.as_mut_ptr().cast::<StateID>(),
294            slice.len(),
295        )
296    }
297}
298
299/// Safely converts a `&[u32]` to `&[PatternID]` with zero cost.
300#[cfg_attr(feature = "perf-inline", inline(always))]
301pub(crate) fn u32s_to_pattern_ids(slice: &[u32]) -> &[PatternID] {
302    // SAFETY: This is safe because PatternID is defined to have the same
303    // memory representation as a u32 (it is repr(transparent)). While not
304    // every u32 is a "valid" PatternID, callers are not permitted to rely
305    // on the validity of PatternIDs for memory safety. It can only lead to
306    // logical errors. (This is why PatternID::new_unchecked is safe.)
307    unsafe {
308        core::slice::from_raw_parts(
309            slice.as_ptr().cast::<PatternID>(),
310            slice.len(),
311        )
312    }
313}
314
315/// Checks that the given slice has an alignment that matches `T`.
316///
317/// This is useful for checking that a slice has an appropriate alignment
318/// before casting it to a &[T]. Note though that alignment is not itself
319/// sufficient to perform the cast for any `T`.
320pub(crate) fn check_alignment<T>(
321    slice: &[u8],
322) -> Result<(), DeserializeError> {
323    let alignment = core::mem::align_of::<T>();
324    let address = slice.as_ptr().as_usize();
325    if address % alignment == 0 {
326        return Ok(());
327    }
328    Err(DeserializeError::alignment_mismatch(alignment, address))
329}
330
331/// Reads a possibly empty amount of padding, up to 7 bytes, from the beginning
332/// of the given slice. All padding bytes must be NUL bytes.
333///
334/// This is useful because it can be theoretically necessary to pad the
335/// beginning of a serialized object with NUL bytes to ensure that it starts
336/// at a correctly aligned address. These padding bytes should come immediately
337/// before the label.
338///
339/// This returns the number of bytes read from the given slice.
340pub(crate) fn skip_initial_padding(slice: &[u8]) -> usize {
341    let mut nread = 0;
342    while nread < 7 && nread < slice.len() && slice[nread] == 0 {
343        nread += 1;
344    }
345    nread
346}
347
348/// Allocate a byte buffer of the given size, along with some initial padding
349/// such that `buf[padding..]` has the same alignment as `T`, where the
350/// alignment of `T` must be at most `8`. In particular, callers should treat
351/// the first N bytes (second return value) as padding bytes that must not be
352/// overwritten. In all cases, the following identity holds:
353///
354/// ```ignore
355/// let (buf, padding) = alloc_aligned_buffer::<StateID>(SIZE);
356/// assert_eq!(SIZE, buf[padding..].len());
357/// ```
358///
359/// In practice, padding is often zero.
360///
361/// The requirement for `8` as a maximum here is somewhat arbitrary. In
362/// practice, we never need anything bigger in this crate, and so this function
363/// does some sanity asserts under the assumption of a max alignment of `8`.
364#[cfg(feature = "alloc")]
365pub(crate) fn alloc_aligned_buffer<T>(size: usize) -> (Vec<u8>, usize) {
366    // NOTE: This is a kludge because there's no easy way to allocate a Vec<u8>
367    // with an alignment guaranteed to be greater than 1. We could create a
368    // Vec<u32>, but this cannot be safely transmuted to a Vec<u8> without
369    // concern, since reallocing or dropping the Vec<u8> is UB (different
370    // alignment than the initial allocation). We could define a wrapper type
371    // to manage this for us, but it seems like more machinery than it's worth.
372    let buf = vec![0; size];
373    let align = core::mem::align_of::<T>();
374    let address = buf.as_ptr().as_usize();
375    if address % align == 0 {
376        return (buf, 0);
377    }
378    // Let's try this again. We have to create a totally new alloc with
379    // the maximum amount of bytes we might need. We can't just extend our
380    // pre-existing 'buf' because that might create a new alloc with a
381    // different alignment.
382    let extra = align - 1;
383    let mut buf = vec![0; size + extra];
384    let address = buf.as_ptr().as_usize();
385    // The code below handles the case where 'address' is aligned to T, so if
386    // we got lucky and 'address' is now aligned to T (when it previously
387    // wasn't), then we're done.
388    if address % align == 0 {
389        buf.truncate(size);
390        return (buf, 0);
391    }
392    let padding = ((address & !(align - 1)).checked_add(align).unwrap())
393        .checked_sub(address)
394        .unwrap();
395    assert!(padding <= 7, "padding of {} is bigger than 7", padding);
396    assert!(
397        padding <= extra,
398        "padding of {} is bigger than extra {} bytes",
399        padding,
400        extra
401    );
402    buf.truncate(size + padding);
403    assert_eq!(size + padding, buf.len());
404    assert_eq!(
405        0,
406        buf[padding..].as_ptr().as_usize() % align,
407        "expected end of initial padding to be aligned to {}",
408        align,
409    );
410    (buf, padding)
411}
412
413/// Reads a NUL terminated label starting at the beginning of the given slice.
414///
415/// If a NUL terminated label could not be found, then an error is returned.
416/// Similarly, if a label is found but doesn't match the expected label, then
417/// an error is returned.
418///
419/// Upon success, the total number of bytes read (including padding bytes) is
420/// returned.
421pub(crate) fn read_label(
422    slice: &[u8],
423    expected_label: &'static str,
424) -> Result<usize, DeserializeError> {
425    // Set an upper bound on how many bytes we scan for a NUL. Since no label
426    // in this crate is longer than 256 bytes, if we can't find one within that
427    // range, then we have corrupted data.
428    let first_nul =
429        slice[..cmp::min(slice.len(), 256)].iter().position(|&b| b == 0);
430    let first_nul = match first_nul {
431        Some(first_nul) => first_nul,
432        None => {
433            return Err(DeserializeError::generic(
434                "could not find NUL terminated label \
435                 at start of serialized object",
436            ));
437        }
438    };
439    let len = first_nul + padding_len(first_nul);
440    if slice.len() < len {
441        return Err(DeserializeError::generic(
442            "could not find properly sized label at start of serialized object"
443        ));
444    }
445    if expected_label.as_bytes() != &slice[..first_nul] {
446        return Err(DeserializeError::label_mismatch(expected_label));
447    }
448    Ok(len)
449}
450
451/// Writes the given label to the buffer as a NUL terminated string. The label
452/// given must not contain NUL, otherwise this will panic. Similarly, the label
453/// must not be longer than 255 bytes, otherwise this will panic.
454///
455/// Additional NUL bytes are written as necessary to ensure that the number of
456/// bytes written is always a multiple of 4.
457///
458/// Upon success, the total number of bytes written (including padding) is
459/// returned.
460pub(crate) fn write_label(
461    label: &str,
462    dst: &mut [u8],
463) -> Result<usize, SerializeError> {
464    let nwrite = write_label_len(label);
465    if dst.len() < nwrite {
466        return Err(SerializeError::buffer_too_small("label"));
467    }
468    dst[..label.len()].copy_from_slice(label.as_bytes());
469    for i in 0..(nwrite - label.len()) {
470        dst[label.len() + i] = 0;
471    }
472    assert_eq!(nwrite % 4, 0);
473    Ok(nwrite)
474}
475
476/// Returns the total number of bytes (including padding) that would be written
477/// for the given label. This panics if the given label contains a NUL byte or
478/// is longer than 255 bytes. (The size restriction exists so that searching
479/// for a label during deserialization can be done in small bounded space.)
480pub(crate) fn write_label_len(label: &str) -> usize {
481    assert!(label.len() <= 255, "label must not be longer than 255 bytes");
482    assert!(label.bytes().all(|b| b != 0), "label must not contain NUL bytes");
483    let label_len = label.len() + 1; // +1 for the NUL terminator
484    label_len + padding_len(label_len)
485}
486
487/// Reads the endianness check from the beginning of the given slice and
488/// confirms that the endianness of the serialized object matches the expected
489/// endianness. If the slice is too small or if the endianness check fails,
490/// this returns an error.
491///
492/// Upon success, the total number of bytes read is returned.
493pub(crate) fn read_endianness_check(
494    slice: &[u8],
495) -> Result<usize, DeserializeError> {
496    let (n, nr) = try_read_u32(slice, "endianness check")?;
497    assert_eq!(nr, write_endianness_check_len());
498    if n != 0xFEFF {
499        return Err(DeserializeError::endian_mismatch(0xFEFF, n));
500    }
501    Ok(nr)
502}
503
504/// Writes 0xFEFF as an integer using the given endianness.
505///
506/// This is useful for writing into the header of a serialized object. It can
507/// be read during deserialization as a sanity check to ensure the proper
508/// endianness is used.
509///
510/// Upon success, the total number of bytes written is returned.
511pub(crate) fn write_endianness_check<E: Endian>(
512    dst: &mut [u8],
513) -> Result<usize, SerializeError> {
514    let nwrite = write_endianness_check_len();
515    if dst.len() < nwrite {
516        return Err(SerializeError::buffer_too_small("endianness check"));
517    }
518    E::write_u32(0xFEFF, dst);
519    Ok(nwrite)
520}
521
522/// Returns the number of bytes written by the endianness check.
523pub(crate) fn write_endianness_check_len() -> usize {
524    size_of::<u32>()
525}
526
527/// Reads a version number from the beginning of the given slice and confirms
528/// that is matches the expected version number given. If the slice is too
529/// small or if the version numbers aren't equivalent, this returns an error.
530///
531/// Upon success, the total number of bytes read is returned.
532///
533/// N.B. Currently, we require that the version number is exactly equivalent.
534/// In the future, if we bump the version number without a semver bump, then
535/// we'll need to relax this a bit and support older versions.
536pub(crate) fn read_version(
537    slice: &[u8],
538    expected_version: u32,
539) -> Result<usize, DeserializeError> {
540    let (n, nr) = try_read_u32(slice, "version")?;
541    assert_eq!(nr, write_version_len());
542    if n != expected_version {
543        return Err(DeserializeError::version_mismatch(expected_version, n));
544    }
545    Ok(nr)
546}
547
548/// Writes the given version number to the beginning of the given slice.
549///
550/// This is useful for writing into the header of a serialized object. It can
551/// be read during deserialization as a sanity check to ensure that the library
552/// code supports the format of the serialized object.
553///
554/// Upon success, the total number of bytes written is returned.
555pub(crate) fn write_version<E: Endian>(
556    version: u32,
557    dst: &mut [u8],
558) -> Result<usize, SerializeError> {
559    let nwrite = write_version_len();
560    if dst.len() < nwrite {
561        return Err(SerializeError::buffer_too_small("version number"));
562    }
563    E::write_u32(version, dst);
564    Ok(nwrite)
565}
566
567/// Returns the number of bytes written by writing the version number.
568pub(crate) fn write_version_len() -> usize {
569    size_of::<u32>()
570}
571
572/// Reads a pattern ID from the given slice. If the slice has insufficient
573/// length, then this panics. If the deserialized integer exceeds the pattern
574/// ID limit for the current target, then this returns an error.
575///
576/// Upon success, this also returns the number of bytes read.
577pub(crate) fn read_pattern_id(
578    slice: &[u8],
579    what: &'static str,
580) -> Result<(PatternID, usize), DeserializeError> {
581    let bytes: [u8; PatternID::SIZE] =
582        slice[..PatternID::SIZE].try_into().unwrap();
583    let pid = PatternID::from_ne_bytes(bytes)
584        .map_err(|err| DeserializeError::pattern_id_error(err, what))?;
585    Ok((pid, PatternID::SIZE))
586}
587
588/// Reads a pattern ID from the given slice. If the slice has insufficient
589/// length, then this panics. Otherwise, the deserialized integer is assumed
590/// to be a valid pattern ID.
591///
592/// This also returns the number of bytes read.
593pub(crate) fn read_pattern_id_unchecked(slice: &[u8]) -> (PatternID, usize) {
594    let pid = PatternID::from_ne_bytes_unchecked(
595        slice[..PatternID::SIZE].try_into().unwrap(),
596    );
597    (pid, PatternID::SIZE)
598}
599
600/// Write the given pattern ID to the beginning of the given slice of bytes
601/// using the specified endianness. The given slice must have length at least
602/// `PatternID::SIZE`, or else this panics. Upon success, the total number of
603/// bytes written is returned.
604pub(crate) fn write_pattern_id<E: Endian>(
605    pid: PatternID,
606    dst: &mut [u8],
607) -> usize {
608    E::write_u32(pid.as_u32(), dst);
609    PatternID::SIZE
610}
611
612/// Attempts to read a state ID from the given slice. If the slice has an
613/// insufficient number of bytes or if the state ID exceeds the limit for
614/// the current target, then this returns an error.
615///
616/// Upon success, this also returns the number of bytes read.
617pub(crate) fn try_read_state_id(
618    slice: &[u8],
619    what: &'static str,
620) -> Result<(StateID, usize), DeserializeError> {
621    if slice.len() < StateID::SIZE {
622        return Err(DeserializeError::buffer_too_small(what));
623    }
624    read_state_id(slice, what)
625}
626
627/// Reads a state ID from the given slice. If the slice has insufficient
628/// length, then this panics. If the deserialized integer exceeds the state ID
629/// limit for the current target, then this returns an error.
630///
631/// Upon success, this also returns the number of bytes read.
632pub(crate) fn read_state_id(
633    slice: &[u8],
634    what: &'static str,
635) -> Result<(StateID, usize), DeserializeError> {
636    let bytes: [u8; StateID::SIZE] =
637        slice[..StateID::SIZE].try_into().unwrap();
638    let sid = StateID::from_ne_bytes(bytes)
639        .map_err(|err| DeserializeError::state_id_error(err, what))?;
640    Ok((sid, StateID::SIZE))
641}
642
643/// Reads a state ID from the given slice. If the slice has insufficient
644/// length, then this panics. Otherwise, the deserialized integer is assumed
645/// to be a valid state ID.
646///
647/// This also returns the number of bytes read.
648pub(crate) fn read_state_id_unchecked(slice: &[u8]) -> (StateID, usize) {
649    let sid = StateID::from_ne_bytes_unchecked(
650        slice[..StateID::SIZE].try_into().unwrap(),
651    );
652    (sid, StateID::SIZE)
653}
654
655/// Write the given state ID to the beginning of the given slice of bytes
656/// using the specified endianness. The given slice must have length at least
657/// `StateID::SIZE`, or else this panics. Upon success, the total number of
658/// bytes written is returned.
659pub(crate) fn write_state_id<E: Endian>(
660    sid: StateID,
661    dst: &mut [u8],
662) -> usize {
663    E::write_u32(sid.as_u32(), dst);
664    StateID::SIZE
665}
666
667/// Try to read a u16 as a usize from the beginning of the given slice in
668/// native endian format. If the slice has fewer than 2 bytes or if the
669/// deserialized number cannot be represented by usize, then this returns an
670/// error. The error message will include the `what` description of what is
671/// being deserialized, for better error messages. `what` should be a noun in
672/// singular form.
673///
674/// Upon success, this also returns the number of bytes read.
675pub(crate) fn try_read_u16_as_usize(
676    slice: &[u8],
677    what: &'static str,
678) -> Result<(usize, usize), DeserializeError> {
679    try_read_u16(slice, what).and_then(|(n, nr)| {
680        usize::try_from(n)
681            .map(|n| (n, nr))
682            .map_err(|_| DeserializeError::invalid_usize(what))
683    })
684}
685
686/// Try to read a u32 as a usize from the beginning of the given slice in
687/// native endian format. If the slice has fewer than 4 bytes or if the
688/// deserialized number cannot be represented by usize, then this returns an
689/// error. The error message will include the `what` description of what is
690/// being deserialized, for better error messages. `what` should be a noun in
691/// singular form.
692///
693/// Upon success, this also returns the number of bytes read.
694pub(crate) fn try_read_u32_as_usize(
695    slice: &[u8],
696    what: &'static str,
697) -> Result<(usize, usize), DeserializeError> {
698    try_read_u32(slice, what).and_then(|(n, nr)| {
699        usize::try_from(n)
700            .map(|n| (n, nr))
701            .map_err(|_| DeserializeError::invalid_usize(what))
702    })
703}
704
705/// Try to read a u16 from the beginning of the given slice in native endian
706/// format. If the slice has fewer than 2 bytes, then this returns an error.
707/// The error message will include the `what` description of what is being
708/// deserialized, for better error messages. `what` should be a noun in
709/// singular form.
710///
711/// Upon success, this also returns the number of bytes read.
712pub(crate) fn try_read_u16(
713    slice: &[u8],
714    what: &'static str,
715) -> Result<(u16, usize), DeserializeError> {
716    check_slice_len(slice, size_of::<u16>(), what)?;
717    Ok((read_u16(slice), size_of::<u16>()))
718}
719
720/// Try to read a u32 from the beginning of the given slice in native endian
721/// format. If the slice has fewer than 4 bytes, then this returns an error.
722/// The error message will include the `what` description of what is being
723/// deserialized, for better error messages. `what` should be a noun in
724/// singular form.
725///
726/// Upon success, this also returns the number of bytes read.
727pub(crate) fn try_read_u32(
728    slice: &[u8],
729    what: &'static str,
730) -> Result<(u32, usize), DeserializeError> {
731    check_slice_len(slice, size_of::<u32>(), what)?;
732    Ok((read_u32(slice), size_of::<u32>()))
733}
734
735/// Try to read a u128 from the beginning of the given slice in native endian
736/// format. If the slice has fewer than 16 bytes, then this returns an error.
737/// The error message will include the `what` description of what is being
738/// deserialized, for better error messages. `what` should be a noun in
739/// singular form.
740///
741/// Upon success, this also returns the number of bytes read.
742pub(crate) fn try_read_u128(
743    slice: &[u8],
744    what: &'static str,
745) -> Result<(u128, usize), DeserializeError> {
746    check_slice_len(slice, size_of::<u128>(), what)?;
747    Ok((read_u128(slice), size_of::<u128>()))
748}
749
750/// Read a u16 from the beginning of the given slice in native endian format.
751/// If the slice has fewer than 2 bytes, then this panics.
752///
753/// Marked as inline to speed up sparse searching which decodes integers from
754/// its automaton at search time.
755#[cfg_attr(feature = "perf-inline", inline(always))]
756pub(crate) fn read_u16(slice: &[u8]) -> u16 {
757    let bytes: [u8; 2] = slice[..size_of::<u16>()].try_into().unwrap();
758    u16::from_ne_bytes(bytes)
759}
760
761/// Read a u32 from the beginning of the given slice in native endian format.
762/// If the slice has fewer than 4 bytes, then this panics.
763///
764/// Marked as inline to speed up sparse searching which decodes integers from
765/// its automaton at search time.
766#[cfg_attr(feature = "perf-inline", inline(always))]
767pub(crate) fn read_u32(slice: &[u8]) -> u32 {
768    let bytes: [u8; 4] = slice[..size_of::<u32>()].try_into().unwrap();
769    u32::from_ne_bytes(bytes)
770}
771
772/// Read a u128 from the beginning of the given slice in native endian format.
773/// If the slice has fewer than 16 bytes, then this panics.
774pub(crate) fn read_u128(slice: &[u8]) -> u128 {
775    let bytes: [u8; 16] = slice[..size_of::<u128>()].try_into().unwrap();
776    u128::from_ne_bytes(bytes)
777}
778
779/// Checks that the given slice has some minimal length. If it's smaller than
780/// the bound given, then a "buffer too small" error is returned with `what`
781/// describing what the buffer represents.
782pub(crate) fn check_slice_len<T>(
783    slice: &[T],
784    at_least_len: usize,
785    what: &'static str,
786) -> Result<(), DeserializeError> {
787    if slice.len() < at_least_len {
788        return Err(DeserializeError::buffer_too_small(what));
789    }
790    Ok(())
791}
792
793/// Multiply the given numbers, and on overflow, return an error that includes
794/// 'what' in the error message.
795///
796/// This is useful when doing arithmetic with untrusted data.
797pub(crate) fn mul(
798    a: usize,
799    b: usize,
800    what: &'static str,
801) -> Result<usize, DeserializeError> {
802    match a.checked_mul(b) {
803        Some(c) => Ok(c),
804        None => Err(DeserializeError::arithmetic_overflow(what)),
805    }
806}
807
808/// Add the given numbers, and on overflow, return an error that includes
809/// 'what' in the error message.
810///
811/// This is useful when doing arithmetic with untrusted data.
812pub(crate) fn add(
813    a: usize,
814    b: usize,
815    what: &'static str,
816) -> Result<usize, DeserializeError> {
817    match a.checked_add(b) {
818        Some(c) => Ok(c),
819        None => Err(DeserializeError::arithmetic_overflow(what)),
820    }
821}
822
823/// Shift `a` left by `b`, and on overflow, return an error that includes
824/// 'what' in the error message.
825///
826/// This is useful when doing arithmetic with untrusted data.
827pub(crate) fn shl(
828    a: usize,
829    b: usize,
830    what: &'static str,
831) -> Result<usize, DeserializeError> {
832    let amount = u32::try_from(b)
833        .map_err(|_| DeserializeError::arithmetic_overflow(what))?;
834    match a.checked_shl(amount) {
835        Some(c) => Ok(c),
836        None => Err(DeserializeError::arithmetic_overflow(what)),
837    }
838}
839
840/// Returns the number of additional bytes required to add to the given length
841/// in order to make the total length a multiple of 4. The return value is
842/// always less than 4.
843pub(crate) fn padding_len(non_padding_len: usize) -> usize {
844    (4 - (non_padding_len & 0b11)) & 0b11
845}
846
847/// A simple trait for writing code generic over endianness.
848///
849/// This is similar to what byteorder provides, but we only need a very small
850/// subset.
851pub(crate) trait Endian {
852    /// Writes a u16 to the given destination buffer in a particular
853    /// endianness. If the destination buffer has a length smaller than 2, then
854    /// this panics.
855    fn write_u16(n: u16, dst: &mut [u8]);
856
857    /// Writes a u32 to the given destination buffer in a particular
858    /// endianness. If the destination buffer has a length smaller than 4, then
859    /// this panics.
860    fn write_u32(n: u32, dst: &mut [u8]);
861
862    /// Writes a u128 to the given destination buffer in a particular
863    /// endianness. If the destination buffer has a length smaller than 16,
864    /// then this panics.
865    fn write_u128(n: u128, dst: &mut [u8]);
866}
867
868/// Little endian writing.
869pub(crate) enum LE {}
870/// Big endian writing.
871pub(crate) enum BE {}
872
873#[cfg(target_endian = "little")]
874pub(crate) type NE = LE;
875#[cfg(target_endian = "big")]
876pub(crate) type NE = BE;
877
878impl Endian for LE {
879    fn write_u16(n: u16, dst: &mut [u8]) {
880        dst[..2].copy_from_slice(&n.to_le_bytes());
881    }
882
883    fn write_u32(n: u32, dst: &mut [u8]) {
884        dst[..4].copy_from_slice(&n.to_le_bytes());
885    }
886
887    fn write_u128(n: u128, dst: &mut [u8]) {
888        dst[..16].copy_from_slice(&n.to_le_bytes());
889    }
890}
891
892impl Endian for BE {
893    fn write_u16(n: u16, dst: &mut [u8]) {
894        dst[..2].copy_from_slice(&n.to_be_bytes());
895    }
896
897    fn write_u32(n: u32, dst: &mut [u8]) {
898        dst[..4].copy_from_slice(&n.to_be_bytes());
899    }
900
901    fn write_u128(n: u128, dst: &mut [u8]) {
902        dst[..16].copy_from_slice(&n.to_be_bytes());
903    }
904}
905
906#[cfg(all(test, feature = "alloc"))]
907mod tests {
908    use super::*;
909
910    #[test]
911    fn labels() {
912        let mut buf = [0; 1024];
913
914        let nwrite = write_label("fooba", &mut buf).unwrap();
915        assert_eq!(nwrite, 8);
916        assert_eq!(&buf[..nwrite], b"fooba\x00\x00\x00");
917
918        let nread = read_label(&buf, "fooba").unwrap();
919        assert_eq!(nread, 8);
920    }
921
922    #[test]
923    #[should_panic]
924    fn bad_label_interior_nul() {
925        // interior NULs are not allowed
926        write_label("foo\x00bar", &mut [0; 1024]).unwrap();
927    }
928
929    #[test]
930    fn bad_label_almost_too_long() {
931        // ok
932        write_label(&"z".repeat(255), &mut [0; 1024]).unwrap();
933    }
934
935    #[test]
936    #[should_panic]
937    fn bad_label_too_long() {
938        // labels longer than 255 bytes are banned
939        write_label(&"z".repeat(256), &mut [0; 1024]).unwrap();
940    }
941
942    #[test]
943    fn padding() {
944        assert_eq!(0, padding_len(8));
945        assert_eq!(3, padding_len(9));
946        assert_eq!(2, padding_len(10));
947        assert_eq!(1, padding_len(11));
948        assert_eq!(0, padding_len(12));
949        assert_eq!(3, padding_len(13));
950        assert_eq!(2, padding_len(14));
951        assert_eq!(1, padding_len(15));
952        assert_eq!(0, padding_len(16));
953    }
954}