Skip to main content

elliptic_curve/scalar/
nonzero.rs

1//! Non-zero scalar type.
2
3use crate::{
4    CurveArithmetic, Error, FieldBytes, PrimeCurve, Scalar, ScalarValue, SecretKey,
5    ops::{Invert, Reduce, ReduceNonZero},
6    point::NonIdentity,
7    scalar::IsHigh,
8};
9use base16ct::HexDisplay;
10use common::Generate;
11use core::{
12    fmt,
13    ops::{Deref, Mul, MulAssign, Neg},
14    str,
15};
16use ff::{Field, PrimeField};
17use rand_core::{CryptoRng, TryCryptoRng};
18use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
19use zeroize::Zeroize;
20
21#[cfg(feature = "serde")]
22use serdect::serde::{Deserialize, Serialize, de, ser};
23
24/// Non-zero scalar type.
25///
26/// This type ensures that its value is not zero, ala `core::num::NonZero*`.
27/// To do this, the generic `S` type must impl both `Default` and
28/// `ConstantTimeEq`, with the requirement that `S::default()` returns 0.
29///
30/// In the context of ECC, it's useful for ensuring that scalar multiplication
31/// cannot result in the point at infinity.
32#[derive(Clone)]
33// `repr` is needed for `unsafe` safety invariants below
34#[repr(transparent)]
35pub struct NonZeroScalar<C>
36where
37    C: CurveArithmetic,
38{
39    scalar: Scalar<C>,
40}
41
42impl<C: CurveArithmetic> fmt::Debug for NonZeroScalar<C> {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        f.debug_struct("NonZeroScalar").finish_non_exhaustive()
45    }
46}
47
48impl<C> NonZeroScalar<C>
49where
50    C: CurveArithmetic,
51{
52    /// Create a [`NonZeroScalar`] from a scalar.
53    pub fn new(scalar: Scalar<C>) -> CtOption<Self> {
54        CtOption::new(Self { scalar }, !scalar.is_zero())
55    }
56
57    /// Decode a [`NonZeroScalar`] from a big endian-serialized field element.
58    pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> {
59        Scalar::<C>::from_repr(repr).and_then(Self::new)
60    }
61
62    /// Create a [`NonZeroScalar`] from a `C::Uint`.
63    pub fn from_uint(uint: C::Uint) -> CtOption<Self> {
64        ScalarValue::new(uint).and_then(|scalar| Self::new(scalar.into()))
65    }
66
67    /// Transform array reference containing [`NonZeroScalar`]s to an array reference to the inner
68    /// scalar type.
69    pub fn cast_array_as_inner<const N: usize>(scalars: &[Self; N]) -> &[Scalar<C>; N] {
70        // SAFETY: `NonZeroScalar` is a `repr(transparent)` newtype for `Scalar<C>` so it's safe to
71        // cast to the inner scalar type.
72        #[allow(unsafe_code)]
73        unsafe {
74            &*scalars.as_ptr().cast()
75        }
76    }
77
78    /// Transform slice containing [`NonZeroScalar`]s to a slice of the inner scalar type.
79    pub fn cast_slice_as_inner(scalars: &[Self]) -> &[Scalar<C>] {
80        // SAFETY: `NonZeroScalar` is a `repr(transparent)` newtype for `Scalar<C>` so it's safe to
81        // cast to the inner scalar type.
82        #[allow(unsafe_code)]
83        unsafe {
84            &*(core::ptr::from_ref(scalars) as *const [Scalar<C>])
85        }
86    }
87
88    /// Deprecated: Generate a random [`NonZeroScalar`].
89    #[deprecated(since = "0.14.0", note = "use the `Generate` trait instead")]
90    pub fn random<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
91        Self::generate_from_rng(rng)
92    }
93}
94
95impl<C> AsRef<Scalar<C>> for NonZeroScalar<C>
96where
97    C: CurveArithmetic,
98{
99    fn as_ref(&self) -> &Scalar<C> {
100        &self.scalar
101    }
102}
103
104impl<C> ConditionallySelectable for NonZeroScalar<C>
105where
106    C: CurveArithmetic,
107{
108    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
109        Self {
110            scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice),
111        }
112    }
113}
114
115impl<C> ConstantTimeEq for NonZeroScalar<C>
116where
117    C: CurveArithmetic,
118{
119    fn ct_eq(&self, other: &Self) -> Choice {
120        self.scalar.ct_eq(&other.scalar)
121    }
122}
123
124impl<C> Copy for NonZeroScalar<C> where C: CurveArithmetic {}
125
126impl<C> Deref for NonZeroScalar<C>
127where
128    C: CurveArithmetic,
129{
130    type Target = Scalar<C>;
131
132    fn deref(&self) -> &Scalar<C> {
133        &self.scalar
134    }
135}
136
137impl<C> Eq for NonZeroScalar<C> where C: CurveArithmetic {}
138
139impl<C> From<NonZeroScalar<C>> for FieldBytes<C>
140where
141    C: CurveArithmetic,
142{
143    fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> {
144        Self::from(&scalar)
145    }
146}
147
148impl<C> From<&NonZeroScalar<C>> for FieldBytes<C>
149where
150    C: CurveArithmetic,
151{
152    fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> {
153        scalar.to_repr()
154    }
155}
156
157impl<C> From<NonZeroScalar<C>> for ScalarValue<C>
158where
159    C: CurveArithmetic,
160{
161    #[inline]
162    fn from(scalar: NonZeroScalar<C>) -> ScalarValue<C> {
163        Self::from(&scalar)
164    }
165}
166
167impl<C> From<&NonZeroScalar<C>> for ScalarValue<C>
168where
169    C: CurveArithmetic,
170{
171    fn from(scalar: &NonZeroScalar<C>) -> ScalarValue<C> {
172        scalar.scalar.into()
173    }
174}
175
176impl<C> From<SecretKey<C>> for NonZeroScalar<C>
177where
178    C: CurveArithmetic,
179{
180    fn from(sk: SecretKey<C>) -> NonZeroScalar<C> {
181        Self::from(&sk)
182    }
183}
184
185impl<C> From<&SecretKey<C>> for NonZeroScalar<C>
186where
187    C: CurveArithmetic,
188{
189    fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> {
190        let scalar = sk.as_scalar_value().to_scalar();
191        debug_assert!(!bool::from(scalar.is_zero()));
192        Self { scalar }
193    }
194}
195
196impl<C> Generate for NonZeroScalar<C>
197where
198    C: CurveArithmetic,
199{
200    fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
201        // Use rejection sampling to eliminate zero values.
202        // While this method isn't constant-time, the attacker shouldn't learn
203        // anything about unrelated outputs so long as `rng` is a secure `CryptoRng`.
204        loop {
205            if let Some(result) = Self::new(Scalar::<C>::try_generate_from_rng(rng)?).into() {
206                break Ok(result);
207            }
208        }
209    }
210}
211
212impl<C> Invert for NonZeroScalar<C>
213where
214    C: CurveArithmetic,
215    Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
216{
217    type Output = Self;
218
219    fn invert(&self) -> Self {
220        Self {
221            // This will always succeed since `scalar` will never be 0
222            scalar: Invert::invert(&self.scalar).unwrap(),
223        }
224    }
225
226    fn invert_vartime(&self) -> Self::Output {
227        Self {
228            // This will always succeed since `scalar` will never be 0
229            scalar: Invert::invert_vartime(&self.scalar).unwrap(),
230        }
231    }
232}
233
234impl<C> IsHigh for NonZeroScalar<C>
235where
236    C: CurveArithmetic,
237{
238    fn is_high(&self) -> Choice {
239        self.scalar.is_high()
240    }
241}
242
243impl<C> Neg for NonZeroScalar<C>
244where
245    C: CurveArithmetic,
246{
247    type Output = NonZeroScalar<C>;
248
249    fn neg(self) -> NonZeroScalar<C> {
250        let scalar = -self.scalar;
251        debug_assert!(!bool::from(scalar.is_zero()));
252        NonZeroScalar { scalar }
253    }
254}
255
256impl<C> Mul<NonZeroScalar<C>> for NonZeroScalar<C>
257where
258    C: PrimeCurve + CurveArithmetic,
259{
260    type Output = Self;
261
262    #[inline]
263    fn mul(self, other: Self) -> Self {
264        Self::mul(self, &other)
265    }
266}
267
268impl<C> Mul<&NonZeroScalar<C>> for NonZeroScalar<C>
269where
270    C: PrimeCurve + CurveArithmetic,
271{
272    type Output = Self;
273
274    fn mul(self, other: &Self) -> Self {
275        // Multiplication is modulo a prime, so the product of two non-zero
276        // scalars is also non-zero.
277        let scalar = self.scalar * other.scalar;
278        debug_assert!(!bool::from(scalar.is_zero()));
279        NonZeroScalar { scalar }
280    }
281}
282
283impl<C, P> Mul<NonIdentity<P>> for NonZeroScalar<C>
284where
285    C: CurveArithmetic,
286    NonIdentity<P>: Mul<NonZeroScalar<C>, Output = NonIdentity<P>>,
287{
288    type Output = NonIdentity<P>;
289
290    fn mul(self, rhs: NonIdentity<P>) -> Self::Output {
291        rhs * self
292    }
293}
294
295impl<C, P> Mul<&NonIdentity<P>> for NonZeroScalar<C>
296where
297    C: CurveArithmetic,
298    for<'a> &'a NonIdentity<P>: Mul<NonZeroScalar<C>, Output = NonIdentity<P>>,
299{
300    type Output = NonIdentity<P>;
301
302    fn mul(self, rhs: &NonIdentity<P>) -> Self::Output {
303        rhs * self
304    }
305}
306
307impl<C, P> Mul<NonIdentity<P>> for &NonZeroScalar<C>
308where
309    C: CurveArithmetic,
310    for<'a> NonIdentity<P>: Mul<&'a NonZeroScalar<C>, Output = NonIdentity<P>>,
311{
312    type Output = NonIdentity<P>;
313
314    fn mul(self, rhs: NonIdentity<P>) -> Self::Output {
315        rhs * self
316    }
317}
318
319impl<C, P> Mul<&NonIdentity<P>> for &NonZeroScalar<C>
320where
321    C: CurveArithmetic,
322    for<'a> &'a NonIdentity<P>: Mul<&'a NonZeroScalar<C>, Output = NonIdentity<P>>,
323{
324    type Output = NonIdentity<P>;
325
326    fn mul(self, rhs: &NonIdentity<P>) -> Self::Output {
327        rhs * self
328    }
329}
330
331impl<C> MulAssign for NonZeroScalar<C>
332where
333    C: PrimeCurve + CurveArithmetic,
334{
335    fn mul_assign(&mut self, rhs: Self) {
336        *self = *self * rhs;
337    }
338}
339
340impl<C> PartialEq for NonZeroScalar<C>
341where
342    C: CurveArithmetic,
343{
344    fn eq(&self, other: &Self) -> bool {
345        self.scalar.eq(&other.scalar)
346    }
347}
348
349impl<C, T> Reduce<T> for NonZeroScalar<C>
350where
351    C: CurveArithmetic,
352    Scalar<C>: ReduceNonZero<T>,
353{
354    fn reduce(n: &T) -> Self {
355        <Self as ReduceNonZero<T>>::reduce_nonzero(n)
356    }
357}
358
359impl<C, T> ReduceNonZero<T> for NonZeroScalar<C>
360where
361    C: CurveArithmetic,
362    Scalar<C>: ReduceNonZero<T>,
363{
364    fn reduce_nonzero(n: &T) -> Self {
365        let scalar = Scalar::<C>::reduce_nonzero(n);
366        debug_assert!(!bool::from(scalar.is_zero()));
367        Self { scalar }
368    }
369}
370
371impl<C> TryFrom<&[u8]> for NonZeroScalar<C>
372where
373    C: CurveArithmetic,
374{
375    type Error = Error;
376
377    fn try_from(bytes: &[u8]) -> Result<Self, Error> {
378        NonZeroScalar::from_repr(bytes.try_into()?)
379            .into_option()
380            .ok_or(Error)
381    }
382}
383
384impl<C> Zeroize for NonZeroScalar<C>
385where
386    C: CurveArithmetic,
387{
388    fn zeroize(&mut self) {
389        // Use zeroize's volatile writes to ensure value is cleared.
390        self.scalar.zeroize();
391
392        // Write a 1 instead of a 0 to ensure this type's non-zero invariant
393        // is upheld.
394        self.scalar = Scalar::<C>::ONE;
395    }
396}
397
398impl<C> fmt::Display for NonZeroScalar<C>
399where
400    C: CurveArithmetic,
401{
402    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403        write!(f, "{self:X}")
404    }
405}
406
407impl<C> fmt::LowerHex for NonZeroScalar<C>
408where
409    C: CurveArithmetic,
410{
411    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
412        write!(f, "{:x}", HexDisplay(&self.to_repr()))
413    }
414}
415
416impl<C> fmt::UpperHex for NonZeroScalar<C>
417where
418    C: CurveArithmetic,
419{
420    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
421        write!(f, "{:}", HexDisplay(&self.to_repr()))
422    }
423}
424
425impl<C> str::FromStr for NonZeroScalar<C>
426where
427    C: CurveArithmetic,
428{
429    type Err = Error;
430
431    fn from_str(hex: &str) -> Result<Self, Error> {
432        let mut bytes = FieldBytes::<C>::default();
433
434        if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() {
435            Self::from_repr(bytes).into_option().ok_or(Error)
436        } else {
437            Err(Error)
438        }
439    }
440}
441
442#[cfg(feature = "serde")]
443impl<C> Serialize for NonZeroScalar<C>
444where
445    C: CurveArithmetic,
446{
447    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
448    where
449        S: ser::Serializer,
450    {
451        ScalarValue::from(self).serialize(serializer)
452    }
453}
454
455#[cfg(feature = "serde")]
456impl<'de, C> Deserialize<'de> for NonZeroScalar<C>
457where
458    C: CurveArithmetic,
459{
460    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
461    where
462        D: de::Deserializer<'de>,
463    {
464        let scalar = ScalarValue::deserialize(deserializer)?;
465        Self::new(scalar.into())
466            .into_option()
467            .ok_or_else(|| de::Error::custom("expected non-zero scalar"))
468    }
469}
470
471#[cfg(all(test, feature = "dev"))]
472mod tests {
473    use crate::dev::{NonZeroScalar, Scalar};
474    use ff::{Field, PrimeField};
475    use hex_literal::hex;
476    use zeroize::Zeroize;
477
478    #[test]
479    fn round_trip() {
480        let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
481        let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap();
482        assert_eq!(&bytes, scalar.to_repr().as_slice());
483    }
484
485    #[test]
486    fn zeroize() {
487        let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap();
488        scalar.zeroize();
489        assert_eq!(*scalar, Scalar::ONE);
490    }
491}