1#![allow(non_snake_case)]
17
18mod elligator;
160
161#[cfg(feature = "alloc")]
162use alloc::vec::Vec;
163
164use core::array::TryFromSliceError;
165use core::borrow::Borrow;
166use core::fmt::Debug;
167use core::iter::Sum;
168use core::ops::{Add, Neg, Sub};
169use core::ops::{AddAssign, SubAssign};
170use core::ops::{Mul, MulAssign};
171
172#[cfg(feature = "digest")]
173use digest::Digest;
174#[cfg(feature = "digest")]
175use digest::array::typenum::U64;
176
177use crate::constants;
178use crate::field::FieldElement;
179
180#[cfg(feature = "group")]
181use {
182 group::{GroupEncoding, cofactor::CofactorGroup, prime::PrimeGroup},
183 rand_core::TryRng,
184 subtle::CtOption,
185};
186
187#[cfg(feature = "rand_core")]
188use {
189 core::convert::Infallible,
190 rand_core::{CryptoRng, TryCryptoRng},
191};
192
193use subtle::Choice;
194use subtle::ConditionallyNegatable;
195use subtle::ConditionallySelectable;
196use subtle::ConstantTimeEq;
197
198#[cfg(feature = "zeroize")]
199use zeroize::Zeroize;
200
201#[cfg(feature = "precomputed-tables")]
202use crate::edwards::EdwardsBasepointTable;
203use crate::edwards::EdwardsPoint;
204
205use crate::scalar::Scalar;
206
207#[cfg(feature = "precomputed-tables")]
208use crate::traits::BasepointTable;
209use crate::traits::Identity;
210#[cfg(feature = "alloc")]
211use crate::traits::{MultiscalarMul, VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul};
212
213#[allow(clippy::derived_hash_with_manual_eq)]
222#[derive(Copy, Clone, Hash)]
223pub struct CompressedRistretto(pub [u8; 32]);
224
225impl Eq for CompressedRistretto {}
226impl PartialEq for CompressedRistretto {
227 fn eq(&self, other: &Self) -> bool {
228 self.ct_eq(other).into()
229 }
230}
231
232impl ConstantTimeEq for CompressedRistretto {
233 fn ct_eq(&self, other: &CompressedRistretto) -> Choice {
234 self.as_bytes().ct_eq(other.as_bytes())
235 }
236}
237
238impl CompressedRistretto {
239 pub const fn to_bytes(&self) -> [u8; 32] {
241 self.0
242 }
243
244 pub const fn as_bytes(&self) -> &[u8; 32] {
246 &self.0
247 }
248
249 pub fn from_slice(bytes: &[u8]) -> Result<CompressedRistretto, TryFromSliceError> {
256 bytes.try_into().map(CompressedRistretto)
257 }
258
259 pub fn decompress(&self) -> Option<RistrettoPoint> {
267 let (s_encoding_is_canonical, s_is_negative, s) = decompress::step_1(self);
268
269 if (!s_encoding_is_canonical | s_is_negative).into() {
270 return None;
271 }
272
273 let (ok, t_is_negative, y_is_zero, res) = decompress::step_2(s);
274
275 if (!ok | t_is_negative | y_is_zero).into() {
276 None
277 } else {
278 Some(res)
279 }
280 }
281}
282
283mod decompress {
284 use super::*;
285
286 pub(super) fn step_1(repr: &CompressedRistretto) -> (Choice, Choice, FieldElement) {
287 let s = FieldElement::from_bytes(repr.as_bytes());
299 let s_bytes_check = s.to_bytes();
300 let s_encoding_is_canonical = s_bytes_check[..].ct_eq(repr.as_bytes());
301 let s_is_negative = s.is_negative();
302
303 (s_encoding_is_canonical, s_is_negative, s)
304 }
305
306 pub(super) fn step_2(s: FieldElement) -> (Choice, Choice, Choice, RistrettoPoint) {
307 let one = FieldElement::ONE;
309 let ss = s.square();
310 let u1 = &one - &ss; let u2 = &one + &ss; let u2_sqr = u2.square(); let v = &(&(-&constants::EDWARDS_D) * &u1.square()) - &u2_sqr;
316
317 let (ok, I) = (&v * &u2_sqr).invsqrt(); let Dx = &I * &u2; let Dy = &I * &(&Dx * &v); let mut x = &(&s + &s) * &Dx;
324 let x_neg = x.is_negative();
325 x.conditional_negate(x_neg);
326
327 let y = &u1 * &Dy;
329
330 let t = &x * &y;
332
333 (
334 ok,
335 t.is_negative(),
336 y.is_zero(),
337 RistrettoPoint(EdwardsPoint {
338 X: x,
339 Y: y,
340 Z: one,
341 T: t,
342 }),
343 )
344 }
345}
346
347impl Identity for CompressedRistretto {
348 fn identity() -> CompressedRistretto {
349 CompressedRistretto([0u8; 32])
350 }
351}
352
353impl Default for CompressedRistretto {
354 fn default() -> CompressedRistretto {
355 CompressedRistretto::identity()
356 }
357}
358
359impl TryFrom<&[u8]> for CompressedRistretto {
360 type Error = TryFromSliceError;
361
362 fn try_from(slice: &[u8]) -> Result<CompressedRistretto, TryFromSliceError> {
363 Self::from_slice(slice)
364 }
365}
366
367#[cfg(feature = "serde")]
376use serde::de::Visitor;
377#[cfg(feature = "serde")]
378use serde::{Deserialize, Deserializer, Serialize, Serializer};
379
380#[cfg(feature = "serde")]
381impl Serialize for RistrettoPoint {
382 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
383 where
384 S: Serializer,
385 {
386 use serde::ser::SerializeTuple;
387 let mut tup = serializer.serialize_tuple(32)?;
388 for byte in self.compress().as_bytes().iter() {
389 tup.serialize_element(byte)?;
390 }
391 tup.end()
392 }
393}
394
395#[cfg(feature = "serde")]
396impl Serialize for CompressedRistretto {
397 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
398 where
399 S: Serializer,
400 {
401 use serde::ser::SerializeTuple;
402 let mut tup = serializer.serialize_tuple(32)?;
403 for byte in self.as_bytes().iter() {
404 tup.serialize_element(byte)?;
405 }
406 tup.end()
407 }
408}
409
410#[cfg(feature = "serde")]
411impl<'de> Deserialize<'de> for RistrettoPoint {
412 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
413 where
414 D: Deserializer<'de>,
415 {
416 struct RistrettoPointVisitor;
417
418 impl<'de> Visitor<'de> for RistrettoPointVisitor {
419 type Value = RistrettoPoint;
420
421 fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
422 formatter.write_str("a valid point in Ristretto format")
423 }
424
425 fn visit_seq<A>(self, mut seq: A) -> Result<RistrettoPoint, A::Error>
426 where
427 A: serde::de::SeqAccess<'de>,
428 {
429 let mut bytes = [0u8; 32];
430 #[allow(clippy::needless_range_loop)]
431 for i in 0..32 {
432 bytes[i] = seq
433 .next_element()?
434 .ok_or_else(|| serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
435 }
436 CompressedRistretto(bytes)
437 .decompress()
438 .ok_or_else(|| serde::de::Error::custom("decompression failed"))
439 }
440 }
441
442 deserializer.deserialize_tuple(32, RistrettoPointVisitor)
443 }
444}
445
446#[cfg(feature = "serde")]
447impl<'de> Deserialize<'de> for CompressedRistretto {
448 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
449 where
450 D: Deserializer<'de>,
451 {
452 struct CompressedRistrettoVisitor;
453
454 impl<'de> Visitor<'de> for CompressedRistrettoVisitor {
455 type Value = CompressedRistretto;
456
457 fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
458 formatter.write_str("32 bytes of data")
459 }
460
461 fn visit_seq<A>(self, mut seq: A) -> Result<CompressedRistretto, A::Error>
462 where
463 A: serde::de::SeqAccess<'de>,
464 {
465 let mut bytes = [0u8; 32];
466 #[allow(clippy::needless_range_loop)]
467 for i in 0..32 {
468 bytes[i] = seq
469 .next_element()?
470 .ok_or_else(|| serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
471 }
472 Ok(CompressedRistretto(bytes))
473 }
474 }
475
476 deserializer.deserialize_tuple(32, CompressedRistrettoVisitor)
477 }
478}
479
480#[derive(Copy, Clone)]
496pub struct RistrettoPoint(pub(crate) EdwardsPoint);
497
498impl RistrettoPoint {
499 pub fn compress(&self) -> CompressedRistretto {
501 let mut X = self.0.X;
502 let mut Y = self.0.Y;
503 let Z = &self.0.Z;
504 let T = &self.0.T;
505
506 let u1 = &(Z + &Y) * &(Z - &Y);
507 let u2 = &X * &Y;
508 let (_, invsqrt) = (&u1 * &u2.square()).invsqrt();
510 let i1 = &invsqrt * &u1;
511 let i2 = &invsqrt * &u2;
512 let z_inv = &i1 * &(&i2 * T);
513 let mut den_inv = i2;
514
515 let iX = &X * &constants::SQRT_M1;
516 let iY = &Y * &constants::SQRT_M1;
517 let ristretto_magic = &constants::INVSQRT_A_MINUS_D;
518 let enchanted_denominator = &i1 * ristretto_magic;
519
520 let rotate = (T * &z_inv).is_negative();
521
522 X.conditional_assign(&iY, rotate);
523 Y.conditional_assign(&iX, rotate);
524 den_inv.conditional_assign(&enchanted_denominator, rotate);
525
526 Y.conditional_negate((&X * &z_inv).is_negative());
527
528 let mut s = &den_inv * &(Z - &Y);
529 let s_is_negative = s.is_negative();
530 s.conditional_negate(s_is_negative);
531
532 CompressedRistretto(s.to_bytes())
533 }
534
535 #[cfg_attr(feature = "rand_core", doc = "```")]
544 #[cfg_attr(not(feature = "rand_core"), doc = "```ignore")]
545 #[cfg(feature = "alloc")]
564 pub fn double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
565 where
566 I: IntoIterator<Item = &'a RistrettoPoint>,
567 {
568 #[derive(Copy, Clone, Debug)]
569 struct BatchCompressState {
570 e: FieldElement,
571 f: FieldElement,
572 g: FieldElement,
573 h: FieldElement,
574 eg: FieldElement,
575 fh: FieldElement,
576 }
577
578 impl BatchCompressState {
579 fn efgh(&self) -> FieldElement {
580 &self.eg * &self.fh
581 }
582 }
583
584 impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
585 #[rustfmt::skip] fn from(P: &'a RistrettoPoint) -> BatchCompressState {
587 let XX = P.0.X.square();
588 let YY = P.0.Y.square();
589 let ZZ = P.0.Z.square();
590 let dTT = &P.0.T.square() * &constants::EDWARDS_D;
591
592 let e = &P.0.X * &(&P.0.Y + &P.0.Y); let f = &ZZ + &dTT; let g = &YY + &XX; let h = &ZZ - &dTT; let eg = &e * &g;
598 let fh = &f * &h;
599
600 BatchCompressState{ e, f, g, h, eg, fh }
601 }
602 }
603
604 let states: Vec<BatchCompressState> =
605 points.into_iter().map(BatchCompressState::from).collect();
606
607 let mut invs: Vec<FieldElement> = states.iter().map(|state| state.efgh()).collect();
608
609 FieldElement::invert_batch_alloc(&mut invs[..]);
610
611 states
612 .iter()
613 .zip(invs.iter())
614 .map(|(state, inv): (&BatchCompressState, &FieldElement)| {
615 let Zinv = &state.eg * inv;
616 let Tinv = &state.fh * inv;
617
618 let mut magic = constants::INVSQRT_A_MINUS_D;
619
620 let negcheck1 = (&state.eg * &Zinv).is_negative();
621
622 let mut e = state.e;
623 let mut g = state.g;
624 let mut h = state.h;
625
626 let minus_e = -&e;
627 let f_times_sqrta = &state.f * &constants::SQRT_M1;
628
629 e.conditional_assign(&state.g, negcheck1);
630 g.conditional_assign(&minus_e, negcheck1);
631 h.conditional_assign(&f_times_sqrta, negcheck1);
632
633 magic.conditional_assign(&constants::SQRT_M1, negcheck1);
634
635 let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
636
637 g.conditional_negate(negcheck2);
638
639 let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
640
641 let s_is_negative = s.is_negative();
642 s.conditional_negate(s_is_negative);
643
644 CompressedRistretto(s.to_bytes())
645 })
646 .collect()
647 }
648
649 fn coset4(&self) -> [EdwardsPoint; 4] {
651 [
652 self.0,
653 self.0 + constants::EIGHT_TORSION[2],
654 self.0 + constants::EIGHT_TORSION[4],
655 self.0 + constants::EIGHT_TORSION[6],
656 ]
657 }
658
659 #[cfg(feature = "rand_core")]
676 pub fn random<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
677 Self::try_random(rng)
678 .map_err(|_: Infallible| {})
679 .expect("[bug] unfallible rng failed")
680 }
681
682 #[cfg(feature = "rand_core")]
699 pub fn try_random<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
700 let mut uniform_bytes = [0u8; 64];
701 rng.try_fill_bytes(&mut uniform_bytes)?;
702
703 Ok(RistrettoPoint::from_uniform_bytes(&uniform_bytes))
704 }
705
706 #[cfg(feature = "digest")]
707 #[cfg_attr(feature = "digest", doc = "```")]
724 #[cfg_attr(not(feature = "digest"), doc = "```ignore")]
725 pub fn hash_from_bytes<D>(input: &[u8]) -> RistrettoPoint
737 where
738 D: Digest<OutputSize = U64> + Default,
739 {
740 let mut hash = D::default();
741 hash.update(input);
742 RistrettoPoint::from_hash(hash)
743 }
744
745 #[cfg(feature = "digest")]
746 pub fn from_hash<D>(hash: D) -> RistrettoPoint
752 where
753 D: Digest<OutputSize = U64> + Default,
754 {
755 let output = hash.finalize();
757 let mut output_bytes = [0u8; 64];
758 output_bytes.copy_from_slice(output.as_slice());
759
760 RistrettoPoint::from_uniform_bytes(&output_bytes)
761 }
762
763 pub fn from_uniform_bytes(bytes: &[u8; 64]) -> RistrettoPoint {
775 let mut r_1_bytes = [0u8; 32];
778 r_1_bytes.copy_from_slice(&bytes[0..32]);
779 let r_1 = FieldElement::from_bytes(&r_1_bytes);
780 let R_1 = RistrettoPoint::elligator_ristretto_flavor(&r_1);
781
782 let mut r_2_bytes = [0u8; 32];
783 r_2_bytes.copy_from_slice(&bytes[32..64]);
784 let r_2 = FieldElement::from_bytes(&r_2_bytes);
785 let R_2 = RistrettoPoint::elligator_ristretto_flavor(&r_2);
786
787 R_1 + R_2
790 }
791}
792
793impl Identity for RistrettoPoint {
794 fn identity() -> RistrettoPoint {
795 RistrettoPoint(EdwardsPoint::identity())
796 }
797}
798
799impl Default for RistrettoPoint {
800 fn default() -> RistrettoPoint {
801 RistrettoPoint::identity()
802 }
803}
804
805impl PartialEq for RistrettoPoint {
810 fn eq(&self, other: &RistrettoPoint) -> bool {
811 self.ct_eq(other).into()
812 }
813}
814
815impl ConstantTimeEq for RistrettoPoint {
816 fn ct_eq(&self, other: &RistrettoPoint) -> Choice {
823 let X1Y2 = &self.0.X * &other.0.Y;
824 let Y1X2 = &self.0.Y * &other.0.X;
825 let X1X2 = &self.0.X * &other.0.X;
826 let Y1Y2 = &self.0.Y * &other.0.Y;
827
828 X1Y2.ct_eq(&Y1X2) | X1X2.ct_eq(&Y1Y2)
829 }
830}
831
832impl Eq for RistrettoPoint {}
833
834impl<'a> Add<&'a RistrettoPoint> for &RistrettoPoint {
839 type Output = RistrettoPoint;
840
841 fn add(self, other: &'a RistrettoPoint) -> RistrettoPoint {
842 RistrettoPoint(self.0 + other.0)
843 }
844}
845
846define_add_variants!(
847 LHS = RistrettoPoint,
848 RHS = RistrettoPoint,
849 Output = RistrettoPoint
850);
851
852impl AddAssign<&RistrettoPoint> for RistrettoPoint {
853 fn add_assign(&mut self, _rhs: &RistrettoPoint) {
854 *self = (self as &RistrettoPoint) + _rhs;
855 }
856}
857
858define_add_assign_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint);
859
860impl<'a> Sub<&'a RistrettoPoint> for &RistrettoPoint {
861 type Output = RistrettoPoint;
862
863 fn sub(self, other: &'a RistrettoPoint) -> RistrettoPoint {
864 RistrettoPoint(self.0 - other.0)
865 }
866}
867
868define_sub_variants!(
869 LHS = RistrettoPoint,
870 RHS = RistrettoPoint,
871 Output = RistrettoPoint
872);
873
874impl SubAssign<&RistrettoPoint> for RistrettoPoint {
875 fn sub_assign(&mut self, _rhs: &RistrettoPoint) {
876 *self = (self as &RistrettoPoint) - _rhs;
877 }
878}
879
880define_sub_assign_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint);
881
882impl<T> Sum<T> for RistrettoPoint
883where
884 T: Borrow<RistrettoPoint>,
885{
886 fn sum<I>(iter: I) -> Self
887 where
888 I: Iterator<Item = T>,
889 {
890 iter.fold(RistrettoPoint::identity(), |acc, item| acc + item.borrow())
891 }
892}
893
894impl Neg for &RistrettoPoint {
895 type Output = RistrettoPoint;
896
897 fn neg(self) -> RistrettoPoint {
898 RistrettoPoint(-&self.0)
899 }
900}
901
902impl Neg for RistrettoPoint {
903 type Output = RistrettoPoint;
904
905 fn neg(self) -> RistrettoPoint {
906 -&self
907 }
908}
909
910impl<'a> MulAssign<&'a Scalar> for RistrettoPoint {
911 fn mul_assign(&mut self, scalar: &'a Scalar) {
912 let result = (self as &RistrettoPoint) * scalar;
913 *self = result;
914 }
915}
916
917impl<'a> Mul<&'a Scalar> for &RistrettoPoint {
918 type Output = RistrettoPoint;
919 fn mul(self, scalar: &'a Scalar) -> RistrettoPoint {
921 RistrettoPoint(self.0 * scalar)
922 }
923}
924
925impl<'a> Mul<&'a RistrettoPoint> for &Scalar {
926 type Output = RistrettoPoint;
927
928 fn mul(self, point: &'a RistrettoPoint) -> RistrettoPoint {
930 RistrettoPoint(self * point.0)
931 }
932}
933
934impl RistrettoPoint {
935 pub fn mul_base(scalar: &Scalar) -> Self {
940 #[cfg(not(feature = "precomputed-tables"))]
941 {
942 scalar * constants::RISTRETTO_BASEPOINT_POINT
943 }
944
945 #[cfg(feature = "precomputed-tables")]
946 {
947 scalar * constants::RISTRETTO_BASEPOINT_TABLE
948 }
949 }
950}
951
952define_mul_assign_variants!(LHS = RistrettoPoint, RHS = Scalar);
953
954define_mul_variants!(LHS = RistrettoPoint, RHS = Scalar, Output = RistrettoPoint);
955define_mul_variants!(LHS = Scalar, RHS = RistrettoPoint, Output = RistrettoPoint);
956
957#[cfg(feature = "alloc")]
965impl MultiscalarMul for RistrettoPoint {
966 type Point = RistrettoPoint;
967
968 fn multiscalar_mul<I, J>(scalars: I, points: J) -> RistrettoPoint
969 where
970 I: IntoIterator,
971 I::Item: Borrow<Scalar>,
972 J: IntoIterator,
973 J::Item: Borrow<RistrettoPoint>,
974 {
975 let extended_points = points.into_iter().map(|P| P.borrow().0);
976 RistrettoPoint(EdwardsPoint::multiscalar_mul(scalars, extended_points))
977 }
978}
979
980#[cfg(feature = "alloc")]
981impl VartimeMultiscalarMul for RistrettoPoint {
982 type Point = RistrettoPoint;
983
984 fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<RistrettoPoint>
985 where
986 I: IntoIterator,
987 I::Item: Borrow<Scalar>,
988 J: IntoIterator<Item = Option<RistrettoPoint>>,
989 {
990 let extended_points = points.into_iter().map(|opt_P| opt_P.map(|P| P.0));
991
992 EdwardsPoint::optional_multiscalar_mul(scalars, extended_points).map(RistrettoPoint)
993 }
994}
995
996#[cfg(feature = "alloc")]
1004pub struct VartimeRistrettoPrecomputation(crate::backend::VartimePrecomputedStraus);
1005
1006#[cfg(feature = "alloc")]
1007impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation {
1008 type Point = RistrettoPoint;
1009
1010 fn new<I>(static_points: I) -> Self
1011 where
1012 I: IntoIterator,
1013 I::Item: Borrow<Self::Point>,
1014 {
1015 Self(crate::backend::VartimePrecomputedStraus::new(
1016 static_points.into_iter().map(|P| P.borrow().0),
1017 ))
1018 }
1019
1020 fn len(&self) -> usize {
1021 self.0.len()
1022 }
1023
1024 fn is_empty(&self) -> bool {
1025 self.0.is_empty()
1026 }
1027
1028 fn optional_mixed_multiscalar_mul<I, J, K>(
1029 &self,
1030 static_scalars: I,
1031 dynamic_scalars: J,
1032 dynamic_points: K,
1033 ) -> Option<Self::Point>
1034 where
1035 I: IntoIterator,
1036 I::Item: Borrow<Scalar>,
1037 J: IntoIterator,
1038 J::Item: Borrow<Scalar>,
1039 K: IntoIterator<Item = Option<Self::Point>>,
1040 {
1041 self.0
1042 .optional_mixed_multiscalar_mul(
1043 static_scalars,
1044 dynamic_scalars,
1045 dynamic_points.into_iter().map(|P_opt| P_opt.map(|P| P.0)),
1046 )
1047 .map(RistrettoPoint)
1048 }
1049}
1050
1051impl RistrettoPoint {
1052 pub fn vartime_double_scalar_mul_basepoint(
1055 a: &Scalar,
1056 A: &RistrettoPoint,
1057 b: &Scalar,
1058 ) -> RistrettoPoint {
1059 RistrettoPoint(EdwardsPoint::vartime_double_scalar_mul_basepoint(
1060 a, &A.0, b,
1061 ))
1062 }
1063}
1064
1065#[cfg(feature = "precomputed-tables")]
1078#[derive(Clone)]
1079#[repr(transparent)]
1080pub struct RistrettoBasepointTable(pub(crate) EdwardsBasepointTable);
1081
1082#[cfg(feature = "precomputed-tables")]
1083impl<'b> Mul<&'b Scalar> for &RistrettoBasepointTable {
1084 type Output = RistrettoPoint;
1085
1086 fn mul(self, scalar: &'b Scalar) -> RistrettoPoint {
1087 RistrettoPoint(&self.0 * scalar)
1088 }
1089}
1090
1091#[cfg(feature = "precomputed-tables")]
1092impl<'a> Mul<&'a RistrettoBasepointTable> for &Scalar {
1093 type Output = RistrettoPoint;
1094
1095 fn mul(self, basepoint_table: &'a RistrettoBasepointTable) -> RistrettoPoint {
1096 RistrettoPoint(self * &basepoint_table.0)
1097 }
1098}
1099
1100#[cfg(feature = "precomputed-tables")]
1101impl RistrettoBasepointTable {
1102 pub fn create(basepoint: &RistrettoPoint) -> RistrettoBasepointTable {
1104 RistrettoBasepointTable(EdwardsBasepointTable::create(&basepoint.0))
1105 }
1106
1107 pub fn basepoint(&self) -> RistrettoPoint {
1109 RistrettoPoint(self.0.basepoint())
1110 }
1111}
1112
1113impl ConditionallySelectable for RistrettoPoint {
1118 fn conditional_select(
1143 a: &RistrettoPoint,
1144 b: &RistrettoPoint,
1145 choice: Choice,
1146 ) -> RistrettoPoint {
1147 RistrettoPoint(EdwardsPoint::conditional_select(&a.0, &b.0, choice))
1148 }
1149}
1150
1151impl Debug for CompressedRistretto {
1156 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1157 write!(f, "CompressedRistretto: {:?}", self.as_bytes())
1158 }
1159}
1160
1161impl Debug for RistrettoPoint {
1162 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1163 let coset = self.coset4();
1164 write!(
1165 f,
1166 "RistrettoPoint: coset \n{:?}\n{:?}\n{:?}\n{:?}",
1167 coset[0], coset[1], coset[2], coset[3]
1168 )
1169 }
1170}
1171
1172#[cfg(feature = "group")]
1179impl group::Group for RistrettoPoint {
1180 type Scalar = Scalar;
1181
1182 fn try_random<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
1183 let mut uniform_bytes = [0u8; 64];
1185 rng.try_fill_bytes(&mut uniform_bytes)?;
1186 Ok(RistrettoPoint::from_uniform_bytes(&uniform_bytes))
1187 }
1188
1189 fn identity() -> Self {
1190 Identity::identity()
1191 }
1192
1193 fn generator() -> Self {
1194 constants::RISTRETTO_BASEPOINT_POINT
1195 }
1196
1197 fn is_identity(&self) -> Choice {
1198 self.ct_eq(&Identity::identity())
1199 }
1200
1201 fn double(&self) -> Self {
1202 self + self
1203 }
1204}
1205
1206#[cfg(feature = "group")]
1207impl GroupEncoding for RistrettoPoint {
1208 type Repr = [u8; 32];
1209
1210 fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
1211 let (s_encoding_is_canonical, s_is_negative, s) =
1212 decompress::step_1(&CompressedRistretto(*bytes));
1213
1214 let s_is_valid = s_encoding_is_canonical & !s_is_negative;
1215
1216 let (ok, t_is_negative, y_is_zero, res) = decompress::step_2(s);
1217
1218 CtOption::new(res, s_is_valid & ok & !t_is_negative & !y_is_zero)
1219 }
1220
1221 fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
1222 Self::from_bytes(bytes)
1224 }
1225
1226 fn to_bytes(&self) -> Self::Repr {
1227 self.compress().to_bytes()
1228 }
1229}
1230
1231#[cfg(feature = "group")]
1232impl PrimeGroup for RistrettoPoint {}
1233
1234#[cfg(feature = "group")]
1236impl CofactorGroup for RistrettoPoint {
1237 type Subgroup = Self;
1238
1239 fn clear_cofactor(&self) -> Self::Subgroup {
1240 *self
1241 }
1242
1243 fn into_subgroup(self) -> CtOption<Self::Subgroup> {
1244 CtOption::new(self, Choice::from(1))
1245 }
1246
1247 fn is_torsion_free(&self) -> Choice {
1248 Choice::from(1)
1249 }
1250}
1251
1252#[cfg(feature = "zeroize")]
1257impl Zeroize for CompressedRistretto {
1258 fn zeroize(&mut self) {
1259 self.0.zeroize();
1260 }
1261}
1262
1263#[cfg(feature = "zeroize")]
1264impl Zeroize for RistrettoPoint {
1265 fn zeroize(&mut self) {
1266 self.0.zeroize();
1267 }
1268}
1269
1270#[cfg(test)]
1275mod test {
1276 use super::*;
1277 use crate::edwards::CompressedEdwardsY;
1278 #[cfg(feature = "rand_core")]
1279 use getrandom::{SysRng, rand_core::UnwrapErr};
1280 #[cfg(feature = "group")]
1281 use proptest::prelude::*;
1282
1283 #[test]
1284 #[cfg(feature = "serde")]
1285 fn serde_postcard_basepoint_roundtrip() {
1286 let encoded = postcard::to_allocvec(&constants::RISTRETTO_BASEPOINT_POINT).unwrap();
1287 let enc_compressed =
1288 postcard::to_allocvec(&constants::RISTRETTO_BASEPOINT_COMPRESSED).unwrap();
1289 assert_eq!(encoded, enc_compressed);
1290
1291 assert_eq!(encoded.len(), 32);
1293
1294 let dec_uncompressed: RistrettoPoint = postcard::from_bytes(&encoded).unwrap();
1295 let dec_compressed: CompressedRistretto = postcard::from_bytes(&encoded).unwrap();
1296
1297 assert_eq!(dec_uncompressed, constants::RISTRETTO_BASEPOINT_POINT);
1298 assert_eq!(dec_compressed, constants::RISTRETTO_BASEPOINT_COMPRESSED);
1299
1300 let raw_bytes = constants::RISTRETTO_BASEPOINT_COMPRESSED.as_bytes();
1304 let bp: RistrettoPoint = postcard::from_bytes(raw_bytes).unwrap();
1305 assert_eq!(bp, constants::RISTRETTO_BASEPOINT_POINT);
1306 }
1307
1308 #[test]
1309 fn scalarmult_ristrettopoint_works_both_ways() {
1310 let P = constants::RISTRETTO_BASEPOINT_POINT;
1311 let s = Scalar::from(999u64);
1312
1313 let P1 = P * s;
1314 let P2 = s * P;
1315
1316 assert!(P1.compress().as_bytes() == P2.compress().as_bytes());
1317 }
1318
1319 #[test]
1320 #[cfg(feature = "alloc")]
1321 fn impl_sum() {
1322 let BASE = constants::RISTRETTO_BASEPOINT_POINT;
1324
1325 let s1 = Scalar::from(999u64);
1326 let P1 = BASE * s1;
1327
1328 let s2 = Scalar::from(333u64);
1329 let P2 = BASE * s2;
1330
1331 let vec = vec![P1, P2];
1332 let sum: RistrettoPoint = vec.iter().sum();
1333
1334 assert_eq!(sum, P1 + P2);
1335
1336 let empty_vector: Vec<RistrettoPoint> = vec![];
1338 let sum: RistrettoPoint = empty_vector.iter().sum();
1339
1340 assert_eq!(sum, RistrettoPoint::identity());
1341
1342 let s = Scalar::from(2u64);
1344 let mapped = vec.iter().map(|x| x * s);
1345 let sum: RistrettoPoint = mapped.sum();
1346
1347 assert_eq!(sum, P1 * s + P2 * s);
1348 }
1349
1350 #[test]
1351 fn decompress_negative_s_fails() {
1352 let bad_compressed = CompressedRistretto(constants::EDWARDS_D.to_bytes());
1354 assert!(bad_compressed.decompress().is_none());
1355 }
1356
1357 #[test]
1358 fn decompress_id() {
1359 let compressed_id = CompressedRistretto::identity();
1360 let id = compressed_id.decompress().unwrap();
1361 let mut identity_in_coset = false;
1362 for P in &id.coset4() {
1363 if P.compress() == CompressedEdwardsY::identity() {
1364 identity_in_coset = true;
1365 }
1366 }
1367 assert!(identity_in_coset);
1368 }
1369
1370 #[test]
1371 fn compress_id() {
1372 let id = RistrettoPoint::identity();
1373 assert_eq!(id.compress(), CompressedRistretto::identity());
1374 }
1375
1376 #[test]
1377 fn basepoint_roundtrip() {
1378 let bp_compressed_ristretto = constants::RISTRETTO_BASEPOINT_POINT.compress();
1379 let bp_recaf = bp_compressed_ristretto.decompress().unwrap().0;
1380 let diff = constants::RISTRETTO_BASEPOINT_POINT.0 - bp_recaf;
1382 let diff4 = diff.mul_by_pow_2(2);
1383 assert_eq!(diff4.compress(), CompressedEdwardsY::identity());
1384 }
1385
1386 #[test]
1387 fn encodings_of_small_multiples_of_basepoint() {
1388 let compressed = [
1391 CompressedRistretto([
1392 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1393 0, 0, 0, 0,
1394 ]),
1395 CompressedRistretto([
1396 226, 242, 174, 10, 106, 188, 78, 113, 168, 132, 169, 97, 197, 0, 81, 95, 88, 227,
1397 11, 106, 165, 130, 221, 141, 182, 166, 89, 69, 224, 141, 45, 118,
1398 ]),
1399 CompressedRistretto([
1400 106, 73, 50, 16, 247, 73, 156, 209, 127, 236, 181, 16, 174, 12, 234, 35, 161, 16,
1401 232, 213, 185, 1, 248, 172, 173, 211, 9, 92, 115, 163, 185, 25,
1402 ]),
1403 CompressedRistretto([
1404 148, 116, 31, 93, 93, 82, 117, 94, 206, 79, 35, 240, 68, 238, 39, 213, 209, 234,
1405 30, 43, 209, 150, 180, 98, 22, 107, 22, 21, 42, 157, 2, 89,
1406 ]),
1407 CompressedRistretto([
1408 218, 128, 134, 39, 115, 53, 139, 70, 111, 250, 223, 224, 179, 41, 58, 179, 217,
1409 253, 83, 197, 234, 108, 149, 83, 88, 245, 104, 50, 45, 175, 106, 87,
1410 ]),
1411 CompressedRistretto([
1412 232, 130, 177, 49, 1, 107, 82, 193, 211, 51, 112, 128, 24, 124, 247, 104, 66, 62,
1413 252, 203, 181, 23, 187, 73, 90, 184, 18, 196, 22, 15, 244, 78,
1414 ]),
1415 CompressedRistretto([
1416 246, 71, 70, 211, 201, 43, 19, 5, 14, 216, 216, 2, 54, 167, 240, 0, 124, 59, 63,
1417 150, 47, 91, 167, 147, 209, 154, 96, 30, 187, 29, 244, 3,
1418 ]),
1419 CompressedRistretto([
1420 68, 245, 53, 32, 146, 110, 200, 31, 189, 90, 56, 120, 69, 190, 183, 223, 133, 169,
1421 106, 36, 236, 225, 135, 56, 189, 207, 166, 167, 130, 42, 23, 109,
1422 ]),
1423 CompressedRistretto([
1424 144, 50, 147, 216, 242, 40, 126, 190, 16, 226, 55, 77, 193, 165, 62, 11, 200, 135,
1425 229, 146, 105, 159, 2, 208, 119, 213, 38, 60, 221, 85, 96, 28,
1426 ]),
1427 CompressedRistretto([
1428 2, 98, 42, 206, 143, 115, 3, 163, 28, 175, 198, 63, 143, 196, 143, 220, 22, 225,
1429 200, 200, 210, 52, 178, 240, 214, 104, 82, 130, 169, 7, 96, 49,
1430 ]),
1431 CompressedRistretto([
1432 32, 112, 111, 215, 136, 178, 114, 10, 30, 210, 165, 218, 212, 149, 43, 1, 244, 19,
1433 188, 240, 231, 86, 77, 232, 205, 200, 22, 104, 158, 45, 185, 95,
1434 ]),
1435 CompressedRistretto([
1436 188, 232, 63, 139, 165, 221, 47, 165, 114, 134, 76, 36, 186, 24, 16, 249, 82, 43,
1437 198, 0, 74, 254, 149, 135, 122, 199, 50, 65, 202, 253, 171, 66,
1438 ]),
1439 CompressedRistretto([
1440 228, 84, 158, 225, 107, 154, 160, 48, 153, 202, 32, 140, 103, 173, 175, 202, 250,
1441 76, 63, 62, 78, 83, 3, 222, 96, 38, 227, 202, 143, 248, 68, 96,
1442 ]),
1443 CompressedRistretto([
1444 170, 82, 224, 0, 223, 46, 22, 245, 95, 177, 3, 47, 195, 59, 196, 39, 66, 218, 214,
1445 189, 90, 143, 192, 190, 1, 103, 67, 108, 89, 72, 80, 31,
1446 ]),
1447 CompressedRistretto([
1448 70, 55, 107, 128, 244, 9, 178, 157, 194, 181, 246, 240, 197, 37, 145, 153, 8, 150,
1449 229, 113, 111, 65, 71, 124, 211, 0, 133, 171, 127, 16, 48, 30,
1450 ]),
1451 CompressedRistretto([
1452 224, 196, 24, 247, 200, 217, 196, 205, 215, 57, 91, 147, 234, 18, 79, 58, 217, 144,
1453 33, 187, 104, 29, 252, 51, 2, 169, 217, 154, 46, 83, 230, 78,
1454 ]),
1455 ];
1456 let mut bp = RistrettoPoint::identity();
1457 for point in compressed {
1458 assert_eq!(bp.compress(), point);
1459 bp += constants::RISTRETTO_BASEPOINT_POINT;
1460 }
1461 }
1462
1463 #[test]
1464 fn four_torsion_basepoint() {
1465 let bp = constants::RISTRETTO_BASEPOINT_POINT;
1466 let bp_coset = bp.coset4();
1467 for point in bp_coset {
1468 assert_eq!(bp, RistrettoPoint(point));
1469 }
1470 }
1471
1472 #[cfg(feature = "rand_core")]
1473 #[test]
1474 fn four_torsion_random() {
1475 let mut rng = UnwrapErr(SysRng);
1476 let P = RistrettoPoint::mul_base(&Scalar::random(&mut rng));
1477 let P_coset = P.coset4();
1478 for point in P_coset {
1479 assert_eq!(P, RistrettoPoint(point));
1480 }
1481 }
1482
1483 #[cfg(feature = "rand_core")]
1484 #[test]
1485 fn random_roundtrip() {
1486 let mut rng = UnwrapErr(SysRng);
1487 for _ in 0..100 {
1488 let P = RistrettoPoint::mul_base(&Scalar::random(&mut rng));
1489 let compressed_P = P.compress();
1490 let Q = compressed_P.decompress().unwrap();
1491 assert_eq!(P, Q);
1492 }
1493 }
1494
1495 #[test]
1496 #[cfg(all(feature = "alloc", feature = "rand_core", feature = "group"))]
1497 fn double_and_compress_1024_random_points() {
1498 use group::Group;
1499 let mut rng = SysRng;
1500
1501 let mut points: Vec<RistrettoPoint> = (0..1024)
1502 .map(|_| RistrettoPoint::try_random(&mut rng).unwrap())
1503 .collect();
1504 points[500] = <RistrettoPoint as Group>::identity();
1505
1506 let compressed = RistrettoPoint::double_and_compress_batch(&points);
1507
1508 for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
1509 assert_eq!(*P2_compressed, (P + P).compress());
1510 }
1511 }
1512
1513 #[cfg(feature = "group")]
1514 proptest! {
1515 #[test]
1516 fn multiply_double_and_compress_random_points(
1517 p1 in any::<[u8; 64]>(),
1518 p2 in any::<[u8; 64]>(),
1519 s1 in any::<[u8; 32]>(),
1520 s2 in any::<[u8; 32]>(),
1521 ) {
1522 use group::Group;
1523
1524 let scalars = [
1525 Scalar::from_bytes_mod_order(s1),
1526 Scalar::ZERO,
1527 Scalar::from_bytes_mod_order(s2),
1528 ];
1529
1530 let points = [
1531 RistrettoPoint::from_uniform_bytes(&p1),
1532 <RistrettoPoint as Group>::identity(),
1533 RistrettoPoint::from_uniform_bytes(&p2),
1534 ];
1535
1536 let multiplied_points: [_; 3] =
1537 core::array::from_fn(|i| scalars[i].div_by_2() * points[i]);
1538 let compressed = RistrettoPoint::double_and_compress_batch(&multiplied_points);
1539
1540 for ((s, P), P2_compressed) in scalars.iter().zip(points).zip(compressed) {
1541 prop_assert_eq!(P2_compressed, (s * P).compress());
1542 }
1543 }
1544 }
1545
1546 #[test]
1547 #[cfg(all(feature = "alloc", feature = "rand_core"))]
1548 fn vartime_precomputed_vs_nonprecomputed_multiscalar() {
1549 let mut rng = UnwrapErr(SysRng);
1550
1551 let static_scalars = (0..128)
1552 .map(|_| Scalar::random(&mut rng))
1553 .collect::<Vec<_>>();
1554
1555 let dynamic_scalars = (0..128)
1556 .map(|_| Scalar::random(&mut rng))
1557 .collect::<Vec<_>>();
1558
1559 let check_scalar: Scalar = static_scalars
1560 .iter()
1561 .chain(dynamic_scalars.iter())
1562 .map(|s| s * s)
1563 .sum();
1564
1565 let static_points = static_scalars
1566 .iter()
1567 .map(RistrettoPoint::mul_base)
1568 .collect::<Vec<_>>();
1569 let dynamic_points = dynamic_scalars
1570 .iter()
1571 .map(RistrettoPoint::mul_base)
1572 .collect::<Vec<_>>();
1573
1574 let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1575
1576 assert_eq!(precomputation.len(), 128);
1577 assert!(!precomputation.is_empty());
1578
1579 let P = precomputation.vartime_mixed_multiscalar_mul(
1580 &static_scalars,
1581 &dynamic_scalars,
1582 &dynamic_points,
1583 );
1584
1585 use crate::traits::VartimeMultiscalarMul;
1586 let Q = RistrettoPoint::vartime_multiscalar_mul(
1587 static_scalars.iter().chain(dynamic_scalars.iter()),
1588 static_points.iter().chain(dynamic_points.iter()),
1589 );
1590
1591 let R = RistrettoPoint::mul_base(&check_scalar);
1592
1593 assert_eq!(P.compress(), R.compress());
1594 assert_eq!(Q.compress(), R.compress());
1595 }
1596
1597 #[test]
1598 #[cfg(all(feature = "alloc", feature = "rand_core"))]
1599 fn partial_precomputed_mixed_multiscalar_empty() {
1600 let mut rng = UnwrapErr(SysRng);
1601
1602 let n_static = 16;
1603 let n_dynamic = 8;
1604
1605 let static_points = (0..n_static)
1606 .map(|_| RistrettoPoint::random(&mut rng))
1607 .collect::<Vec<_>>();
1608
1609 let static_scalars = Vec::new();
1611
1612 let dynamic_points = (0..n_dynamic)
1613 .map(|_| RistrettoPoint::random(&mut rng))
1614 .collect::<Vec<_>>();
1615
1616 let dynamic_scalars = (0..n_dynamic)
1617 .map(|_| Scalar::random(&mut rng))
1618 .collect::<Vec<_>>();
1619
1620 let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1622 let result_multiscalar = precomputation.vartime_mixed_multiscalar_mul(
1623 &static_scalars,
1624 &dynamic_scalars,
1625 &dynamic_points,
1626 );
1627
1628 let mut result_manual = RistrettoPoint::identity();
1630 for i in 0..static_scalars.len() {
1631 result_manual += static_points[i] * static_scalars[i];
1632 }
1633 for i in 0..n_dynamic {
1634 result_manual += dynamic_points[i] * dynamic_scalars[i];
1635 }
1636
1637 assert_eq!(result_multiscalar, result_manual);
1638 }
1639
1640 #[test]
1641 #[cfg(all(feature = "alloc", feature = "rand_core"))]
1642 fn partial_precomputed_mixed_multiscalar() {
1643 let mut rng = UnwrapErr(SysRng);
1644
1645 let n_static = 16;
1646 let n_dynamic = 8;
1647
1648 let static_points = (0..n_static)
1649 .map(|_| RistrettoPoint::random(&mut rng))
1650 .collect::<Vec<_>>();
1651
1652 let static_scalars = (0..n_static - 1)
1654 .map(|_| Scalar::random(&mut rng))
1655 .collect::<Vec<_>>();
1656
1657 let dynamic_points = (0..n_dynamic)
1658 .map(|_| RistrettoPoint::random(&mut rng))
1659 .collect::<Vec<_>>();
1660
1661 let dynamic_scalars = (0..n_dynamic)
1662 .map(|_| Scalar::random(&mut rng))
1663 .collect::<Vec<_>>();
1664
1665 let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1667 let result_multiscalar = precomputation.vartime_mixed_multiscalar_mul(
1668 &static_scalars,
1669 &dynamic_scalars,
1670 &dynamic_points,
1671 );
1672
1673 let mut result_manual = RistrettoPoint::identity();
1675 for i in 0..static_scalars.len() {
1676 result_manual += static_points[i] * static_scalars[i];
1677 }
1678 for i in 0..n_dynamic {
1679 result_manual += dynamic_points[i] * dynamic_scalars[i];
1680 }
1681
1682 assert_eq!(result_multiscalar, result_manual);
1683 }
1684
1685 #[test]
1686 #[cfg(all(feature = "alloc", feature = "rand_core"))]
1687 fn partial_precomputed_multiscalar() {
1688 let mut rng = UnwrapErr(SysRng);
1689
1690 let n_static = 16;
1691
1692 let static_points = (0..n_static)
1693 .map(|_| RistrettoPoint::random(&mut rng))
1694 .collect::<Vec<_>>();
1695
1696 let static_scalars = (0..n_static - 1)
1698 .map(|_| Scalar::random(&mut rng))
1699 .collect::<Vec<_>>();
1700
1701 let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1703 let result_multiscalar = precomputation.vartime_multiscalar_mul(&static_scalars);
1704
1705 let mut result_manual = RistrettoPoint::identity();
1707 for i in 0..static_scalars.len() {
1708 result_manual += static_points[i] * static_scalars[i];
1709 }
1710
1711 assert_eq!(result_multiscalar, result_manual);
1712 }
1713
1714 #[test]
1715 #[cfg(all(feature = "alloc", feature = "rand_core"))]
1716 fn partial_precomputed_multiscalar_empty() {
1717 let mut rng = UnwrapErr(SysRng);
1718
1719 let n_static = 16;
1720
1721 let static_points = (0..n_static)
1722 .map(|_| RistrettoPoint::random(&mut rng))
1723 .collect::<Vec<_>>();
1724
1725 let static_scalars = Vec::new();
1727
1728 let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1730 let result_multiscalar = precomputation.vartime_multiscalar_mul(&static_scalars);
1731
1732 let mut result_manual = RistrettoPoint::identity();
1734 for i in 0..static_scalars.len() {
1735 result_manual += static_points[i] * static_scalars[i];
1736 }
1737
1738 assert_eq!(result_multiscalar, result_manual);
1739 }
1740}