1#![cfg(feature = "arithmetic")]
4
5use common::Generate;
6use core::ops::{Deref, Mul};
7use group::{Group, GroupEncoding, prime::PrimeCurveAffine};
8use rand_core::{CryptoRng, TryCryptoRng};
9use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11#[cfg(feature = "alloc")]
12use alloc::vec::Vec;
13#[cfg(feature = "serde")]
14use serdect::serde::{Deserialize, Serialize, de, ser};
15use zeroize::Zeroize;
16
17use crate::{BatchNormalize, CurveArithmetic, CurveGroup, NonZeroScalar, Scalar};
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq)]
26#[repr(transparent)]
28pub struct NonIdentity<P> {
29 point: P,
30}
31
32impl<P> NonIdentity<P>
33where
34 P: ConditionallySelectable + ConstantTimeEq + Default,
35{
36 pub fn new(point: P) -> CtOption<Self> {
38 CtOption::new(Self { point }, !point.ct_eq(&P::default()))
39 }
40
41 pub(crate) fn new_unchecked(point: P) -> Self {
42 Self { point }
43 }
44}
45
46impl<P> NonIdentity<P>
47where
48 P: ConditionallySelectable + ConstantTimeEq + Default + GroupEncoding,
49{
50 pub fn from_repr(repr: &P::Repr) -> CtOption<Self> {
52 Self::from_bytes(repr)
53 }
54}
55
56impl<P> NonIdentity<P> {
57 pub fn array_as_inner<const N: usize>(points: &[Self; N]) -> &[P; N] {
60 #[allow(unsafe_code)]
63 unsafe {
64 &*points.as_ptr().cast()
65 }
66 }
67
68 pub fn slice_as_inner(points: &[Self]) -> &[P] {
70 #[allow(unsafe_code)]
73 unsafe {
74 &*(core::ptr::from_ref(points) as *const [P])
75 }
76 }
77
78 #[deprecated(since = "0.14.0", note = "use `NonIdentity::array_as_inner` instead")]
81 pub fn cast_array_as_inner<const N: usize>(points: &[Self; N]) -> &[P; N] {
82 Self::array_as_inner(points)
83 }
84
85 #[deprecated(since = "0.14.0", note = "use `NonIdentity::slice_as_inner` instead")]
87 pub fn cast_slice_as_inner(points: &[Self]) -> &[P] {
88 Self::slice_as_inner(points)
89 }
90}
91
92impl<P: Copy> NonIdentity<P> {
93 pub fn to_point(self) -> P {
95 self.point
96 }
97}
98
99impl<P> NonIdentity<P>
100where
101 P: ConditionallySelectable + ConstantTimeEq + CurveGroup + Default,
102{
103 #[deprecated(since = "0.14.0", note = "use the `Generate` trait instead")]
105 pub fn random<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
106 loop {
107 if let Some(point) = Self::new(P::random(rng)).into() {
108 break point;
109 }
110 }
111 }
112
113 pub fn to_affine(self) -> NonIdentity<P::Affine> {
115 NonIdentity {
116 point: self.point.to_affine(),
117 }
118 }
119
120 pub fn mul_by_generator<C: CurveArithmetic>(scalar: &NonZeroScalar<C>) -> Self
122 where
123 P: Group<Scalar = C::Scalar>,
124 {
125 Self {
126 point: P::mul_by_generator(scalar),
127 }
128 }
129}
130
131impl<P> NonIdentity<P>
132where
133 P: PrimeCurveAffine,
134{
135 pub fn to_curve(self) -> NonIdentity<P::Curve> {
137 NonIdentity {
138 point: self.point.to_curve(),
139 }
140 }
141}
142
143impl<P> AsRef<P> for NonIdentity<P> {
144 fn as_ref(&self) -> &P {
145 &self.point
146 }
147}
148
149impl<const N: usize, P> BatchNormalize<[Self; N]> for NonIdentity<P>
150where
151 P: CurveGroup + BatchNormalize<[P; N], Output = [P::Affine; N]>,
152{
153 type Output = [NonIdentity<P::Affine>; N];
154
155 fn batch_normalize(points: &[Self; N]) -> [NonIdentity<P::Affine>; N] {
156 let points = Self::array_as_inner::<N>(points);
157 let affine_points = <P as BatchNormalize<_>>::batch_normalize(points);
158 affine_points.map(|point| NonIdentity { point })
159 }
160}
161
162#[cfg(feature = "alloc")]
163impl<P> BatchNormalize<[Self]> for NonIdentity<P>
164where
165 P: CurveGroup + BatchNormalize<[P], Output = Vec<P::Affine>>,
166{
167 type Output = Vec<NonIdentity<P::Affine>>;
168
169 fn batch_normalize(points: &[Self]) -> Vec<NonIdentity<P::Affine>> {
170 let points = Self::slice_as_inner(points);
171 let affine_points = <P as BatchNormalize<_>>::batch_normalize(points);
172 affine_points
173 .into_iter()
174 .map(|point| NonIdentity { point })
175 .collect()
176 }
177}
178
179impl<P> ConditionallySelectable for NonIdentity<P>
180where
181 P: ConditionallySelectable,
182{
183 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
184 Self {
185 point: P::conditional_select(&a.point, &b.point, choice),
186 }
187 }
188}
189
190impl<P> ConstantTimeEq for NonIdentity<P>
191where
192 P: ConstantTimeEq,
193{
194 fn ct_eq(&self, other: &Self) -> Choice {
195 self.point.ct_eq(&other.point)
196 }
197}
198
199impl<P> Deref for NonIdentity<P> {
200 type Target = P;
201
202 fn deref(&self) -> &Self::Target {
203 &self.point
204 }
205}
206
207impl<P> Generate for NonIdentity<P>
208where
209 P: ConditionallySelectable + ConstantTimeEq + Default + Generate,
210{
211 fn try_generate_from_rng<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
212 loop {
213 if let Some(point) = Self::new(P::try_generate_from_rng(rng)?).into() {
214 break Ok(point);
215 }
216 }
217 }
218}
219
220impl<P> GroupEncoding for NonIdentity<P>
221where
222 P: ConditionallySelectable + ConstantTimeEq + Default + GroupEncoding,
223{
224 type Repr = P::Repr;
225
226 fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
227 let point = P::from_bytes(bytes);
228 point.and_then(|point| CtOption::new(Self { point }, !point.ct_eq(&P::default())))
229 }
230
231 fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
232 P::from_bytes_unchecked(bytes).map(|point| Self { point })
233 }
234
235 fn to_bytes(&self) -> Self::Repr {
236 self.point.to_bytes()
237 }
238}
239
240impl<C, P> Mul<NonZeroScalar<C>> for NonIdentity<P>
241where
242 C: CurveArithmetic,
243 P: Copy + Mul<Scalar<C>, Output = P>,
244{
245 type Output = NonIdentity<P>;
246
247 fn mul(self, rhs: NonZeroScalar<C>) -> Self::Output {
248 &self * &rhs
249 }
250}
251
252impl<C, P> Mul<&NonZeroScalar<C>> for NonIdentity<P>
253where
254 C: CurveArithmetic,
255 P: Copy + Mul<Scalar<C>, Output = P>,
256{
257 type Output = NonIdentity<P>;
258
259 fn mul(self, rhs: &NonZeroScalar<C>) -> Self::Output {
260 self * *rhs
261 }
262}
263
264impl<C, P> Mul<NonZeroScalar<C>> for &NonIdentity<P>
265where
266 C: CurveArithmetic,
267 P: Copy + Mul<Scalar<C>, Output = P>,
268{
269 type Output = NonIdentity<P>;
270
271 fn mul(self, rhs: NonZeroScalar<C>) -> Self::Output {
272 NonIdentity {
273 point: self.point * *rhs.as_ref(),
274 }
275 }
276}
277
278impl<C, P> Mul<&NonZeroScalar<C>> for &NonIdentity<P>
279where
280 C: CurveArithmetic,
281 P: Copy + Mul<Scalar<C>, Output = P>,
282{
283 type Output = NonIdentity<P>;
284
285 fn mul(self, rhs: &NonZeroScalar<C>) -> Self::Output {
286 self * *rhs
287 }
288}
289
290#[cfg(feature = "serde")]
291impl<P> Serialize for NonIdentity<P>
292where
293 P: Serialize,
294{
295 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
296 where
297 S: ser::Serializer,
298 {
299 self.point.serialize(serializer)
300 }
301}
302
303#[cfg(feature = "serde")]
304impl<'de, P> Deserialize<'de> for NonIdentity<P>
305where
306 P: ConditionallySelectable + ConstantTimeEq + Default + Deserialize<'de> + GroupEncoding,
307{
308 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
309 where
310 D: de::Deserializer<'de>,
311 {
312 Self::new(P::deserialize(deserializer)?)
313 .into_option()
314 .ok_or_else(|| de::Error::custom("expected non-identity point"))
315 }
316}
317
318impl<P: Group> Zeroize for NonIdentity<P> {
319 fn zeroize(&mut self) {
320 self.point = P::generator();
321 }
322}
323
324#[cfg(all(test, feature = "dev"))]
325mod tests {
326 use super::NonIdentity;
327 use crate::BatchNormalize;
328 use crate::dev::{AffinePoint, NonZeroScalar, ProjectivePoint, SecretKey};
329 use group::GroupEncoding;
330 use hex_literal::hex;
331 use zeroize::Zeroize;
332
333 #[test]
334 fn new_success() {
335 let point = ProjectivePoint::from_bytes(
336 &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
337 )
338 .unwrap();
339
340 assert!(bool::from(NonIdentity::new(point).is_some()));
341
342 assert!(bool::from(
343 NonIdentity::new(AffinePoint::from(point)).is_some()
344 ));
345 }
346
347 #[test]
348 fn new_fail() {
349 assert!(bool::from(
350 NonIdentity::new(ProjectivePoint::default()).is_none()
351 ));
352 assert!(bool::from(
353 NonIdentity::new(AffinePoint::default()).is_none()
354 ));
355 }
356
357 #[test]
358 fn round_trip() {
359 let bytes = hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
360 let point = NonIdentity::<ProjectivePoint>::from_repr(&bytes.into()).unwrap();
361 assert_eq!(&bytes, point.to_bytes().as_slice());
362
363 let bytes = hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
364 let point = NonIdentity::<AffinePoint>::from_repr(&bytes.into()).unwrap();
365 assert_eq!(&bytes, point.to_bytes().as_slice());
366 }
367
368 #[test]
369 fn zeroize() {
370 let point = ProjectivePoint::from_bytes(
371 &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
372 )
373 .unwrap();
374 let mut point = NonIdentity::new(point).unwrap();
375 point.zeroize();
376
377 assert_eq!(point.to_point(), ProjectivePoint::Generator);
378 }
379
380 #[test]
381 fn mul_by_generator() {
382 let scalar = NonZeroScalar::from_repr(
383 hex!("c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
384 )
385 .unwrap();
386 let point = NonIdentity::<ProjectivePoint>::mul_by_generator(&scalar);
387
388 let sk = SecretKey::from(scalar);
389 let pk = sk.public_key();
390
391 assert_eq!(point.to_point(), pk.to_projective());
392 }
393
394 #[test]
395 fn batch_normalize() {
396 let point = ProjectivePoint::from_bytes(
397 &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
398 )
399 .unwrap();
400 let point = NonIdentity::new(point).unwrap();
401 let points = [point, point];
402
403 for (point, affine_point) in points
404 .into_iter()
405 .zip(NonIdentity::batch_normalize(&points))
406 {
407 assert_eq!(point.to_affine(), affine_point);
408 }
409 }
410
411 #[test]
412 #[cfg(feature = "alloc")]
413 fn batch_normalize_alloc() {
414 let point = ProjectivePoint::from_bytes(
415 &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
416 )
417 .unwrap();
418 let point = NonIdentity::new(point).unwrap();
419 let points = vec![point, point];
420
421 let affine_points = NonIdentity::batch_normalize(points.as_slice());
422
423 for (point, affine_point) in points.into_iter().zip(affine_points) {
424 assert_eq!(point.to_affine(), affine_point);
425 }
426 }
427}