1use crate::{Error, Result};
9use base16ct::HexDisplay;
10use core::{
11 cmp::Ordering,
12 fmt::{self, Debug},
13 hash::{Hash, Hasher},
14 ops::{Add, Sub},
15 str,
16};
17use hybrid_array::{Array, ArraySize, typenum::U1};
18
19#[cfg(feature = "alloc")]
20use alloc::boxed::Box;
21#[cfg(feature = "ctutils")]
22use ctutils::{Choice, CtAssign, CtAssignSlice, CtEq, CtEqSlice, CtSelect};
23#[cfg(feature = "serde")]
24use serdect::serde::{Deserialize, Serialize, de, ser};
25#[cfg(feature = "zeroize")]
26use zeroize::Zeroize;
27
28pub trait ModulusSize: ArraySize + Add<U1, Output = Self::CompressedPointSize> {
32 type CompressedPointSize: ArraySize + Add<Self, Output = Self::UncompressedPointSize>;
36
37 type UncompressedPointSize: ArraySize;
40
41 type UntaggedPointSize: ArraySize + Sub<Self, Output = Self>;
44}
45
46impl<T> ModulusSize for T
47where
48 T: ArraySize,
49 T: Add<U1, Output: ArraySize>,
50 T: Add<T, Output: ArraySize + Sub<T, Output = T>>,
51 <T as Add<U1>>::Output: Add<T, Output: 'static + ArraySize>,
52{
53 type CompressedPointSize = <T as Add<U1>>::Output;
54 type UncompressedPointSize = <Self::CompressedPointSize as Add<T>>::Output;
55 type UntaggedPointSize = <T as Add<T>>::Output;
56}
57
58#[derive(Clone, Default)]
64pub struct EncodedPoint<Size>
65where
66 Size: ModulusSize,
67{
68 bytes: Array<u8, Size::UncompressedPointSize>,
69}
70
71#[allow(clippy::len_without_is_empty)]
72impl<Size> EncodedPoint<Size>
73where
74 Size: ModulusSize,
75{
76 pub fn from_bytes(input: impl AsRef<[u8]>) -> Result<Self> {
83 let input = input.as_ref();
84
85 let tag = input
87 .first()
88 .cloned()
89 .ok_or(Error::PointEncoding)
90 .and_then(Tag::from_u8)?;
91
92 let expected_len = tag.message_len(Size::to_usize());
94
95 if input.len() != expected_len {
96 return Err(Error::PointEncoding);
97 }
98
99 let mut bytes = Array::default();
100 bytes[..expected_len].copy_from_slice(input);
101 Ok(Self { bytes })
102 }
103
104 pub fn from_untagged_bytes(bytes: &Array<u8, Size::UntaggedPointSize>) -> Self {
108 let (x, y) = bytes.split_ref();
109 Self::from_affine_coordinates(x, y, false)
110 }
111
112 pub fn from_affine_coordinates(
115 x: &Array<u8, Size>,
116 y: &Array<u8, Size>,
117 compress: bool,
118 ) -> Self {
119 let tag = if compress {
120 Tag::compress_y(y.as_slice())
121 } else {
122 Tag::Uncompressed
123 };
124
125 let mut bytes = Array::default();
126 bytes[0] = tag.into();
127 bytes[1..(Size::to_usize() + 1)].copy_from_slice(x);
128
129 if !compress {
130 bytes[(Size::to_usize() + 1)..].copy_from_slice(y);
131 }
132
133 Self { bytes }
134 }
135
136 pub fn identity() -> Self {
139 Self::default()
140 }
141
142 pub fn len(&self) -> usize {
144 self.tag().message_len(Size::to_usize())
145 }
146
147 pub fn as_bytes(&self) -> &[u8] {
149 &self.bytes[..self.len()]
150 }
151
152 #[cfg(feature = "alloc")]
154 pub fn to_bytes(&self) -> Box<[u8]> {
155 self.as_bytes().to_vec().into_boxed_slice()
156 }
157
158 pub fn is_compact(&self) -> bool {
160 self.tag().is_compact()
161 }
162
163 pub fn is_compressed(&self) -> bool {
165 self.tag().is_compressed()
166 }
167
168 pub fn is_identity(&self) -> bool {
170 self.tag().is_identity()
171 }
172
173 pub fn compress(&self) -> Self {
175 match self.coordinates() {
176 Coordinates::Compressed { .. }
177 | Coordinates::Compact { .. }
178 | Coordinates::Identity => self.clone(),
179 Coordinates::Uncompressed { x, y } => Self::from_affine_coordinates(x, y, true),
180 }
181 }
182
183 pub fn tag(&self) -> Tag {
185 Tag::from_u8(self.bytes[0]).expect("invalid tag")
187 }
188
189 #[inline]
191 pub fn coordinates(&self) -> Coordinates<'_, Size> {
192 if self.is_identity() {
193 return Coordinates::Identity;
194 }
195
196 let (x_bytes, y_bytes) = self.bytes[1..].split_at(Size::to_usize());
197 let x = x_bytes.try_into().expect("size invariants were violated");
198
199 if self.is_compressed() {
200 Coordinates::Compressed {
201 x,
202 y_is_odd: self.tag() as u8 & 1 == 1,
203 }
204 } else if self.is_compact() {
205 Coordinates::Compact { x }
206 } else {
207 Coordinates::Uncompressed {
208 x,
209 y: y_bytes.try_into().expect("size invariants were violated"),
210 }
211 }
212 }
213
214 pub fn x(&self) -> Option<&Array<u8, Size>> {
218 match self.coordinates() {
219 Coordinates::Identity => None,
220 Coordinates::Compressed { x, .. } => Some(x),
221 Coordinates::Uncompressed { x, .. } => Some(x),
222 Coordinates::Compact { x } => Some(x),
223 }
224 }
225
226 pub fn y(&self) -> Option<&Array<u8, Size>> {
230 match self.coordinates() {
231 Coordinates::Compressed { .. } | Coordinates::Identity => None,
232 Coordinates::Uncompressed { y, .. } => Some(y),
233 Coordinates::Compact { .. } => None,
234 }
235 }
236}
237
238impl<Size> AsRef<[u8]> for EncodedPoint<Size>
239where
240 Size: ModulusSize,
241{
242 #[inline]
243 fn as_ref(&self) -> &[u8] {
244 self.as_bytes()
245 }
246}
247
248impl<Size> Copy for EncodedPoint<Size>
249where
250 Size: ModulusSize,
251 <Size::UncompressedPointSize as ArraySize>::ArrayType<u8>: Copy,
252{
253}
254
255impl<Size> Debug for EncodedPoint<Size>
256where
257 Size: ModulusSize,
258{
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 write!(f, "EncodedPoint({:?})", self.coordinates())
261 }
262}
263
264impl<Size: ModulusSize> Eq for EncodedPoint<Size> {}
265
266impl<Size> PartialEq for EncodedPoint<Size>
267where
268 Size: ModulusSize,
269{
270 fn eq(&self, other: &Self) -> bool {
271 self.as_bytes() == other.as_bytes()
272 }
273}
274
275impl<Size> Hash for EncodedPoint<Size>
276where
277 Size: ModulusSize,
278{
279 fn hash<H: Hasher>(&self, state: &mut H) {
280 self.as_bytes().hash(state)
281 }
282}
283
284impl<Size: ModulusSize> PartialOrd for EncodedPoint<Size>
285where
286 Size: ModulusSize,
287{
288 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
289 Some(self.cmp(other))
290 }
291}
292
293impl<Size: ModulusSize> Ord for EncodedPoint<Size>
294where
295 Size: ModulusSize,
296{
297 fn cmp(&self, other: &Self) -> Ordering {
298 self.as_bytes().cmp(other.as_bytes())
299 }
300}
301
302impl<Size: ModulusSize> TryFrom<&[u8]> for EncodedPoint<Size>
303where
304 Size: ModulusSize,
305{
306 type Error = Error;
307
308 fn try_from(bytes: &[u8]) -> Result<Self> {
309 Self::from_bytes(bytes)
310 }
311}
312
313impl<Size> fmt::Display for EncodedPoint<Size>
314where
315 Size: ModulusSize,
316{
317 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318 write!(f, "{self:X}")
319 }
320}
321
322impl<Size> fmt::LowerHex for EncodedPoint<Size>
323where
324 Size: ModulusSize,
325{
326 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327 write!(f, "{:x}", HexDisplay(self.as_bytes()))
328 }
329}
330
331impl<Size> fmt::UpperHex for EncodedPoint<Size>
332where
333 Size: ModulusSize,
334{
335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336 write!(f, "{:X}", HexDisplay(self.as_bytes()))
337 }
338}
339
340impl<Size> str::FromStr for EncodedPoint<Size>
345where
346 Size: ModulusSize,
347{
348 type Err = Error;
349
350 fn from_str(hex: &str) -> Result<Self> {
351 let mut buf = Array::<u8, Size::UncompressedPointSize>::default();
352 base16ct::mixed::decode(hex, &mut buf)
353 .map_err(|_| Error::PointEncoding)
354 .and_then(Self::from_bytes)
355 }
356}
357
358#[cfg(feature = "ctutils")]
359impl<Size> CtAssign for EncodedPoint<Size>
360where
361 Size: ModulusSize,
362{
363 fn ct_assign(&mut self, other: &Self, choice: Choice) {
364 for (i, byte) in self.bytes.iter_mut().enumerate() {
365 byte.ct_assign(&other.bytes[i], choice);
366 }
367 }
368}
369#[cfg(feature = "ctutils")]
370impl<Size: ModulusSize> CtAssignSlice for EncodedPoint<Size> {}
371
372#[cfg(feature = "ctutils")]
373impl<Size> CtEq for EncodedPoint<Size>
374where
375 Size: ModulusSize,
376{
377 fn ct_eq(&self, other: &Self) -> Choice {
378 self.bytes.as_slice().ct_eq(other.bytes.as_slice())
379 }
380}
381#[cfg(feature = "ctutils")]
382impl<Size: ModulusSize> CtEqSlice for EncodedPoint<Size> {}
383
384#[cfg(feature = "ctutils")]
385impl<Size> CtSelect for EncodedPoint<Size>
386where
387 Size: ModulusSize,
388{
389 fn ct_select(&self, other: &Self, choice: Choice) -> Self {
390 let mut bytes = Array::default();
391
392 for (i, byte) in bytes.iter_mut().enumerate() {
393 *byte = self.bytes[i].ct_select(&other.bytes[i], choice);
394 }
395
396 Self { bytes }
397 }
398}
399
400#[cfg(feature = "subtle")]
401impl<Size> subtle::ConditionallySelectable for EncodedPoint<Size>
402where
403 Size: ModulusSize,
404 <Size::UncompressedPointSize as ArraySize>::ArrayType<u8>: Copy,
405{
406 fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
407 let mut bytes = Array::default();
408
409 for (i, byte) in bytes.iter_mut().enumerate() {
410 *byte = u8::conditional_select(&a.bytes[i], &b.bytes[i], choice);
411 }
412
413 Self { bytes }
414 }
415}
416
417#[cfg(feature = "serde")]
418impl<Size> Serialize for EncodedPoint<Size>
419where
420 Size: ModulusSize,
421{
422 fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
423 where
424 S: ser::Serializer,
425 {
426 serdect::slice::serialize_hex_upper_or_bin(&self.as_bytes(), serializer)
427 }
428}
429
430#[cfg(feature = "serde")]
431impl<'de, Size> Deserialize<'de> for EncodedPoint<Size>
432where
433 Size: ModulusSize,
434{
435 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
436 where
437 D: de::Deserializer<'de>,
438 {
439 let bytes = serdect::slice::deserialize_hex_or_bin_vec(deserializer)?;
440 Self::from_bytes(bytes).map_err(de::Error::custom)
441 }
442}
443
444#[cfg(feature = "zeroize")]
445impl<Size> Zeroize for EncodedPoint<Size>
446where
447 Size: ModulusSize,
448{
449 fn zeroize(&mut self) {
450 self.bytes.zeroize();
451 *self = Self::identity();
452 }
453}
454
455#[derive(Copy, Clone, Debug, Eq, PartialEq)]
458pub enum Coordinates<'a, Size: ModulusSize> {
459 Identity,
461
462 Compact {
464 x: &'a Array<u8, Size>,
466 },
467
468 Compressed {
470 x: &'a Array<u8, Size>,
472
473 y_is_odd: bool,
475 },
476
477 Uncompressed {
479 x: &'a Array<u8, Size>,
481
482 y: &'a Array<u8, Size>,
484 },
485}
486
487impl<Size: ModulusSize> Coordinates<'_, Size> {
488 pub fn tag(&self) -> Tag {
490 match self {
491 Coordinates::Compact { .. } => Tag::Compact,
492 Coordinates::Compressed { y_is_odd, .. } => {
493 if *y_is_odd {
494 Tag::CompressedOddY
495 } else {
496 Tag::CompressedEvenY
497 }
498 }
499 Coordinates::Identity => Tag::Identity,
500 Coordinates::Uncompressed { .. } => Tag::Uncompressed,
501 }
502 }
503}
504
505#[derive(Copy, Clone, Debug, Eq, PartialEq)]
507#[repr(u8)]
508pub enum Tag {
509 Identity = 0,
511
512 CompressedEvenY = 2,
514
515 CompressedOddY = 3,
517
518 Uncompressed = 4,
520
521 Compact = 5,
523}
524
525impl Tag {
526 pub fn from_u8(byte: u8) -> Result<Self> {
528 match byte {
529 0 => Ok(Tag::Identity),
530 2 => Ok(Tag::CompressedEvenY),
531 3 => Ok(Tag::CompressedOddY),
532 4 => Ok(Tag::Uncompressed),
533 5 => Ok(Tag::Compact),
534 _ => Err(Error::PointEncoding),
535 }
536 }
537
538 pub fn is_compact(self) -> bool {
540 matches!(self, Tag::Compact)
541 }
542
543 pub fn is_compressed(self) -> bool {
545 matches!(self, Tag::CompressedEvenY | Tag::CompressedOddY)
546 }
547
548 pub fn is_identity(self) -> bool {
550 self == Tag::Identity
551 }
552
553 pub fn message_len(self, field_element_size: usize) -> usize {
557 1 + match self {
558 Tag::Identity => 0,
559 Tag::CompressedEvenY | Tag::CompressedOddY => field_element_size,
560 Tag::Uncompressed => field_element_size * 2,
561 Tag::Compact => field_element_size,
562 }
563 }
564
565 fn compress_y(y: &[u8]) -> Self {
567 if y.as_ref().last().expect("empty y-coordinate") & 1 == 1 {
569 Tag::CompressedOddY
570 } else {
571 Tag::CompressedEvenY
572 }
573 }
574}
575
576impl TryFrom<u8> for Tag {
577 type Error = Error;
578
579 fn try_from(byte: u8) -> Result<Self> {
580 Self::from_u8(byte)
581 }
582}
583
584impl From<Tag> for u8 {
585 fn from(tag: Tag) -> u8 {
586 tag as u8
587 }
588}
589
590#[cfg(test)]
591#[allow(clippy::unwrap_used)]
592mod tests {
593 use super::{Coordinates, Tag};
594 use core::str::FromStr;
595 use hex_literal::hex;
596 use hybrid_array::typenum::U32;
597
598 #[cfg(feature = "alloc")]
599 use alloc::string::ToString;
600
601 #[cfg(feature = "subtle")]
602 use subtle::ConditionallySelectable;
603
604 type EncodedPoint = super::EncodedPoint<U32>;
605
606 const IDENTITY_BYTES: [u8; 1] = [0];
608
609 const UNCOMPRESSED_BYTES: [u8; 65] = hex!(
611 "0411111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222"
612 );
613
614 const COMPRESSED_BYTES: [u8; 33] =
616 hex!("021111111111111111111111111111111111111111111111111111111111111111");
617
618 #[test]
619 fn decode_compressed_point() {
620 let compressed_even_y_bytes =
622 hex!("020100000000000000000000000000000000000000000000000000000000000000");
623
624 let compressed_even_y = EncodedPoint::from_bytes(&compressed_even_y_bytes[..]).unwrap();
625
626 assert!(compressed_even_y.is_compressed());
627 assert_eq!(compressed_even_y.tag(), Tag::CompressedEvenY);
628 assert_eq!(compressed_even_y.len(), 33);
629 assert_eq!(compressed_even_y.as_bytes(), &compressed_even_y_bytes[..]);
630
631 assert_eq!(
632 compressed_even_y.coordinates(),
633 Coordinates::Compressed {
634 x: &hex!("0100000000000000000000000000000000000000000000000000000000000000").into(),
635 y_is_odd: false
636 }
637 );
638
639 assert_eq!(
640 compressed_even_y.x().unwrap(),
641 &hex!("0100000000000000000000000000000000000000000000000000000000000000")
642 );
643 assert_eq!(compressed_even_y.y(), None);
644
645 let compressed_odd_y_bytes =
647 hex!("030200000000000000000000000000000000000000000000000000000000000000");
648
649 let compressed_odd_y = EncodedPoint::from_bytes(&compressed_odd_y_bytes[..]).unwrap();
650
651 assert!(compressed_odd_y.is_compressed());
652 assert_eq!(compressed_odd_y.tag(), Tag::CompressedOddY);
653 assert_eq!(compressed_odd_y.len(), 33);
654 assert_eq!(compressed_odd_y.as_bytes(), &compressed_odd_y_bytes[..]);
655
656 assert_eq!(
657 compressed_odd_y.coordinates(),
658 Coordinates::Compressed {
659 x: &hex!("0200000000000000000000000000000000000000000000000000000000000000").into(),
660 y_is_odd: true
661 }
662 );
663
664 assert_eq!(
665 compressed_odd_y.x().unwrap(),
666 &hex!("0200000000000000000000000000000000000000000000000000000000000000")
667 );
668 assert_eq!(compressed_odd_y.y(), None);
669 }
670
671 #[test]
672 fn decode_uncompressed_point() {
673 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
674
675 assert!(!uncompressed_point.is_compressed());
676 assert_eq!(uncompressed_point.tag(), Tag::Uncompressed);
677 assert_eq!(uncompressed_point.len(), 65);
678 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
679
680 assert_eq!(
681 uncompressed_point.coordinates(),
682 Coordinates::Uncompressed {
683 x: &hex!("1111111111111111111111111111111111111111111111111111111111111111").into(),
684 y: &hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
685 }
686 );
687
688 assert_eq!(
689 uncompressed_point.x().unwrap(),
690 &hex!("1111111111111111111111111111111111111111111111111111111111111111")
691 );
692 assert_eq!(
693 uncompressed_point.y().unwrap(),
694 &hex!("2222222222222222222222222222222222222222222222222222222222222222")
695 );
696 }
697
698 #[test]
699 fn decode_identity() {
700 let identity_point = EncodedPoint::from_bytes(&IDENTITY_BYTES[..]).unwrap();
701 assert!(identity_point.is_identity());
702 assert_eq!(identity_point.tag(), Tag::Identity);
703 assert_eq!(identity_point.len(), 1);
704 assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
705 assert_eq!(identity_point.coordinates(), Coordinates::Identity);
706 assert_eq!(identity_point.x(), None);
707 assert_eq!(identity_point.y(), None);
708 }
709
710 #[test]
711 fn decode_invalid_tag() {
712 let mut compressed_bytes = COMPRESSED_BYTES;
713 let mut uncompressed_bytes = UNCOMPRESSED_BYTES;
714
715 for bytes in &mut [&mut compressed_bytes[..], &mut uncompressed_bytes[..]] {
716 for tag in 0..=0xFF {
717 if tag == 2 || tag == 3 || tag == 4 || tag == 5 {
719 continue;
720 }
721
722 (*bytes)[0] = tag;
723 let decode_result = EncodedPoint::from_bytes(&*bytes);
724 assert!(decode_result.is_err());
725 }
726 }
727 }
728
729 #[test]
730 fn decode_truncated_point() {
731 for bytes in &[&COMPRESSED_BYTES[..], &UNCOMPRESSED_BYTES[..]] {
732 for len in 0..bytes.len() {
733 let decode_result = EncodedPoint::from_bytes(&bytes[..len]);
734 assert!(decode_result.is_err());
735 }
736 }
737 }
738
739 #[test]
740 fn from_untagged_point() {
741 let untagged_bytes = hex!(
742 "11111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222"
743 );
744 let uncompressed_point = EncodedPoint::from_untagged_bytes(&untagged_bytes.into());
745 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
746 }
747
748 #[test]
749 fn from_affine_coordinates() {
750 let x = hex!("1111111111111111111111111111111111111111111111111111111111111111");
751 let y = hex!("2222222222222222222222222222222222222222222222222222222222222222");
752
753 let uncompressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false);
754 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
755
756 let compressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), true);
757 assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
758 }
759
760 #[test]
761 fn compress() {
762 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
763 let compressed_point = uncompressed_point.compress();
764 assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
765 }
766
767 #[cfg(feature = "subtle")]
768 #[test]
769 fn conditional_select() {
770 let a = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
771 let b = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
772
773 let a_selected = EncodedPoint::conditional_select(&a, &b, 0.into());
774 assert_eq!(a, a_selected);
775
776 let b_selected = EncodedPoint::conditional_select(&a, &b, 1.into());
777 assert_eq!(b, b_selected);
778 }
779
780 #[test]
781 fn identity() {
782 let identity_point = EncodedPoint::identity();
783 assert_eq!(identity_point.tag(), Tag::Identity);
784 assert_eq!(identity_point.len(), 1);
785 assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
786
787 assert_eq!(identity_point, EncodedPoint::default());
789 }
790
791 #[test]
792 fn decode_hex() {
793 let point = EncodedPoint::from_str(
794 "021111111111111111111111111111111111111111111111111111111111111111",
795 )
796 .unwrap();
797 assert_eq!(point.as_bytes(), COMPRESSED_BYTES);
798 }
799
800 #[cfg(feature = "alloc")]
801 #[test]
802 fn to_bytes() {
803 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
804 assert_eq!(&*uncompressed_point.to_bytes(), &UNCOMPRESSED_BYTES[..]);
805 }
806
807 #[cfg(feature = "alloc")]
808 #[test]
809 fn to_string() {
810 let point = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
811 assert_eq!(
812 point.to_string(),
813 "021111111111111111111111111111111111111111111111111111111111111111"
814 );
815 }
816}