1use crate::{FieldBytes, NistP521, Uint};
14use core::{
15 cmp::Ordering,
16 fmt::{self, Debug},
17 iter::{Product, Sum},
18 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
19};
20use elliptic_curve::{
21 Error, Generate,
22 array::Array,
23 bigint::{self, Limb, Odd, Word, cpubits, modular::Retrieve},
24 ff::{self, Field, PrimeField},
25 field::bytes_to_uint,
26 ops::{BatchInvert, Invert},
27 rand_core::TryRng,
28 subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess, CtOption},
29 zeroize::DefaultIsZeroes,
30};
31use primefield::{FieldExt, PrimeFieldExt};
32
33cpubits! {
35 32 => {
36 #[allow(clippy::needless_lifetimes, clippy::unnecessary_cast)]
37 #[allow(dead_code)]
38 #[rustfmt::skip]
39 #[path = "field/p521_32.rs"]
40 mod field_impl;
41 }
42 64 => {
43 #[allow(clippy::needless_lifetimes, clippy::unnecessary_cast)]
44 #[allow(dead_code)]
45 #[rustfmt::skip]
46 #[path = "field/p521_64.rs"]
47 mod field_impl;
48 }
49}
50
51mod loose;
52
53use self::field_impl::*;
54pub(crate) use self::loose::LooseFieldElement;
55
56const MODULUS_HEX: &str = {
57 cpubits! {
58 32 => {
59 "000001ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
60 }
61 64 => {
62 "00000000000001ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
63 }
64 }
65};
66
67pub(crate) const MODULUS: Uint = Uint::from_be_hex(MODULUS_HEX);
69
70#[derive(Clone, Copy)]
72pub struct FieldElement(pub(crate) fiat_p521_tight_field_element);
73
74impl FieldElement {
75 pub const ZERO: Self = Self::from_u64(0);
77
78 pub const ONE: Self = Self::from_u64(1);
80
81 cpubits! {
82 32 => { const LIMBS: usize = 19; }
83 64 => { const LIMBS: usize = 9; }
84 }
85
86 pub fn from_bytes(repr: &FieldBytes) -> CtOption<Self> {
88 Self::from_uint(bytes_to_uint::<NistP521>(repr))
89 }
90
91 pub fn from_slice(slice: &[u8]) -> elliptic_curve::Result<Self> {
93 let field_bytes = FieldBytes::try_from(slice).map_err(|_| Error)?;
94 Self::from_bytes(&field_bytes).into_option().ok_or(Error)
95 }
96
97 pub fn from_uint(uint: Uint) -> CtOption<Self> {
99 let is_some = uint.ct_lt(&MODULUS);
100 CtOption::new(Self::from_uint_unchecked(uint), is_some)
101 }
102
103 pub(crate) const fn from_hex(hex: &str) -> Self {
111 assert!(
112 hex.len() == 521usize.div_ceil(8) * 2,
113 "hex is the wrong length (expected 132 hex chars)"
114 );
115
116 let mut hex_bytes = [b'0'; { Uint::BITS as usize / 4 }];
118
119 let offset = hex_bytes.len() - hex.len();
120 let mut i = 0;
121 while i < hex.len() {
122 hex_bytes[i + offset] = hex.as_bytes()[i];
123 i += 1;
124 }
125
126 let uint = match core::str::from_utf8(&hex_bytes) {
127 Ok(padded_hex) => Uint::from_be_hex(padded_hex),
128 Err(_) => panic!("invalid hex string"),
129 };
130
131 assert!(matches!(uint.cmp_vartime(&MODULUS), Ordering::Less));
132 Self::from_uint_unchecked(uint)
133 }
134
135 pub const fn from_u64(w: u64) -> Self {
137 Self::from_uint_unchecked(Uint::from_u64(w))
138 }
139
140 pub(crate) const fn from_uint_unchecked(w: Uint) -> Self {
146 let le_bytes_wide = w.to_le_bytes();
150
151 let mut le_bytes = [0u8; 66];
152 let mut i = 0;
153
154 while i < le_bytes.len() {
156 le_bytes[i] = le_bytes_wide.as_slice()[i];
157 i += 1;
158 }
159
160 let mut out = fiat_p521_tight_field_element([0; Self::LIMBS]);
163 fiat_p521_from_bytes(&mut out, &le_bytes);
164 Self(out)
165 }
166
167 pub const fn to_bytes(self) -> FieldBytes {
169 const BYTES: usize = 66;
170
171 let mut ret = [0u8; BYTES];
172 fiat_p521_to_bytes(&mut ret, &self.0);
173
174 let mut i = 0;
177 while i < (BYTES / 2) {
178 let j = BYTES - i - 1;
179 let tmp = ret[i];
180 ret[i] = ret[j];
181 ret[j] = tmp;
182 i += 1;
183 }
184
185 Array(ret)
186 }
187
188 pub fn is_odd(&self) -> Choice {
194 Choice::from(self.0[0] as u8 & 1)
195 }
196
197 pub fn is_even(&self) -> Choice {
203 !self.is_odd()
204 }
205
206 pub fn is_zero(&self) -> Choice {
212 self.ct_eq(&Self::ZERO)
213 }
214
215 #[inline]
217 pub const fn add_loose(&self, rhs: &Self) -> LooseFieldElement {
218 let mut out = fiat_p521_loose_field_element([0; Self::LIMBS]);
219 fiat_p521_add(&mut out, &self.0, &rhs.0);
220 LooseFieldElement(out)
221 }
222
223 #[inline]
225 #[must_use]
226 pub const fn double_loose(&self) -> LooseFieldElement {
227 self.add_loose(self)
228 }
229
230 #[inline]
232 pub const fn sub_loose(&self, rhs: &Self) -> LooseFieldElement {
233 let mut out = fiat_p521_loose_field_element([0; Self::LIMBS]);
234 fiat_p521_sub(&mut out, &self.0, &rhs.0);
235 LooseFieldElement(out)
236 }
237
238 #[inline]
240 pub const fn neg_loose(&self) -> LooseFieldElement {
241 let mut out = fiat_p521_loose_field_element([0; Self::LIMBS]);
242 fiat_p521_opp(&mut out, &self.0);
243 LooseFieldElement(out)
244 }
245
246 #[inline]
248 pub const fn add(&self, rhs: &Self) -> Self {
249 let mut out = fiat_p521_tight_field_element([0; Self::LIMBS]);
250 fiat_p521_carry_add(&mut out, &self.0, &rhs.0);
251 Self(out)
252 }
253
254 #[inline]
256 pub const fn sub(&self, rhs: &Self) -> Self {
257 let mut out = fiat_p521_tight_field_element([0; Self::LIMBS]);
258 fiat_p521_carry_sub(&mut out, &self.0, &rhs.0);
259 Self(out)
260 }
261
262 #[inline]
264 pub const fn neg(&self) -> Self {
265 let mut out = fiat_p521_tight_field_element([0; Self::LIMBS]);
266 fiat_p521_carry_opp(&mut out, &self.0);
267 Self(out)
268 }
269
270 #[inline]
272 #[must_use]
273 pub const fn double(&self) -> Self {
274 self.add(self)
275 }
276
277 #[inline]
279 pub const fn multiply(&self, rhs: &Self) -> Self {
280 self.relax().multiply(&rhs.relax())
281 }
282
283 #[inline]
285 pub const fn square(&self) -> Self {
286 self.relax().square()
287 }
288
289 const fn sqn(&self, n: usize) -> Self {
291 self.sqn_vartime(n)
292 }
293
294 pub const fn pow_vartime<const RHS_LIMBS: usize>(&self, exp: &bigint::Uint<RHS_LIMBS>) -> Self {
300 let mut res = Self::ONE;
301 let mut i = RHS_LIMBS;
302
303 while i > 0 {
304 i -= 1;
305
306 let mut j = Limb::BITS;
307 while j > 0 {
308 j -= 1;
309 res = res.square();
310
311 if ((exp.as_limbs()[i].0 >> j) & 1) == 1 {
312 res = res.multiply(self);
313 }
314 }
315 }
316
317 res
318 }
319
320 pub const fn sqn_vartime(&self, n: usize) -> Self {
326 let mut x = *self;
327 let mut i = 0;
328 while i < n {
329 x = x.square();
330 i += 1;
331 }
332 x
333 }
334
335 pub fn invert(&self) -> CtOption<Self> {
337 self.to_uint()
338 .invert_odd_mod(const { &Odd::from_be_hex(MODULUS_HEX) })
339 .map(Self::from_uint_unchecked)
340 .into()
341 }
342
343 pub fn invert_vartime(&self) -> CtOption<Self> {
345 self.to_uint()
346 .invert_odd_mod_vartime(const { &Odd::from_be_hex(MODULUS_HEX) })
347 .map(Self::from_uint_unchecked)
348 .into()
349 }
350
351 const fn invert_unwrap(&self) -> Self {
356 Self::from_uint_unchecked(
357 self.to_uint()
358 .invert_odd_mod(const { &Odd::from_be_hex(MODULUS_HEX) })
359 .expect_copied("input should be non-zero"),
360 )
361 }
362
363 pub fn sqrt(&self) -> CtOption<Self> {
375 let sqrt = self.sqn(519);
376 CtOption::new(sqrt, sqrt.square().ct_eq(self))
377 }
378
379 #[inline]
381 pub const fn relax(&self) -> LooseFieldElement {
382 let mut out = fiat_p521_loose_field_element([0; Self::LIMBS]);
383 fiat_p521_relax(&mut out, &self.0);
384 LooseFieldElement(out)
385 }
386
387 #[inline]
389 pub(crate) const fn to_uint(self) -> Uint {
390 let field_bytes = self.to_bytes();
391 let mut uint_bytes = [0u8; Uint::LIMBS * Limb::BYTES];
392
393 let offset = uint_bytes.len() - field_bytes.0.len();
394 let mut i = 0;
395 while i < field_bytes.0.len() {
396 uint_bytes[i + offset] = field_bytes.0[i];
397 i += 1
398 }
399
400 Uint::from_be_slice(&uint_bytes)
401 }
402}
403
404impl AsRef<fiat_p521_tight_field_element> for FieldElement {
405 fn as_ref(&self) -> &fiat_p521_tight_field_element {
406 &self.0
407 }
408}
409
410impl BatchInvert for FieldElement {}
411
412impl Default for FieldElement {
413 fn default() -> Self {
414 Self::ZERO
415 }
416}
417
418impl Debug for FieldElement {
419 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
443 let bytes = self.to_bytes();
444 let formatter = base16ct::HexDisplay(&bytes);
445 f.debug_tuple("FieldElement")
446 .field(&format_args!("0x{formatter:X}"))
447 .finish()
448 }
449}
450
451impl Eq for FieldElement {}
452impl PartialEq for FieldElement {
453 fn eq(&self, rhs: &Self) -> bool {
454 self.ct_eq(rhs).into()
455 }
456}
457
458impl From<u32> for FieldElement {
459 fn from(n: u32) -> FieldElement {
460 Self::from_uint_unchecked(Uint::from(n))
461 }
462}
463
464impl From<u64> for FieldElement {
465 fn from(n: u64) -> FieldElement {
466 Self::from_uint_unchecked(Uint::from(n))
467 }
468}
469
470impl From<u128> for FieldElement {
471 fn from(n: u128) -> FieldElement {
472 Self::from_uint_unchecked(Uint::from(n))
473 }
474}
475
476impl ConditionallySelectable for FieldElement {
477 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
478 let out = <[Word; Self::LIMBS]>::conditional_select(&a.0.0, &b.0.0, choice);
479 Self(fiat_p521_tight_field_element(out))
480 }
481}
482
483impl ConstantTimeEq for FieldElement {
484 fn ct_eq(&self, other: &Self) -> Choice {
485 let a = self.to_bytes();
486 let b = other.to_bytes();
487 a.ct_eq(&b)
488 }
489}
490
491impl DefaultIsZeroes for FieldElement {}
492
493impl Field for FieldElement {
494 const ZERO: Self = Self::ZERO;
495 const ONE: Self = Self::ONE;
496
497 fn try_random<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
498 let mut bytes = <FieldBytes>::default();
500
501 loop {
502 rng.try_fill_bytes(&mut bytes)?;
503 if let Some(fe) = Self::from_bytes(&bytes).into() {
504 return Ok(fe);
505 }
506 }
507 }
508
509 fn is_zero(&self) -> Choice {
510 Self::ZERO.ct_eq(self)
511 }
512
513 fn square(&self) -> Self {
514 self.square()
515 }
516
517 fn double(&self) -> Self {
518 self.double()
519 }
520
521 fn invert(&self) -> CtOption<Self> {
522 self.invert()
523 }
524
525 fn sqrt(&self) -> CtOption<Self> {
526 self.sqrt()
527 }
528
529 fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
530 ff::helpers::sqrt_ratio_generic(num, div)
531 }
532}
533
534impl FieldExt for FieldElement {}
535impl PrimeFieldExt for FieldElement {}
536
537impl Generate for FieldElement {
538 fn try_generate_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
539 Self::try_random(rng)
540 }
541}
542
543impl PrimeField for FieldElement {
544 type Repr = FieldBytes;
545
546 const MODULUS: &'static str = MODULUS_HEX;
547 const NUM_BITS: u32 = 521;
548 const CAPACITY: u32 = 520;
549 const TWO_INV: Self = Self::from_u64(2).invert_unwrap();
550 const MULTIPLICATIVE_GENERATOR: Self = Self::from_u64(3);
551 const S: u32 = 1;
552 const ROOT_OF_UNITY: Self = Self::from_hex(
553 "01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe",
554 );
555 const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unwrap();
556 const DELTA: Self = Self::from_u64(9);
557
558 #[inline]
559 fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
560 Self::from_bytes(&bytes)
561 }
562
563 #[inline]
564 fn to_repr(&self) -> FieldBytes {
565 self.to_bytes()
566 }
567
568 #[inline]
569 fn is_odd(&self) -> Choice {
570 self.is_odd()
571 }
572}
573
574impl Add for FieldElement {
579 type Output = FieldElement;
580
581 #[inline]
582 fn add(self, rhs: FieldElement) -> FieldElement {
583 Self::add(&self, &rhs)
584 }
585}
586
587impl Add<&FieldElement> for FieldElement {
588 type Output = FieldElement;
589
590 #[inline]
591 fn add(self, rhs: &FieldElement) -> FieldElement {
592 Self::add(&self, rhs)
593 }
594}
595
596impl Add<&FieldElement> for &FieldElement {
597 type Output = FieldElement;
598
599 #[inline]
600 fn add(self, rhs: &FieldElement) -> FieldElement {
601 FieldElement::add(self, rhs)
602 }
603}
604
605impl AddAssign<FieldElement> for FieldElement {
606 #[inline]
607 fn add_assign(&mut self, other: FieldElement) {
608 *self = *self + other;
609 }
610}
611
612impl AddAssign<&FieldElement> for FieldElement {
613 #[inline]
614 fn add_assign(&mut self, other: &FieldElement) {
615 *self = *self + other;
616 }
617}
618
619impl Sub for FieldElement {
620 type Output = FieldElement;
621
622 #[inline]
623 fn sub(self, rhs: FieldElement) -> FieldElement {
624 Self::sub(&self, &rhs)
625 }
626}
627
628impl Sub<&FieldElement> for FieldElement {
629 type Output = FieldElement;
630
631 #[inline]
632 fn sub(self, rhs: &FieldElement) -> FieldElement {
633 Self::sub(&self, rhs)
634 }
635}
636
637impl Sub<&FieldElement> for &FieldElement {
638 type Output = FieldElement;
639
640 #[inline]
641 fn sub(self, rhs: &FieldElement) -> FieldElement {
642 FieldElement::sub(self, rhs)
643 }
644}
645
646impl SubAssign<FieldElement> for FieldElement {
647 #[inline]
648 fn sub_assign(&mut self, other: FieldElement) {
649 *self = *self - other;
650 }
651}
652
653impl SubAssign<&FieldElement> for FieldElement {
654 #[inline]
655 fn sub_assign(&mut self, other: &FieldElement) {
656 *self = *self - other;
657 }
658}
659
660impl Mul for FieldElement {
661 type Output = FieldElement;
662
663 #[inline]
664 fn mul(self, rhs: FieldElement) -> FieldElement {
665 self.relax().mul(&rhs.relax())
666 }
667}
668
669impl Mul<&FieldElement> for FieldElement {
670 type Output = FieldElement;
671
672 #[inline]
673 fn mul(self, rhs: &FieldElement) -> FieldElement {
674 self.relax().mul(&rhs.relax())
675 }
676}
677
678impl Mul<&FieldElement> for &FieldElement {
679 type Output = FieldElement;
680
681 #[inline]
682 fn mul(self, rhs: &FieldElement) -> FieldElement {
683 self.relax().mul(&rhs.relax())
684 }
685}
686
687impl MulAssign<&FieldElement> for FieldElement {
688 #[inline]
689 fn mul_assign(&mut self, other: &FieldElement) {
690 *self = *self * other;
691 }
692}
693
694impl MulAssign for FieldElement {
695 #[inline]
696 fn mul_assign(&mut self, other: FieldElement) {
697 *self = *self * other;
698 }
699}
700
701impl Neg for FieldElement {
702 type Output = FieldElement;
703
704 #[inline]
705 fn neg(self) -> FieldElement {
706 Self::neg(&self)
707 }
708}
709
710impl Sum for FieldElement {
715 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
716 iter.reduce(Add::add).unwrap_or(Self::ZERO)
717 }
718}
719
720impl<'a> Sum<&'a FieldElement> for FieldElement {
721 fn sum<I: Iterator<Item = &'a FieldElement>>(iter: I) -> Self {
722 iter.copied().sum()
723 }
724}
725
726impl Product for FieldElement {
727 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
728 iter.reduce(Mul::mul).unwrap_or(Self::ONE)
729 }
730}
731
732impl<'a> Product<&'a FieldElement> for FieldElement {
733 fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
734 iter.copied().product()
735 }
736}
737
738impl Invert for FieldElement {
741 type Output = CtOption<Self>;
742
743 fn invert(&self) -> CtOption<Self> {
744 self.invert()
745 }
746
747 fn invert_vartime(&self) -> CtOption<Self> {
748 self.invert_vartime()
749 }
750}
751
752impl Retrieve for FieldElement {
753 type Output = Uint;
754
755 fn retrieve(&self) -> Uint {
756 self.to_uint()
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::{FieldElement, Uint};
763 use hex_literal::hex;
764
765 primefield::test_primefield!(FieldElement, Uint);
766
767 #[test]
769 fn decode_invalid_field_element_returns_err() {
770 let overflowing_bytes = hex!(
771 "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"
772 );
773 let ct_option = FieldElement::from_bytes(&overflowing_bytes.into());
774 assert!(bool::from(ct_option.is_none()));
775 }
776
777 #[test]
778 fn sqn_edge_cases() {
779 let a = FieldElement::from_u64(5);
780 assert_eq!(a.sqn(0), a);
781 assert_eq!(a.sqn(1), a.square());
782 assert_eq!(a.sqn(2), a.square().square());
783 }
784}