1use crate::{Error, Result};
4
5#[cfg(feature = "algorithm")]
6use {
7 crate::{
8 EcdsaCurve, Signature, SignatureSize, SigningKey, VerifyingKey,
9 hazmat::{DigestAlgorithm, bits2field, sign_prehashed_rfc6979, verify_prehashed},
10 },
11 digest::{Digest, FixedOutputReset, Update},
12 elliptic_curve::{
13 AffinePoint, CurveArithmetic, FieldBytes, FieldBytesSize, Group, PrimeField,
14 ProjectivePoint, Scalar,
15 array::ArraySize,
16 bigint::CheckedAdd,
17 field,
18 ops::Invert,
19 ops::{LinearCombination, Reduce},
20 point::DecompressPoint,
21 sec1::{self, FromSec1Point, ToSec1Point},
22 subtle::CtOption,
23 },
24 signature::{
25 DigestSigner, MultipartSigner, RandomizedDigestSigner, Signer,
26 hazmat::{PrehashSigner, RandomizedPrehashSigner},
27 rand_core::TryCryptoRng,
28 },
29};
30
31#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
45pub struct RecoveryId(pub(crate) u8);
46
47impl RecoveryId {
48 pub const MAX: u8 = 3;
50
51 pub const fn new(is_y_odd: bool, is_x_reduced: bool) -> Self {
56 Self(((is_x_reduced as u8) << 1) | (is_y_odd as u8))
57 }
58
59 pub const fn is_x_reduced(self) -> bool {
61 (self.0 & 0b10) != 0
62 }
63
64 pub const fn is_y_odd(self) -> bool {
66 (self.0 & 1) != 0
67 }
68
69 pub const fn from_byte(byte: u8) -> Option<Self> {
71 if byte <= Self::MAX {
72 Some(Self(byte))
73 } else {
74 None
75 }
76 }
77
78 pub const fn to_byte(self) -> u8 {
80 self.0
81 }
82}
83
84#[cfg(feature = "algorithm")]
85impl RecoveryId {
86 pub fn trial_recovery_from_msg<C>(
90 verifying_key: &VerifyingKey<C>,
91 msg: &[u8],
92 signature: &Signature<C>,
93 ) -> Result<Self>
94 where
95 C: EcdsaCurve + CurveArithmetic + DigestAlgorithm,
96 AffinePoint<C>: DecompressPoint<C> + FromSec1Point<C> + ToSec1Point<C>,
97 FieldBytesSize<C>: sec1::ModulusSize,
98 SignatureSize<C>: ArraySize,
99 {
100 Self::trial_recovery_from_digest(verifying_key, C::Digest::new_with_prefix(msg), signature)
101 }
102
103 pub fn trial_recovery_from_digest<C, D>(
107 verifying_key: &VerifyingKey<C>,
108 digest: D,
109 signature: &Signature<C>,
110 ) -> Result<Self>
111 where
112 C: EcdsaCurve + CurveArithmetic,
113 D: Digest,
114 AffinePoint<C>: DecompressPoint<C> + FromSec1Point<C> + ToSec1Point<C>,
115 FieldBytesSize<C>: sec1::ModulusSize,
116 SignatureSize<C>: ArraySize,
117 {
118 Self::trial_recovery_from_prehash(verifying_key, &digest.finalize(), signature)
119 }
120
121 pub fn trial_recovery_from_prehash<C>(
125 verifying_key: &VerifyingKey<C>,
126 prehash: &[u8],
127 signature: &Signature<C>,
128 ) -> Result<Self>
129 where
130 C: EcdsaCurve + CurveArithmetic,
131 AffinePoint<C>: DecompressPoint<C> + FromSec1Point<C> + ToSec1Point<C>,
132 FieldBytesSize<C>: sec1::ModulusSize,
133 SignatureSize<C>: ArraySize,
134 {
135 verify_prehashed::<C>(
137 &ProjectivePoint::<C>::from(*verifying_key.as_affine()),
138 &bits2field::<C>(prehash)?,
139 signature,
140 )?;
141
142 for id in 0..=Self::MAX {
143 let recovery_id = RecoveryId(id);
144
145 if let Ok(vk) = VerifyingKey::recover_from_prehash(prehash, signature, recovery_id) {
146 if verifying_key == &vk {
147 return Ok(recovery_id);
148 }
149 }
150 }
151
152 Err(Error::new())
153 }
154}
155
156impl TryFrom<u8> for RecoveryId {
157 type Error = Error;
158
159 fn try_from(byte: u8) -> Result<Self> {
160 Self::from_byte(byte).ok_or_else(Error::new)
161 }
162}
163
164impl From<RecoveryId> for u8 {
165 fn from(id: RecoveryId) -> u8 {
166 id.0
167 }
168}
169
170#[cfg(feature = "algorithm")]
171impl<C> SigningKey<C>
172where
173 C: EcdsaCurve + CurveArithmetic + DigestAlgorithm,
174 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
175 SignatureSize<C>: ArraySize,
176{
177 pub fn sign_prehash_recoverable_with_rng<R: TryCryptoRng + ?Sized>(
180 &self,
181 rng: &mut R,
182 prehash: &[u8],
183 ) -> Result<(Signature<C>, RecoveryId)> {
184 let z = bits2field::<C>(prehash)?;
185
186 loop {
187 let mut ad = FieldBytes::<C>::default();
188 rng.try_fill_bytes(&mut ad).map_err(|_| Error::new())?;
189
190 if let Ok(result) =
191 sign_prehashed_rfc6979::<C, C::Digest>(self.as_nonzero_scalar(), &z, &ad)
192 {
193 break Ok(result);
194 }
195 }
196 }
197
198 pub fn sign_prehash_recoverable(&self, prehash: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
200 let z = bits2field::<C>(prehash)?;
201 sign_prehashed_rfc6979::<C, C::Digest>(self.as_nonzero_scalar(), &z, &[])
202 }
203
204 pub fn sign_digest_recoverable<D>(&self, msg_digest: D) -> Result<(Signature<C>, RecoveryId)>
206 where
207 D: Digest,
208 {
209 self.sign_prehash_recoverable(&msg_digest.finalize())
210 }
211
212 pub fn sign_recoverable(&self, msg: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
215 self.sign_digest_recoverable(C::Digest::new_with_prefix(msg))
216 }
217}
218
219#[cfg(feature = "algorithm")]
220impl<C, D> DigestSigner<D, (Signature<C>, RecoveryId)> for SigningKey<C>
221where
222 C: EcdsaCurve + CurveArithmetic + DigestAlgorithm,
223 D: Digest + Update,
224 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
225 SignatureSize<C>: ArraySize,
226{
227 fn try_sign_digest<F: Fn(&mut D) -> Result<()>>(
228 &self,
229 f: F,
230 ) -> Result<(Signature<C>, RecoveryId)> {
231 let mut digest = D::new();
232 f(&mut digest)?;
233 self.sign_digest_recoverable(digest)
234 }
235}
236
237#[cfg(feature = "algorithm")]
238impl<C> RandomizedPrehashSigner<(Signature<C>, RecoveryId)> for SigningKey<C>
239where
240 C: EcdsaCurve + CurveArithmetic + DigestAlgorithm,
241 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
242 SignatureSize<C>: ArraySize,
243{
244 fn sign_prehash_with_rng<R: TryCryptoRng + ?Sized>(
245 &self,
246 rng: &mut R,
247 prehash: &[u8],
248 ) -> Result<(Signature<C>, RecoveryId)> {
249 self.sign_prehash_recoverable_with_rng(rng, prehash)
250 }
251}
252
253#[cfg(feature = "algorithm")]
254impl<C, D> RandomizedDigestSigner<D, (Signature<C>, RecoveryId)> for SigningKey<C>
255where
256 C: EcdsaCurve + CurveArithmetic + DigestAlgorithm,
257 D: Digest + Update + FixedOutputReset,
258 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
259 SignatureSize<C>: ArraySize,
260{
261 fn try_sign_digest_with_rng<R: TryCryptoRng + ?Sized, F: Fn(&mut D) -> Result<()>>(
262 &self,
263 rng: &mut R,
264 f: F,
265 ) -> Result<(Signature<C>, RecoveryId)> {
266 let mut digest = D::new();
267 f(&mut digest)?;
268 self.sign_prehash_with_rng(rng, &digest.finalize_reset())
269 }
270}
271
272#[cfg(feature = "algorithm")]
273impl<C> PrehashSigner<(Signature<C>, RecoveryId)> for SigningKey<C>
274where
275 C: EcdsaCurve + CurveArithmetic + DigestAlgorithm,
276 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
277 SignatureSize<C>: ArraySize,
278{
279 fn sign_prehash(&self, prehash: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
280 self.sign_prehash_recoverable(prehash)
281 }
282}
283
284#[cfg(feature = "algorithm")]
285impl<C> Signer<(Signature<C>, RecoveryId)> for SigningKey<C>
286where
287 C: EcdsaCurve + CurveArithmetic + DigestAlgorithm,
288 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
289 SignatureSize<C>: ArraySize,
290{
291 fn try_sign(&self, msg: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
292 self.try_multipart_sign(&[msg])
293 }
294}
295
296#[cfg(feature = "algorithm")]
297impl<C> MultipartSigner<(Signature<C>, RecoveryId)> for SigningKey<C>
298where
299 C: EcdsaCurve + CurveArithmetic + DigestAlgorithm,
300 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
301 SignatureSize<C>: ArraySize,
302{
303 fn try_multipart_sign(&self, msg: &[&[u8]]) -> Result<(Signature<C>, RecoveryId)> {
304 let mut digest = C::Digest::new();
305 msg.iter()
306 .for_each(|slice| Update::update(&mut digest, slice));
307 self.sign_digest_recoverable(digest)
308 }
309}
310
311#[cfg(feature = "algorithm")]
312impl<C> VerifyingKey<C>
313where
314 C: EcdsaCurve + CurveArithmetic,
315 AffinePoint<C>: DecompressPoint<C> + FromSec1Point<C> + ToSec1Point<C>,
316 FieldBytesSize<C>: sec1::ModulusSize,
317 SignatureSize<C>: ArraySize,
318{
319 pub fn recover_from_msg(
323 msg: &[u8],
324 signature: &Signature<C>,
325 recovery_id: RecoveryId,
326 ) -> Result<Self>
327 where
328 C: DigestAlgorithm,
329 {
330 Self::recover_from_digest(C::Digest::new_with_prefix(msg), signature, recovery_id)
331 }
332
333 pub fn recover_from_digest<D>(
335 msg_digest: D,
336 signature: &Signature<C>,
337 recovery_id: RecoveryId,
338 ) -> Result<Self>
339 where
340 D: Digest,
341 {
342 Self::recover_from_prehash(&msg_digest.finalize(), signature, recovery_id)
343 }
344
345 #[allow(non_snake_case)]
359 pub fn recover_from_prehash(
360 prehash: &[u8],
361 signature: &Signature<C>,
362 recovery_id: RecoveryId,
363 ) -> Result<Self> {
364 let (r, s) = signature.split_scalars();
365 let z = Scalar::<C>::reduce(&bits2field::<C>(prehash)?);
366
367 let r_bytes = if recovery_id.is_x_reduced() {
368 let uint = field::bytes_to_uint::<C>(&r.to_repr())
369 .checked_add(&C::ORDER)
370 .into_option()
371 .ok_or_else(Error::new)?;
372
373 field::uint_to_bytes::<C>(&uint)
374 } else {
375 r.to_repr()
376 };
377
378 let R: ProjectivePoint<C> =
379 AffinePoint::<C>::decompress(&r_bytes, u8::from(recovery_id.is_y_odd()).into())
380 .into_option()
381 .ok_or_else(Error::new)?
382 .into();
383
384 let r_inv = *r.invert();
385 let u1 = -(r_inv * z);
386 let u2 = r_inv * *s;
387 let pk = ProjectivePoint::<C>::lincomb(&[(ProjectivePoint::<C>::generator(), u1), (R, u2)]);
388 Self::from_affine(pk.into())
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::RecoveryId;
395
396 #[test]
397 fn new() {
398 assert_eq!(RecoveryId::new(false, false).to_byte(), 0);
399 assert_eq!(RecoveryId::new(true, false).to_byte(), 1);
400 assert_eq!(RecoveryId::new(false, true).to_byte(), 2);
401 assert_eq!(RecoveryId::new(true, true).to_byte(), 3);
402 }
403
404 #[test]
405 fn try_from() {
406 for n in 0u8..=3 {
407 assert_eq!(RecoveryId::try_from(n).expect("RecoveryId").to_byte(), n);
408 }
409
410 for n in 4u8..=255 {
411 assert!(RecoveryId::try_from(n).is_err());
412 }
413 }
414
415 #[test]
416 fn is_x_reduced() {
417 assert!(!RecoveryId::try_from(0).expect("RecoveryId").is_x_reduced());
418 assert!(!RecoveryId::try_from(1).expect("RecoveryId").is_x_reduced());
419 assert!(RecoveryId::try_from(2).expect("RecoveryId").is_x_reduced());
420 assert!(RecoveryId::try_from(3).expect("RecoveryId").is_x_reduced());
421 }
422
423 #[test]
424 fn is_y_odd() {
425 assert!(!RecoveryId::try_from(0).expect("RecoveryId").is_y_odd());
426 assert!(RecoveryId::try_from(1).expect("RecoveryId").is_y_odd());
427 assert!(!RecoveryId::try_from(2).expect("RecoveryId").is_y_odd());
428 assert!(RecoveryId::try_from(3).expect("RecoveryId").is_y_odd());
429 }
430}