1use self::scalar_impl::barrett_reduce;
4use crate::{FieldBytes, NistP256, ORDER_HEX};
5use core::{
6 fmt::{self, Debug},
7 iter::{Product, Sum},
8 ops::{Add, AddAssign, Mul, MulAssign, Neg, Shr, ShrAssign, Sub, SubAssign},
9};
10use elliptic_curve::{
11 Curve, Generate,
12 bigint::{ArrayEncoding, Limb, Odd, U256, Uint, cpubits, modular::Retrieve},
13 ctutils,
14 group::ff::{self, Field, FromUniformBytes, PrimeField},
15 ops::{Invert, Reduce, ReduceNonZero},
16 rand_core::TryRng,
17 scalar::{FromUintUnchecked, IsHigh},
18 subtle::{
19 Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
20 CtOption,
21 },
22 zeroize::DefaultIsZeroes,
23};
24use primefield::{FieldExt, PrimeFieldExt};
25
26cpubits! {
27 32 => {
28 #[path = "scalar/scalar32.rs"]
29 mod scalar_impl;
30 }
31 64 => {
32 #[path = "scalar/scalar64.rs"]
33 mod scalar_impl;
34 }
35}
36
37#[cfg(feature = "serde")]
38use {
39 elliptic_curve::ScalarValue,
40 serdect::serde::{Deserialize, Serialize, de, ser},
41};
42
43pub(crate) const MODULUS: Odd<U256> = NistP256::ORDER;
46
47const FRAC_MODULUS_2: Scalar = Scalar(MODULUS.as_ref().shr_vartime(1));
49
50#[doc = primefield::monty_field_element_doc!("Scalars are elements in the finite field modulo n.")]
51#[derive(Clone, Copy, Default)]
52pub struct Scalar(pub(crate) U256);
53
54impl Scalar {
55 pub const ZERO: Self = Self(U256::ZERO);
57
58 pub const ONE: Self = Self(U256::ONE);
60
61 pub fn to_bytes(&self) -> FieldBytes {
63 self.0.to_be_byte_array()
64 }
65
66 pub const fn add(&self, rhs: &Self) -> Self {
68 Self(self.0.add_mod(&rhs.0, NistP256::ORDER.as_nz_ref()))
69 }
70
71 pub const fn double(&self) -> Self {
73 self.add(self)
74 }
75
76 pub const fn sub(&self, rhs: &Self) -> Self {
78 Self(self.0.sub_mod(&rhs.0, NistP256::ORDER.as_nz_ref()))
79 }
80
81 pub const fn multiply(&self, rhs: &Self) -> Self {
83 let (lo, hi) = self.0.widening_mul(&rhs.0);
84 Self(barrett_reduce(lo, hi))
85 }
86
87 pub const fn square(&self) -> Self {
89 self.multiply(self)
91 }
92
93 pub const fn shr_vartime(&self, shift: u32) -> Scalar {
97 Self(self.0.unbounded_shr_vartime(shift))
98 }
99
100 pub fn invert(&self) -> CtOption<Self> {
102 self.0
103 .invert_odd_mod(const { &Odd::from_be_hex(ORDER_HEX) })
104 .map(Self)
105 .into()
106 }
107
108 pub fn invert_vartime(&self) -> CtOption<Self> {
110 self.0
111 .invert_odd_mod_vartime(const { &Odd::from_be_hex(ORDER_HEX) })
112 .map(Self)
113 .into()
114 }
115
116 const fn invert_unwrap(&self) -> Self {
121 Self(
122 self.0
123 .invert_odd_mod(const { &Odd::from_be_hex(ORDER_HEX) })
124 .expect_copied("input should be non-zero"),
125 )
126 }
127
128 pub const fn pow_vartime<const RHS_LIMBS: usize>(&self, exp: &Uint<RHS_LIMBS>) -> Self {
134 let mut res = Self::ONE;
135 let mut i = RHS_LIMBS;
136
137 while i > 0 {
138 i -= 1;
139
140 let mut j = Limb::BITS;
141 while j > 0 {
142 j -= 1;
143 res = res.square();
144
145 if ((exp.as_limbs()[i].0 >> j) & 1) == 1 {
146 res = res.multiply(self);
147 }
148 }
149 }
150
151 res
152 }
153
154 pub const fn sqn_vartime(&self, n: usize) -> Self {
160 let mut x = *self;
161 let mut i = 0;
162 while i < n {
163 x = x.square();
164 i += 1;
165 }
166 x
167 }
168
169 pub fn is_odd(&self) -> Choice {
171 self.0.is_odd().into()
172 }
173
174 pub fn is_even(&self) -> Choice {
176 !self.is_odd()
177 }
178}
179
180elliptic_curve::scalar_impls!(NistP256, Scalar);
181
182impl AsRef<Scalar> for Scalar {
183 fn as_ref(&self) -> &Scalar {
184 self
185 }
186}
187
188impl Field for Scalar {
189 const ZERO: Self = Self::ZERO;
190 const ONE: Self = Self::ONE;
191
192 fn try_random<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
193 let mut bytes = FieldBytes::default();
194
195 loop {
205 rng.try_fill_bytes(&mut bytes)?;
206 if let Some(scalar) = Scalar::from_repr(bytes).into() {
207 return Ok(scalar);
208 }
209 }
210 }
211
212 fn square(&self) -> Self {
213 Scalar::square(self)
214 }
215
216 fn double(&self) -> Self {
217 self.add(self)
218 }
219
220 fn invert(&self) -> CtOption<Self> {
221 Scalar::invert(self)
222 }
223
224 #[allow(clippy::many_single_char_names)]
227 fn sqrt(&self) -> CtOption<Self> {
228 const EXP: U256 =
229 U256::from_be_hex("07fffffff800000007fffffffffffffffde737d56d38bcf4279dce5617e3192a");
230
231 let w = self.pow_vartime(&EXP);
233
234 let mut v = Self::S;
235 let mut x = *self * w;
236 let mut b = x * w;
237 let mut z = Self::ROOT_OF_UNITY;
238
239 for max_v in (1..=Self::S).rev() {
240 let mut k = 1;
241 let mut tmp = b.square();
242 let mut j_less_than_v = Choice::from(1);
243
244 for j in 2..max_v {
245 let tmp_is_one = tmp.ct_eq(&Self::ONE);
246 let squared = Self::conditional_select(&tmp, &z, tmp_is_one).square();
247 tmp = Self::conditional_select(&squared, &tmp, tmp_is_one);
248 let new_z = Self::conditional_select(&z, &squared, tmp_is_one);
249 j_less_than_v &= !ConstantTimeEq::ct_eq(&j, &v);
250 k = u32::conditional_select(&j, &k, tmp_is_one);
251 z = Self::conditional_select(&z, &new_z, j_less_than_v);
252 }
253
254 let result = x * z;
255 x = Self::conditional_select(&result, &x, b.ct_eq(&Self::ONE));
256 z = z.square();
257 b *= z;
258 v = k;
259 }
260
261 CtOption::new(x, x.square().ct_eq(self))
262 }
263
264 fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
265 ff::helpers::sqrt_ratio_generic(num, div)
266 }
267}
268
269impl Generate for Scalar {
270 fn try_generate_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
271 Self::try_random(rng)
272 }
273}
274
275impl PrimeField for Scalar {
276 type Repr = FieldBytes;
277
278 const MODULUS: &'static str = ORDER_HEX;
279 const NUM_BITS: u32 = 256;
280 const CAPACITY: u32 = 255;
281 const TWO_INV: Self = Self(U256::from_u8(2)).invert_unwrap();
282 const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_u8(7));
283 const S: u32 = 4;
284 const ROOT_OF_UNITY: Self = Self(U256::from_be_hex(
285 "ffc97f062a770992ba807ace842a3dfc1546cad004378daf0592d7fbb41e6602",
286 ));
287 const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unwrap();
288 const DELTA: Self = Self(U256::from_u64(33232930569601));
289
290 fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
295 let inner = U256::from_be_byte_array(bytes);
296 CtOption::new(
297 Self(inner),
298 ConstantTimeLess::ct_lt(&inner, &NistP256::ORDER),
299 )
300 }
301
302 fn to_repr(&self) -> FieldBytes {
303 self.to_bytes()
304 }
305
306 fn is_odd(&self) -> Choice {
307 self.0.is_odd().into()
308 }
309}
310
311impl FieldExt for Scalar {}
312impl PrimeFieldExt for Scalar {}
313
314impl Retrieve for Scalar {
315 type Output = U256;
316
317 fn retrieve(&self) -> U256 {
318 self.0
319 }
320}
321
322impl DefaultIsZeroes for Scalar {}
323
324impl Eq for Scalar {}
325
326impl FromUintUnchecked for Scalar {
327 type Uint = U256;
328
329 fn from_uint_unchecked(uint: Self::Uint) -> Self {
330 Self(uint)
331 }
332}
333
334impl Invert for Scalar {
335 type Output = CtOption<Self>;
336
337 fn invert(&self) -> CtOption<Self> {
338 self.invert()
339 }
340
341 fn invert_vartime(&self) -> CtOption<Self> {
342 self.invert_vartime()
343 }
344}
345
346impl IsHigh for Scalar {
347 fn is_high(&self) -> Choice {
348 ConstantTimeGreater::ct_gt(&self.0, &FRAC_MODULUS_2.0)
349 }
350}
351
352impl Shr<usize> for Scalar {
353 type Output = Self;
354
355 fn shr(self, rhs: usize) -> Self::Output {
356 self.shr_vartime(rhs as u32)
357 }
358}
359
360impl Shr<usize> for &Scalar {
361 type Output = Scalar;
362
363 fn shr(self, rhs: usize) -> Self::Output {
364 self.shr_vartime(rhs as u32)
365 }
366}
367
368impl ShrAssign<usize> for Scalar {
369 fn shr_assign(&mut self, rhs: usize) {
370 *self = *self >> rhs;
371 }
372}
373
374impl PartialEq for Scalar {
375 fn eq(&self, other: &Self) -> bool {
376 self.ct_eq(other).into()
377 }
378}
379
380impl PartialOrd for Scalar {
381 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
382 Some(self.cmp(other))
383 }
384}
385
386impl Ord for Scalar {
387 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
388 self.0.cmp(&other.0)
389 }
390}
391
392impl From<u32> for Scalar {
393 fn from(k: u32) -> Self {
394 Scalar(k.into())
395 }
396}
397
398impl From<u64> for Scalar {
399 fn from(k: u64) -> Self {
400 Scalar(k.into())
401 }
402}
403
404impl From<u128> for Scalar {
405 fn from(k: u128) -> Self {
406 Scalar(k.into())
407 }
408}
409
410impl From<Scalar> for FieldBytes {
411 fn from(scalar: Scalar) -> Self {
412 scalar.to_bytes()
413 }
414}
415
416impl From<&Scalar> for FieldBytes {
417 fn from(scalar: &Scalar) -> Self {
418 scalar.to_bytes()
419 }
420}
421
422impl From<Scalar> for U256 {
423 fn from(scalar: Scalar) -> U256 {
424 scalar.0
425 }
426}
427
428impl From<&Scalar> for U256 {
429 fn from(scalar: &Scalar) -> U256 {
430 scalar.0
431 }
432}
433
434impl FromUniformBytes<64> for Scalar {
435 fn from_uniform_bytes(bytes: &[u8; 64]) -> Self {
436 Self(barrett_reduce(
437 U256::from_be_slice(&bytes[32..]),
438 U256::from_be_slice(&bytes[..32]),
439 ))
440 }
441}
442
443impl Add<Scalar> for Scalar {
444 type Output = Scalar;
445
446 fn add(self, other: Scalar) -> Scalar {
447 Scalar::add(&self, &other)
448 }
449}
450
451impl Add<&Scalar> for &Scalar {
452 type Output = Scalar;
453
454 fn add(self, other: &Scalar) -> Scalar {
455 Scalar::add(self, other)
456 }
457}
458
459impl Add<&Scalar> for Scalar {
460 type Output = Scalar;
461
462 fn add(self, other: &Scalar) -> Scalar {
463 Scalar::add(&self, other)
464 }
465}
466
467impl AddAssign<Scalar> for Scalar {
468 fn add_assign(&mut self, rhs: Scalar) {
469 *self = Scalar::add(self, &rhs);
470 }
471}
472
473impl AddAssign<&Scalar> for Scalar {
474 fn add_assign(&mut self, rhs: &Scalar) {
475 *self = Scalar::add(self, rhs);
476 }
477}
478
479impl Sub<Scalar> for Scalar {
480 type Output = Scalar;
481
482 fn sub(self, other: Scalar) -> Scalar {
483 Scalar::sub(&self, &other)
484 }
485}
486
487impl Sub<&Scalar> for &Scalar {
488 type Output = Scalar;
489
490 fn sub(self, other: &Scalar) -> Scalar {
491 Scalar::sub(self, other)
492 }
493}
494
495impl Sub<&Scalar> for Scalar {
496 type Output = Scalar;
497
498 fn sub(self, other: &Scalar) -> Scalar {
499 Scalar::sub(&self, other)
500 }
501}
502
503impl SubAssign<Scalar> for Scalar {
504 fn sub_assign(&mut self, rhs: Scalar) {
505 *self = Scalar::sub(self, &rhs);
506 }
507}
508
509impl SubAssign<&Scalar> for Scalar {
510 fn sub_assign(&mut self, rhs: &Scalar) {
511 *self = Scalar::sub(self, rhs);
512 }
513}
514
515impl Mul<Scalar> for Scalar {
516 type Output = Scalar;
517
518 fn mul(self, other: Scalar) -> Scalar {
519 Scalar::multiply(&self, &other)
520 }
521}
522
523impl Mul<&Scalar> for &Scalar {
524 type Output = Scalar;
525
526 fn mul(self, other: &Scalar) -> Scalar {
527 Scalar::multiply(self, other)
528 }
529}
530
531impl Mul<&Scalar> for Scalar {
532 type Output = Scalar;
533
534 fn mul(self, other: &Scalar) -> Scalar {
535 Scalar::multiply(&self, other)
536 }
537}
538
539impl MulAssign<Scalar> for Scalar {
540 fn mul_assign(&mut self, rhs: Scalar) {
541 *self = Scalar::multiply(self, &rhs);
542 }
543}
544
545impl MulAssign<&Scalar> for Scalar {
546 fn mul_assign(&mut self, rhs: &Scalar) {
547 *self = Scalar::multiply(self, rhs);
548 }
549}
550
551impl Neg for Scalar {
552 type Output = Scalar;
553
554 fn neg(self) -> Scalar {
555 Scalar::ZERO - self
556 }
557}
558
559impl Neg for &Scalar {
560 type Output = Scalar;
561
562 fn neg(self) -> Scalar {
563 Scalar::ZERO - self
564 }
565}
566
567impl Reduce<U256> for Scalar {
568 fn reduce(w: &U256) -> Self {
569 let (r, underflow) = w.borrowing_sub(&NistP256::ORDER, Limb::ZERO);
570 let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
571 Self(U256::conditional_select(w, &r, !underflow))
572 }
573}
574
575impl Reduce<FieldBytes> for Scalar {
576 #[inline]
577 fn reduce(bytes: &FieldBytes) -> Self {
578 Self::reduce(&U256::from_be_byte_array(*bytes))
579 }
580}
581
582impl ReduceNonZero<U256> for Scalar {
583 fn reduce_nonzero(w: &U256) -> Self {
584 const ORDER_MINUS_ONE: U256 = NistP256::ORDER.as_ref().wrapping_sub(&U256::ONE);
585 let (r, underflow) = w.borrowing_sub(&ORDER_MINUS_ONE, Limb::ZERO);
586 let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
587 Self(U256::conditional_select(w, &r, !underflow).wrapping_add(&U256::ONE))
588 }
589}
590
591impl ReduceNonZero<FieldBytes> for Scalar {
592 #[inline]
593 fn reduce_nonzero(bytes: &FieldBytes) -> Self {
594 Self::reduce_nonzero(&U256::from_be_byte_array(*bytes))
595 }
596}
597
598impl Sum for Scalar {
599 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
600 iter.reduce(Add::add).unwrap_or(Self::ZERO)
601 }
602}
603
604impl<'a> Sum<&'a Scalar> for Scalar {
605 fn sum<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
606 iter.copied().sum()
607 }
608}
609
610impl Product for Scalar {
611 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
612 iter.reduce(Mul::mul).unwrap_or(Self::ONE)
613 }
614}
615
616impl<'a> Product<&'a Scalar> for Scalar {
617 fn product<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
618 iter.copied().product()
619 }
620}
621
622impl ConditionallySelectable for Scalar {
623 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
624 Self(U256::conditional_select(&a.0, &b.0, choice))
625 }
626}
627
628impl ConstantTimeEq for Scalar {
629 fn ct_eq(&self, other: &Self) -> Choice {
630 ConstantTimeEq::ct_eq(&self.0, &other.0)
631 }
632}
633
634impl ctutils::CtEq for Scalar {
635 fn ct_eq(&self, other: &Self) -> ctutils::Choice {
636 ConstantTimeEq::ct_eq(self, other).into()
637 }
638}
639
640impl ctutils::CtSelect for Scalar {
641 fn ct_select(&self, other: &Self, choice: ctutils::Choice) -> Self {
642 ConditionallySelectable::conditional_select(self, other, choice.into())
643 }
644}
645
646impl Debug for Scalar {
647 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
648 write!(f, "Scalar(0x{:X})", &self.0)
649 }
650}
651
652#[cfg(feature = "serde")]
653impl Serialize for Scalar {
654 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
655 where
656 S: ser::Serializer,
657 {
658 ScalarValue::from(self).serialize(serializer)
659 }
660}
661
662#[cfg(feature = "serde")]
663impl<'de> Deserialize<'de> for Scalar {
664 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
665 where
666 D: de::Deserializer<'de>,
667 {
668 Ok(ScalarValue::deserialize(deserializer)?.into())
669 }
670}
671
672#[cfg(test)]
673mod tests {
674 use super::{Scalar, U256};
675 use crate::{FieldBytes, NistP256, SecretKey};
676 use elliptic_curve::{Curve, array::Array, group::ff::PrimeField, ops::ReduceNonZero};
677
678 primefield::test_primefield!(Scalar, U256);
679
680 #[test]
681 fn from_to_bytes_roundtrip() {
682 let k: u64 = 42;
683 let mut bytes = FieldBytes::default();
684 bytes[24..].copy_from_slice(k.to_be_bytes().as_ref());
685
686 let scalar = Scalar::from_repr(bytes).unwrap();
687 assert_eq!(bytes, scalar.to_bytes());
688 }
689
690 #[test]
692 fn multiply() {
693 let one = Scalar::ONE;
694 let two = one + one;
695 let three = two + one;
696 let six = three + three;
697 assert_eq!(six, two * three);
698
699 let minus_two = -two;
700 let minus_three = -three;
701 assert_eq!(two, -minus_two);
702
703 assert_eq!(minus_three * minus_two, minus_two * minus_three);
704 assert_eq!(six, minus_two * minus_three);
705 }
706
707 #[test]
709 fn from_ec_secret() {
710 let scalar = Scalar::ONE;
711 let secret = SecretKey::from_bytes(&scalar.to_bytes()).unwrap();
712 let rederived_scalar = Scalar::from(&secret);
713 assert_eq!(scalar.0, rederived_scalar.0);
714 }
715
716 #[test]
717 fn reduce_nonzero() {
718 assert_eq!(Scalar::reduce_nonzero(&Array::default()).0, U256::ONE,);
719 assert_eq!(Scalar::reduce_nonzero(&U256::ONE).0, U256::from_u8(2),);
720 assert_eq!(
721 Scalar::reduce_nonzero(&U256::from_u8(2)).0,
722 U256::from_u8(3),
723 );
724
725 assert_eq!(
726 Scalar::reduce_nonzero(NistP256::ORDER.as_ref()).0,
727 U256::from_u8(2),
728 );
729 assert_eq!(
730 Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(1))).0,
731 U256::ONE,
732 );
733 assert_eq!(
734 Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(2))).0,
735 NistP256::ORDER.wrapping_sub(&U256::ONE),
736 );
737 assert_eq!(
738 Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(3))).0,
739 NistP256::ORDER.wrapping_sub(&U256::from_u8(2)),
740 );
741
742 assert_eq!(
743 Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_add(&U256::ONE)).0,
744 U256::from_u8(3),
745 );
746 assert_eq!(
747 Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_add(&U256::from_u8(2))).0,
748 U256::from_u8(4),
749 );
750 }
751}