Skip to main content

hash2curve/hash2field/
expand_msg.rs

1//! `expand_message` interface `for hash_to_field`.
2
3pub(super) mod xmd;
4pub(super) mod xof;
5
6use core::num::NonZero;
7
8use digest::{Digest, ExtendableOutput, Update, XofReader};
9use elliptic_curve::Error;
10use elliptic_curve::array::{Array, ArraySize};
11use xmd::ExpandMsgXmdError;
12use xof::ExpandMsgXofError;
13
14/// Salt when the DST is too long
15const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-";
16/// Maximum domain separation tag length
17const MAX_DST_LEN: usize = 255;
18
19/// Trait for types implementing expand_message interface for `hash_to_field`.
20///
21/// `K` is the target security level in bytes:
22/// <https://www.rfc-editor.org/rfc/rfc9380.html#section-8.9-2.2>
23/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-target-security-levels>
24///
25/// # Errors
26/// See implementors of [`ExpandMsg`] for errors.
27pub trait ExpandMsg<K> {
28    /// The hash used by this implementation.
29    type Hash;
30    /// Type holding data for the [`Expander`].
31    type Expander<'dst>: Expander + Sized;
32    /// Error returned by [`ExpandMsg::expand_message`].
33    type Error: core::error::Error;
34
35    /// Expands `msg` to the required number of bytes.
36    ///
37    /// Returns an expander that can be used to call `read` until enough
38    /// bytes have been consumed
39    fn expand_message<'dst>(
40        msg: &[&[u8]],
41        dst: &'dst [&[u8]],
42        len_in_bytes: NonZero<u16>,
43    ) -> Result<Self::Expander<'dst>, Self::Error>;
44}
45
46/// Expander that, call `read` until enough bytes have been consumed.
47pub trait Expander {
48    /// Fill the array with the expanded bytes, returning how many bytes were read.
49    ///
50    /// # Errors
51    ///
52    /// If no bytes are left.
53    fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error>;
54}
55
56/// The domain separation tag
57///
58/// Implements [section 5.3.3 of RFC9380][dst].
59///
60/// [dst]: https://www.rfc-editor.org/rfc/rfc9380.html#name-using-dsts-longer-than-255-
61#[derive(Debug)]
62pub(crate) enum Domain<'a, L: ArraySize> {
63    /// > 255
64    Hashed(Array<u8, L>),
65    /// <= 255
66    Array(&'a [&'a [u8]]),
67}
68
69impl<'a, L: ArraySize> Domain<'a, L> {
70    pub fn xof<X>(dst: &'a [&'a [u8]]) -> Result<Self, ExpandMsgXofError>
71    where
72        X: Default + ExtendableOutput + Update,
73    {
74        // https://www.rfc-editor.org/rfc/rfc9380.html#section-3.1-4.2
75        let dst_len = dst.iter().map(|slice| slice.len()).sum::<usize>();
76
77        if dst_len == 0 {
78            Err(ExpandMsgXofError::EmptyDst)
79        } else if dst_len > MAX_DST_LEN {
80            if L::USIZE > u8::MAX.into() {
81                return Err(ExpandMsgXofError::DstSecurityLevel);
82            }
83            let mut data = Array::<u8, L>::default();
84            let mut hash = X::default();
85            hash.update(OVERSIZE_DST_SALT);
86
87            for slice in dst {
88                hash.update(slice);
89            }
90
91            hash.finalize_xof().read(&mut data);
92
93            Ok(Self::Hashed(data))
94        } else {
95            Ok(Self::Array(dst))
96        }
97    }
98
99    pub fn xmd<X>(dst: &'a [&'a [u8]]) -> Result<Self, ExpandMsgXmdError>
100    where
101        X: Digest<OutputSize = L>,
102    {
103        // https://www.rfc-editor.org/rfc/rfc9380.html#section-3.1-4.2
104        let dst_len = dst.iter().map(|slice| slice.len()).sum::<usize>();
105
106        if dst_len == 0 {
107            Err(ExpandMsgXmdError::EmptyDst)
108        } else if dst_len > MAX_DST_LEN {
109            if L::USIZE > u8::MAX.into() {
110                return Err(ExpandMsgXmdError::DstHash);
111            }
112            Ok(Self::Hashed({
113                let mut hash = X::new();
114                hash.update(OVERSIZE_DST_SALT);
115
116                for slice in dst {
117                    hash.update(slice);
118                }
119
120                hash.finalize()
121            }))
122        } else {
123            Ok(Self::Array(dst))
124        }
125    }
126
127    pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
128        match self {
129            Self::Hashed(d) => hash.update(d),
130            Self::Array(d) => {
131                for d in d.iter() {
132                    hash.update(d)
133                }
134            }
135        }
136    }
137
138    pub fn len(&self) -> u8 {
139        match self {
140            // Can't overflow because it's checked on creation.
141            Self::Hashed(_) => L::U8,
142            // Can't overflow because it's checked on creation.
143            Self::Array(d) => {
144                u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
145            }
146        }
147    }
148
149    #[cfg(test)]
150    pub fn assert(&self, bytes: &[u8]) {
151        let data = match self {
152            Domain::Hashed(d) => d.to_vec(),
153            Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
154        };
155        assert_eq!(data, bytes);
156    }
157
158    #[cfg(test)]
159    pub fn assert_dst(&self, bytes: &[u8]) {
160        let data = match self {
161            Domain::Hashed(d) => d.to_vec(),
162            Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
163        };
164        assert_eq!(data, &bytes[..bytes.len() - 1]);
165        assert_eq!(self.len(), bytes[bytes.len() - 1]);
166    }
167}