hash2curve/hash2field/
expand_msg.rs1pub(super) mod xmd;
4pub(super) mod xof;
5
6use core::num::NonZero;
7
8use digest::{Digest, ExtendableOutput, Update, XofReader};
9use elliptic_curve::Error;
10use elliptic_curve::array::{Array, ArraySize};
11use xmd::ExpandMsgXmdError;
12use xof::ExpandMsgXofError;
13
14const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-";
16const MAX_DST_LEN: usize = 255;
18
19pub trait ExpandMsg<K> {
28 type Hash;
30 type Expander<'dst>: Expander + Sized;
32 type Error: core::error::Error;
34
35 fn expand_message<'dst>(
40 msg: &[&[u8]],
41 dst: &'dst [&[u8]],
42 len_in_bytes: NonZero<u16>,
43 ) -> Result<Self::Expander<'dst>, Self::Error>;
44}
45
46pub trait Expander {
48 fn fill_bytes(&mut self, okm: &mut [u8]) -> Result<usize, Error>;
54}
55
56#[derive(Debug)]
62pub(crate) enum Domain<'a, L: ArraySize> {
63 Hashed(Array<u8, L>),
65 Array(&'a [&'a [u8]]),
67}
68
69impl<'a, L: ArraySize> Domain<'a, L> {
70 pub fn xof<X>(dst: &'a [&'a [u8]]) -> Result<Self, ExpandMsgXofError>
71 where
72 X: Default + ExtendableOutput + Update,
73 {
74 let dst_len = dst.iter().map(|slice| slice.len()).sum::<usize>();
76
77 if dst_len == 0 {
78 Err(ExpandMsgXofError::EmptyDst)
79 } else if dst_len > MAX_DST_LEN {
80 if L::USIZE > u8::MAX.into() {
81 return Err(ExpandMsgXofError::DstSecurityLevel);
82 }
83 let mut data = Array::<u8, L>::default();
84 let mut hash = X::default();
85 hash.update(OVERSIZE_DST_SALT);
86
87 for slice in dst {
88 hash.update(slice);
89 }
90
91 hash.finalize_xof().read(&mut data);
92
93 Ok(Self::Hashed(data))
94 } else {
95 Ok(Self::Array(dst))
96 }
97 }
98
99 pub fn xmd<X>(dst: &'a [&'a [u8]]) -> Result<Self, ExpandMsgXmdError>
100 where
101 X: Digest<OutputSize = L>,
102 {
103 let dst_len = dst.iter().map(|slice| slice.len()).sum::<usize>();
105
106 if dst_len == 0 {
107 Err(ExpandMsgXmdError::EmptyDst)
108 } else if dst_len > MAX_DST_LEN {
109 if L::USIZE > u8::MAX.into() {
110 return Err(ExpandMsgXmdError::DstHash);
111 }
112 Ok(Self::Hashed({
113 let mut hash = X::new();
114 hash.update(OVERSIZE_DST_SALT);
115
116 for slice in dst {
117 hash.update(slice);
118 }
119
120 hash.finalize()
121 }))
122 } else {
123 Ok(Self::Array(dst))
124 }
125 }
126
127 pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
128 match self {
129 Self::Hashed(d) => hash.update(d),
130 Self::Array(d) => {
131 for d in d.iter() {
132 hash.update(d)
133 }
134 }
135 }
136 }
137
138 pub fn len(&self) -> u8 {
139 match self {
140 Self::Hashed(_) => L::U8,
142 Self::Array(d) => {
144 u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
145 }
146 }
147 }
148
149 #[cfg(test)]
150 pub fn assert(&self, bytes: &[u8]) {
151 let data = match self {
152 Domain::Hashed(d) => d.to_vec(),
153 Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
154 };
155 assert_eq!(data, bytes);
156 }
157
158 #[cfg(test)]
159 pub fn assert_dst(&self, bytes: &[u8]) {
160 let data = match self {
161 Domain::Hashed(d) => d.to_vec(),
162 Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
163 };
164 assert_eq!(data, &bytes[..bytes.len() - 1]);
165 assert_eq!(self.len(), bytes[bytes.len() - 1]);
166 }
167}