1use crate::{
4    asn1::*, Encode, EncodeValue, ErrorKind, Header, Length, Result, Tag, TagMode, TagNumber,
5    Tagged, Writer,
6};
7
8#[derive(Debug)]
10pub struct SliceWriter<'a> {
11    bytes: &'a mut [u8],
13
14    failed: bool,
16
17    position: Length,
19}
20
21impl<'a> SliceWriter<'a> {
22    pub fn new(bytes: &'a mut [u8]) -> Self {
24        Self {
25            bytes,
26            failed: false,
27            position: Length::ZERO,
28        }
29    }
30
31    pub fn encode<T: Encode>(&mut self, encodable: &T) -> Result<()> {
33        if self.is_failed() {
34            self.error(ErrorKind::Failed)?
35        }
36
37        encodable.encode(self).map_err(|e| {
38            self.failed = true;
39            e.nested(self.position)
40        })
41    }
42
43    pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
46        self.failed = true;
47        Err(kind.at(self.position))
48    }
49
50    pub fn is_failed(&self) -> bool {
52        self.failed
53    }
54
55    pub fn finish(self) -> Result<&'a [u8]> {
58        let position = self.position;
59
60        if self.is_failed() {
61            return Err(ErrorKind::Failed.at(position));
62        }
63
64        self.bytes
65            .get(..usize::try_from(position)?)
66            .ok_or_else(|| ErrorKind::Overlength.at(position))
67    }
68
69    pub fn context_specific<T>(
71        &mut self,
72        tag_number: TagNumber,
73        tag_mode: TagMode,
74        value: &T,
75    ) -> Result<()>
76    where
77        T: EncodeValue + Tagged,
78    {
79        ContextSpecificRef {
80            tag_number,
81            tag_mode,
82            value,
83        }
84        .encode(self)
85    }
86
87    pub fn sequence<F>(&mut self, length: Length, f: F) -> Result<()>
92    where
93        F: FnOnce(&mut SliceWriter<'_>) -> Result<()>,
94    {
95        Header::new(Tag::Sequence, length).and_then(|header| header.encode(self))?;
96
97        let mut nested_encoder = SliceWriter::new(self.reserve(length)?);
98        f(&mut nested_encoder)?;
99
100        if nested_encoder.finish()?.len() == usize::try_from(length)? {
101            Ok(())
102        } else {
103            self.error(ErrorKind::Length { tag: Tag::Sequence })
104        }
105    }
106
107    fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
110        if self.is_failed() {
111            return Err(ErrorKind::Failed.at(self.position));
112        }
113
114        let len = len
115            .try_into()
116            .or_else(|_| self.error(ErrorKind::Overflow))?;
117
118        let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
119        let slice = self
120            .bytes
121            .get_mut(self.position.try_into()?..end.try_into()?)
122            .ok_or_else(|| ErrorKind::Overlength.at(end))?;
123
124        self.position = end;
125        Ok(slice)
126    }
127}
128
129impl<'a> Writer for SliceWriter<'a> {
130    fn write(&mut self, slice: &[u8]) -> Result<()> {
131        self.reserve(slice.len())?.copy_from_slice(slice);
132        Ok(())
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::SliceWriter;
139    use crate::{Encode, ErrorKind, Length};
140
141    #[test]
142    fn overlength_message() {
143        let mut buffer = [];
144        let mut writer = SliceWriter::new(&mut buffer);
145        let err = false.encode(&mut writer).err().unwrap();
146        assert_eq!(err.kind(), ErrorKind::Overlength);
147        assert_eq!(err.position(), Some(Length::ONE));
148    }
149}