1use crate::{
4 CurveArithmetic, Error, FieldBytes, PrimeCurve, Scalar, ScalarValue, SecretKey,
5 ops::{Invert, Reduce, ReduceNonZero},
6 point::NonIdentity,
7 scalar::IsHigh,
8};
9use base16ct::HexDisplay;
10use common::Generate;
11use core::{
12 fmt,
13 ops::{Deref, Mul, MulAssign, Neg},
14 str,
15};
16use ff::{Field, PrimeField};
17use rand_core::{CryptoRng, TryCryptoRng};
18use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
19use zeroize::Zeroize;
20
21#[cfg(feature = "serde")]
22use serdect::serde::{Deserialize, Serialize, de, ser};
23
24#[derive(Clone)]
33#[repr(transparent)]
35pub struct NonZeroScalar<C>
36where
37 C: CurveArithmetic,
38{
39 scalar: Scalar<C>,
40}
41
42impl<C: CurveArithmetic> fmt::Debug for NonZeroScalar<C> {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.debug_struct("NonZeroScalar").finish_non_exhaustive()
45 }
46}
47
48impl<C> NonZeroScalar<C>
49where
50 C: CurveArithmetic,
51{
52 pub fn new(scalar: Scalar<C>) -> CtOption<Self> {
54 CtOption::new(Self { scalar }, !scalar.is_zero())
55 }
56
57 pub fn from_repr(repr: FieldBytes<C>) -> CtOption<Self> {
59 Scalar::<C>::from_repr(repr).and_then(Self::new)
60 }
61
62 pub fn from_uint(uint: C::Uint) -> CtOption<Self> {
64 ScalarValue::new(uint).and_then(|scalar| Self::new(scalar.into()))
65 }
66
67 pub fn cast_array_as_inner<const N: usize>(scalars: &[Self; N]) -> &[Scalar<C>; N] {
70 #[allow(unsafe_code)]
73 unsafe {
74 &*scalars.as_ptr().cast()
75 }
76 }
77
78 pub fn cast_slice_as_inner(scalars: &[Self]) -> &[Scalar<C>] {
80 #[allow(unsafe_code)]
83 unsafe {
84 &*(core::ptr::from_ref(scalars) as *const [Scalar<C>])
85 }
86 }
87
88 #[deprecated(since = "0.14.0", note = "use the `Generate` trait instead")]
90 pub fn random<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
91 Self::generate_from_rng(rng)
92 }
93}
94
95impl<C> AsRef<Scalar<C>> for NonZeroScalar<C>
96where
97 C: CurveArithmetic,
98{
99 fn as_ref(&self) -> &Scalar<C> {
100 &self.scalar
101 }
102}
103
104impl<C> ConditionallySelectable for NonZeroScalar<C>
105where
106 C: CurveArithmetic,
107{
108 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
109 Self {
110 scalar: Scalar::<C>::conditional_select(&a.scalar, &b.scalar, choice),
111 }
112 }
113}
114
115impl<C> ConstantTimeEq for NonZeroScalar<C>
116where
117 C: CurveArithmetic,
118{
119 fn ct_eq(&self, other: &Self) -> Choice {
120 self.scalar.ct_eq(&other.scalar)
121 }
122}
123
124impl<C> Copy for NonZeroScalar<C> where C: CurveArithmetic {}
125
126impl<C> Deref for NonZeroScalar<C>
127where
128 C: CurveArithmetic,
129{
130 type Target = Scalar<C>;
131
132 fn deref(&self) -> &Scalar<C> {
133 &self.scalar
134 }
135}
136
137impl<C> Eq for NonZeroScalar<C> where C: CurveArithmetic {}
138
139impl<C> From<NonZeroScalar<C>> for FieldBytes<C>
140where
141 C: CurveArithmetic,
142{
143 fn from(scalar: NonZeroScalar<C>) -> FieldBytes<C> {
144 Self::from(&scalar)
145 }
146}
147
148impl<C> From<&NonZeroScalar<C>> for FieldBytes<C>
149where
150 C: CurveArithmetic,
151{
152 fn from(scalar: &NonZeroScalar<C>) -> FieldBytes<C> {
153 scalar.to_repr()
154 }
155}
156
157impl<C> From<NonZeroScalar<C>> for ScalarValue<C>
158where
159 C: CurveArithmetic,
160{
161 #[inline]
162 fn from(scalar: NonZeroScalar<C>) -> ScalarValue<C> {
163 Self::from(&scalar)
164 }
165}
166
167impl<C> From<&NonZeroScalar<C>> for ScalarValue<C>
168where
169 C: CurveArithmetic,
170{
171 fn from(scalar: &NonZeroScalar<C>) -> ScalarValue<C> {
172 scalar.scalar.into()
173 }
174}
175
176impl<C> From<SecretKey<C>> for NonZeroScalar<C>
177where
178 C: CurveArithmetic,
179{
180 fn from(sk: SecretKey<C>) -> NonZeroScalar<C> {
181 Self::from(&sk)
182 }
183}
184
185impl<C> From<&SecretKey<C>> for NonZeroScalar<C>
186where
187 C: CurveArithmetic,
188{
189 fn from(sk: &SecretKey<C>) -> NonZeroScalar<C> {
190 let scalar = sk.as_scalar_value().to_scalar();
191 debug_assert!(!bool::from(scalar.is_zero()));
192 Self { scalar }
193 }
194}
195
196impl<C> Generate for NonZeroScalar<C>
197where
198 C: CurveArithmetic,
199{
200 fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
201 loop {
205 if let Some(result) = Self::new(Scalar::<C>::try_generate_from_rng(rng)?).into() {
206 break Ok(result);
207 }
208 }
209 }
210}
211
212impl<C> Invert for NonZeroScalar<C>
213where
214 C: CurveArithmetic,
215 Scalar<C>: Invert<Output = CtOption<Scalar<C>>>,
216{
217 type Output = Self;
218
219 fn invert(&self) -> Self {
220 Self {
221 scalar: Invert::invert(&self.scalar).unwrap(),
223 }
224 }
225
226 fn invert_vartime(&self) -> Self::Output {
227 Self {
228 scalar: Invert::invert_vartime(&self.scalar).unwrap(),
230 }
231 }
232}
233
234impl<C> IsHigh for NonZeroScalar<C>
235where
236 C: CurveArithmetic,
237{
238 fn is_high(&self) -> Choice {
239 self.scalar.is_high()
240 }
241}
242
243impl<C> Neg for NonZeroScalar<C>
244where
245 C: CurveArithmetic,
246{
247 type Output = NonZeroScalar<C>;
248
249 fn neg(self) -> NonZeroScalar<C> {
250 let scalar = -self.scalar;
251 debug_assert!(!bool::from(scalar.is_zero()));
252 NonZeroScalar { scalar }
253 }
254}
255
256impl<C> Mul<NonZeroScalar<C>> for NonZeroScalar<C>
257where
258 C: PrimeCurve + CurveArithmetic,
259{
260 type Output = Self;
261
262 #[inline]
263 fn mul(self, other: Self) -> Self {
264 Self::mul(self, &other)
265 }
266}
267
268impl<C> Mul<&NonZeroScalar<C>> for NonZeroScalar<C>
269where
270 C: PrimeCurve + CurveArithmetic,
271{
272 type Output = Self;
273
274 fn mul(self, other: &Self) -> Self {
275 let scalar = self.scalar * other.scalar;
278 debug_assert!(!bool::from(scalar.is_zero()));
279 NonZeroScalar { scalar }
280 }
281}
282
283impl<C, P> Mul<NonIdentity<P>> for NonZeroScalar<C>
284where
285 C: CurveArithmetic,
286 NonIdentity<P>: Mul<NonZeroScalar<C>, Output = NonIdentity<P>>,
287{
288 type Output = NonIdentity<P>;
289
290 fn mul(self, rhs: NonIdentity<P>) -> Self::Output {
291 rhs * self
292 }
293}
294
295impl<C, P> Mul<&NonIdentity<P>> for NonZeroScalar<C>
296where
297 C: CurveArithmetic,
298 for<'a> &'a NonIdentity<P>: Mul<NonZeroScalar<C>, Output = NonIdentity<P>>,
299{
300 type Output = NonIdentity<P>;
301
302 fn mul(self, rhs: &NonIdentity<P>) -> Self::Output {
303 rhs * self
304 }
305}
306
307impl<C, P> Mul<NonIdentity<P>> for &NonZeroScalar<C>
308where
309 C: CurveArithmetic,
310 for<'a> NonIdentity<P>: Mul<&'a NonZeroScalar<C>, Output = NonIdentity<P>>,
311{
312 type Output = NonIdentity<P>;
313
314 fn mul(self, rhs: NonIdentity<P>) -> Self::Output {
315 rhs * self
316 }
317}
318
319impl<C, P> Mul<&NonIdentity<P>> for &NonZeroScalar<C>
320where
321 C: CurveArithmetic,
322 for<'a> &'a NonIdentity<P>: Mul<&'a NonZeroScalar<C>, Output = NonIdentity<P>>,
323{
324 type Output = NonIdentity<P>;
325
326 fn mul(self, rhs: &NonIdentity<P>) -> Self::Output {
327 rhs * self
328 }
329}
330
331impl<C> MulAssign for NonZeroScalar<C>
332where
333 C: PrimeCurve + CurveArithmetic,
334{
335 fn mul_assign(&mut self, rhs: Self) {
336 *self = *self * rhs;
337 }
338}
339
340impl<C> PartialEq for NonZeroScalar<C>
341where
342 C: CurveArithmetic,
343{
344 fn eq(&self, other: &Self) -> bool {
345 self.scalar.eq(&other.scalar)
346 }
347}
348
349impl<C, T> Reduce<T> for NonZeroScalar<C>
350where
351 C: CurveArithmetic,
352 Scalar<C>: ReduceNonZero<T>,
353{
354 fn reduce(n: &T) -> Self {
355 <Self as ReduceNonZero<T>>::reduce_nonzero(n)
356 }
357}
358
359impl<C, T> ReduceNonZero<T> for NonZeroScalar<C>
360where
361 C: CurveArithmetic,
362 Scalar<C>: ReduceNonZero<T>,
363{
364 fn reduce_nonzero(n: &T) -> Self {
365 let scalar = Scalar::<C>::reduce_nonzero(n);
366 debug_assert!(!bool::from(scalar.is_zero()));
367 Self { scalar }
368 }
369}
370
371impl<C> TryFrom<&[u8]> for NonZeroScalar<C>
372where
373 C: CurveArithmetic,
374{
375 type Error = Error;
376
377 fn try_from(bytes: &[u8]) -> Result<Self, Error> {
378 NonZeroScalar::from_repr(bytes.try_into()?)
379 .into_option()
380 .ok_or(Error)
381 }
382}
383
384impl<C> Zeroize for NonZeroScalar<C>
385where
386 C: CurveArithmetic,
387{
388 fn zeroize(&mut self) {
389 self.scalar.zeroize();
391
392 self.scalar = Scalar::<C>::ONE;
395 }
396}
397
398impl<C> fmt::Display for NonZeroScalar<C>
399where
400 C: CurveArithmetic,
401{
402 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
403 write!(f, "{self:X}")
404 }
405}
406
407impl<C> fmt::LowerHex for NonZeroScalar<C>
408where
409 C: CurveArithmetic,
410{
411 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
412 write!(f, "{:x}", HexDisplay(&self.to_repr()))
413 }
414}
415
416impl<C> fmt::UpperHex for NonZeroScalar<C>
417where
418 C: CurveArithmetic,
419{
420 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
421 write!(f, "{:}", HexDisplay(&self.to_repr()))
422 }
423}
424
425impl<C> str::FromStr for NonZeroScalar<C>
426where
427 C: CurveArithmetic,
428{
429 type Err = Error;
430
431 fn from_str(hex: &str) -> Result<Self, Error> {
432 let mut bytes = FieldBytes::<C>::default();
433
434 if base16ct::mixed::decode(hex, &mut bytes)?.len() == bytes.len() {
435 Self::from_repr(bytes).into_option().ok_or(Error)
436 } else {
437 Err(Error)
438 }
439 }
440}
441
442#[cfg(feature = "serde")]
443impl<C> Serialize for NonZeroScalar<C>
444where
445 C: CurveArithmetic,
446{
447 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
448 where
449 S: ser::Serializer,
450 {
451 ScalarValue::from(self).serialize(serializer)
452 }
453}
454
455#[cfg(feature = "serde")]
456impl<'de, C> Deserialize<'de> for NonZeroScalar<C>
457where
458 C: CurveArithmetic,
459{
460 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
461 where
462 D: de::Deserializer<'de>,
463 {
464 let scalar = ScalarValue::deserialize(deserializer)?;
465 Self::new(scalar.into())
466 .into_option()
467 .ok_or_else(|| de::Error::custom("expected non-zero scalar"))
468 }
469}
470
471#[cfg(all(test, feature = "dev"))]
472mod tests {
473 use crate::dev::{NonZeroScalar, Scalar};
474 use ff::{Field, PrimeField};
475 use hex_literal::hex;
476 use zeroize::Zeroize;
477
478 #[test]
479 fn round_trip() {
480 let bytes = hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
481 let scalar = NonZeroScalar::from_repr(bytes.into()).unwrap();
482 assert_eq!(&bytes, scalar.to_repr().as_slice());
483 }
484
485 #[test]
486 fn zeroize() {
487 let mut scalar = NonZeroScalar::new(Scalar::from(42u64)).unwrap();
488 scalar.zeroize();
489 assert_eq!(*scalar, Scalar::ONE);
490 }
491}