Skip to main content

der/reader/
position.rs

1//! Position tracking for processing nested input messages using only the stack.
2
3use crate::{Error, ErrorKind, Length, Result};
4
5/// State tracker for the current position in the input.
6#[derive(Clone, Debug)]
7pub(super) struct Position {
8    /// Input length (in bytes after Base64 decoding).
9    input_len: Length,
10
11    /// Position in the input buffer (in bytes after Base64 decoding).
12    position: Length,
13}
14
15impl Position {
16    /// Create a new position tracker with the given overall length.
17    pub(super) fn new(input_len: Length) -> Self {
18        Self {
19            input_len,
20            position: Length::ZERO,
21        }
22    }
23
24    /// Get the input length.
25    pub(super) fn input_len(&self) -> Length {
26        self.input_len
27    }
28
29    /// Get the current position.
30    pub(super) fn current(&self) -> Length {
31        self.position
32    }
33
34    /// Advance the current position by the given amount.
35    ///
36    /// # Returns
37    ///
38    /// The new current position.
39    pub(super) fn advance(&mut self, amount: Length) -> Result<Length> {
40        let new_position = (self.position + amount)?;
41
42        if new_position > self.input_len {
43            return Err(ErrorKind::Incomplete {
44                expected_len: new_position,
45                actual_len: self.input_len,
46            }
47            .at(self.position));
48        }
49
50        self.position = new_position;
51        Ok(new_position)
52    }
53
54    /// Split a nested position tracker of the given size.
55    ///
56    /// # Returns
57    ///
58    /// A [`Resumption`] value which can be used to continue parsing the outer message.
59    pub(super) fn split_nested(&mut self, len: Length) -> Result<Resumption> {
60        let nested_input_len = (self.position + len)?;
61
62        if nested_input_len > self.input_len {
63            return Err(Error::incomplete(self.input_len));
64        }
65
66        let resumption = Resumption {
67            input_len: self.input_len,
68        };
69        self.input_len = nested_input_len;
70        Ok(resumption)
71    }
72
73    /// Resume processing the rest of a message after processing a nested inner portion.
74    pub(super) fn resume_nested(&mut self, resumption: Resumption) {
75        self.input_len = resumption.input_len;
76    }
77}
78
79/// Resumption state needed to continue processing a message after handling a nested inner portion.
80#[derive(Debug)]
81pub(super) struct Resumption {
82    /// Outer input length.
83    input_len: Length,
84}
85
86#[cfg(test)]
87#[allow(clippy::unwrap_used)]
88mod tests {
89    use super::Position;
90    use crate::{ErrorKind, Length};
91
92    const EXAMPLE_LEN: Length = match Length::new_usize(42) {
93        Ok(len) => len,
94        Err(_) => panic!("invalid example len"),
95    };
96
97    #[test]
98    fn initial_state() {
99        let pos = Position::new(EXAMPLE_LEN);
100        assert_eq!(pos.input_len(), EXAMPLE_LEN);
101        assert_eq!(pos.current(), Length::ZERO);
102    }
103
104    #[test]
105    fn advance() {
106        let mut pos = Position::new(EXAMPLE_LEN);
107
108        // advance 1 byte: success
109        let new_pos = pos.advance(Length::ONE).unwrap();
110        assert_eq!(new_pos, Length::ONE);
111        assert_eq!(pos.current(), Length::ONE);
112
113        // advance to end: success
114        let end_pos = pos.advance((EXAMPLE_LEN - Length::ONE).unwrap()).unwrap();
115        assert_eq!(end_pos, EXAMPLE_LEN);
116        assert_eq!(pos.current(), EXAMPLE_LEN);
117
118        // advance one byte past end: error
119        let err = pos.advance(Length::ONE).unwrap_err();
120        assert!(matches!(err.kind(), ErrorKind::Incomplete { .. }));
121    }
122
123    #[test]
124    fn nested() {
125        let mut pos = Position::new(EXAMPLE_LEN);
126
127        // split first byte
128        let resumption = pos.split_nested(Length::ONE).unwrap();
129        assert_eq!(pos.current(), Length::ZERO);
130        assert_eq!(pos.input_len(), Length::ONE);
131
132        // advance one byte
133        assert_eq!(pos.advance(Length::ONE).unwrap(), Length::ONE);
134
135        // can't advance two bytes
136        let err = pos.advance(Length::ONE).unwrap_err();
137        assert!(matches!(err.kind(), ErrorKind::Incomplete { .. }));
138
139        // resume processing the rest of the message
140        // TODO(tarcieri): should we fail here if we previously failed reading a nested message?
141        pos.resume_nested(resumption);
142
143        assert_eq!(pos.current(), Length::ONE);
144        assert_eq!(pos.input_len(), EXAMPLE_LEN);
145
146        // try to split one byte past end: error
147        let err = pos.split_nested(EXAMPLE_LEN).unwrap_err();
148        assert!(matches!(err.kind(), ErrorKind::Incomplete { .. }));
149    }
150}