Skip to main content

p256/arithmetic/
scalar.rs

1//! Scalar field arithmetic modulo n = 115792089210356248762697446949407573529996955224135760342422259061068512044369
2
3use self::scalar_impl::barrett_reduce;
4use crate::{FieldBytes, NistP256, ORDER_HEX};
5use core::{
6    fmt::{self, Debug},
7    iter::{Product, Sum},
8    ops::{Add, AddAssign, Mul, MulAssign, Neg, Shr, ShrAssign, Sub, SubAssign},
9};
10use elliptic_curve::{
11    Curve, Generate,
12    bigint::{ArrayEncoding, Limb, Odd, U256, Uint, cpubits, modular::Retrieve},
13    ctutils,
14    group::ff::{self, Field, FromUniformBytes, PrimeField},
15    ops::{Invert, Reduce, ReduceNonZero},
16    rand_core::TryRng,
17    scalar::{FromUintUnchecked, IsHigh},
18    subtle::{
19        Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
20        CtOption,
21    },
22    zeroize::DefaultIsZeroes,
23};
24use primefield::{FieldExt, PrimeFieldExt};
25
26cpubits! {
27    32 => {
28        #[path = "scalar/scalar32.rs"]
29        mod scalar_impl;
30    }
31    64 => {
32        #[path = "scalar/scalar64.rs"]
33        mod scalar_impl;
34    }
35}
36
37#[cfg(feature = "serde")]
38use {
39    elliptic_curve::ScalarValue,
40    serdect::serde::{Deserialize, Serialize, de, ser},
41};
42
43/// Constant representing the modulus
44/// n = FFFFFFFF 00000000 FFFFFFFF FFFFFFFF BCE6FAAD A7179E84 F3B9CAC2 FC632551
45pub(crate) const MODULUS: Odd<U256> = NistP256::ORDER;
46
47/// `MODULUS / 2`
48const FRAC_MODULUS_2: Scalar = Scalar(MODULUS.as_ref().shr_vartime(1));
49
50#[doc = primefield::monty_field_element_doc!("Scalars are elements in the finite field modulo n.")]
51#[derive(Clone, Copy, Default)]
52pub struct Scalar(pub(crate) U256);
53
54impl Scalar {
55    /// Zero scalar.
56    pub const ZERO: Self = Self(U256::ZERO);
57
58    /// Multiplicative identity.
59    pub const ONE: Self = Self(U256::ONE);
60
61    /// Returns the SEC1 encoding of this scalar.
62    pub fn to_bytes(&self) -> FieldBytes {
63        self.0.to_be_byte_array()
64    }
65
66    /// Returns self + rhs mod n
67    pub const fn add(&self, rhs: &Self) -> Self {
68        Self(self.0.add_mod(&rhs.0, NistP256::ORDER.as_nz_ref()))
69    }
70
71    /// Returns 2*self.
72    pub const fn double(&self) -> Self {
73        self.add(self)
74    }
75
76    /// Returns self - rhs mod n.
77    pub const fn sub(&self, rhs: &Self) -> Self {
78        Self(self.0.sub_mod(&rhs.0, NistP256::ORDER.as_nz_ref()))
79    }
80
81    /// Returns self * rhs mod n
82    pub const fn multiply(&self, rhs: &Self) -> Self {
83        let (lo, hi) = self.0.widening_mul(&rhs.0);
84        Self(barrett_reduce(lo, hi))
85    }
86
87    /// Returns self * self mod p
88    pub const fn square(&self) -> Self {
89        // Schoolbook multiplication.
90        self.multiply(self)
91    }
92
93    /// Right shifts the scalar.
94    ///
95    /// Note: not constant-time with respect to the `shift` parameter.
96    pub const fn shr_vartime(&self, shift: u32) -> Scalar {
97        Self(self.0.unbounded_shr_vartime(shift))
98    }
99
100    /// Compute [`FieldElement`] inversion: `1 / self`.
101    pub fn invert(&self) -> CtOption<Self> {
102        self.0
103            .invert_odd_mod(const { &Odd::from_be_hex(ORDER_HEX) })
104            .map(Self)
105            .into()
106    }
107
108    /// Compute [`FieldElement`] inversion: `1 / self` in variable-time.
109    pub fn invert_vartime(&self) -> CtOption<Self> {
110        self.0
111            .invert_odd_mod_vartime(const { &Odd::from_be_hex(ORDER_HEX) })
112            .map(Self)
113            .into()
114    }
115
116    /// Returns the multiplicative inverse of self.
117    ///
118    /// # Panics
119    /// Will panic in the event `self` is zero
120    const fn invert_unwrap(&self) -> Self {
121        Self(
122            self.0
123                .invert_odd_mod(const { &Odd::from_be_hex(ORDER_HEX) })
124                .expect_copied("input should be non-zero"),
125        )
126    }
127
128    /// Returns `self^exp`, where `exp` is a little-endian integer exponent.
129    ///
130    /// **This operation is variable time with respect to the exponent `exp`.**
131    ///
132    /// If the exponent is fixed, this operation is constant time.
133    pub const fn pow_vartime<const RHS_LIMBS: usize>(&self, exp: &Uint<RHS_LIMBS>) -> Self {
134        let mut res = Self::ONE;
135        let mut i = RHS_LIMBS;
136
137        while i > 0 {
138            i -= 1;
139
140            let mut j = Limb::BITS;
141            while j > 0 {
142                j -= 1;
143                res = res.square();
144
145                if ((exp.as_limbs()[i].0 >> j) & 1) == 1 {
146                    res = res.multiply(self);
147                }
148            }
149        }
150
151        res
152    }
153
154    /// Returns `self^(2^n) mod p`.
155    ///
156    /// **This operation is variable time with respect to the exponent `n`.**
157    ///
158    /// If the exponent is fixed, this operation is constant time.
159    pub const fn sqn_vartime(&self, n: usize) -> Self {
160        let mut x = *self;
161        let mut i = 0;
162        while i < n {
163            x = x.square();
164            i += 1;
165        }
166        x
167    }
168
169    /// Is integer representing equivalence class odd?
170    pub fn is_odd(&self) -> Choice {
171        self.0.is_odd().into()
172    }
173
174    /// Is integer representing equivalence class even?
175    pub fn is_even(&self) -> Choice {
176        !self.is_odd()
177    }
178}
179
180elliptic_curve::scalar_impls!(NistP256, Scalar);
181
182impl AsRef<Scalar> for Scalar {
183    fn as_ref(&self) -> &Scalar {
184        self
185    }
186}
187
188impl Field for Scalar {
189    const ZERO: Self = Self::ZERO;
190    const ONE: Self = Self::ONE;
191
192    fn try_random<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
193        let mut bytes = FieldBytes::default();
194
195        // Generate a uniformly random scalar using rejection sampling,
196        // which produces a uniformly random distribution of scalars.
197        //
198        // This method is not constant time, but should be secure so long as
199        // rejected RNG outputs are unrelated to future ones (which is a
200        // necessary property of a `CryptoRng`).
201        //
202        // With an unbiased RNG, the probability of failing to complete after 4
203        // iterations is vanishingly small.
204        loop {
205            rng.try_fill_bytes(&mut bytes)?;
206            if let Some(scalar) = Scalar::from_repr(bytes).into() {
207                return Ok(scalar);
208            }
209        }
210    }
211
212    fn square(&self) -> Self {
213        Scalar::square(self)
214    }
215
216    fn double(&self) -> Self {
217        self.add(self)
218    }
219
220    fn invert(&self) -> CtOption<Self> {
221        Scalar::invert(self)
222    }
223
224    /// Tonelli-Shank's algorithm for q mod 16 = 1
225    /// <https://eprint.iacr.org/2012/685.pdf> (page 12, algorithm 5)
226    #[allow(clippy::many_single_char_names)]
227    fn sqrt(&self) -> CtOption<Self> {
228        const EXP: U256 =
229            U256::from_be_hex("07fffffff800000007fffffffffffffffde737d56d38bcf4279dce5617e3192a");
230
231        // Note: `pow_vartime` is constant-time with respect to `self`
232        let w = self.pow_vartime(&EXP);
233
234        let mut v = Self::S;
235        let mut x = *self * w;
236        let mut b = x * w;
237        let mut z = Self::ROOT_OF_UNITY;
238
239        for max_v in (1..=Self::S).rev() {
240            let mut k = 1;
241            let mut tmp = b.square();
242            let mut j_less_than_v = Choice::from(1);
243
244            for j in 2..max_v {
245                let tmp_is_one = tmp.ct_eq(&Self::ONE);
246                let squared = Self::conditional_select(&tmp, &z, tmp_is_one).square();
247                tmp = Self::conditional_select(&squared, &tmp, tmp_is_one);
248                let new_z = Self::conditional_select(&z, &squared, tmp_is_one);
249                j_less_than_v &= !ConstantTimeEq::ct_eq(&j, &v);
250                k = u32::conditional_select(&j, &k, tmp_is_one);
251                z = Self::conditional_select(&z, &new_z, j_less_than_v);
252            }
253
254            let result = x * z;
255            x = Self::conditional_select(&result, &x, b.ct_eq(&Self::ONE));
256            z = z.square();
257            b *= z;
258            v = k;
259        }
260
261        CtOption::new(x, x.square().ct_eq(self))
262    }
263
264    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
265        ff::helpers::sqrt_ratio_generic(num, div)
266    }
267}
268
269impl Generate for Scalar {
270    fn try_generate_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
271        Self::try_random(rng)
272    }
273}
274
275impl PrimeField for Scalar {
276    type Repr = FieldBytes;
277
278    const MODULUS: &'static str = ORDER_HEX;
279    const NUM_BITS: u32 = 256;
280    const CAPACITY: u32 = 255;
281    const TWO_INV: Self = Self(U256::from_u8(2)).invert_unwrap();
282    const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_u8(7));
283    const S: u32 = 4;
284    const ROOT_OF_UNITY: Self = Self(U256::from_be_hex(
285        "ffc97f062a770992ba807ace842a3dfc1546cad004378daf0592d7fbb41e6602",
286    ));
287    const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unwrap();
288    const DELTA: Self = Self(U256::from_u64(33232930569601));
289
290    /// Attempts to parse the given byte array as an SEC1-encoded scalar.
291    ///
292    /// Returns None if the byte array does not contain a big-endian integer in the range
293    /// [0, p).
294    fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
295        let inner = U256::from_be_byte_array(bytes);
296        CtOption::new(
297            Self(inner),
298            ConstantTimeLess::ct_lt(&inner, &NistP256::ORDER),
299        )
300    }
301
302    fn to_repr(&self) -> FieldBytes {
303        self.to_bytes()
304    }
305
306    fn is_odd(&self) -> Choice {
307        self.0.is_odd().into()
308    }
309}
310
311impl FieldExt for Scalar {}
312impl PrimeFieldExt for Scalar {}
313
314impl Retrieve for Scalar {
315    type Output = U256;
316
317    fn retrieve(&self) -> U256 {
318        self.0
319    }
320}
321
322impl DefaultIsZeroes for Scalar {}
323
324impl Eq for Scalar {}
325
326impl FromUintUnchecked for Scalar {
327    type Uint = U256;
328
329    fn from_uint_unchecked(uint: Self::Uint) -> Self {
330        Self(uint)
331    }
332}
333
334impl Invert for Scalar {
335    type Output = CtOption<Self>;
336
337    fn invert(&self) -> CtOption<Self> {
338        self.invert()
339    }
340
341    fn invert_vartime(&self) -> CtOption<Self> {
342        self.invert_vartime()
343    }
344}
345
346impl IsHigh for Scalar {
347    fn is_high(&self) -> Choice {
348        ConstantTimeGreater::ct_gt(&self.0, &FRAC_MODULUS_2.0)
349    }
350}
351
352impl Shr<usize> for Scalar {
353    type Output = Self;
354
355    fn shr(self, rhs: usize) -> Self::Output {
356        self.shr_vartime(rhs as u32)
357    }
358}
359
360impl Shr<usize> for &Scalar {
361    type Output = Scalar;
362
363    fn shr(self, rhs: usize) -> Self::Output {
364        self.shr_vartime(rhs as u32)
365    }
366}
367
368impl ShrAssign<usize> for Scalar {
369    fn shr_assign(&mut self, rhs: usize) {
370        *self = *self >> rhs;
371    }
372}
373
374impl PartialEq for Scalar {
375    fn eq(&self, other: &Self) -> bool {
376        self.ct_eq(other).into()
377    }
378}
379
380impl PartialOrd for Scalar {
381    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
382        Some(self.cmp(other))
383    }
384}
385
386impl Ord for Scalar {
387    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
388        self.0.cmp(&other.0)
389    }
390}
391
392impl From<u32> for Scalar {
393    fn from(k: u32) -> Self {
394        Scalar(k.into())
395    }
396}
397
398impl From<u64> for Scalar {
399    fn from(k: u64) -> Self {
400        Scalar(k.into())
401    }
402}
403
404impl From<u128> for Scalar {
405    fn from(k: u128) -> Self {
406        Scalar(k.into())
407    }
408}
409
410impl From<Scalar> for FieldBytes {
411    fn from(scalar: Scalar) -> Self {
412        scalar.to_bytes()
413    }
414}
415
416impl From<&Scalar> for FieldBytes {
417    fn from(scalar: &Scalar) -> Self {
418        scalar.to_bytes()
419    }
420}
421
422impl From<Scalar> for U256 {
423    fn from(scalar: Scalar) -> U256 {
424        scalar.0
425    }
426}
427
428impl From<&Scalar> for U256 {
429    fn from(scalar: &Scalar) -> U256 {
430        scalar.0
431    }
432}
433
434impl FromUniformBytes<64> for Scalar {
435    fn from_uniform_bytes(bytes: &[u8; 64]) -> Self {
436        Self(barrett_reduce(
437            U256::from_be_slice(&bytes[32..]),
438            U256::from_be_slice(&bytes[..32]),
439        ))
440    }
441}
442
443impl Add<Scalar> for Scalar {
444    type Output = Scalar;
445
446    fn add(self, other: Scalar) -> Scalar {
447        Scalar::add(&self, &other)
448    }
449}
450
451impl Add<&Scalar> for &Scalar {
452    type Output = Scalar;
453
454    fn add(self, other: &Scalar) -> Scalar {
455        Scalar::add(self, other)
456    }
457}
458
459impl Add<&Scalar> for Scalar {
460    type Output = Scalar;
461
462    fn add(self, other: &Scalar) -> Scalar {
463        Scalar::add(&self, other)
464    }
465}
466
467impl AddAssign<Scalar> for Scalar {
468    fn add_assign(&mut self, rhs: Scalar) {
469        *self = Scalar::add(self, &rhs);
470    }
471}
472
473impl AddAssign<&Scalar> for Scalar {
474    fn add_assign(&mut self, rhs: &Scalar) {
475        *self = Scalar::add(self, rhs);
476    }
477}
478
479impl Sub<Scalar> for Scalar {
480    type Output = Scalar;
481
482    fn sub(self, other: Scalar) -> Scalar {
483        Scalar::sub(&self, &other)
484    }
485}
486
487impl Sub<&Scalar> for &Scalar {
488    type Output = Scalar;
489
490    fn sub(self, other: &Scalar) -> Scalar {
491        Scalar::sub(self, other)
492    }
493}
494
495impl Sub<&Scalar> for Scalar {
496    type Output = Scalar;
497
498    fn sub(self, other: &Scalar) -> Scalar {
499        Scalar::sub(&self, other)
500    }
501}
502
503impl SubAssign<Scalar> for Scalar {
504    fn sub_assign(&mut self, rhs: Scalar) {
505        *self = Scalar::sub(self, &rhs);
506    }
507}
508
509impl SubAssign<&Scalar> for Scalar {
510    fn sub_assign(&mut self, rhs: &Scalar) {
511        *self = Scalar::sub(self, rhs);
512    }
513}
514
515impl Mul<Scalar> for Scalar {
516    type Output = Scalar;
517
518    fn mul(self, other: Scalar) -> Scalar {
519        Scalar::multiply(&self, &other)
520    }
521}
522
523impl Mul<&Scalar> for &Scalar {
524    type Output = Scalar;
525
526    fn mul(self, other: &Scalar) -> Scalar {
527        Scalar::multiply(self, other)
528    }
529}
530
531impl Mul<&Scalar> for Scalar {
532    type Output = Scalar;
533
534    fn mul(self, other: &Scalar) -> Scalar {
535        Scalar::multiply(&self, other)
536    }
537}
538
539impl MulAssign<Scalar> for Scalar {
540    fn mul_assign(&mut self, rhs: Scalar) {
541        *self = Scalar::multiply(self, &rhs);
542    }
543}
544
545impl MulAssign<&Scalar> for Scalar {
546    fn mul_assign(&mut self, rhs: &Scalar) {
547        *self = Scalar::multiply(self, rhs);
548    }
549}
550
551impl Neg for Scalar {
552    type Output = Scalar;
553
554    fn neg(self) -> Scalar {
555        Scalar::ZERO - self
556    }
557}
558
559impl Neg for &Scalar {
560    type Output = Scalar;
561
562    fn neg(self) -> Scalar {
563        Scalar::ZERO - self
564    }
565}
566
567impl Reduce<U256> for Scalar {
568    fn reduce(w: &U256) -> Self {
569        let (r, underflow) = w.borrowing_sub(&NistP256::ORDER, Limb::ZERO);
570        let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
571        Self(U256::conditional_select(w, &r, !underflow))
572    }
573}
574
575impl Reduce<FieldBytes> for Scalar {
576    #[inline]
577    fn reduce(bytes: &FieldBytes) -> Self {
578        Self::reduce(&U256::from_be_byte_array(*bytes))
579    }
580}
581
582impl ReduceNonZero<U256> for Scalar {
583    fn reduce_nonzero(w: &U256) -> Self {
584        const ORDER_MINUS_ONE: U256 = NistP256::ORDER.as_ref().wrapping_sub(&U256::ONE);
585        let (r, underflow) = w.borrowing_sub(&ORDER_MINUS_ONE, Limb::ZERO);
586        let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
587        Self(U256::conditional_select(w, &r, !underflow).wrapping_add(&U256::ONE))
588    }
589}
590
591impl ReduceNonZero<FieldBytes> for Scalar {
592    #[inline]
593    fn reduce_nonzero(bytes: &FieldBytes) -> Self {
594        Self::reduce_nonzero(&U256::from_be_byte_array(*bytes))
595    }
596}
597
598impl Sum for Scalar {
599    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
600        iter.reduce(Add::add).unwrap_or(Self::ZERO)
601    }
602}
603
604impl<'a> Sum<&'a Scalar> for Scalar {
605    fn sum<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
606        iter.copied().sum()
607    }
608}
609
610impl Product for Scalar {
611    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
612        iter.reduce(Mul::mul).unwrap_or(Self::ONE)
613    }
614}
615
616impl<'a> Product<&'a Scalar> for Scalar {
617    fn product<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
618        iter.copied().product()
619    }
620}
621
622impl ConditionallySelectable for Scalar {
623    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
624        Self(U256::conditional_select(&a.0, &b.0, choice))
625    }
626}
627
628impl ConstantTimeEq for Scalar {
629    fn ct_eq(&self, other: &Self) -> Choice {
630        ConstantTimeEq::ct_eq(&self.0, &other.0)
631    }
632}
633
634impl ctutils::CtEq for Scalar {
635    fn ct_eq(&self, other: &Self) -> ctutils::Choice {
636        ConstantTimeEq::ct_eq(self, other).into()
637    }
638}
639
640impl ctutils::CtSelect for Scalar {
641    fn ct_select(&self, other: &Self, choice: ctutils::Choice) -> Self {
642        ConditionallySelectable::conditional_select(self, other, choice.into())
643    }
644}
645
646impl Debug for Scalar {
647    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
648        write!(f, "Scalar(0x{:X})", &self.0)
649    }
650}
651
652#[cfg(feature = "serde")]
653impl Serialize for Scalar {
654    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
655    where
656        S: ser::Serializer,
657    {
658        ScalarValue::from(self).serialize(serializer)
659    }
660}
661
662#[cfg(feature = "serde")]
663impl<'de> Deserialize<'de> for Scalar {
664    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
665    where
666        D: de::Deserializer<'de>,
667    {
668        Ok(ScalarValue::deserialize(deserializer)?.into())
669    }
670}
671
672#[cfg(test)]
673mod tests {
674    use super::{Scalar, U256};
675    use crate::{FieldBytes, NistP256, SecretKey};
676    use elliptic_curve::{Curve, array::Array, group::ff::PrimeField, ops::ReduceNonZero};
677
678    primefield::test_primefield!(Scalar, U256);
679
680    #[test]
681    fn from_to_bytes_roundtrip() {
682        let k: u64 = 42;
683        let mut bytes = FieldBytes::default();
684        bytes[24..].copy_from_slice(k.to_be_bytes().as_ref());
685
686        let scalar = Scalar::from_repr(bytes).unwrap();
687        assert_eq!(bytes, scalar.to_bytes());
688    }
689
690    /// Basic tests that multiplication works.
691    #[test]
692    fn multiply() {
693        let one = Scalar::ONE;
694        let two = one + one;
695        let three = two + one;
696        let six = three + three;
697        assert_eq!(six, two * three);
698
699        let minus_two = -two;
700        let minus_three = -three;
701        assert_eq!(two, -minus_two);
702
703        assert_eq!(minus_three * minus_two, minus_two * minus_three);
704        assert_eq!(six, minus_two * minus_three);
705    }
706
707    /// Tests that a Scalar can be safely converted to a SecretKey and back
708    #[test]
709    fn from_ec_secret() {
710        let scalar = Scalar::ONE;
711        let secret = SecretKey::from_bytes(&scalar.to_bytes()).unwrap();
712        let rederived_scalar = Scalar::from(&secret);
713        assert_eq!(scalar.0, rederived_scalar.0);
714    }
715
716    #[test]
717    fn reduce_nonzero() {
718        assert_eq!(Scalar::reduce_nonzero(&Array::default()).0, U256::ONE,);
719        assert_eq!(Scalar::reduce_nonzero(&U256::ONE).0, U256::from_u8(2),);
720        assert_eq!(
721            Scalar::reduce_nonzero(&U256::from_u8(2)).0,
722            U256::from_u8(3),
723        );
724
725        assert_eq!(
726            Scalar::reduce_nonzero(NistP256::ORDER.as_ref()).0,
727            U256::from_u8(2),
728        );
729        assert_eq!(
730            Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(1))).0,
731            U256::ONE,
732        );
733        assert_eq!(
734            Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(2))).0,
735            NistP256::ORDER.wrapping_sub(&U256::ONE),
736        );
737        assert_eq!(
738            Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(3))).0,
739            NistP256::ORDER.wrapping_sub(&U256::from_u8(2)),
740        );
741
742        assert_eq!(
743            Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_add(&U256::ONE)).0,
744            U256::from_u8(3),
745        );
746        assert_eq!(
747            Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_add(&U256::from_u8(2))).0,
748            U256::from_u8(4),
749        );
750    }
751}