Skip to main content

curve25519_dalek/
ristretto.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2016-2021 isis lovecruft
5// Copyright (c) 2016-2020 Henry de Valence
6// See LICENSE for licensing information.
7//
8// Authors:
9// - isis agora lovecruft <[email protected]>
10// - Henry de Valence <[email protected]>
11
12// We allow non snake_case names because coordinates in projective space are
13// traditionally denoted by the capitalisation of their respective
14// counterparts in affine space.  Yeah, you heard me, rustc, I'm gonna have my
15// affine and projective cakes and eat both of them too.
16#![allow(non_snake_case)]
17
18//! An implementation of [Ristretto][ristretto_main], which provides a
19//! prime-order group.
20//!
21//! # The Ristretto Group
22//!
23//! Ristretto is a modification of Mike Hamburg's Decaf scheme to work
24//! with cofactor-\\(8\\) curves, such as Curve25519.
25//!
26//! The introduction of the Decaf paper, [_Decaf:
27//! Eliminating cofactors through point
28//! compression_](https://eprint.iacr.org/2015/673.pdf), notes that while
29//! most cryptographic systems require a group of prime order, most
30//! concrete implementations using elliptic curve groups fall short –
31//! they either provide a group of prime order, but with incomplete or
32//! variable-time addition formulae (for instance, most Weierstrass
33//! models), or else they provide a fast and safe implementation of a
34//! group whose order is not quite a prime \\(q\\), but \\(hq\\) for a
35//! small cofactor \\(h\\) (for instance, Edwards curves, which have
36//! cofactor at least \\(4\\)).
37//!
38//! This abstraction mismatch is commonly “handled” by pushing the
39//! complexity upwards, adding ad-hoc protocol modifications.  But
40//! these modifications require careful analysis and are a recurring
41//! source of [vulnerabilities][cryptonote] and [design
42//! complications][ed25519_hkd].
43//!
44//! Instead, Decaf (and Ristretto) use a quotient group to implement a
45//! prime-order group using a non-prime-order curve.  This provides
46//! the correct abstraction for cryptographic systems, while retaining
47//! the speed and safety benefits of an Edwards curve.
48//!
49//! Decaf is named “after the procedure which divides the effect of
50//! coffee by \\(4\\)”.  However, Curve25519 has a cofactor of
51//! \\(8\\).  To eliminate its cofactor, Ristretto restricts further;
52//! this [additional restriction][ristretto_coffee] gives the
53//! _Ristretto_ encoding.
54//!
55//! More details on why Ristretto is necessary can be found in the
56//! [Why Ristretto?][why_ristretto] section of the Ristretto website.
57//!
58//! Ristretto
59//! points are provided in `curve25519-dalek` by the `RistrettoPoint`
60//! struct.
61//!
62//! ## Encoding and Decoding
63//!
64//! Encoding is done by converting to and from a `CompressedRistretto`
65//! struct, which is a typed wrapper around `[u8; 32]`.
66//!
67//! The encoding is not batchable, but it is possible to
68//! double-and-encode in a batch using
69//! `RistrettoPoint::double_and_compress_batch`.
70//!
71//! ## Equality Testing
72//!
73//! Testing equality of points on an Edwards curve in projective
74//! coordinates requires an expensive inversion.  By contrast, equality
75//! checking in the Ristretto group can be done in projective
76//! coordinates without requiring an inversion, so it is much faster.
77//!
78//! The `RistrettoPoint` struct implements the
79//! [`subtle::ConstantTimeEq`] trait for constant-time equality
80//! checking, and also uses this to ensure `Eq` equality checking
81//! runs in constant time.
82//!
83//! ## Scalars
84//!
85//! Scalars are represented by the `Scalar` struct.  Each scalar has a
86//! canonical representative mod the group order.  To attempt to load
87//! a supposedly-canonical scalar, use
88//! `Scalar::from_canonical_bytes()`. To check whether a
89//! representative is canonical, use `Scalar::is_canonical()`.
90//!
91//! ## Scalar Multiplication
92//!
93//! Scalar multiplication on Ristretto points is provided by:
94//!
95//! * the `*` operator between a `Scalar` and a `RistrettoPoint`, which
96//!   performs constant-time variable-base scalar multiplication;
97//!
98//! * the `*` operator between a `Scalar` and a
99//!   `RistrettoBasepointTable`, which performs constant-time fixed-base
100//!   scalar multiplication;
101//!
102//! * an implementation of the
103//!   [`MultiscalarMul`](../traits/trait.MultiscalarMul.html) trait for
104//!   constant-time variable-base multiscalar multiplication;
105//!
106//! * an implementation of the
107//!   [`VartimeMultiscalarMul`](../traits/trait.VartimeMultiscalarMul.html)
108//!   trait for variable-time variable-base multiscalar multiplication;
109//!
110//! ## Random Points and Hashing to Ristretto
111//!
112//! The Ristretto group comes equipped with an Elligator map.  This is
113//! used to implement
114//!
115//! * `RistrettoPoint::random()`, which generates random points from an
116//!   RNG - enabled by `rand_core` feature;
117//!
118//! * `RistrettoPoint::from_hash()` and
119//!   `RistrettoPoint::hash_from_bytes()`, which perform hashing to the
120//!   group.
121//!
122//! ## Implementation
123//!
124//! The Decaf suggestion is to use a quotient group, such as \\(\mathcal
125//! E / \mathcal E\[4\]\\) or \\(2 \mathcal E / \mathcal E\[2\] \\), to
126//! implement a prime-order group using a non-prime-order curve.
127//!
128//! This requires only changing
129//!
130//! 1. the function for equality checking (so that two representatives
131//!    of the same coset are considered equal);
132//! 2. the function for encoding (so that two representatives of the
133//!    same coset are encoded as identical bitstrings);
134//! 3. the function for decoding (so that only the canonical encoding of
135//!    a coset is accepted).
136//!
137//! Internally, each coset is represented by a curve point; two points
138//! \\( P, Q \\) may represent the same coset in the same way that two
139//! points with different \\(X,Y,Z\\) coordinates may represent the
140//! same point.  The group operations are carried out with no overhead
141//! using Edwards formulas.
142//!
143//! Notes on the details of the encoding can be found in the
144//! [Details][ristretto_notes] section of the Ristretto website.
145//!
146//! [cryptonote]:
147//! https://moderncrypto.org/mail-archive/curves/2017/000898.html
148//! [ed25519_hkd]:
149//! https://moderncrypto.org/mail-archive/curves/2017/000858.html
150//! [ristretto_coffee]:
151//! https://en.wikipedia.org/wiki/Ristretto
152//! [ristretto_notes]:
153//! https://ristretto.group/details/index.html
154//! [why_ristretto]:
155//! https://ristretto.group/why_ristretto.html
156//! [ristretto_main]:
157//! https://ristretto.group/
158
159mod elligator;
160
161#[cfg(feature = "alloc")]
162use alloc::vec::Vec;
163
164use core::array::TryFromSliceError;
165use core::borrow::Borrow;
166use core::fmt::Debug;
167use core::iter::Sum;
168use core::ops::{Add, Neg, Sub};
169use core::ops::{AddAssign, SubAssign};
170use core::ops::{Mul, MulAssign};
171
172#[cfg(feature = "digest")]
173use digest::Digest;
174#[cfg(feature = "digest")]
175use digest::array::typenum::U64;
176
177use crate::constants;
178use crate::field::FieldElement;
179
180#[cfg(feature = "group")]
181use {
182    group::{GroupEncoding, cofactor::CofactorGroup, prime::PrimeGroup},
183    rand_core::TryRng,
184    subtle::CtOption,
185};
186
187#[cfg(feature = "rand_core")]
188use {
189    core::convert::Infallible,
190    rand_core::{CryptoRng, TryCryptoRng},
191};
192
193use subtle::Choice;
194use subtle::ConditionallyNegatable;
195use subtle::ConditionallySelectable;
196use subtle::ConstantTimeEq;
197
198#[cfg(feature = "zeroize")]
199use zeroize::Zeroize;
200
201#[cfg(feature = "precomputed-tables")]
202use crate::edwards::EdwardsBasepointTable;
203use crate::edwards::EdwardsPoint;
204
205use crate::scalar::Scalar;
206
207#[cfg(feature = "precomputed-tables")]
208use crate::traits::BasepointTable;
209use crate::traits::Identity;
210#[cfg(feature = "alloc")]
211use crate::traits::{MultiscalarMul, VartimeMultiscalarMul, VartimePrecomputedMultiscalarMul};
212
213// ------------------------------------------------------------------------
214// Compressed points
215// ------------------------------------------------------------------------
216
217/// A Ristretto point, in compressed wire format.
218///
219/// The Ristretto encoding is canonical, so two points are equal if and
220/// only if their encodings are equal.
221#[allow(clippy::derived_hash_with_manual_eq)]
222#[derive(Copy, Clone, Hash)]
223pub struct CompressedRistretto(pub [u8; 32]);
224
225impl Eq for CompressedRistretto {}
226impl PartialEq for CompressedRistretto {
227    fn eq(&self, other: &Self) -> bool {
228        self.ct_eq(other).into()
229    }
230}
231
232impl ConstantTimeEq for CompressedRistretto {
233    fn ct_eq(&self, other: &CompressedRistretto) -> Choice {
234        self.as_bytes().ct_eq(other.as_bytes())
235    }
236}
237
238impl CompressedRistretto {
239    /// Copy the bytes of this `CompressedRistretto`.
240    pub const fn to_bytes(&self) -> [u8; 32] {
241        self.0
242    }
243
244    /// View this `CompressedRistretto` as an array of bytes.
245    pub const fn as_bytes(&self) -> &[u8; 32] {
246        &self.0
247    }
248
249    /// Construct a `CompressedRistretto` from a slice of bytes.
250    ///
251    /// # Errors
252    ///
253    /// Returns [`TryFromSliceError`] if the input `bytes` slice does not have
254    /// a length of 32.
255    pub fn from_slice(bytes: &[u8]) -> Result<CompressedRistretto, TryFromSliceError> {
256        bytes.try_into().map(CompressedRistretto)
257    }
258
259    /// Attempt to decompress to an `RistrettoPoint`.
260    ///
261    /// # Return
262    ///
263    /// - `Some(RistrettoPoint)` if `self` was the canonical encoding of a point;
264    ///
265    /// - `None` if `self` was not the canonical encoding of a point.
266    pub fn decompress(&self) -> Option<RistrettoPoint> {
267        let (s_encoding_is_canonical, s_is_negative, s) = decompress::step_1(self);
268
269        if (!s_encoding_is_canonical | s_is_negative).into() {
270            return None;
271        }
272
273        let (ok, t_is_negative, y_is_zero, res) = decompress::step_2(s);
274
275        if (!ok | t_is_negative | y_is_zero).into() {
276            None
277        } else {
278            Some(res)
279        }
280    }
281}
282
283mod decompress {
284    use super::*;
285
286    pub(super) fn step_1(repr: &CompressedRistretto) -> (Choice, Choice, FieldElement) {
287        // Step 1. Check s for validity:
288        // 1.a) s must be 32 bytes (we get this from the type system)
289        // 1.b) s < p
290        // 1.c) s is nonnegative
291        //
292        // Our decoding routine ignores the high bit, so the only
293        // possible failure for 1.b) is if someone encodes s in 0..18
294        // as s+p in 2^255-19..2^255-1.  We can check this by
295        // converting back to bytes, and checking that we get the
296        // original input, since our encoding routine is canonical.
297
298        let s = FieldElement::from_bytes(repr.as_bytes());
299        let s_bytes_check = s.to_bytes();
300        let s_encoding_is_canonical = s_bytes_check[..].ct_eq(repr.as_bytes());
301        let s_is_negative = s.is_negative();
302
303        (s_encoding_is_canonical, s_is_negative, s)
304    }
305
306    pub(super) fn step_2(s: FieldElement) -> (Choice, Choice, Choice, RistrettoPoint) {
307        // Step 2.  Compute (X:Y:Z:T).
308        let one = FieldElement::ONE;
309        let ss = s.square();
310        let u1 = &one - &ss; //  1 + as²
311        let u2 = &one + &ss; //  1 - as²    where a=-1
312        let u2_sqr = u2.square(); // (1 - as²)²
313
314        // v == ad(1+as²)² - (1-as²)²            where d=-121665/121666
315        let v = &(&(-&constants::EDWARDS_D) * &u1.square()) - &u2_sqr;
316
317        let (ok, I) = (&v * &u2_sqr).invsqrt(); // 1/sqrt(v*u_2²)
318
319        let Dx = &I * &u2; // 1/sqrt(v)
320        let Dy = &I * &(&Dx * &v); // 1/u2
321
322        // x == | 2s/sqrt(v) | == + sqrt(4s²/(ad(1+as²)² - (1-as²)²))
323        let mut x = &(&s + &s) * &Dx;
324        let x_neg = x.is_negative();
325        x.conditional_negate(x_neg);
326
327        // y == (1-as²)/(1+as²)
328        let y = &u1 * &Dy;
329
330        // t == ((1+as²) sqrt(4s²/(ad(1+as²)² - (1-as²)²)))/(1-as²)
331        let t = &x * &y;
332
333        (
334            ok,
335            t.is_negative(),
336            y.is_zero(),
337            RistrettoPoint(EdwardsPoint {
338                X: x,
339                Y: y,
340                Z: one,
341                T: t,
342            }),
343        )
344    }
345}
346
347impl Identity for CompressedRistretto {
348    fn identity() -> CompressedRistretto {
349        CompressedRistretto([0u8; 32])
350    }
351}
352
353impl Default for CompressedRistretto {
354    fn default() -> CompressedRistretto {
355        CompressedRistretto::identity()
356    }
357}
358
359impl TryFrom<&[u8]> for CompressedRistretto {
360    type Error = TryFromSliceError;
361
362    fn try_from(slice: &[u8]) -> Result<CompressedRistretto, TryFromSliceError> {
363        Self::from_slice(slice)
364    }
365}
366
367// ------------------------------------------------------------------------
368// Serde support
369// ------------------------------------------------------------------------
370// Serializes to and from `RistrettoPoint` directly, doing compression
371// and decompression internally.  This means that users can create
372// structs containing `RistrettoPoint`s and use Serde's derived
373// serializers to serialize those structures.
374
375#[cfg(feature = "serde")]
376use serde::de::Visitor;
377#[cfg(feature = "serde")]
378use serde::{Deserialize, Deserializer, Serialize, Serializer};
379
380#[cfg(feature = "serde")]
381impl Serialize for RistrettoPoint {
382    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
383    where
384        S: Serializer,
385    {
386        use serde::ser::SerializeTuple;
387        let mut tup = serializer.serialize_tuple(32)?;
388        for byte in self.compress().as_bytes().iter() {
389            tup.serialize_element(byte)?;
390        }
391        tup.end()
392    }
393}
394
395#[cfg(feature = "serde")]
396impl Serialize for CompressedRistretto {
397    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
398    where
399        S: Serializer,
400    {
401        use serde::ser::SerializeTuple;
402        let mut tup = serializer.serialize_tuple(32)?;
403        for byte in self.as_bytes().iter() {
404            tup.serialize_element(byte)?;
405        }
406        tup.end()
407    }
408}
409
410#[cfg(feature = "serde")]
411impl<'de> Deserialize<'de> for RistrettoPoint {
412    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
413    where
414        D: Deserializer<'de>,
415    {
416        struct RistrettoPointVisitor;
417
418        impl<'de> Visitor<'de> for RistrettoPointVisitor {
419            type Value = RistrettoPoint;
420
421            fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
422                formatter.write_str("a valid point in Ristretto format")
423            }
424
425            fn visit_seq<A>(self, mut seq: A) -> Result<RistrettoPoint, A::Error>
426            where
427                A: serde::de::SeqAccess<'de>,
428            {
429                let mut bytes = [0u8; 32];
430                #[allow(clippy::needless_range_loop)]
431                for i in 0..32 {
432                    bytes[i] = seq
433                        .next_element()?
434                        .ok_or_else(|| serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
435                }
436                CompressedRistretto(bytes)
437                    .decompress()
438                    .ok_or_else(|| serde::de::Error::custom("decompression failed"))
439            }
440        }
441
442        deserializer.deserialize_tuple(32, RistrettoPointVisitor)
443    }
444}
445
446#[cfg(feature = "serde")]
447impl<'de> Deserialize<'de> for CompressedRistretto {
448    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
449    where
450        D: Deserializer<'de>,
451    {
452        struct CompressedRistrettoVisitor;
453
454        impl<'de> Visitor<'de> for CompressedRistrettoVisitor {
455            type Value = CompressedRistretto;
456
457            fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
458                formatter.write_str("32 bytes of data")
459            }
460
461            fn visit_seq<A>(self, mut seq: A) -> Result<CompressedRistretto, A::Error>
462            where
463                A: serde::de::SeqAccess<'de>,
464            {
465                let mut bytes = [0u8; 32];
466                #[allow(clippy::needless_range_loop)]
467                for i in 0..32 {
468                    bytes[i] = seq
469                        .next_element()?
470                        .ok_or_else(|| serde::de::Error::invalid_length(i, &"expected 32 bytes"))?;
471                }
472                Ok(CompressedRistretto(bytes))
473            }
474        }
475
476        deserializer.deserialize_tuple(32, CompressedRistrettoVisitor)
477    }
478}
479
480// ------------------------------------------------------------------------
481// Internal point representations
482// ------------------------------------------------------------------------
483
484/// A `RistrettoPoint` represents a point in the Ristretto group for
485/// Curve25519.  Ristretto, a variant of Decaf, constructs a
486/// prime-order group as a quotient group of a subgroup of (the
487/// Edwards form of) Curve25519.
488///
489/// Internally, a `RistrettoPoint` is implemented as a wrapper type
490/// around `EdwardsPoint`, with custom equality, compression, and
491/// decompression routines to account for the quotient.  This means that
492/// operations on `RistrettoPoint`s are exactly as fast as operations on
493/// `EdwardsPoint`s.
494///
495#[derive(Copy, Clone)]
496pub struct RistrettoPoint(pub(crate) EdwardsPoint);
497
498impl RistrettoPoint {
499    /// Compress this point using the Ristretto encoding.
500    pub fn compress(&self) -> CompressedRistretto {
501        let mut X = self.0.X;
502        let mut Y = self.0.Y;
503        let Z = &self.0.Z;
504        let T = &self.0.T;
505
506        let u1 = &(Z + &Y) * &(Z - &Y);
507        let u2 = &X * &Y;
508        // Ignore return value since this is always square
509        let (_, invsqrt) = (&u1 * &u2.square()).invsqrt();
510        let i1 = &invsqrt * &u1;
511        let i2 = &invsqrt * &u2;
512        let z_inv = &i1 * &(&i2 * T);
513        let mut den_inv = i2;
514
515        let iX = &X * &constants::SQRT_M1;
516        let iY = &Y * &constants::SQRT_M1;
517        let ristretto_magic = &constants::INVSQRT_A_MINUS_D;
518        let enchanted_denominator = &i1 * ristretto_magic;
519
520        let rotate = (T * &z_inv).is_negative();
521
522        X.conditional_assign(&iY, rotate);
523        Y.conditional_assign(&iX, rotate);
524        den_inv.conditional_assign(&enchanted_denominator, rotate);
525
526        Y.conditional_negate((&X * &z_inv).is_negative());
527
528        let mut s = &den_inv * &(Z - &Y);
529        let s_is_negative = s.is_negative();
530        s.conditional_negate(s_is_negative);
531
532        CompressedRistretto(s.to_bytes())
533    }
534
535    /// Double-and-compress a batch of points.  The Ristretto encoding
536    /// is not batchable, since it requires an inverse square root.
537    ///
538    /// However, given input points \\( P\_1, \ldots, P\_n, \\)
539    /// it is possible to compute the encodings of their doubles \\(
540    /// \mathrm{enc}( \[2\]P\_1), \ldots, \mathrm{enc}( \[2\]P\_n ) \\)
541    /// in a batch.
542    ///
543    #[cfg_attr(feature = "rand_core", doc = "```")]
544    #[cfg_attr(not(feature = "rand_core"), doc = "```ignore")]
545    /// # use curve25519_dalek::ristretto::RistrettoPoint;
546    /// use getrandom::{SysRng, rand_core::UnwrapErr};
547    ///
548    /// # // Need fn main() here in comment so the doctest compiles
549    /// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
550    /// # fn main() {
551    /// let mut rng = UnwrapErr(SysRng);
552    ///
553    /// let points: Vec<RistrettoPoint> =
554    ///     (0..32).map(|_| RistrettoPoint::random(&mut rng)).collect();
555    ///
556    /// let compressed = RistrettoPoint::double_and_compress_batch(&points);
557    ///
558    /// for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
559    ///     assert_eq!(*P2_compressed, (P + P).compress());
560    /// }
561    /// # }
562    /// ```
563    #[cfg(feature = "alloc")]
564    pub fn double_and_compress_batch<'a, I>(points: I) -> Vec<CompressedRistretto>
565    where
566        I: IntoIterator<Item = &'a RistrettoPoint>,
567    {
568        #[derive(Copy, Clone, Debug)]
569        struct BatchCompressState {
570            e: FieldElement,
571            f: FieldElement,
572            g: FieldElement,
573            h: FieldElement,
574            eg: FieldElement,
575            fh: FieldElement,
576        }
577
578        impl BatchCompressState {
579            fn efgh(&self) -> FieldElement {
580                &self.eg * &self.fh
581            }
582        }
583
584        impl<'a> From<&'a RistrettoPoint> for BatchCompressState {
585            #[rustfmt::skip] // keep alignment of explanatory comments
586            fn from(P: &'a RistrettoPoint) -> BatchCompressState {
587                let XX = P.0.X.square();
588                let YY = P.0.Y.square();
589                let ZZ = P.0.Z.square();
590                let dTT = &P.0.T.square() * &constants::EDWARDS_D;
591
592                let e = &P.0.X * &(&P.0.Y + &P.0.Y); // = 2*X*Y
593                let f = &ZZ + &dTT;                  // = Z^2 + d*T^2
594                let g = &YY + &XX;                   // = Y^2 - a*X^2
595                let h = &ZZ - &dTT;                  // = Z^2 - d*T^2
596
597                let eg = &e * &g;
598                let fh = &f * &h;
599
600                BatchCompressState{ e, f, g, h, eg, fh }
601            }
602        }
603
604        let states: Vec<BatchCompressState> =
605            points.into_iter().map(BatchCompressState::from).collect();
606
607        let mut invs: Vec<FieldElement> = states.iter().map(|state| state.efgh()).collect();
608
609        FieldElement::invert_batch_alloc(&mut invs[..]);
610
611        states
612            .iter()
613            .zip(invs.iter())
614            .map(|(state, inv): (&BatchCompressState, &FieldElement)| {
615                let Zinv = &state.eg * inv;
616                let Tinv = &state.fh * inv;
617
618                let mut magic = constants::INVSQRT_A_MINUS_D;
619
620                let negcheck1 = (&state.eg * &Zinv).is_negative();
621
622                let mut e = state.e;
623                let mut g = state.g;
624                let mut h = state.h;
625
626                let minus_e = -&e;
627                let f_times_sqrta = &state.f * &constants::SQRT_M1;
628
629                e.conditional_assign(&state.g, negcheck1);
630                g.conditional_assign(&minus_e, negcheck1);
631                h.conditional_assign(&f_times_sqrta, negcheck1);
632
633                magic.conditional_assign(&constants::SQRT_M1, negcheck1);
634
635                let negcheck2 = (&(&h * &e) * &Zinv).is_negative();
636
637                g.conditional_negate(negcheck2);
638
639                let mut s = &(&h - &g) * &(&magic * &(&g * &Tinv));
640
641                let s_is_negative = s.is_negative();
642                s.conditional_negate(s_is_negative);
643
644                CompressedRistretto(s.to_bytes())
645            })
646            .collect()
647    }
648
649    /// Return the coset self + E\[4\], for debugging.
650    fn coset4(&self) -> [EdwardsPoint; 4] {
651        [
652            self.0,
653            self.0 + constants::EIGHT_TORSION[2],
654            self.0 + constants::EIGHT_TORSION[4],
655            self.0 + constants::EIGHT_TORSION[6],
656        ]
657    }
658
659    /// Return a `RistrettoPoint` chosen uniformly at random using a user-provided RNG.
660    ///
661    /// # Inputs
662    ///
663    /// * `rng`: any RNG which implements `CryptoRng` interface.
664    ///
665    /// # Returns
666    ///
667    /// A random element of the Ristretto group.
668    ///
669    /// # Implementation
670    ///
671    /// Uses the Ristretto-flavoured Elligator 2 map, so that the
672    /// discrete log of the output point with respect to any other
673    /// point should be unknown.  The map is applied twice and the
674    /// results are added, to ensure a uniform distribution.
675    #[cfg(feature = "rand_core")]
676    pub fn random<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
677        Self::try_random(rng)
678            .map_err(|_: Infallible| {})
679            .expect("[bug] unfallible rng failed")
680    }
681
682    /// Return a `RistrettoPoint` chosen uniformly at random using a user-provided RNG.
683    ///
684    /// # Inputs
685    ///
686    /// * `rng`: any RNG which implements `TryCryptoRng` interface.
687    ///
688    /// # Returns
689    ///
690    /// A random element of the Ristretto group.
691    ///
692    /// # Implementation
693    ///
694    /// Uses the Ristretto-flavoured Elligator 2 map, so that the
695    /// discrete log of the output point with respect to any other
696    /// point should be unknown.  The map is applied twice and the
697    /// results are added, to ensure a uniform distribution.
698    #[cfg(feature = "rand_core")]
699    pub fn try_random<R: TryCryptoRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
700        let mut uniform_bytes = [0u8; 64];
701        rng.try_fill_bytes(&mut uniform_bytes)?;
702
703        Ok(RistrettoPoint::from_uniform_bytes(&uniform_bytes))
704    }
705
706    #[cfg(feature = "digest")]
707    /// Hash a slice of bytes into a `RistrettoPoint`.
708    ///
709    /// Takes a type parameter `D`, which is any `Digest` producing 64
710    /// bytes of output.
711    ///
712    /// Convenience wrapper around `from_hash`.
713    ///
714    /// # Implementation
715    ///
716    /// Uses the Ristretto-flavoured Elligator 2 map, so that the
717    /// discrete log of the output point with respect to any other
718    /// point should be unknown.  The map is applied twice and the
719    /// results are added, to ensure a uniform distribution.
720    ///
721    /// # Example
722    ///
723    #[cfg_attr(feature = "digest", doc = "```")]
724    #[cfg_attr(not(feature = "digest"), doc = "```ignore")]
725    /// # use curve25519_dalek::ristretto::RistrettoPoint;
726    /// use sha2::Sha512;
727    ///
728    /// # // Need fn main() here in comment so the doctest compiles
729    /// # // See https://doc.rust-lang.org/book/documentation.html#documentation-as-tests
730    /// # fn main() {
731    /// let msg = "To really appreciate architecture, you may even need to commit a murder";
732    /// let P = RistrettoPoint::hash_from_bytes::<Sha512>(msg.as_bytes());
733    /// # }
734    /// ```
735    ///
736    pub fn hash_from_bytes<D>(input: &[u8]) -> RistrettoPoint
737    where
738        D: Digest<OutputSize = U64> + Default,
739    {
740        let mut hash = D::default();
741        hash.update(input);
742        RistrettoPoint::from_hash(hash)
743    }
744
745    #[cfg(feature = "digest")]
746    /// Construct a `RistrettoPoint` from an existing `Digest` instance.
747    ///
748    /// Use this instead of `hash_from_bytes` if it is more convenient
749    /// to stream data into the `Digest` than to pass a single byte
750    /// slice.
751    pub fn from_hash<D>(hash: D) -> RistrettoPoint
752    where
753        D: Digest<OutputSize = U64> + Default,
754    {
755        // dealing with generic arrays is clumsy, until const generics land
756        let output = hash.finalize();
757        let mut output_bytes = [0u8; 64];
758        output_bytes.copy_from_slice(output.as_slice());
759
760        RistrettoPoint::from_uniform_bytes(&output_bytes)
761    }
762
763    /// Construct a `RistrettoPoint` from 64 bytes of data.
764    ///
765    /// If the input bytes are uniformly distributed, the resulting
766    /// point will be uniformly distributed over the group, and its
767    /// discrete log with respect to other points should be unknown.
768    ///
769    /// # Implementation
770    ///
771    /// This function splits the input array into two 32-byte halves,
772    /// takes the low 255 bits of each half mod p, applies the
773    /// Ristretto-flavored Elligator map to each, and adds the results.
774    pub fn from_uniform_bytes(bytes: &[u8; 64]) -> RistrettoPoint {
775        // This follows the one-way map construction from the Ristretto RFC:
776        // https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-ristretto255-decaf448-04#section-4.3.4
777        let mut r_1_bytes = [0u8; 32];
778        r_1_bytes.copy_from_slice(&bytes[0..32]);
779        let r_1 = FieldElement::from_bytes(&r_1_bytes);
780        let R_1 = RistrettoPoint::elligator_ristretto_flavor(&r_1);
781
782        let mut r_2_bytes = [0u8; 32];
783        r_2_bytes.copy_from_slice(&bytes[32..64]);
784        let r_2 = FieldElement::from_bytes(&r_2_bytes);
785        let R_2 = RistrettoPoint::elligator_ristretto_flavor(&r_2);
786
787        // Applying Elligator twice and adding the results ensures a
788        // uniform distribution.
789        R_1 + R_2
790    }
791}
792
793impl Identity for RistrettoPoint {
794    fn identity() -> RistrettoPoint {
795        RistrettoPoint(EdwardsPoint::identity())
796    }
797}
798
799impl Default for RistrettoPoint {
800    fn default() -> RistrettoPoint {
801        RistrettoPoint::identity()
802    }
803}
804
805// ------------------------------------------------------------------------
806// Equality
807// ------------------------------------------------------------------------
808
809impl PartialEq for RistrettoPoint {
810    fn eq(&self, other: &RistrettoPoint) -> bool {
811        self.ct_eq(other).into()
812    }
813}
814
815impl ConstantTimeEq for RistrettoPoint {
816    /// Test equality between two `RistrettoPoint`s.
817    ///
818    /// # Returns
819    ///
820    /// * `Choice(1)` if the two `RistrettoPoint`s are equal;
821    /// * `Choice(0)` otherwise.
822    fn ct_eq(&self, other: &RistrettoPoint) -> Choice {
823        let X1Y2 = &self.0.X * &other.0.Y;
824        let Y1X2 = &self.0.Y * &other.0.X;
825        let X1X2 = &self.0.X * &other.0.X;
826        let Y1Y2 = &self.0.Y * &other.0.Y;
827
828        X1Y2.ct_eq(&Y1X2) | X1X2.ct_eq(&Y1Y2)
829    }
830}
831
832impl Eq for RistrettoPoint {}
833
834// ------------------------------------------------------------------------
835// Arithmetic
836// ------------------------------------------------------------------------
837
838impl<'a> Add<&'a RistrettoPoint> for &RistrettoPoint {
839    type Output = RistrettoPoint;
840
841    fn add(self, other: &'a RistrettoPoint) -> RistrettoPoint {
842        RistrettoPoint(self.0 + other.0)
843    }
844}
845
846define_add_variants!(
847    LHS = RistrettoPoint,
848    RHS = RistrettoPoint,
849    Output = RistrettoPoint
850);
851
852impl AddAssign<&RistrettoPoint> for RistrettoPoint {
853    fn add_assign(&mut self, _rhs: &RistrettoPoint) {
854        *self = (self as &RistrettoPoint) + _rhs;
855    }
856}
857
858define_add_assign_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint);
859
860impl<'a> Sub<&'a RistrettoPoint> for &RistrettoPoint {
861    type Output = RistrettoPoint;
862
863    fn sub(self, other: &'a RistrettoPoint) -> RistrettoPoint {
864        RistrettoPoint(self.0 - other.0)
865    }
866}
867
868define_sub_variants!(
869    LHS = RistrettoPoint,
870    RHS = RistrettoPoint,
871    Output = RistrettoPoint
872);
873
874impl SubAssign<&RistrettoPoint> for RistrettoPoint {
875    fn sub_assign(&mut self, _rhs: &RistrettoPoint) {
876        *self = (self as &RistrettoPoint) - _rhs;
877    }
878}
879
880define_sub_assign_variants!(LHS = RistrettoPoint, RHS = RistrettoPoint);
881
882impl<T> Sum<T> for RistrettoPoint
883where
884    T: Borrow<RistrettoPoint>,
885{
886    fn sum<I>(iter: I) -> Self
887    where
888        I: Iterator<Item = T>,
889    {
890        iter.fold(RistrettoPoint::identity(), |acc, item| acc + item.borrow())
891    }
892}
893
894impl Neg for &RistrettoPoint {
895    type Output = RistrettoPoint;
896
897    fn neg(self) -> RistrettoPoint {
898        RistrettoPoint(-&self.0)
899    }
900}
901
902impl Neg for RistrettoPoint {
903    type Output = RistrettoPoint;
904
905    fn neg(self) -> RistrettoPoint {
906        -&self
907    }
908}
909
910impl<'a> MulAssign<&'a Scalar> for RistrettoPoint {
911    fn mul_assign(&mut self, scalar: &'a Scalar) {
912        let result = (self as &RistrettoPoint) * scalar;
913        *self = result;
914    }
915}
916
917impl<'a> Mul<&'a Scalar> for &RistrettoPoint {
918    type Output = RistrettoPoint;
919    /// Scalar multiplication: compute `scalar * self`.
920    fn mul(self, scalar: &'a Scalar) -> RistrettoPoint {
921        RistrettoPoint(self.0 * scalar)
922    }
923}
924
925impl<'a> Mul<&'a RistrettoPoint> for &Scalar {
926    type Output = RistrettoPoint;
927
928    /// Scalar multiplication: compute `self * scalar`.
929    fn mul(self, point: &'a RistrettoPoint) -> RistrettoPoint {
930        RistrettoPoint(self * point.0)
931    }
932}
933
934impl RistrettoPoint {
935    /// Fixed-base scalar multiplication by the Ristretto base point.
936    ///
937    /// Uses precomputed basepoint tables when the `precomputed-tables` feature
938    /// is enabled, trading off increased code size for ~4x better performance.
939    pub fn mul_base(scalar: &Scalar) -> Self {
940        #[cfg(not(feature = "precomputed-tables"))]
941        {
942            scalar * constants::RISTRETTO_BASEPOINT_POINT
943        }
944
945        #[cfg(feature = "precomputed-tables")]
946        {
947            scalar * constants::RISTRETTO_BASEPOINT_TABLE
948        }
949    }
950}
951
952define_mul_assign_variants!(LHS = RistrettoPoint, RHS = Scalar);
953
954define_mul_variants!(LHS = RistrettoPoint, RHS = Scalar, Output = RistrettoPoint);
955define_mul_variants!(LHS = Scalar, RHS = RistrettoPoint, Output = RistrettoPoint);
956
957// ------------------------------------------------------------------------
958// Multiscalar Multiplication impls
959// ------------------------------------------------------------------------
960
961// These use iterator combinators to unwrap the underlying points and
962// forward to the EdwardsPoint implementations.
963
964#[cfg(feature = "alloc")]
965impl MultiscalarMul for RistrettoPoint {
966    type Point = RistrettoPoint;
967
968    fn multiscalar_mul<I, J>(scalars: I, points: J) -> RistrettoPoint
969    where
970        I: IntoIterator,
971        I::Item: Borrow<Scalar>,
972        J: IntoIterator,
973        J::Item: Borrow<RistrettoPoint>,
974    {
975        let extended_points = points.into_iter().map(|P| P.borrow().0);
976        RistrettoPoint(EdwardsPoint::multiscalar_mul(scalars, extended_points))
977    }
978}
979
980#[cfg(feature = "alloc")]
981impl VartimeMultiscalarMul for RistrettoPoint {
982    type Point = RistrettoPoint;
983
984    fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<RistrettoPoint>
985    where
986        I: IntoIterator,
987        I::Item: Borrow<Scalar>,
988        J: IntoIterator<Item = Option<RistrettoPoint>>,
989    {
990        let extended_points = points.into_iter().map(|opt_P| opt_P.map(|P| P.0));
991
992        EdwardsPoint::optional_multiscalar_mul(scalars, extended_points).map(RistrettoPoint)
993    }
994}
995
996/// Precomputation for variable-time multiscalar multiplication with `RistrettoPoint`s.
997///
998/// Note that for large numbers of `RistrettoPoint`s, this functionality may be less
999/// efficient than the corresponding `VartimeMultiscalarMul` implementation.
1000// This wraps the inner implementation in a facade type so that we can
1001// decouple stability of the inner type from the stability of the
1002// outer type.
1003#[cfg(feature = "alloc")]
1004pub struct VartimeRistrettoPrecomputation(crate::backend::VartimePrecomputedStraus);
1005
1006#[cfg(feature = "alloc")]
1007impl VartimePrecomputedMultiscalarMul for VartimeRistrettoPrecomputation {
1008    type Point = RistrettoPoint;
1009
1010    fn new<I>(static_points: I) -> Self
1011    where
1012        I: IntoIterator,
1013        I::Item: Borrow<Self::Point>,
1014    {
1015        Self(crate::backend::VartimePrecomputedStraus::new(
1016            static_points.into_iter().map(|P| P.borrow().0),
1017        ))
1018    }
1019
1020    fn len(&self) -> usize {
1021        self.0.len()
1022    }
1023
1024    fn is_empty(&self) -> bool {
1025        self.0.is_empty()
1026    }
1027
1028    fn optional_mixed_multiscalar_mul<I, J, K>(
1029        &self,
1030        static_scalars: I,
1031        dynamic_scalars: J,
1032        dynamic_points: K,
1033    ) -> Option<Self::Point>
1034    where
1035        I: IntoIterator,
1036        I::Item: Borrow<Scalar>,
1037        J: IntoIterator,
1038        J::Item: Borrow<Scalar>,
1039        K: IntoIterator<Item = Option<Self::Point>>,
1040    {
1041        self.0
1042            .optional_mixed_multiscalar_mul(
1043                static_scalars,
1044                dynamic_scalars,
1045                dynamic_points.into_iter().map(|P_opt| P_opt.map(|P| P.0)),
1046            )
1047            .map(RistrettoPoint)
1048    }
1049}
1050
1051impl RistrettoPoint {
1052    /// Compute \\(aA + bB\\) in variable time, where \\(B\\) is the
1053    /// Ristretto basepoint.
1054    pub fn vartime_double_scalar_mul_basepoint(
1055        a: &Scalar,
1056        A: &RistrettoPoint,
1057        b: &Scalar,
1058    ) -> RistrettoPoint {
1059        RistrettoPoint(EdwardsPoint::vartime_double_scalar_mul_basepoint(
1060            a, &A.0, b,
1061        ))
1062    }
1063}
1064
1065/// A precomputed table of multiples of a basepoint, used to accelerate
1066/// scalar multiplication.
1067///
1068/// A precomputed table of multiples of the Ristretto basepoint is
1069/// available in the `constants` module:
1070/// ```
1071/// use curve25519_dalek::constants::RISTRETTO_BASEPOINT_TABLE;
1072/// use curve25519_dalek::scalar::Scalar;
1073///
1074/// let a = Scalar::from(87329482u64);
1075/// let P = &a * RISTRETTO_BASEPOINT_TABLE;
1076/// ```
1077#[cfg(feature = "precomputed-tables")]
1078#[derive(Clone)]
1079#[repr(transparent)]
1080pub struct RistrettoBasepointTable(pub(crate) EdwardsBasepointTable);
1081
1082#[cfg(feature = "precomputed-tables")]
1083impl<'b> Mul<&'b Scalar> for &RistrettoBasepointTable {
1084    type Output = RistrettoPoint;
1085
1086    fn mul(self, scalar: &'b Scalar) -> RistrettoPoint {
1087        RistrettoPoint(&self.0 * scalar)
1088    }
1089}
1090
1091#[cfg(feature = "precomputed-tables")]
1092impl<'a> Mul<&'a RistrettoBasepointTable> for &Scalar {
1093    type Output = RistrettoPoint;
1094
1095    fn mul(self, basepoint_table: &'a RistrettoBasepointTable) -> RistrettoPoint {
1096        RistrettoPoint(self * &basepoint_table.0)
1097    }
1098}
1099
1100#[cfg(feature = "precomputed-tables")]
1101impl RistrettoBasepointTable {
1102    /// Create a precomputed table of multiples of the given `basepoint`.
1103    pub fn create(basepoint: &RistrettoPoint) -> RistrettoBasepointTable {
1104        RistrettoBasepointTable(EdwardsBasepointTable::create(&basepoint.0))
1105    }
1106
1107    /// Get the basepoint for this table as a `RistrettoPoint`.
1108    pub fn basepoint(&self) -> RistrettoPoint {
1109        RistrettoPoint(self.0.basepoint())
1110    }
1111}
1112
1113// ------------------------------------------------------------------------
1114// Constant-time conditional selection
1115// ------------------------------------------------------------------------
1116
1117impl ConditionallySelectable for RistrettoPoint {
1118    /// Conditionally select between `self` and `other`.
1119    ///
1120    /// # Example
1121    ///
1122    /// ```
1123    /// use subtle::ConditionallySelectable;
1124    /// use subtle::Choice;
1125    /// #
1126    /// # use curve25519_dalek::traits::Identity;
1127    /// # use curve25519_dalek::ristretto::RistrettoPoint;
1128    /// # use curve25519_dalek::constants;
1129    /// # fn main() {
1130    ///
1131    /// let A = RistrettoPoint::identity();
1132    /// let B = constants::RISTRETTO_BASEPOINT_POINT;
1133    ///
1134    /// let mut P = A;
1135    ///
1136    /// P = RistrettoPoint::conditional_select(&A, &B, Choice::from(0));
1137    /// assert_eq!(P, A);
1138    /// P = RistrettoPoint::conditional_select(&A, &B, Choice::from(1));
1139    /// assert_eq!(P, B);
1140    /// # }
1141    /// ```
1142    fn conditional_select(
1143        a: &RistrettoPoint,
1144        b: &RistrettoPoint,
1145        choice: Choice,
1146    ) -> RistrettoPoint {
1147        RistrettoPoint(EdwardsPoint::conditional_select(&a.0, &b.0, choice))
1148    }
1149}
1150
1151// ------------------------------------------------------------------------
1152// Debug traits
1153// ------------------------------------------------------------------------
1154
1155impl Debug for CompressedRistretto {
1156    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1157        write!(f, "CompressedRistretto: {:?}", self.as_bytes())
1158    }
1159}
1160
1161impl Debug for RistrettoPoint {
1162    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1163        let coset = self.coset4();
1164        write!(
1165            f,
1166            "RistrettoPoint: coset \n{:?}\n{:?}\n{:?}\n{:?}",
1167            coset[0], coset[1], coset[2], coset[3]
1168        )
1169    }
1170}
1171
1172// ------------------------------------------------------------------------
1173// group traits
1174// ------------------------------------------------------------------------
1175
1176// Use the full trait path to avoid Group::identity overlapping Identity::identity in the
1177// rest of the module (e.g. tests).
1178#[cfg(feature = "group")]
1179impl group::Group for RistrettoPoint {
1180    type Scalar = Scalar;
1181
1182    fn try_random<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
1183        // NOTE: this is duplicated due to different `rng` bounds
1184        let mut uniform_bytes = [0u8; 64];
1185        rng.try_fill_bytes(&mut uniform_bytes)?;
1186        Ok(RistrettoPoint::from_uniform_bytes(&uniform_bytes))
1187    }
1188
1189    fn identity() -> Self {
1190        Identity::identity()
1191    }
1192
1193    fn generator() -> Self {
1194        constants::RISTRETTO_BASEPOINT_POINT
1195    }
1196
1197    fn is_identity(&self) -> Choice {
1198        self.ct_eq(&Identity::identity())
1199    }
1200
1201    fn double(&self) -> Self {
1202        self + self
1203    }
1204}
1205
1206#[cfg(feature = "group")]
1207impl GroupEncoding for RistrettoPoint {
1208    type Repr = [u8; 32];
1209
1210    fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
1211        let (s_encoding_is_canonical, s_is_negative, s) =
1212            decompress::step_1(&CompressedRistretto(*bytes));
1213
1214        let s_is_valid = s_encoding_is_canonical & !s_is_negative;
1215
1216        let (ok, t_is_negative, y_is_zero, res) = decompress::step_2(s);
1217
1218        CtOption::new(res, s_is_valid & ok & !t_is_negative & !y_is_zero)
1219    }
1220
1221    fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
1222        // Just use the checked API; the checks we could skip aren't expensive.
1223        Self::from_bytes(bytes)
1224    }
1225
1226    fn to_bytes(&self) -> Self::Repr {
1227        self.compress().to_bytes()
1228    }
1229}
1230
1231#[cfg(feature = "group")]
1232impl PrimeGroup for RistrettoPoint {}
1233
1234/// Ristretto has a cofactor of 1.
1235#[cfg(feature = "group")]
1236impl CofactorGroup for RistrettoPoint {
1237    type Subgroup = Self;
1238
1239    fn clear_cofactor(&self) -> Self::Subgroup {
1240        *self
1241    }
1242
1243    fn into_subgroup(self) -> CtOption<Self::Subgroup> {
1244        CtOption::new(self, Choice::from(1))
1245    }
1246
1247    fn is_torsion_free(&self) -> Choice {
1248        Choice::from(1)
1249    }
1250}
1251
1252// ------------------------------------------------------------------------
1253// Zeroize traits
1254// ------------------------------------------------------------------------
1255
1256#[cfg(feature = "zeroize")]
1257impl Zeroize for CompressedRistretto {
1258    fn zeroize(&mut self) {
1259        self.0.zeroize();
1260    }
1261}
1262
1263#[cfg(feature = "zeroize")]
1264impl Zeroize for RistrettoPoint {
1265    fn zeroize(&mut self) {
1266        self.0.zeroize();
1267    }
1268}
1269
1270// ------------------------------------------------------------------------
1271// Tests
1272// ------------------------------------------------------------------------
1273
1274#[cfg(test)]
1275mod test {
1276    use super::*;
1277    use crate::edwards::CompressedEdwardsY;
1278    #[cfg(feature = "rand_core")]
1279    use getrandom::{SysRng, rand_core::UnwrapErr};
1280    #[cfg(feature = "group")]
1281    use proptest::prelude::*;
1282
1283    #[test]
1284    #[cfg(feature = "serde")]
1285    fn serde_postcard_basepoint_roundtrip() {
1286        let encoded = postcard::to_allocvec(&constants::RISTRETTO_BASEPOINT_POINT).unwrap();
1287        let enc_compressed =
1288            postcard::to_allocvec(&constants::RISTRETTO_BASEPOINT_COMPRESSED).unwrap();
1289        assert_eq!(encoded, enc_compressed);
1290
1291        // Check that the encoding is 32 bytes exactly
1292        assert_eq!(encoded.len(), 32);
1293
1294        let dec_uncompressed: RistrettoPoint = postcard::from_bytes(&encoded).unwrap();
1295        let dec_compressed: CompressedRistretto = postcard::from_bytes(&encoded).unwrap();
1296
1297        assert_eq!(dec_uncompressed, constants::RISTRETTO_BASEPOINT_POINT);
1298        assert_eq!(dec_compressed, constants::RISTRETTO_BASEPOINT_COMPRESSED);
1299
1300        // Check that the encoding itself matches the usual one.
1301        // serde::Deserialize on fixed-size arrays calls tuple deserialization. postcard
1302        // (de)serializes tuples by just doing each element and that's it.
1303        let raw_bytes = constants::RISTRETTO_BASEPOINT_COMPRESSED.as_bytes();
1304        let bp: RistrettoPoint = postcard::from_bytes(raw_bytes).unwrap();
1305        assert_eq!(bp, constants::RISTRETTO_BASEPOINT_POINT);
1306    }
1307
1308    #[test]
1309    fn scalarmult_ristrettopoint_works_both_ways() {
1310        let P = constants::RISTRETTO_BASEPOINT_POINT;
1311        let s = Scalar::from(999u64);
1312
1313        let P1 = P * s;
1314        let P2 = s * P;
1315
1316        assert!(P1.compress().as_bytes() == P2.compress().as_bytes());
1317    }
1318
1319    #[test]
1320    #[cfg(feature = "alloc")]
1321    fn impl_sum() {
1322        // Test that sum works for non-empty iterators
1323        let BASE = constants::RISTRETTO_BASEPOINT_POINT;
1324
1325        let s1 = Scalar::from(999u64);
1326        let P1 = BASE * s1;
1327
1328        let s2 = Scalar::from(333u64);
1329        let P2 = BASE * s2;
1330
1331        let vec = vec![P1, P2];
1332        let sum: RistrettoPoint = vec.iter().sum();
1333
1334        assert_eq!(sum, P1 + P2);
1335
1336        // Test that sum works for the empty iterator
1337        let empty_vector: Vec<RistrettoPoint> = vec![];
1338        let sum: RistrettoPoint = empty_vector.iter().sum();
1339
1340        assert_eq!(sum, RistrettoPoint::identity());
1341
1342        // Test that sum works on owning iterators
1343        let s = Scalar::from(2u64);
1344        let mapped = vec.iter().map(|x| x * s);
1345        let sum: RistrettoPoint = mapped.sum();
1346
1347        assert_eq!(sum, P1 * s + P2 * s);
1348    }
1349
1350    #[test]
1351    fn decompress_negative_s_fails() {
1352        // constants::d is neg, so decompression should fail as |d| != d.
1353        let bad_compressed = CompressedRistretto(constants::EDWARDS_D.to_bytes());
1354        assert!(bad_compressed.decompress().is_none());
1355    }
1356
1357    #[test]
1358    fn decompress_id() {
1359        let compressed_id = CompressedRistretto::identity();
1360        let id = compressed_id.decompress().unwrap();
1361        let mut identity_in_coset = false;
1362        for P in &id.coset4() {
1363            if P.compress() == CompressedEdwardsY::identity() {
1364                identity_in_coset = true;
1365            }
1366        }
1367        assert!(identity_in_coset);
1368    }
1369
1370    #[test]
1371    fn compress_id() {
1372        let id = RistrettoPoint::identity();
1373        assert_eq!(id.compress(), CompressedRistretto::identity());
1374    }
1375
1376    #[test]
1377    fn basepoint_roundtrip() {
1378        let bp_compressed_ristretto = constants::RISTRETTO_BASEPOINT_POINT.compress();
1379        let bp_recaf = bp_compressed_ristretto.decompress().unwrap().0;
1380        // Check that bp_recaf differs from bp by a point of order 4
1381        let diff = constants::RISTRETTO_BASEPOINT_POINT.0 - bp_recaf;
1382        let diff4 = diff.mul_by_pow_2(2);
1383        assert_eq!(diff4.compress(), CompressedEdwardsY::identity());
1384    }
1385
1386    #[test]
1387    fn encodings_of_small_multiples_of_basepoint() {
1388        // Table of encodings of i*basepoint
1389        // Generated using ristretto.sage
1390        let compressed = [
1391            CompressedRistretto([
1392                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1393                0, 0, 0, 0,
1394            ]),
1395            CompressedRistretto([
1396                226, 242, 174, 10, 106, 188, 78, 113, 168, 132, 169, 97, 197, 0, 81, 95, 88, 227,
1397                11, 106, 165, 130, 221, 141, 182, 166, 89, 69, 224, 141, 45, 118,
1398            ]),
1399            CompressedRistretto([
1400                106, 73, 50, 16, 247, 73, 156, 209, 127, 236, 181, 16, 174, 12, 234, 35, 161, 16,
1401                232, 213, 185, 1, 248, 172, 173, 211, 9, 92, 115, 163, 185, 25,
1402            ]),
1403            CompressedRistretto([
1404                148, 116, 31, 93, 93, 82, 117, 94, 206, 79, 35, 240, 68, 238, 39, 213, 209, 234,
1405                30, 43, 209, 150, 180, 98, 22, 107, 22, 21, 42, 157, 2, 89,
1406            ]),
1407            CompressedRistretto([
1408                218, 128, 134, 39, 115, 53, 139, 70, 111, 250, 223, 224, 179, 41, 58, 179, 217,
1409                253, 83, 197, 234, 108, 149, 83, 88, 245, 104, 50, 45, 175, 106, 87,
1410            ]),
1411            CompressedRistretto([
1412                232, 130, 177, 49, 1, 107, 82, 193, 211, 51, 112, 128, 24, 124, 247, 104, 66, 62,
1413                252, 203, 181, 23, 187, 73, 90, 184, 18, 196, 22, 15, 244, 78,
1414            ]),
1415            CompressedRistretto([
1416                246, 71, 70, 211, 201, 43, 19, 5, 14, 216, 216, 2, 54, 167, 240, 0, 124, 59, 63,
1417                150, 47, 91, 167, 147, 209, 154, 96, 30, 187, 29, 244, 3,
1418            ]),
1419            CompressedRistretto([
1420                68, 245, 53, 32, 146, 110, 200, 31, 189, 90, 56, 120, 69, 190, 183, 223, 133, 169,
1421                106, 36, 236, 225, 135, 56, 189, 207, 166, 167, 130, 42, 23, 109,
1422            ]),
1423            CompressedRistretto([
1424                144, 50, 147, 216, 242, 40, 126, 190, 16, 226, 55, 77, 193, 165, 62, 11, 200, 135,
1425                229, 146, 105, 159, 2, 208, 119, 213, 38, 60, 221, 85, 96, 28,
1426            ]),
1427            CompressedRistretto([
1428                2, 98, 42, 206, 143, 115, 3, 163, 28, 175, 198, 63, 143, 196, 143, 220, 22, 225,
1429                200, 200, 210, 52, 178, 240, 214, 104, 82, 130, 169, 7, 96, 49,
1430            ]),
1431            CompressedRistretto([
1432                32, 112, 111, 215, 136, 178, 114, 10, 30, 210, 165, 218, 212, 149, 43, 1, 244, 19,
1433                188, 240, 231, 86, 77, 232, 205, 200, 22, 104, 158, 45, 185, 95,
1434            ]),
1435            CompressedRistretto([
1436                188, 232, 63, 139, 165, 221, 47, 165, 114, 134, 76, 36, 186, 24, 16, 249, 82, 43,
1437                198, 0, 74, 254, 149, 135, 122, 199, 50, 65, 202, 253, 171, 66,
1438            ]),
1439            CompressedRistretto([
1440                228, 84, 158, 225, 107, 154, 160, 48, 153, 202, 32, 140, 103, 173, 175, 202, 250,
1441                76, 63, 62, 78, 83, 3, 222, 96, 38, 227, 202, 143, 248, 68, 96,
1442            ]),
1443            CompressedRistretto([
1444                170, 82, 224, 0, 223, 46, 22, 245, 95, 177, 3, 47, 195, 59, 196, 39, 66, 218, 214,
1445                189, 90, 143, 192, 190, 1, 103, 67, 108, 89, 72, 80, 31,
1446            ]),
1447            CompressedRistretto([
1448                70, 55, 107, 128, 244, 9, 178, 157, 194, 181, 246, 240, 197, 37, 145, 153, 8, 150,
1449                229, 113, 111, 65, 71, 124, 211, 0, 133, 171, 127, 16, 48, 30,
1450            ]),
1451            CompressedRistretto([
1452                224, 196, 24, 247, 200, 217, 196, 205, 215, 57, 91, 147, 234, 18, 79, 58, 217, 144,
1453                33, 187, 104, 29, 252, 51, 2, 169, 217, 154, 46, 83, 230, 78,
1454            ]),
1455        ];
1456        let mut bp = RistrettoPoint::identity();
1457        for point in compressed {
1458            assert_eq!(bp.compress(), point);
1459            bp += constants::RISTRETTO_BASEPOINT_POINT;
1460        }
1461    }
1462
1463    #[test]
1464    fn four_torsion_basepoint() {
1465        let bp = constants::RISTRETTO_BASEPOINT_POINT;
1466        let bp_coset = bp.coset4();
1467        for point in bp_coset {
1468            assert_eq!(bp, RistrettoPoint(point));
1469        }
1470    }
1471
1472    #[cfg(feature = "rand_core")]
1473    #[test]
1474    fn four_torsion_random() {
1475        let mut rng = UnwrapErr(SysRng);
1476        let P = RistrettoPoint::mul_base(&Scalar::random(&mut rng));
1477        let P_coset = P.coset4();
1478        for point in P_coset {
1479            assert_eq!(P, RistrettoPoint(point));
1480        }
1481    }
1482
1483    #[cfg(feature = "rand_core")]
1484    #[test]
1485    fn random_roundtrip() {
1486        let mut rng = UnwrapErr(SysRng);
1487        for _ in 0..100 {
1488            let P = RistrettoPoint::mul_base(&Scalar::random(&mut rng));
1489            let compressed_P = P.compress();
1490            let Q = compressed_P.decompress().unwrap();
1491            assert_eq!(P, Q);
1492        }
1493    }
1494
1495    #[test]
1496    #[cfg(all(feature = "alloc", feature = "rand_core", feature = "group"))]
1497    fn double_and_compress_1024_random_points() {
1498        use group::Group;
1499        let mut rng = SysRng;
1500
1501        let mut points: Vec<RistrettoPoint> = (0..1024)
1502            .map(|_| RistrettoPoint::try_random(&mut rng).unwrap())
1503            .collect();
1504        points[500] = <RistrettoPoint as Group>::identity();
1505
1506        let compressed = RistrettoPoint::double_and_compress_batch(&points);
1507
1508        for (P, P2_compressed) in points.iter().zip(compressed.iter()) {
1509            assert_eq!(*P2_compressed, (P + P).compress());
1510        }
1511    }
1512
1513    #[cfg(feature = "group")]
1514    proptest! {
1515        #[test]
1516        fn multiply_double_and_compress_random_points(
1517            p1 in any::<[u8; 64]>(),
1518            p2 in any::<[u8; 64]>(),
1519            s1 in any::<[u8; 32]>(),
1520            s2 in any::<[u8; 32]>(),
1521        ) {
1522            use group::Group;
1523
1524            let scalars = [
1525                Scalar::from_bytes_mod_order(s1),
1526                Scalar::ZERO,
1527                Scalar::from_bytes_mod_order(s2),
1528            ];
1529
1530            let points = [
1531                RistrettoPoint::from_uniform_bytes(&p1),
1532                <RistrettoPoint as Group>::identity(),
1533                RistrettoPoint::from_uniform_bytes(&p2),
1534            ];
1535
1536            let multiplied_points: [_; 3] =
1537                core::array::from_fn(|i| scalars[i].div_by_2() * points[i]);
1538            let compressed = RistrettoPoint::double_and_compress_batch(&multiplied_points);
1539
1540            for ((s, P), P2_compressed) in scalars.iter().zip(points).zip(compressed) {
1541                prop_assert_eq!(P2_compressed, (s * P).compress());
1542            }
1543        }
1544    }
1545
1546    #[test]
1547    #[cfg(all(feature = "alloc", feature = "rand_core"))]
1548    fn vartime_precomputed_vs_nonprecomputed_multiscalar() {
1549        let mut rng = UnwrapErr(SysRng);
1550
1551        let static_scalars = (0..128)
1552            .map(|_| Scalar::random(&mut rng))
1553            .collect::<Vec<_>>();
1554
1555        let dynamic_scalars = (0..128)
1556            .map(|_| Scalar::random(&mut rng))
1557            .collect::<Vec<_>>();
1558
1559        let check_scalar: Scalar = static_scalars
1560            .iter()
1561            .chain(dynamic_scalars.iter())
1562            .map(|s| s * s)
1563            .sum();
1564
1565        let static_points = static_scalars
1566            .iter()
1567            .map(RistrettoPoint::mul_base)
1568            .collect::<Vec<_>>();
1569        let dynamic_points = dynamic_scalars
1570            .iter()
1571            .map(RistrettoPoint::mul_base)
1572            .collect::<Vec<_>>();
1573
1574        let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1575
1576        assert_eq!(precomputation.len(), 128);
1577        assert!(!precomputation.is_empty());
1578
1579        let P = precomputation.vartime_mixed_multiscalar_mul(
1580            &static_scalars,
1581            &dynamic_scalars,
1582            &dynamic_points,
1583        );
1584
1585        use crate::traits::VartimeMultiscalarMul;
1586        let Q = RistrettoPoint::vartime_multiscalar_mul(
1587            static_scalars.iter().chain(dynamic_scalars.iter()),
1588            static_points.iter().chain(dynamic_points.iter()),
1589        );
1590
1591        let R = RistrettoPoint::mul_base(&check_scalar);
1592
1593        assert_eq!(P.compress(), R.compress());
1594        assert_eq!(Q.compress(), R.compress());
1595    }
1596
1597    #[test]
1598    #[cfg(all(feature = "alloc", feature = "rand_core"))]
1599    fn partial_precomputed_mixed_multiscalar_empty() {
1600        let mut rng = UnwrapErr(SysRng);
1601
1602        let n_static = 16;
1603        let n_dynamic = 8;
1604
1605        let static_points = (0..n_static)
1606            .map(|_| RistrettoPoint::random(&mut rng))
1607            .collect::<Vec<_>>();
1608
1609        // Use zero scalars
1610        let static_scalars = Vec::new();
1611
1612        let dynamic_points = (0..n_dynamic)
1613            .map(|_| RistrettoPoint::random(&mut rng))
1614            .collect::<Vec<_>>();
1615
1616        let dynamic_scalars = (0..n_dynamic)
1617            .map(|_| Scalar::random(&mut rng))
1618            .collect::<Vec<_>>();
1619
1620        // Compute the linear combination using precomputed multiscalar multiplication
1621        let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1622        let result_multiscalar = precomputation.vartime_mixed_multiscalar_mul(
1623            &static_scalars,
1624            &dynamic_scalars,
1625            &dynamic_points,
1626        );
1627
1628        // Compute the linear combination manually
1629        let mut result_manual = RistrettoPoint::identity();
1630        for i in 0..static_scalars.len() {
1631            result_manual += static_points[i] * static_scalars[i];
1632        }
1633        for i in 0..n_dynamic {
1634            result_manual += dynamic_points[i] * dynamic_scalars[i];
1635        }
1636
1637        assert_eq!(result_multiscalar, result_manual);
1638    }
1639
1640    #[test]
1641    #[cfg(all(feature = "alloc", feature = "rand_core"))]
1642    fn partial_precomputed_mixed_multiscalar() {
1643        let mut rng = UnwrapErr(SysRng);
1644
1645        let n_static = 16;
1646        let n_dynamic = 8;
1647
1648        let static_points = (0..n_static)
1649            .map(|_| RistrettoPoint::random(&mut rng))
1650            .collect::<Vec<_>>();
1651
1652        // Use one fewer scalars
1653        let static_scalars = (0..n_static - 1)
1654            .map(|_| Scalar::random(&mut rng))
1655            .collect::<Vec<_>>();
1656
1657        let dynamic_points = (0..n_dynamic)
1658            .map(|_| RistrettoPoint::random(&mut rng))
1659            .collect::<Vec<_>>();
1660
1661        let dynamic_scalars = (0..n_dynamic)
1662            .map(|_| Scalar::random(&mut rng))
1663            .collect::<Vec<_>>();
1664
1665        // Compute the linear combination using precomputed multiscalar multiplication
1666        let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1667        let result_multiscalar = precomputation.vartime_mixed_multiscalar_mul(
1668            &static_scalars,
1669            &dynamic_scalars,
1670            &dynamic_points,
1671        );
1672
1673        // Compute the linear combination manually
1674        let mut result_manual = RistrettoPoint::identity();
1675        for i in 0..static_scalars.len() {
1676            result_manual += static_points[i] * static_scalars[i];
1677        }
1678        for i in 0..n_dynamic {
1679            result_manual += dynamic_points[i] * dynamic_scalars[i];
1680        }
1681
1682        assert_eq!(result_multiscalar, result_manual);
1683    }
1684
1685    #[test]
1686    #[cfg(all(feature = "alloc", feature = "rand_core"))]
1687    fn partial_precomputed_multiscalar() {
1688        let mut rng = UnwrapErr(SysRng);
1689
1690        let n_static = 16;
1691
1692        let static_points = (0..n_static)
1693            .map(|_| RistrettoPoint::random(&mut rng))
1694            .collect::<Vec<_>>();
1695
1696        // Use one fewer scalars
1697        let static_scalars = (0..n_static - 1)
1698            .map(|_| Scalar::random(&mut rng))
1699            .collect::<Vec<_>>();
1700
1701        // Compute the linear combination using precomputed multiscalar multiplication
1702        let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1703        let result_multiscalar = precomputation.vartime_multiscalar_mul(&static_scalars);
1704
1705        // Compute the linear combination manually
1706        let mut result_manual = RistrettoPoint::identity();
1707        for i in 0..static_scalars.len() {
1708            result_manual += static_points[i] * static_scalars[i];
1709        }
1710
1711        assert_eq!(result_multiscalar, result_manual);
1712    }
1713
1714    #[test]
1715    #[cfg(all(feature = "alloc", feature = "rand_core"))]
1716    fn partial_precomputed_multiscalar_empty() {
1717        let mut rng = UnwrapErr(SysRng);
1718
1719        let n_static = 16;
1720
1721        let static_points = (0..n_static)
1722            .map(|_| RistrettoPoint::random(&mut rng))
1723            .collect::<Vec<_>>();
1724
1725        // Use zero scalars
1726        let static_scalars = Vec::new();
1727
1728        // Compute the linear combination using precomputed multiscalar multiplication
1729        let precomputation = VartimeRistrettoPrecomputation::new(static_points.iter());
1730        let result_multiscalar = precomputation.vartime_multiscalar_mul(&static_scalars);
1731
1732        // Compute the linear combination manually
1733        let mut result_manual = RistrettoPoint::identity();
1734        for i in 0..static_scalars.len() {
1735            result_manual += static_points[i] * static_scalars[i];
1736        }
1737
1738        assert_eq!(result_multiscalar, result_manual);
1739    }
1740}