1#[cfg(all(feature = "pkcs8", feature = "sec1"))]
9mod pkcs8;
10
11use crate::{Curve, Error, FieldBytes, Result, ScalarValue};
12use array::typenum::Unsigned;
13use common::{Generate, InvalidKey, KeySizeUser, TryKeyInit};
14use core::fmt::{self, Debug};
15use rand_core::{CryptoRng, TryCryptoRng};
16use subtle::{Choice, ConstantTimeEq, CtOption};
17use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
18
19#[cfg(feature = "ecdh")]
20use crate::ecdh;
21#[cfg(feature = "arithmetic")]
22use crate::{CurveArithmetic, NonZeroScalar, PublicKey};
23#[cfg(all(feature = "arithmetic", feature = "pem"))]
24use alloc::string::String;
25#[cfg(feature = "pem")]
26use pem_rfc7468::{self as pem, PemLabel};
27#[cfg(all(feature = "alloc", feature = "arithmetic", feature = "sec1"))]
28use {
29 crate::{
30 AffinePoint,
31 sec1::{FromSec1Point, ToSec1Point},
32 },
33 alloc::vec::Vec,
34 sec1::der::Encode,
35};
36#[cfg(feature = "sec1")]
37use {
38 crate::{
39 DecodeError, DecodeResult, FieldBytesSize,
40 sec1::{ModulusSize, Sec1Point, ValidatePublicKey},
41 },
42 sec1::der::{self, Decode, oid::AssociatedOid},
43};
44
45#[cfg(all(doc, feature = "pkcs8"))]
46use {crate::pkcs8::DecodePrivateKey, core::str::FromStr};
47
48#[derive(Clone)]
74pub struct SecretKey<C: Curve> {
75 inner: ScalarValue<C>,
77}
78
79impl<C> SecretKey<C>
80where
81 C: Curve,
82{
83 pub const MIN_SIZE: usize = 24;
87
88 pub fn from_scalar(scalar: impl Into<ScalarValue<C>>) -> CtOption<Self> {
94 let inner = scalar.into();
95 CtOption::new(Self { inner }, !inner.is_zero())
96 }
97
98 pub fn as_scalar_value(&self) -> &ScalarValue<C> {
106 &self.inner
107 }
108
109 #[cfg(feature = "arithmetic")]
117 pub fn to_nonzero_scalar(&self) -> NonZeroScalar<C>
118 where
119 C: CurveArithmetic,
120 {
121 self.into()
122 }
123
124 #[cfg(feature = "arithmetic")]
126 pub fn public_key(&self) -> PublicKey<C>
127 where
128 C: CurveArithmetic,
129 {
130 PublicKey::from_secret_scalar(&self.to_nonzero_scalar())
131 }
132
133 pub fn from_bytes(bytes: &FieldBytes<C>) -> Result<Self> {
138 let inner = ScalarValue::<C>::from_bytes(bytes)
139 .into_option()
140 .ok_or(Error)?;
141
142 if inner.is_zero().into() {
143 return Err(Error);
144 }
145
146 Ok(Self { inner })
147 }
148
149 pub fn from_slice(slice: &[u8]) -> Result<Self> {
164 if let Ok(field_bytes) = <&FieldBytes<C>>::try_from(slice) {
165 Self::from_bytes(field_bytes)
166 } else if (Self::MIN_SIZE..C::FieldBytesSize::USIZE).contains(&slice.len()) {
167 let mut bytes = Zeroizing::new(FieldBytes::<C>::default());
168 let offset = C::FieldBytesSize::USIZE.saturating_sub(slice.len());
169 bytes[offset..].copy_from_slice(slice);
170 Self::from_bytes(&bytes)
171 } else {
172 Err(Error)
173 }
174 }
175
176 pub fn to_bytes(&self) -> FieldBytes<C> {
178 self.inner.to_bytes()
179 }
180
181 #[cfg(feature = "ecdh")]
185 pub fn diffie_hellman(&self, public_key: &PublicKey<C>) -> ecdh::SharedSecret<C>
186 where
187 C: CurveArithmetic,
188 {
189 ecdh::diffie_hellman(self.to_nonzero_scalar(), public_key.as_affine())
190 }
191
192 #[cfg(any(feature = "pkcs8", feature = "sec1"))]
206 #[allow(clippy::missing_panics_doc, reason = "should not panic")]
207 pub fn from_der(der_bytes: &[u8]) -> DecodeResult<Self>
208 where
209 C: AssociatedOid + Curve + ValidatePublicKey,
210 FieldBytesSize<C>: ModulusSize,
211 {
212 #[allow(unused_assignments)]
213 let mut err: Option<DecodeError> = None;
214
215 #[cfg(feature = "pkcs8")]
216 match ::pkcs8::DecodePrivateKey::from_pkcs8_der(der_bytes) {
217 Ok(sk) => return Ok(sk),
218 Err(e) => err = Some(e.into()),
219 }
220
221 #[cfg(feature = "sec1")]
222 match Self::from_sec1_der(der_bytes) {
223 Ok(sk) => return Ok(sk),
224 Err(e) => {
225 let _ = err.get_or_insert(e);
227 }
228 }
229
230 Err(err.expect("should be set"))
231 }
232
233 #[cfg(feature = "pem")]
244 pub fn from_pem(pem: &str) -> DecodeResult<Self>
245 where
246 C: AssociatedOid + Curve + ValidatePublicKey,
247 FieldBytesSize<C>: ModulusSize,
248 {
249 let label = pem_rfc7468::decode_label(pem.as_bytes()).map_err(DecodeError::Pem)?;
250
251 if ::pkcs8::PrivateKeyInfoRef::validate_pem_label(label).is_ok() {
252 return ::pkcs8::DecodePrivateKey::from_pkcs8_pem(pem).map_err(DecodeError::Pkcs8);
253 } else if ::sec1::EcPrivateKey::validate_pem_label(label).is_ok() {
254 return ::sec1::DecodeEcPrivateKey::from_sec1_pem(pem).map_err(DecodeError::Sec1);
255 }
256
257 Err(pem_rfc7468::Error::Label.into())
258 }
259
260 #[cfg(feature = "sec1")]
266 pub fn from_sec1_der(der_bytes: &[u8]) -> DecodeResult<Self>
267 where
268 C: AssociatedOid + Curve + ValidatePublicKey,
269 FieldBytesSize<C>: ModulusSize,
270 {
271 let sec1_key = sec1::EcPrivateKey::try_from(der_bytes)?;
272 Self::try_from(sec1_key).map_err(|e| DecodeError::Sec1(e.into()))
273 }
274
275 #[cfg(all(feature = "alloc", feature = "arithmetic", feature = "sec1"))]
280 pub fn to_sec1_der(&self) -> der::Result<Zeroizing<Vec<u8>>>
281 where
282 C: AssociatedOid + CurveArithmetic,
283 AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
284 FieldBytesSize<C>: ModulusSize,
285 {
286 let private_key_bytes = Zeroizing::new(self.to_bytes());
287 let public_key_bytes = self.public_key().to_sec1_point(false);
288 let parameters = sec1::EcParameters::NamedCurve(C::OID);
289
290 let ec_private_key = Zeroizing::new(
291 sec1::EcPrivateKey {
292 private_key: &private_key_bytes,
293 parameters: Some(parameters),
294 public_key: Some(public_key_bytes.as_bytes()),
295 }
296 .to_der()?,
297 );
298
299 Ok(ec_private_key)
300 }
301
302 #[cfg(feature = "pem")]
314 pub fn from_sec1_pem(s: &str) -> DecodeResult<Self>
315 where
316 C: AssociatedOid + Curve + ValidatePublicKey,
317 FieldBytesSize<C>: ModulusSize,
318 {
319 let (label, der_bytes) = pem::decode_vec(s.as_bytes()).map_err(DecodeError::Pem)?;
320
321 if label != sec1::EcPrivateKey::PEM_LABEL {
322 return Err(pem_rfc7468::Error::Label.into());
323 }
324
325 Self::from_sec1_der(&der_bytes)
326 }
327
328 #[cfg(feature = "pem")]
336 pub fn to_sec1_pem(&self, line_ending: pem::LineEnding) -> Result<Zeroizing<String>>
337 where
338 C: AssociatedOid + CurveArithmetic,
339 AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
340 FieldBytesSize<C>: ModulusSize,
341 {
342 self.to_sec1_der()
343 .ok()
344 .and_then(|der| {
345 pem::encode_string(sec1::EcPrivateKey::PEM_LABEL, line_ending, &der).ok()
346 })
347 .map(Zeroizing::new)
348 .ok_or(Error)
349 }
350
351 #[deprecated(since = "0.14.0", note = "use the `Generate` trait instead")]
353 pub fn random<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
354 Self::generate_from_rng(rng)
355 }
356}
357
358impl<C> ConstantTimeEq for SecretKey<C>
359where
360 C: Curve,
361{
362 fn ct_eq(&self, other: &Self) -> Choice {
363 self.inner.ct_eq(&other.inner)
364 }
365}
366
367impl<C> Debug for SecretKey<C>
368where
369 C: Curve,
370{
371 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372 f.debug_struct(core::any::type_name::<Self>())
373 .finish_non_exhaustive()
374 }
375}
376
377impl<C> Drop for SecretKey<C>
378where
379 C: Curve,
380{
381 fn drop(&mut self) {
382 self.inner.zeroize();
383 }
384}
385impl<C> ZeroizeOnDrop for SecretKey<C> where C: Curve {}
386
387impl<C: Curve> Eq for SecretKey<C> {}
388
389impl<C> PartialEq for SecretKey<C>
390where
391 C: Curve,
392{
393 fn eq(&self, other: &Self) -> bool {
394 self.ct_eq(other).into()
395 }
396}
397
398impl<C> Generate for SecretKey<C>
399where
400 C: Curve,
401{
402 fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(
403 rng: &mut R,
404 ) -> core::result::Result<Self, R::Error> {
405 Ok(Self {
406 inner: ScalarValue::<C>::try_generate_from_rng(rng)?,
407 })
408 }
409}
410
411impl<C> KeySizeUser for SecretKey<C>
412where
413 C: Curve,
414{
415 type KeySize = C::FieldBytesSize;
416}
417
418impl<C> TryKeyInit for SecretKey<C>
419where
420 C: Curve,
421{
422 fn new(key_bytes: &FieldBytes<C>) -> core::result::Result<Self, InvalidKey> {
423 Self::from_bytes(key_bytes).map_err(|_| InvalidKey)
424 }
425}
426
427#[cfg(feature = "sec1")]
428impl<C> sec1::DecodeEcPrivateKey for SecretKey<C>
429where
430 C: AssociatedOid + Curve + ValidatePublicKey,
431 FieldBytesSize<C>: ModulusSize,
432{
433 fn from_sec1_der(bytes: &[u8]) -> sec1::Result<Self> {
434 Ok(sec1::EcPrivateKey::from_der(bytes)?.try_into()?)
435 }
436}
437
438#[cfg(all(feature = "alloc", feature = "arithmetic", feature = "sec1"))]
439impl<C> sec1::EncodeEcPrivateKey for SecretKey<C>
440where
441 C: AssociatedOid + CurveArithmetic,
442 AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
443 FieldBytesSize<C>: ModulusSize,
444{
445 fn to_sec1_der(&self) -> sec1::Result<der::SecretDocument> {
446 let private_key_bytes = Zeroizing::new(self.to_bytes());
447 let public_key_bytes = self.public_key().to_sec1_point(false);
448
449 Ok(der::SecretDocument::encode_msg(&sec1::EcPrivateKey {
450 private_key: &private_key_bytes,
451 parameters: Some(C::OID.into()),
452 public_key: Some(public_key_bytes.as_bytes()),
453 })?)
454 }
455}
456
457#[cfg(feature = "sec1")]
458impl<C> TryFrom<sec1::EcPrivateKey<'_>> for SecretKey<C>
459where
460 C: AssociatedOid + Curve + ValidatePublicKey,
461 FieldBytesSize<C>: ModulusSize,
462{
463 type Error = der::Error;
464
465 fn try_from(sec1_private_key: sec1::EcPrivateKey<'_>) -> der::Result<Self> {
466 if let Some(sec1::EcParameters::NamedCurve(curve_oid)) = sec1_private_key.parameters {
467 if C::OID != curve_oid {
468 return Err(der::Tag::ObjectIdentifier.value_error().into());
469 }
470 }
471
472 let secret_key = Self::from_slice(sec1_private_key.private_key)
473 .map_err(|_| der::Tag::OctetString.value_error())?;
474
475 if let Some(pk_bytes) = sec1_private_key.public_key {
476 let pk = Sec1Point::<C>::from_bytes(pk_bytes)
477 .map_err(|_| der::Tag::BitString.value_error())?;
478
479 if C::validate_public_key(&secret_key, &pk).is_err() {
480 return Err(der::Tag::BitString.value_error().into());
481 }
482 }
483
484 Ok(secret_key)
485 }
486}
487
488#[cfg(feature = "arithmetic")]
489impl<C> From<NonZeroScalar<C>> for SecretKey<C>
490where
491 C: CurveArithmetic,
492{
493 fn from(scalar: NonZeroScalar<C>) -> SecretKey<C> {
494 SecretKey::from(&scalar)
495 }
496}
497
498#[cfg(feature = "arithmetic")]
499impl<C> From<&NonZeroScalar<C>> for SecretKey<C>
500where
501 C: CurveArithmetic,
502{
503 fn from(scalar: &NonZeroScalar<C>) -> SecretKey<C> {
504 SecretKey {
505 inner: scalar.into(),
506 }
507 }
508}