Skip to main content

elliptic_curve/point/
non_identity.rs

1//! Non-identity point type.
2
3#![cfg(feature = "arithmetic")]
4
5use common::Generate;
6use core::ops::{Deref, Mul};
7use group::{Group, GroupEncoding, prime::PrimeCurveAffine};
8use rand_core::{CryptoRng, TryCryptoRng};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "alloc")]
12use alloc::vec::Vec;
13#[cfg(feature = "serde")]
14use serdect::serde::{Deserialize, Serialize, de, ser};
15use zeroize::Zeroize;
16
17use crate::{BatchNormalize, CurveArithmetic, CurveGroup, NonZeroScalar, Scalar};
18
19/// Non-identity point type.
20///
21/// This type ensures that its value is not the identity point, ala `core::num::NonZero*`.
22///
23/// In the context of ECC, it's useful for ensuring that certain arithmetic
24/// cannot result in the identity point.
25#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26// `repr` is needed for `unsafe` safety invariants below
27#[repr(transparent)]
28pub struct NonIdentity<P> {
29    point: P,
30}
31
32impl<P> NonIdentity<P>
33where
34    P: ConditionallySelectable + ConstantTimeEq + Default,
35{
36    /// Create a [`NonIdentity`] from a point.
37    pub fn new(point: P) -> CtOption<Self> {
38        CtOption::new(Self { point }, !point.ct_eq(&P::default()))
39    }
40
41    pub(crate) fn new_unchecked(point: P) -> Self {
42        Self { point }
43    }
44}
45
46impl<P> NonIdentity<P>
47where
48    P: ConditionallySelectable + ConstantTimeEq + Default + GroupEncoding,
49{
50    /// Decode a [`NonIdentity`] from its encoding.
51    pub fn from_repr(repr: &P::Repr) -> CtOption<Self> {
52        Self::from_bytes(repr)
53    }
54}
55
56impl<P> NonIdentity<P> {
57    /// Transform array reference containing [`NonIdentity`] points to an array reference to the
58    /// inner point type.
59    pub fn array_as_inner<const N: usize>(points: &[Self; N]) -> &[P; N] {
60        // SAFETY: `NonIdentity` is a `repr(transparent)` newtype for `P` so it's safe to cast to
61        // the inner `P` type.
62        #[allow(unsafe_code)]
63        unsafe {
64            &*points.as_ptr().cast()
65        }
66    }
67
68    /// Transform slice containing [`NonIdentity`] points to a slice of the inner point type.
69    pub fn slice_as_inner(points: &[Self]) -> &[P] {
70        // SAFETY: `NonIdentity` is a `repr(transparent)` newtype for `P` so it's safe to cast to
71        // the inner `P` type.
72        #[allow(unsafe_code)]
73        unsafe {
74            &*(core::ptr::from_ref(points) as *const [P])
75        }
76    }
77
78    /// Transform array reference containing [`NonIdentity`] points to an array reference to the
79    /// inner point type.
80    #[deprecated(since = "0.14.0", note = "use `NonIdentity::array_as_inner` instead")]
81    pub fn cast_array_as_inner<const N: usize>(points: &[Self; N]) -> &[P; N] {
82        Self::array_as_inner(points)
83    }
84
85    /// Transform slice containing [`NonIdentity`] points to a slice of the inner point type.
86    #[deprecated(since = "0.14.0", note = "use `NonIdentity::slice_as_inner` instead")]
87    pub fn cast_slice_as_inner(points: &[Self]) -> &[P] {
88        Self::slice_as_inner(points)
89    }
90}
91
92impl<P: Copy> NonIdentity<P> {
93    /// Return wrapped point.
94    pub fn to_point(self) -> P {
95        self.point
96    }
97}
98
99impl<P> NonIdentity<P>
100where
101    P: ConditionallySelectable + ConstantTimeEq + CurveGroup + Default,
102{
103    /// Generate a random `NonIdentity<ProjectivePoint>`.
104    #[deprecated(since = "0.14.0", note = "use the `Generate` trait instead")]
105    pub fn random<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
106        loop {
107            if let Some(point) = Self::new(P::random(rng)).into() {
108                break point;
109            }
110        }
111    }
112
113    /// Converts this element into its affine representation.
114    pub fn to_affine(self) -> NonIdentity<P::Affine> {
115        NonIdentity {
116            point: self.point.to_affine(),
117        }
118    }
119
120    /// Multiply by the generator of the prime-order subgroup.
121    pub fn mul_by_generator<C: CurveArithmetic>(scalar: &NonZeroScalar<C>) -> Self
122    where
123        P: Group<Scalar = C::Scalar>,
124    {
125        Self {
126            point: P::mul_by_generator(scalar),
127        }
128    }
129}
130
131impl<P> NonIdentity<P>
132where
133    P: PrimeCurveAffine,
134{
135    /// Converts this element to its curve representation.
136    pub fn to_curve(self) -> NonIdentity<P::Curve> {
137        NonIdentity {
138            point: self.point.to_curve(),
139        }
140    }
141}
142
143impl<P> AsRef<P> for NonIdentity<P> {
144    fn as_ref(&self) -> &P {
145        &self.point
146    }
147}
148
149impl<const N: usize, P> BatchNormalize<[Self; N]> for NonIdentity<P>
150where
151    P: CurveGroup + BatchNormalize<[P; N], Output = [P::Affine; N]>,
152{
153    type Output = [NonIdentity<P::Affine>; N];
154
155    fn batch_normalize(points: &[Self; N]) -> [NonIdentity<P::Affine>; N] {
156        let points = Self::array_as_inner::<N>(points);
157        let affine_points = <P as BatchNormalize<_>>::batch_normalize(points);
158        affine_points.map(|point| NonIdentity { point })
159    }
160}
161
162#[cfg(feature = "alloc")]
163impl<P> BatchNormalize<[Self]> for NonIdentity<P>
164where
165    P: CurveGroup + BatchNormalize<[P], Output = Vec<P::Affine>>,
166{
167    type Output = Vec<NonIdentity<P::Affine>>;
168
169    fn batch_normalize(points: &[Self]) -> Vec<NonIdentity<P::Affine>> {
170        let points = Self::slice_as_inner(points);
171        let affine_points = <P as BatchNormalize<_>>::batch_normalize(points);
172        affine_points
173            .into_iter()
174            .map(|point| NonIdentity { point })
175            .collect()
176    }
177}
178
179impl<P> ConditionallySelectable for NonIdentity<P>
180where
181    P: ConditionallySelectable,
182{
183    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
184        Self {
185            point: P::conditional_select(&a.point, &b.point, choice),
186        }
187    }
188}
189
190impl<P> ConstantTimeEq for NonIdentity<P>
191where
192    P: ConstantTimeEq,
193{
194    fn ct_eq(&self, other: &Self) -> Choice {
195        self.point.ct_eq(&other.point)
196    }
197}
198
199impl<P> Deref for NonIdentity<P> {
200    type Target = P;
201
202    fn deref(&self) -> &Self::Target {
203        &self.point
204    }
205}
206
207impl<P> Generate for NonIdentity<P>
208where
209    P: ConditionallySelectable + ConstantTimeEq + Default + Generate,
210{
211    fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
212        loop {
213            if let Some(point) = Self::new(P::try_generate_from_rng(rng)?).into() {
214                break Ok(point);
215            }
216        }
217    }
218}
219
220impl<P> GroupEncoding for NonIdentity<P>
221where
222    P: ConditionallySelectable + ConstantTimeEq + Default + GroupEncoding,
223{
224    type Repr = P::Repr;
225
226    fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
227        let point = P::from_bytes(bytes);
228        point.and_then(|point| CtOption::new(Self { point }, !point.ct_eq(&P::default())))
229    }
230
231    fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
232        P::from_bytes_unchecked(bytes).map(|point| Self { point })
233    }
234
235    fn to_bytes(&self) -> Self::Repr {
236        self.point.to_bytes()
237    }
238}
239
240impl<C, P> Mul<NonZeroScalar<C>> for NonIdentity<P>
241where
242    C: CurveArithmetic,
243    P: Copy + Mul<Scalar<C>, Output = P>,
244{
245    type Output = NonIdentity<P>;
246
247    fn mul(self, rhs: NonZeroScalar<C>) -> Self::Output {
248        &self * &rhs
249    }
250}
251
252impl<C, P> Mul<&NonZeroScalar<C>> for NonIdentity<P>
253where
254    C: CurveArithmetic,
255    P: Copy + Mul<Scalar<C>, Output = P>,
256{
257    type Output = NonIdentity<P>;
258
259    fn mul(self, rhs: &NonZeroScalar<C>) -> Self::Output {
260        self * *rhs
261    }
262}
263
264impl<C, P> Mul<NonZeroScalar<C>> for &NonIdentity<P>
265where
266    C: CurveArithmetic,
267    P: Copy + Mul<Scalar<C>, Output = P>,
268{
269    type Output = NonIdentity<P>;
270
271    fn mul(self, rhs: NonZeroScalar<C>) -> Self::Output {
272        NonIdentity {
273            point: self.point * *rhs.as_ref(),
274        }
275    }
276}
277
278impl<C, P> Mul<&NonZeroScalar<C>> for &NonIdentity<P>
279where
280    C: CurveArithmetic,
281    P: Copy + Mul<Scalar<C>, Output = P>,
282{
283    type Output = NonIdentity<P>;
284
285    fn mul(self, rhs: &NonZeroScalar<C>) -> Self::Output {
286        self * *rhs
287    }
288}
289
290#[cfg(feature = "serde")]
291impl<P> Serialize for NonIdentity<P>
292where
293    P: Serialize,
294{
295    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296    where
297        S: ser::Serializer,
298    {
299        self.point.serialize(serializer)
300    }
301}
302
303#[cfg(feature = "serde")]
304impl<'de, P> Deserialize<'de> for NonIdentity<P>
305where
306    P: ConditionallySelectable + ConstantTimeEq + Default + Deserialize<'de> + GroupEncoding,
307{
308    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
309    where
310        D: de::Deserializer<'de>,
311    {
312        Self::new(P::deserialize(deserializer)?)
313            .into_option()
314            .ok_or_else(|| de::Error::custom("expected non-identity point"))
315    }
316}
317
318impl<P: Group> Zeroize for NonIdentity<P> {
319    fn zeroize(&mut self) {
320        self.point = P::generator();
321    }
322}
323
324#[cfg(all(test, feature = "dev"))]
325mod tests {
326    use super::NonIdentity;
327    use crate::BatchNormalize;
328    use crate::dev::{AffinePoint, NonZeroScalar, ProjectivePoint, SecretKey};
329    use group::GroupEncoding;
330    use hex_literal::hex;
331    use zeroize::Zeroize;
332
333    #[test]
334    fn new_success() {
335        let point = ProjectivePoint::from_bytes(
336            &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
337        )
338        .unwrap();
339
340        assert!(bool::from(NonIdentity::new(point).is_some()));
341
342        assert!(bool::from(
343            NonIdentity::new(AffinePoint::from(point)).is_some()
344        ));
345    }
346
347    #[test]
348    fn new_fail() {
349        assert!(bool::from(
350            NonIdentity::new(ProjectivePoint::default()).is_none()
351        ));
352        assert!(bool::from(
353            NonIdentity::new(AffinePoint::default()).is_none()
354        ));
355    }
356
357    #[test]
358    fn round_trip() {
359        let bytes = hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
360        let point = NonIdentity::<ProjectivePoint>::from_repr(&bytes.into()).unwrap();
361        assert_eq!(&bytes, point.to_bytes().as_slice());
362
363        let bytes = hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
364        let point = NonIdentity::<AffinePoint>::from_repr(&bytes.into()).unwrap();
365        assert_eq!(&bytes, point.to_bytes().as_slice());
366    }
367
368    #[test]
369    fn zeroize() {
370        let point = ProjectivePoint::from_bytes(
371            &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
372        )
373        .unwrap();
374        let mut point = NonIdentity::new(point).unwrap();
375        point.zeroize();
376
377        assert_eq!(point.to_point(), ProjectivePoint::Generator);
378    }
379
380    #[test]
381    fn mul_by_generator() {
382        let scalar = NonZeroScalar::from_repr(
383            hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
384        )
385        .unwrap();
386        let point = NonIdentity::<ProjectivePoint>::mul_by_generator(&scalar);
387
388        let sk = SecretKey::from(scalar);
389        let pk = sk.public_key();
390
391        assert_eq!(point.to_point(), pk.to_projective());
392    }
393
394    #[test]
395    fn batch_normalize() {
396        let point = ProjectivePoint::from_bytes(
397            &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
398        )
399        .unwrap();
400        let point = NonIdentity::new(point).unwrap();
401        let points = [point, point];
402
403        for (point, affine_point) in points
404            .into_iter()
405            .zip(NonIdentity::batch_normalize(&points))
406        {
407            assert_eq!(point.to_affine(), affine_point);
408        }
409    }
410
411    #[test]
412    #[cfg(feature = "alloc")]
413    fn batch_normalize_alloc() {
414        let point = ProjectivePoint::from_bytes(
415            &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
416        )
417        .unwrap();
418        let point = NonIdentity::new(point).unwrap();
419        let points = vec![point, point];
420
421        let affine_points = NonIdentity::batch_normalize(points.as_slice());
422
423        for (point, affine_point) in points.into_iter().zip(affine_points) {
424            assert_eq!(point.to_affine(), affine_point);
425        }
426    }
427}