1#![allow(clippy::unusual_byte_groupings)]
4
5#[cfg_attr(target_pointer_width = "32", path = "scalar/p384_scalar_32.rs")]
6#[cfg_attr(target_pointer_width = "64", path = "scalar/p384_scalar_64.rs")]
7#[allow(
8    clippy::identity_op,
9    clippy::too_many_arguments,
10    clippy::unnecessary_cast
11)]
12mod scalar_impl;
13
14use self::scalar_impl::*;
15use crate::{FieldBytes, NistP384, SecretKey, ORDER_HEX, U384};
16use core::{
17    iter::{Product, Sum},
18    ops::{AddAssign, MulAssign, Neg, Shr, ShrAssign, SubAssign},
19};
20use elliptic_curve::{
21    bigint::{self, ArrayEncoding, Limb},
22    ff::PrimeField,
23    ops::{Invert, Reduce},
24    scalar::{FromUintUnchecked, IsHigh},
25    subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, CtOption},
26    Curve as _, Error, Result, ScalarPrimitive,
27};
28
29#[cfg(feature = "bits")]
30use {crate::ScalarBits, elliptic_curve::group::ff::PrimeFieldBits};
31
32#[cfg(feature = "serde")]
33use serdect::serde::{de, ser, Deserialize, Serialize};
34
35#[cfg(doc)]
36use core::ops::{Add, Mul, Sub};
37
38#[derive(Clone, Copy, Debug, PartialOrd, Ord)]
70pub struct Scalar(U384);
71
72primeorder::impl_mont_field_element!(
73    NistP384,
74    Scalar,
75    FieldBytes,
76    U384,
77    NistP384::ORDER,
78    fiat_p384_scalar_montgomery_domain_field_element,
79    fiat_p384_scalar_from_montgomery,
80    fiat_p384_scalar_to_montgomery,
81    fiat_p384_scalar_add,
82    fiat_p384_scalar_sub,
83    fiat_p384_scalar_mul,
84    fiat_p384_scalar_opp,
85    fiat_p384_scalar_square
86);
87
88impl Scalar {
89    pub fn invert(&self) -> CtOption<Self> {
91        CtOption::new(self.invert_unchecked(), !self.is_zero())
92    }
93
94    const fn invert_unchecked(&self) -> Self {
98        let words = impl_field_invert!(
99            self.to_canonical().as_words(),
100            Self::ONE.0.to_words(),
101            Limb::BITS,
102            bigint::nlimbs!(U384::BITS),
103            fiat_p384_scalar_mul,
104            fiat_p384_scalar_opp,
105            fiat_p384_scalar_divstep_precomp,
106            fiat_p384_scalar_divstep,
107            fiat_p384_scalar_msat,
108            fiat_p384_scalar_selectznz,
109        );
110
111        Self(U384::from_words(words))
112    }
113
114    pub fn sqrt(&self) -> CtOption<Self> {
116        let t1 = *self;
119        let t10 = t1.square();
120        let t11 = *self * t10;
121        let t101 = t10 * t11;
122        let t111 = t10 * t101;
123        let t1001 = t10 * t111;
124        let t1011 = t10 * t1001;
125        let t1101 = t10 * t1011;
126        let t1111 = t10 * t1101;
127        let t11110 = t1111.square();
128        let t11111 = t1 * t11110;
129        let t1111100 = t11111.sqn(2);
130        let t11111000 = t1111100.square();
131        let i14 = t11111000.square();
132        let i20 = i14.sqn(5) * i14;
133        let i31 = i20.sqn(10) * i20;
134        let i58 = (i31.sqn(4) * t11111000).sqn(21) * i31;
135        let i110 = (i58.sqn(3) * t1111100).sqn(47) * i58;
136        let x194 = i110.sqn(95) * i110 * t1111;
137        let i225 = ((x194.sqn(6) * t111).sqn(3) * t11).sqn(7);
138        let i235 = ((t1101 * i225).sqn(6) * t1101).square() * t1;
139        let i258 = ((i235.sqn(11) * t11111).sqn(2) * t1).sqn(8);
140        let i269 = ((t1101 * i258).sqn(2) * t11).sqn(6) * t1011;
141        let i286 = ((i269.sqn(4) * t111).sqn(6) * t11111).sqn(5);
142        let i308 = ((t1011 * i286).sqn(10) * t1101).sqn(9) * t1101;
143        let i323 = ((i308.sqn(4) * t1011).sqn(6) * t1001).sqn(3);
144        let i340 = ((t1 * i323).sqn(7) * t1011).sqn(7) * t101;
145        let i357 = ((i340.sqn(5) * t111).sqn(5) * t1111).sqn(5);
146        let i369 = ((t1011 * i357).sqn(4) * t1011).sqn(5) * t111;
147        let i387 = ((i369.sqn(3) * t11).sqn(7) * t11).sqn(6);
148        let i397 = ((t1011 * i387).sqn(4) * t101).sqn(3) * t11;
149        let i413 = ((i397.sqn(4) * t11).sqn(4) * t11).sqn(6);
150        let i427 = ((t101 * i413).sqn(5) * t101).sqn(6) * t1011;
151        let x = i427.sqn(3) * t101;
152        CtOption::new(x, x.square().ct_eq(&t1))
153    }
154
155    fn sqn(&self, n: usize) -> Self {
156        let mut x = *self;
157        for _ in 0..n {
158            x = x.square();
159        }
160        x
161    }
162
163    pub const fn shr_vartime(&self, shift: usize) -> Scalar {
167        Self(self.0.shr_vartime(shift))
168    }
169}
170
171impl AsRef<Scalar> for Scalar {
172    fn as_ref(&self) -> &Scalar {
173        self
174    }
175}
176
177impl FromUintUnchecked for Scalar {
178    type Uint = U384;
179
180    fn from_uint_unchecked(uint: Self::Uint) -> Self {
181        Self::from_uint_unchecked(uint)
182    }
183}
184
185impl Invert for Scalar {
186    type Output = CtOption<Self>;
187
188    fn invert(&self) -> CtOption<Self> {
189        self.invert()
190    }
191}
192
193impl IsHigh for Scalar {
194    fn is_high(&self) -> Choice {
195        const MODULUS_SHR1: U384 = NistP384::ORDER.shr_vartime(1);
196        self.to_canonical().ct_gt(&MODULUS_SHR1)
197    }
198}
199
200impl Shr<usize> for Scalar {
201    type Output = Self;
202
203    fn shr(self, rhs: usize) -> Self::Output {
204        self.shr_vartime(rhs)
205    }
206}
207
208impl Shr<usize> for &Scalar {
209    type Output = Scalar;
210
211    fn shr(self, rhs: usize) -> Self::Output {
212        self.shr_vartime(rhs)
213    }
214}
215
216impl ShrAssign<usize> for Scalar {
217    fn shr_assign(&mut self, rhs: usize) {
218        *self = *self >> rhs;
219    }
220}
221
222impl PrimeField for Scalar {
223    type Repr = FieldBytes;
224
225    const MODULUS: &'static str = ORDER_HEX;
226    const CAPACITY: u32 = 383;
227    const NUM_BITS: u32 = 384;
228    const TWO_INV: Self = Self::from_u64(2).invert_unchecked();
229    const MULTIPLICATIVE_GENERATOR: Self = Self::from_u64(2);
230    const S: u32 = 1;
231    const ROOT_OF_UNITY: Self = Self::from_hex("ffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52972");
232    const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unchecked();
233    const DELTA: Self = Self::from_u64(4);
234
235    #[inline]
236    fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
237        Self::from_bytes(&bytes)
238    }
239
240    #[inline]
241    fn to_repr(&self) -> FieldBytes {
242        self.to_bytes()
243    }
244
245    #[inline]
246    fn is_odd(&self) -> Choice {
247        self.is_odd()
248    }
249}
250
251#[cfg(feature = "bits")]
252impl PrimeFieldBits for Scalar {
253    type ReprBits = fiat_p384_scalar_montgomery_domain_field_element;
254
255    fn to_le_bits(&self) -> ScalarBits {
256        self.to_canonical().to_words().into()
257    }
258
259    fn char_le_bits() -> ScalarBits {
260        NistP384::ORDER.to_words().into()
261    }
262}
263
264impl Reduce<U384> for Scalar {
265    type Bytes = FieldBytes;
266
267    fn reduce(w: U384) -> Self {
268        let (r, underflow) = w.sbb(&NistP384::ORDER, Limb::ZERO);
269        let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
270        Self::from_uint_unchecked(U384::conditional_select(&w, &r, !underflow))
271    }
272
273    #[inline]
274    fn reduce_bytes(bytes: &FieldBytes) -> Self {
275        Self::reduce(U384::from_be_byte_array(*bytes))
276    }
277}
278
279impl From<ScalarPrimitive<NistP384>> for Scalar {
280    fn from(w: ScalarPrimitive<NistP384>) -> Self {
281        Scalar::from(&w)
282    }
283}
284
285impl From<&ScalarPrimitive<NistP384>> for Scalar {
286    fn from(w: &ScalarPrimitive<NistP384>) -> Scalar {
287        Scalar::from_uint_unchecked(*w.as_uint())
288    }
289}
290
291impl From<Scalar> for ScalarPrimitive<NistP384> {
292    fn from(scalar: Scalar) -> ScalarPrimitive<NistP384> {
293        ScalarPrimitive::from(&scalar)
294    }
295}
296
297impl From<&Scalar> for ScalarPrimitive<NistP384> {
298    fn from(scalar: &Scalar) -> ScalarPrimitive<NistP384> {
299        ScalarPrimitive::new(scalar.into()).unwrap()
300    }
301}
302
303impl From<Scalar> for FieldBytes {
304    fn from(scalar: Scalar) -> Self {
305        scalar.to_repr()
306    }
307}
308
309impl From<&Scalar> for FieldBytes {
310    fn from(scalar: &Scalar) -> Self {
311        scalar.to_repr()
312    }
313}
314
315impl From<Scalar> for U384 {
316    fn from(scalar: Scalar) -> U384 {
317        U384::from(&scalar)
318    }
319}
320
321impl From<&Scalar> for U384 {
322    fn from(scalar: &Scalar) -> U384 {
323        scalar.to_canonical()
324    }
325}
326
327impl From<&SecretKey> for Scalar {
328    fn from(secret_key: &SecretKey) -> Scalar {
329        *secret_key.to_nonzero_scalar()
330    }
331}
332
333impl TryFrom<U384> for Scalar {
334    type Error = Error;
335
336    fn try_from(w: U384) -> Result<Self> {
337        Option::from(Self::from_uint(w)).ok_or(Error)
338    }
339}
340
341#[cfg(feature = "serde")]
342impl Serialize for Scalar {
343    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
344    where
345        S: ser::Serializer,
346    {
347        ScalarPrimitive::from(self).serialize(serializer)
348    }
349}
350
351#[cfg(feature = "serde")]
352impl<'de> Deserialize<'de> for Scalar {
353    fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
354    where
355        D: de::Deserializer<'de>,
356    {
357        Ok(ScalarPrimitive::deserialize(deserializer)?.into())
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::Scalar;
364    use crate::FieldBytes;
365    use elliptic_curve::ff::PrimeField;
366    use primeorder::impl_primefield_tests;
367
368    const T: [u64; 6] = [
370        0x76760cb5666294b9,
371        0xac0d06d9245853bd,
372        0xe3b1a6c0fa1b96ef,
373        0xffffffffffffffff,
374        0xffffffffffffffff,
375        0x7fffffffffffffff,
376    ];
377
378    impl_primefield_tests!(Scalar, T);
379
380    #[test]
381    fn from_to_bytes_roundtrip() {
382        let k: u64 = 42;
383        let mut bytes = FieldBytes::default();
384        bytes[40..].copy_from_slice(k.to_be_bytes().as_ref());
385
386        let scalar = Scalar::from_repr(bytes).unwrap();
387        assert_eq!(bytes, scalar.to_bytes());
388    }
389
390    #[test]
392    fn multiply() {
393        let one = Scalar::ONE;
394        let two = one + one;
395        let three = two + one;
396        let six = three + three;
397        assert_eq!(six, two * three);
398
399        let minus_two = -two;
400        let minus_three = -three;
401        assert_eq!(two, -minus_two);
402
403        assert_eq!(minus_three * minus_two, minus_two * minus_three);
404        assert_eq!(six, minus_two * minus_three);
405    }
406
407    #[test]
409    fn invert() {
410        let one = Scalar::ONE;
411        let three = one + one + one;
412        let inv_three = three.invert().unwrap();
413        assert_eq!(three * inv_three, one);
414
415        let minus_three = -three;
416        let inv_minus_three = minus_three.invert().unwrap();
417        assert_eq!(inv_minus_three, -inv_three);
418        assert_eq!(three * inv_minus_three, -one);
419    }
420
421    #[test]
423    fn sqrt() {
424        for &n in &[1u64, 4, 9, 16, 25, 36, 49, 64] {
425            let scalar = Scalar::from(n);
426            let sqrt = scalar.sqrt().unwrap();
427            assert_eq!(sqrt.square(), scalar);
428        }
429    }
430}