Skip to main content

curve25519_dalek/backend/vector/avx2/
field.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2016-2021 isis lovecruft
5// Copyright (c) 2016-2019 Henry de Valence
6// See LICENSE for licensing information.
7//
8// Authors:
9// - isis agora lovecruft <[email protected]>
10// - Henry de Valence <[email protected]>
11
12// Nightly and stable currently disagree on the requirement of unsafe blocks when `unsafe_target_feature`
13// gets used.
14// See: https://github.com/rust-lang/rust/issues/132856
15#![allow(unused_unsafe)]
16
17//! An implementation of 4-way vectorized 32bit field arithmetic using
18//! AVX2.
19//!
20//! The `FieldElement2625x4` struct provides a vector of four field
21//! elements, implemented using AVX2 operations.  Its API is designed
22//! to abstract away the platform-dependent details, so that point
23//! arithmetic can be implemented only in terms of a vector of field
24//! elements.
25//!
26//! At this level, the API is optimized for speed and not safety.  The
27//! `FieldElement2625x4` does not always perform reductions.  The pre-
28//! and post-conditions on the bounds of the coefficients are
29//! documented for each method, but it is the caller's responsibility
30//! to ensure that there are no overflows.
31
32#![allow(non_snake_case)]
33
34const A_LANES: u8 = 0b0000_0101;
35const B_LANES: u8 = 0b0000_1010;
36const C_LANES: u8 = 0b0101_0000;
37const D_LANES: u8 = 0b1010_0000;
38
39#[allow(unused)]
40const A_LANES64: u8 = 0b00_00_00_11;
41#[allow(unused)]
42const B_LANES64: u8 = 0b00_00_11_00;
43#[allow(unused)]
44const C_LANES64: u8 = 0b00_11_00_00;
45#[allow(unused)]
46const D_LANES64: u8 = 0b11_00_00_00;
47
48use crate::backend::vector::packed_simd::{u32x8, u64x4};
49use core::ops::{Add, Mul, Neg};
50
51use crate::backend::serial::u64::field::FieldElement51;
52use crate::backend::vector::avx2::constants::{
53    P_TIMES_2_HI, P_TIMES_2_LO, P_TIMES_16_HI, P_TIMES_16_LO,
54};
55
56use curve25519_dalek_derive::unsafe_target_feature;
57
58/// Unpack 32-bit lanes into 64-bit lanes:
59/// ```ascii,no_run
60/// (a0, b0, a1, b1, c0, d0, c1, d1)
61/// ```
62/// into
63/// ```ascii,no_run
64/// (a0, 0, b0, 0, c0, 0, d0, 0)
65/// (a1, 0, b1, 0, c1, 0, d1, 0)
66/// ```
67#[unsafe_target_feature("avx2")]
68#[inline(always)]
69fn unpack_pair(src: u32x8) -> (u32x8, u32x8) {
70    let a: u32x8;
71    let b: u32x8;
72    let zero = u32x8::splat(0);
73    unsafe {
74        use core::arch::x86_64::_mm256_unpackhi_epi32;
75        use core::arch::x86_64::_mm256_unpacklo_epi32;
76        a = _mm256_unpacklo_epi32(src.into(), zero.into()).into();
77        b = _mm256_unpackhi_epi32(src.into(), zero.into()).into();
78    }
79    (a, b)
80}
81
82/// Repack 64-bit lanes into 32-bit lanes:
83/// ```ascii,no_run
84/// (a0, 0, b0, 0, c0, 0, d0, 0)
85/// (a1, 0, b1, 0, c1, 0, d1, 0)
86/// ```
87/// into
88/// ```ascii,no_run
89/// (a0, b0, a1, b1, c0, d0, c1, d1)
90/// ```
91#[unsafe_target_feature("avx2")]
92#[inline(always)]
93fn repack_pair(x: u32x8, y: u32x8) -> u32x8 {
94    unsafe {
95        use core::arch::x86_64::_mm256_blend_epi32;
96        use core::arch::x86_64::_mm256_shuffle_epi32;
97
98        // Input: x = (a0, 0, b0, 0, c0, 0, d0, 0)
99        // Input: y = (a1, 0, b1, 0, c1, 0, d1, 0)
100
101        let x_shuffled = _mm256_shuffle_epi32(x.into(), 0b11_01_10_00);
102        let y_shuffled = _mm256_shuffle_epi32(y.into(), 0b10_00_11_01);
103
104        // x' = (a0, b0,  0,  0, c0, d0,  0,  0)
105        // y' = ( 0,  0, a1, b1,  0,  0, c1, d1)
106
107        _mm256_blend_epi32(x_shuffled, y_shuffled, 0b11001100).into()
108    }
109}
110
111/// The `Lanes` enum represents a subset of the lanes `A,B,C,D` of a
112/// `FieldElement2625x4`.
113///
114/// It's used to specify blend operations without
115/// having to know details about the data layout of the
116/// `FieldElement2625x4`.
117#[allow(clippy::upper_case_acronyms)]
118#[derive(Copy, Clone, Debug)]
119pub enum Lanes {
120    C,
121    D,
122    AB,
123    AC,
124    CD,
125    AD,
126    BC,
127    ABCD,
128}
129
130/// The `Shuffle` enum represents a shuffle of a `FieldElement2625x4`.
131///
132/// The enum variants are named by what they do to a vector \\(
133/// (A,B,C,D) \\); for instance, `Shuffle::BADC` turns \\( (A, B, C,
134/// D) \\) into \\( (B, A, D, C) \\).
135#[allow(clippy::upper_case_acronyms)]
136#[derive(Copy, Clone, Debug)]
137pub enum Shuffle {
138    AAAA,
139    BBBB,
140    CACA,
141    DBBD,
142    ADDA,
143    CBCB,
144    ABAB,
145    BADC,
146    BACD,
147    ABDC,
148}
149
150/// A vector of four field elements.
151///
152/// Each operation on a `FieldElement2625x4` has documented effects on
153/// the bounds of the coefficients.  This API is designed for speed
154/// and not safety; it is the caller's responsibility to ensure that
155/// the post-conditions of one operation are compatible with the
156/// pre-conditions of the next.
157#[derive(Clone, Copy, Debug)]
158pub struct FieldElement2625x4(pub(crate) [u32x8; 5]);
159
160use subtle::Choice;
161use subtle::ConditionallySelectable;
162
163#[unsafe_target_feature("avx2")]
164impl ConditionallySelectable for FieldElement2625x4 {
165    fn conditional_select(
166        a: &FieldElement2625x4,
167        b: &FieldElement2625x4,
168        choice: Choice,
169    ) -> FieldElement2625x4 {
170        let mask = (-(choice.unwrap_u8() as i32)) as u32;
171        let mask_vec = u32x8::splat(mask);
172        FieldElement2625x4([
173            a.0[0] ^ (mask_vec & (a.0[0] ^ b.0[0])),
174            a.0[1] ^ (mask_vec & (a.0[1] ^ b.0[1])),
175            a.0[2] ^ (mask_vec & (a.0[2] ^ b.0[2])),
176            a.0[3] ^ (mask_vec & (a.0[3] ^ b.0[3])),
177            a.0[4] ^ (mask_vec & (a.0[4] ^ b.0[4])),
178        ])
179    }
180
181    fn conditional_assign(&mut self, other: &FieldElement2625x4, choice: Choice) {
182        let mask = (-(choice.unwrap_u8() as i32)) as u32;
183        let mask_vec = u32x8::splat(mask);
184        self.0[0] ^= mask_vec & (self.0[0] ^ other.0[0]);
185        self.0[1] ^= mask_vec & (self.0[1] ^ other.0[1]);
186        self.0[2] ^= mask_vec & (self.0[2] ^ other.0[2]);
187        self.0[3] ^= mask_vec & (self.0[3] ^ other.0[3]);
188        self.0[4] ^= mask_vec & (self.0[4] ^ other.0[4]);
189    }
190}
191
192#[unsafe_target_feature("avx2")]
193impl FieldElement2625x4 {
194    pub const ZERO: FieldElement2625x4 = FieldElement2625x4([u32x8::splat_const::<0>(); 5]);
195
196    /// Split this vector into an array of four (serial) field
197    /// elements.
198    #[rustfmt::skip] // keep alignment of extracted lanes
199    pub fn split(&self) -> [FieldElement51; 4] {
200        let mut out = [FieldElement51::ZERO; 4];
201        for i in 0..5 {
202            let a_2i   = self.0[i].extract::<0>() as u64; //
203            let b_2i   = self.0[i].extract::<1>() as u64; //
204            let a_2i_1 = self.0[i].extract::<2>() as u64; // `.
205            let b_2i_1 = self.0[i].extract::<3>() as u64; //  | pre-swapped to avoid
206            let c_2i   = self.0[i].extract::<4>() as u64; //  | a cross lane shuffle
207            let d_2i   = self.0[i].extract::<5>() as u64; // .'
208            let c_2i_1 = self.0[i].extract::<6>() as u64; //
209            let d_2i_1 = self.0[i].extract::<7>() as u64; //
210
211            out[0].0[i] = a_2i + (a_2i_1 << 26);
212            out[1].0[i] = b_2i + (b_2i_1 << 26);
213            out[2].0[i] = c_2i + (c_2i_1 << 26);
214            out[3].0[i] = d_2i + (d_2i_1 << 26);
215        }
216
217        out
218    }
219
220    /// Rearrange the elements of this vector according to `control`.
221    ///
222    /// The `control` parameter should be a compile-time constant, so
223    /// that when this function is inlined, LLVM is able to lower the
224    /// shuffle using an immediate.
225    #[inline]
226    pub fn shuffle(&self, control: Shuffle) -> FieldElement2625x4 {
227        #[inline(always)]
228        fn shuffle_lanes(x: u32x8, control: Shuffle) -> u32x8 {
229            unsafe {
230                use core::arch::x86_64::_mm256_permutevar8x32_epi32;
231
232                let c: u32x8 = match control {
233                    Shuffle::AAAA => u32x8::new(0, 0, 2, 2, 0, 0, 2, 2),
234                    Shuffle::BBBB => u32x8::new(1, 1, 3, 3, 1, 1, 3, 3),
235                    Shuffle::CACA => u32x8::new(4, 0, 6, 2, 4, 0, 6, 2),
236                    Shuffle::DBBD => u32x8::new(5, 1, 7, 3, 1, 5, 3, 7),
237                    Shuffle::ADDA => u32x8::new(0, 5, 2, 7, 5, 0, 7, 2),
238                    Shuffle::CBCB => u32x8::new(4, 1, 6, 3, 4, 1, 6, 3),
239                    Shuffle::ABAB => u32x8::new(0, 1, 2, 3, 0, 1, 2, 3),
240                    Shuffle::BADC => u32x8::new(1, 0, 3, 2, 5, 4, 7, 6),
241                    Shuffle::BACD => u32x8::new(1, 0, 3, 2, 4, 5, 6, 7),
242                    Shuffle::ABDC => u32x8::new(0, 1, 2, 3, 5, 4, 7, 6),
243                };
244                // Note that this gets turned into a generic LLVM
245                // shuffle-by-constants, which can be lowered to a simpler
246                // instruction than a generic permute.
247                _mm256_permutevar8x32_epi32(x.into(), c.into()).into()
248            }
249        }
250
251        FieldElement2625x4([
252            shuffle_lanes(self.0[0], control),
253            shuffle_lanes(self.0[1], control),
254            shuffle_lanes(self.0[2], control),
255            shuffle_lanes(self.0[3], control),
256            shuffle_lanes(self.0[4], control),
257        ])
258    }
259
260    /// Blend `self` with `other`, taking lanes specified in `control` from `other`.
261    ///
262    /// The `control` parameter should be a compile-time constant, so
263    /// that this function can be inlined and LLVM can lower it to a
264    /// blend instruction using an immediate.
265    #[inline]
266    pub fn blend(&self, other: FieldElement2625x4, control: Lanes) -> FieldElement2625x4 {
267        #[inline(always)]
268        fn blend_lanes(x: u32x8, y: u32x8, control: Lanes) -> u32x8 {
269            unsafe {
270                use core::arch::x86_64::_mm256_blend_epi32;
271
272                // This would be much cleaner if we could factor out the match
273                // statement on the control. Unfortunately, rustc forgets
274                // constant-info very quickly, so we can't even write
275                // ```
276                // match control {
277                //     Lanes::C => {
278                //         let imm = C_LANES as i32;
279                //         _mm256_blend_epi32(..., imm)
280                // ```
281                // let alone
282                // ```
283                // let imm = match control {
284                //     Lanes::C => C_LANES as i32,
285                // }
286                // _mm256_blend_epi32(..., imm)
287                // ```
288                // even though both of these would be constant-folded by LLVM
289                // at a lower level (as happens in the shuffle implementation,
290                // which does not require a shuffle immediate but *is* lowered
291                // to immediate shuffles anyways).
292                match control {
293                    Lanes::C => _mm256_blend_epi32(x.into(), y.into(), C_LANES as i32).into(),
294                    Lanes::D => _mm256_blend_epi32(x.into(), y.into(), D_LANES as i32).into(),
295                    Lanes::AD => {
296                        _mm256_blend_epi32(x.into(), y.into(), (A_LANES | D_LANES) as i32).into()
297                    }
298                    Lanes::AB => {
299                        _mm256_blend_epi32(x.into(), y.into(), (A_LANES | B_LANES) as i32).into()
300                    }
301                    Lanes::AC => {
302                        _mm256_blend_epi32(x.into(), y.into(), (A_LANES | C_LANES) as i32).into()
303                    }
304                    Lanes::CD => {
305                        _mm256_blend_epi32(x.into(), y.into(), (C_LANES | D_LANES) as i32).into()
306                    }
307                    Lanes::BC => {
308                        _mm256_blend_epi32(x.into(), y.into(), (B_LANES | C_LANES) as i32).into()
309                    }
310                    Lanes::ABCD => _mm256_blend_epi32(
311                        x.into(),
312                        y.into(),
313                        (A_LANES | B_LANES | C_LANES | D_LANES) as i32,
314                    )
315                    .into(),
316                }
317            }
318        }
319
320        FieldElement2625x4([
321            blend_lanes(self.0[0], other.0[0], control),
322            blend_lanes(self.0[1], other.0[1], control),
323            blend_lanes(self.0[2], other.0[2], control),
324            blend_lanes(self.0[3], other.0[3], control),
325            blend_lanes(self.0[4], other.0[4], control),
326        ])
327    }
328
329    /// Convenience wrapper around `new(x,x,x,x)`.
330    pub fn splat(x: &FieldElement51) -> FieldElement2625x4 {
331        FieldElement2625x4::new(x, x, x, x)
332    }
333
334    /// Create a `FieldElement2625x4` from four `FieldElement51`s.
335    ///
336    /// # Postconditions
337    ///
338    /// The resulting `FieldElement2625x4` is bounded with \\( b < 0.0002 \\).
339    #[rustfmt::skip] // keep alignment of computed lanes
340    pub fn new(
341        x0: &FieldElement51,
342        x1: &FieldElement51,
343        x2: &FieldElement51,
344        x3: &FieldElement51,
345    ) -> FieldElement2625x4 {
346        let mut buf = [u32x8::splat(0); 5];
347        let low_26_bits = (1 << 26) - 1;
348        #[allow(clippy::needless_range_loop)]
349        for i in 0..5 {
350            let a_2i   = (x0.0[i] & low_26_bits) as u32;
351            let a_2i_1 = (x0.0[i] >> 26) as u32;
352            let b_2i   = (x1.0[i] & low_26_bits) as u32;
353            let b_2i_1 = (x1.0[i] >> 26) as u32;
354            let c_2i   = (x2.0[i] & low_26_bits) as u32;
355            let c_2i_1 = (x2.0[i] >> 26) as u32;
356            let d_2i   = (x3.0[i] & low_26_bits) as u32;
357            let d_2i_1 = (x3.0[i] >> 26) as u32;
358
359            buf[i] = u32x8::new(a_2i, b_2i, a_2i_1, b_2i_1, c_2i, d_2i, c_2i_1, d_2i_1);
360        }
361
362        // We don't know that the original `FieldElement51`s were
363        // fully reduced, so the odd limbs may exceed 2^25.
364        // Reduce them to be sure.
365        FieldElement2625x4(buf).reduce()
366    }
367
368    /// Given \\((A,B,C,D)\\), compute \\((-A,-B,-C,-D)\\), without
369    /// performing a reduction.
370    ///
371    /// # Preconditions
372    ///
373    /// The coefficients of `self` must be bounded with \\( b < 0.999 \\).
374    ///
375    /// # Postconditions
376    ///
377    /// The coefficients of the result are bounded with \\( b < 1 \\).
378    #[inline]
379    pub fn negate_lazy(&self) -> FieldElement2625x4 {
380        // The limbs of self are bounded with b < 0.999, while the
381        // smallest limb of 2*p is 67108845 > 2^{26+0.9999}, so
382        // underflows are not possible.
383        FieldElement2625x4([
384            P_TIMES_2_LO - self.0[0],
385            P_TIMES_2_HI - self.0[1],
386            P_TIMES_2_HI - self.0[2],
387            P_TIMES_2_HI - self.0[3],
388            P_TIMES_2_HI - self.0[4],
389        ])
390    }
391
392    /// Given `self = (A,B,C,D)`, compute `(B - A, B + A, D - C, D + C)`.
393    ///
394    /// # Preconditions
395    ///
396    /// The coefficients of `self` must be bounded with \\( b < 0.01 \\).
397    ///
398    /// # Postconditions
399    ///
400    /// The coefficients of the result are bounded with \\( b < 1.6 \\).
401    #[inline]
402    pub fn diff_sum(&self) -> FieldElement2625x4 {
403        // tmp1 = (B, A, D, C)
404        let tmp1 = self.shuffle(Shuffle::BADC);
405        // tmp2 = (-A, B, -C, D)
406        let tmp2 = self.blend(self.negate_lazy(), Lanes::AC);
407        // (B - A, B + A, D - C, D + C) bounded with b < 1.6
408        tmp1 + tmp2
409    }
410
411    /// Reduce this vector of field elements \\(\mathrm{mod} p\\).
412    ///
413    /// # Postconditions
414    ///
415    /// The coefficients of the result are bounded with \\( b < 0.0002 \\).
416    #[inline]
417    pub fn reduce(&self) -> FieldElement2625x4 {
418        let shifts = u32x8::new(26, 26, 25, 25, 26, 26, 25, 25);
419        let masks = u32x8::new(
420            (1 << 26) - 1,
421            (1 << 26) - 1,
422            (1 << 25) - 1,
423            (1 << 25) - 1,
424            (1 << 26) - 1,
425            (1 << 26) - 1,
426            (1 << 25) - 1,
427            (1 << 25) - 1,
428        );
429
430        // Let c(x) denote the carryout of the coefficient x.
431        //
432        // Given    (   x0,    y0,    x1,    y1,    z0,    w0,    z1,    w1),
433        // compute  (c(x1), c(y1), c(x0), c(y0), c(z1), c(w1), c(z0), c(w0)).
434        //
435        // The carryouts are bounded by 2^(32 - 25) = 2^7.
436        let rotated_carryout = |v: u32x8| -> u32x8 {
437            unsafe {
438                use core::arch::x86_64::_mm256_shuffle_epi32;
439                use core::arch::x86_64::_mm256_srlv_epi32;
440
441                let c = _mm256_srlv_epi32(v.into(), shifts.into());
442                _mm256_shuffle_epi32(c, 0b01_00_11_10).into()
443            }
444        };
445
446        // Combine (lo, lo, lo, lo, lo, lo, lo, lo)
447        //    with (hi, hi, hi, hi, hi, hi, hi, hi)
448        //      to (lo, lo, hi, hi, lo, lo, hi, hi)
449        //
450        // This allows combining carryouts, e.g.,
451        //
452        // lo  (c(x1), c(y1), c(x0), c(y0), c(z1), c(w1), c(z0), c(w0))
453        // hi  (c(x3), c(y3), c(x2), c(y2), c(z3), c(w3), c(z2), c(w2))
454        // ->  (c(x1), c(y1), c(x2), c(y2), c(z1), c(w1), c(z2), c(w2))
455        //
456        // which is exactly the vector of carryins for
457        //
458        //     (   x2,    y2,    x3,    y3,    z2,    w2,    z3,    w3).
459        //
460        let combine = |v_lo: u32x8, v_hi: u32x8| -> u32x8 {
461            unsafe {
462                use core::arch::x86_64::_mm256_blend_epi32;
463                _mm256_blend_epi32(v_lo.into(), v_hi.into(), 0b11_00_11_00).into()
464            }
465        };
466
467        let mut v = self.0;
468
469        let c10 = rotated_carryout(v[0]);
470        v[0] = (v[0] & masks) + combine(u32x8::splat(0), c10);
471
472        let c32 = rotated_carryout(v[1]);
473        v[1] = (v[1] & masks) + combine(c10, c32);
474
475        let c54 = rotated_carryout(v[2]);
476        v[2] = (v[2] & masks) + combine(c32, c54);
477
478        let c76 = rotated_carryout(v[3]);
479        v[3] = (v[3] & masks) + combine(c54, c76);
480
481        let c98 = rotated_carryout(v[4]);
482        v[4] = (v[4] & masks) + combine(c76, c98);
483
484        let c9_19: u32x8 = unsafe {
485            use core::arch::x86_64::_mm256_mul_epu32;
486            use core::arch::x86_64::_mm256_shuffle_epi32;
487
488            // Need to rearrange c98, since vpmuludq uses the low
489            // 32-bits of each 64-bit lane to compute the product:
490            //
491            // c98       = (c(x9), c(y9), c(x8), c(y8), c(z9), c(w9), c(z8), c(w8));
492            // c9_spread = (c(x9), c(x8), c(y9), c(y8), c(z9), c(z8), c(w9), c(w8)).
493            let c9_spread = _mm256_shuffle_epi32(c98.into(), 0b11_01_10_00);
494
495            // Since the carryouts are bounded by 2^7, their products with 19
496            // are bounded by 2^11.25.  This means that
497            //
498            // c9_19_spread = (19*c(x9), 0, 19*c(y9), 0, 19*c(z9), 0, 19*c(w9), 0).
499            let c9_19_spread = _mm256_mul_epu32(c9_spread, u64x4::splat(19).into());
500
501            // Unshuffle:
502            // c9_19 = (19*c(x9), 19*c(y9), 0, 0, 19*c(z9), 19*c(w9), 0, 0).
503            _mm256_shuffle_epi32(c9_19_spread, 0b11_01_10_00).into()
504        };
505
506        // Add the final carryin.
507        v[0] += c9_19;
508
509        // Each output coefficient has exactly one carryin, which is
510        // bounded by 2^11.25, so they are bounded as
511        //
512        // c_even < 2^26 + 2^11.25 < 26.00006 < 2^{26+b}
513        // c_odd  < 2^25 + 2^11.25 < 25.0001  < 2^{25+b}
514        //
515        // where b = 0.0002.
516        FieldElement2625x4(v)
517    }
518
519    /// Given an array of wide coefficients, reduce them to a `FieldElement2625x4`.
520    ///
521    /// # Postconditions
522    ///
523    /// The coefficients of the result are bounded with \\( b < 0.007 \\).
524    #[inline]
525    #[rustfmt::skip] // keep alignment of carry chain
526    fn reduce64(mut z: [u64x4; 10]) -> FieldElement2625x4 {
527        // These aren't const because splat isn't a const fn
528        let LOW_25_BITS: u64x4 = u64x4::splat((1 << 25) - 1);
529        let LOW_26_BITS: u64x4 = u64x4::splat((1 << 26) - 1);
530
531        // Carry the value from limb i = 0..8 to limb i+1
532        let carry = |z: &mut [u64x4; 10], i: usize| {
533            debug_assert!(i < 9);
534            if i % 2 == 0 {
535                // Even limbs have 26 bits
536                z[i + 1] += z[i].shr::<26>();
537                z[i] &= LOW_26_BITS;
538            } else {
539                // Odd limbs have 25 bits
540                z[i + 1] += z[i].shr::<25>();
541                z[i] &= LOW_25_BITS;
542            }
543        };
544
545        // Perform two halves of the carry chain in parallel.
546        carry(&mut z, 0); carry(&mut z, 4);
547        carry(&mut z, 1); carry(&mut z, 5);
548        carry(&mut z, 2); carry(&mut z, 6);
549        carry(&mut z, 3); carry(&mut z, 7);
550        // Since z[3] < 2^64, c < 2^(64-25) = 2^39,
551        // so    z[4] < 2^26 + 2^39 < 2^39.0002
552        carry(&mut z, 4); carry(&mut z, 8);
553        // Now z[4] < 2^26
554        // and z[5] < 2^25 + 2^13.0002 < 2^25.0004 (good enough)
555
556        // Last carry has a multiplication by 19.  In the serial case we
557        // do a 64-bit multiplication by 19, but here we want to do a
558        // 32-bit multiplication.  However, if we only know z[9] < 2^64,
559        // the carry is bounded as c < 2^(64-25) = 2^39, which is too
560        // big.  To ensure c < 2^32, we would need z[9] < 2^57.
561        // Instead, we split the carry in two, with c = c_0 + c_1*2^26.
562
563        let c = z[9].shr::<25>();
564        z[9] &= LOW_25_BITS;
565        let mut c0: u64x4 = c & LOW_26_BITS; // c0 < 2^26;
566        let mut c1: u64x4 = c.shr::<26>();         // c1 < 2^(39-26) = 2^13;
567
568        let x19 = u64x4::splat(19);
569        c0 = u32x8::from(c0).mul32(u32x8::from(x19));
570        c1 = u32x8::from(c1).mul32(u32x8::from(x19));
571
572        z[0] += c0; // z0 < 2^26 + 2^30.25 < 2^30.33
573        z[1] += c1; // z1 < 2^25 + 2^17.25 < 2^25.0067
574        carry(&mut z, 0); // z0 < 2^26, z1 < 2^25.0067 + 2^4.33 = 2^25.007
575
576        // The output coefficients are bounded with
577        //
578        // b = 0.007  for z[1]
579        // b = 0.0004 for z[5]
580        // b = 0      for other z[i].
581        //
582        // So the packed result is bounded with b = 0.007.
583        FieldElement2625x4([
584            repack_pair(z[0].into(), z[1].into()),
585            repack_pair(z[2].into(), z[3].into()),
586            repack_pair(z[4].into(), z[5].into()),
587            repack_pair(z[6].into(), z[7].into()),
588            repack_pair(z[8].into(), z[9].into()),
589        ])
590    }
591
592    /// Square this field element, and negate the result's \\(D\\) value.
593    ///
594    /// # Preconditions
595    ///
596    /// The coefficients of `self` must be bounded with \\( b < 1.5 \\).
597    ///
598    /// # Postconditions
599    ///
600    /// The coefficients of the result are bounded with \\( b < 0.007 \\).
601    #[rustfmt::skip] // keep alignment of z* calculations
602    pub fn square_and_negate_D(&self) -> FieldElement2625x4 {
603        #[inline(always)]
604        fn m(x: u32x8, y: u32x8) -> u64x4 {
605            x.mul32(y)
606        }
607
608        #[inline(always)]
609        fn m_lo(x: u32x8, y: u32x8) -> u32x8 {
610            x.mul32(y).into()
611        }
612
613        let v19 = u32x8::new(19, 0, 19, 0, 19, 0, 19, 0);
614
615        let (x0, x1) = unpack_pair(self.0[0]);
616        let (x2, x3) = unpack_pair(self.0[1]);
617        let (x4, x5) = unpack_pair(self.0[2]);
618        let (x6, x7) = unpack_pair(self.0[3]);
619        let (x8, x9) = unpack_pair(self.0[4]);
620
621        let x0_2 = x0.shl::<1>();
622        let x1_2 = x1.shl::<1>();
623        let x2_2 = x2.shl::<1>();
624        let x3_2 = x3.shl::<1>();
625        let x4_2 = x4.shl::<1>();
626        let x5_2 = x5.shl::<1>();
627        let x6_2 = x6.shl::<1>();
628        let x7_2 = x7.shl::<1>();
629
630        let x5_19 = m_lo(v19, x5);
631        let x6_19 = m_lo(v19, x6);
632        let x7_19 = m_lo(v19, x7);
633        let x8_19 = m_lo(v19, x8);
634        let x9_19 = m_lo(v19, x9);
635
636        let mut z0 = m(x0,   x0) + m(x2_2, x8_19) + m(x4_2, x6_19) + ((m(x1_2, x9_19) +   m(x3_2, x7_19) +    m(x5,   x5_19)).shl::<1>());
637        let mut z1 = m(x0_2, x1) + m(x3_2, x8_19) + m(x5_2, x6_19) +                    ((m(x2,   x9_19) +    m(x4,   x7_19)).shl::<1>());
638        let mut z2 = m(x0_2, x2) + m(x1_2,    x1) + m(x4_2, x8_19) +   m(x6,   x6_19) + ((m(x3_2, x9_19) +    m(x5_2, x7_19)).shl::<1>());
639        let mut z3 = m(x0_2, x3) + m(x1_2,    x2) + m(x5_2, x8_19) +                    ((m(x4,   x9_19) +    m(x6,   x7_19)).shl::<1>());
640        let mut z4 = m(x0_2, x4) + m(x1_2,  x3_2) + m(x2,      x2) +   m(x6_2, x8_19) + ((m(x5_2, x9_19) +    m(x7,   x7_19)).shl::<1>());
641        let mut z5 = m(x0_2, x5) + m(x1_2,    x4) + m(x2_2,    x3) +   m(x7_2, x8_19)                    +  ((m(x6,   x9_19)).shl::<1>());
642        let mut z6 = m(x0_2, x6) + m(x1_2,  x5_2) + m(x2_2,    x4) +   m(x3_2,    x3) +   m(x8,   x8_19) +  ((m(x7_2, x9_19)).shl::<1>());
643        let mut z7 = m(x0_2, x7) + m(x1_2,    x6) + m(x2_2,    x5) +   m(x3_2,    x4)                    +  ((m(x8,   x9_19)).shl::<1>());
644        let mut z8 = m(x0_2, x8) + m(x1_2,  x7_2) + m(x2_2,    x6) +   m(x3_2,  x5_2) +   m(x4,      x4) +  ((m(x9,   x9_19)).shl::<1>());
645        let mut z9 = m(x0_2, x9) + m(x1_2,    x8) + m(x2_2,    x7) +   m(x3_2,    x6) +   m(x4_2,    x5)                                 ;
646
647        // The biggest z_i is bounded as z_i < 249*2^(51 + 2*b);
648        // if b < 1.5 we get z_i < 4485585228861014016.
649        //
650        // The limbs of the multiples of p are bounded above by
651        //
652        // 0x3fffffff << 37 = 9223371899415822336 < 2^63
653        //
654        // and below by
655        //
656        // 0x1fffffff << 37 = 4611685880988434432
657        //                  > 4485585228861014016
658        //
659        // So these multiples of p are big enough to avoid underflow
660        // in subtraction, and small enough to fit within u64
661        // with room for a carry.
662
663        let low__p37 = u64x4::splat(0x3ffffed << 37);
664        let even_p37 = u64x4::splat(0x3ffffff << 37);
665        let odd__p37 = u64x4::splat(0x1ffffff << 37);
666
667        let negate_D = |x: u64x4, p: u64x4| -> u64x4 {
668            unsafe {
669                use core::arch::x86_64::_mm256_blend_epi32;
670                _mm256_blend_epi32(x.into(), (p - x).into(), D_LANES64 as i32).into()
671            }
672        };
673
674        z0 = negate_D(z0, low__p37);
675        z1 = negate_D(z1, odd__p37);
676        z2 = negate_D(z2, even_p37);
677        z3 = negate_D(z3, odd__p37);
678        z4 = negate_D(z4, even_p37);
679        z5 = negate_D(z5, odd__p37);
680        z6 = negate_D(z6, even_p37);
681        z7 = negate_D(z7, odd__p37);
682        z8 = negate_D(z8, even_p37);
683        z9 = negate_D(z9, odd__p37);
684
685        FieldElement2625x4::reduce64([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9])
686    }
687}
688
689#[unsafe_target_feature("avx2")]
690impl Neg for FieldElement2625x4 {
691    type Output = FieldElement2625x4;
692
693    /// Negate this field element, performing a reduction.
694    ///
695    /// If the coefficients are known to be small, use `negate_lazy`
696    /// to avoid performing a reduction.
697    ///
698    /// # Preconditions
699    ///
700    /// The coefficients of `self` must be bounded with \\( b < 4.0 \\).
701    ///
702    /// # Postconditions
703    ///
704    /// The coefficients of the result are bounded with \\( b < 0.0002 \\).
705    #[inline]
706    fn neg(self) -> FieldElement2625x4 {
707        FieldElement2625x4([
708            P_TIMES_16_LO - self.0[0],
709            P_TIMES_16_HI - self.0[1],
710            P_TIMES_16_HI - self.0[2],
711            P_TIMES_16_HI - self.0[3],
712            P_TIMES_16_HI - self.0[4],
713        ])
714        .reduce()
715    }
716}
717
718#[unsafe_target_feature("avx2")]
719impl Add<FieldElement2625x4> for FieldElement2625x4 {
720    type Output = FieldElement2625x4;
721    /// Add two `FieldElement2625x4`s, without performing a reduction.
722    #[inline]
723    fn add(self, rhs: FieldElement2625x4) -> FieldElement2625x4 {
724        FieldElement2625x4([
725            self.0[0] + rhs.0[0],
726            self.0[1] + rhs.0[1],
727            self.0[2] + rhs.0[2],
728            self.0[3] + rhs.0[3],
729            self.0[4] + rhs.0[4],
730        ])
731    }
732}
733
734#[unsafe_target_feature("avx2")]
735impl Mul<(u32, u32, u32, u32)> for FieldElement2625x4 {
736    type Output = FieldElement2625x4;
737    /// Perform a multiplication by a vector of small constants.
738    ///
739    /// # Postconditions
740    ///
741    /// The coefficients of the result are bounded with \\( b < 0.007 \\).
742    #[inline]
743    fn mul(self, scalars: (u32, u32, u32, u32)) -> FieldElement2625x4 {
744        let consts = u32x8::new(scalars.0, 0, scalars.1, 0, scalars.2, 0, scalars.3, 0);
745
746        let (b0, b1) = unpack_pair(self.0[0]);
747        let (b2, b3) = unpack_pair(self.0[1]);
748        let (b4, b5) = unpack_pair(self.0[2]);
749        let (b6, b7) = unpack_pair(self.0[3]);
750        let (b8, b9) = unpack_pair(self.0[4]);
751
752        FieldElement2625x4::reduce64([
753            b0.mul32(consts),
754            b1.mul32(consts),
755            b2.mul32(consts),
756            b3.mul32(consts),
757            b4.mul32(consts),
758            b5.mul32(consts),
759            b6.mul32(consts),
760            b7.mul32(consts),
761            b8.mul32(consts),
762            b9.mul32(consts),
763        ])
764    }
765}
766
767#[unsafe_target_feature("avx2")]
768impl Mul<&FieldElement2625x4> for &FieldElement2625x4 {
769    type Output = FieldElement2625x4;
770    /// Multiply `self` by `rhs`.
771    ///
772    /// # Preconditions
773    ///
774    /// The coefficients of `self` must be bounded with \\( b < 2.5 \\).
775    ///
776    /// The coefficients of `rhs` must be bounded with \\( b < 1.75 \\).
777    ///
778    /// # Postconditions
779    ///
780    /// The coefficients of the result are bounded with \\( b < 0.007 \\).
781    ///
782    #[rustfmt::skip] // keep alignment of z* calculations
783    #[inline]
784    fn mul(self, rhs: &FieldElement2625x4) -> FieldElement2625x4 {
785        #[inline(always)]
786        fn m(x: u32x8, y: u32x8) -> u64x4 {
787            x.mul32(y)
788        }
789
790        #[inline(always)]
791        fn m_lo(x: u32x8, y: u32x8) -> u32x8 {
792            x.mul32(y).into()
793        }
794
795        let (x0, x1) = unpack_pair(self.0[0]);
796        let (x2, x3) = unpack_pair(self.0[1]);
797        let (x4, x5) = unpack_pair(self.0[2]);
798        let (x6, x7) = unpack_pair(self.0[3]);
799        let (x8, x9) = unpack_pair(self.0[4]);
800
801        let (y0, y1) = unpack_pair(rhs.0[0]);
802        let (y2, y3) = unpack_pair(rhs.0[1]);
803        let (y4, y5) = unpack_pair(rhs.0[2]);
804        let (y6, y7) = unpack_pair(rhs.0[3]);
805        let (y8, y9) = unpack_pair(rhs.0[4]);
806
807        let v19 = u32x8::new(19, 0, 19, 0, 19, 0, 19, 0);
808
809        let y1_19 = m_lo(v19, y1); // This fits in a u32
810        let y2_19 = m_lo(v19, y2); // iff 26 + b + lg(19) < 32
811        let y3_19 = m_lo(v19, y3); // if  b < 32 - 26 - 4.248 = 1.752
812        let y4_19 = m_lo(v19, y4);
813        let y5_19 = m_lo(v19, y5);
814        let y6_19 = m_lo(v19, y6);
815        let y7_19 = m_lo(v19, y7);
816        let y8_19 = m_lo(v19, y8);
817        let y9_19 = m_lo(v19, y9);
818
819        let x1_2 = x1 + x1; // This fits in a u32 iff 25 + b + 1 < 32
820        let x3_2 = x3 + x3; //                    iff b < 6
821        let x5_2 = x5 + x5;
822        let x7_2 = x7 + x7;
823        let x9_2 = x9 + x9;
824
825        let z0 = m(x0, y0) + m(x1_2, y9_19) + m(x2, y8_19) + m(x3_2, y7_19) + m(x4, y6_19) + m(x5_2, y5_19) + m(x6, y4_19) + m(x7_2, y3_19) + m(x8, y2_19) + m(x9_2, y1_19);
826        let z1 = m(x0, y1) + m(x1,      y0) + m(x2, y9_19) + m(x3,   y8_19) + m(x4, y7_19) + m(x5,   y6_19) + m(x6, y5_19) + m(x7,   y4_19) + m(x8, y3_19) + m(x9,   y2_19);
827        let z2 = m(x0, y2) + m(x1_2,    y1) + m(x2,    y0) + m(x3_2, y9_19) + m(x4, y8_19) + m(x5_2, y7_19) + m(x6, y6_19) + m(x7_2, y5_19) + m(x8, y4_19) + m(x9_2, y3_19);
828        let z3 = m(x0, y3) + m(x1,      y2) + m(x2,    y1) + m(x3,      y0) + m(x4, y9_19) + m(x5,   y8_19) + m(x6, y7_19) + m(x7,   y6_19) + m(x8, y5_19) + m(x9,   y4_19);
829        let z4 = m(x0, y4) + m(x1_2,    y3) + m(x2,    y2) + m(x3_2,    y1) + m(x4,    y0) + m(x5_2, y9_19) + m(x6, y8_19) + m(x7_2, y7_19) + m(x8, y6_19) + m(x9_2, y5_19);
830        let z5 = m(x0, y5) + m(x1,      y4) + m(x2,    y3) + m(x3,      y2) + m(x4,    y1) + m(x5,      y0) + m(x6, y9_19) + m(x7,   y8_19) + m(x8, y7_19) + m(x9,   y6_19);
831        let z6 = m(x0, y6) + m(x1_2,    y5) + m(x2,    y4) + m(x3_2,    y3) + m(x4,    y2) + m(x5_2,    y1) + m(x6,    y0) + m(x7_2, y9_19) + m(x8, y8_19) + m(x9_2, y7_19);
832        let z7 = m(x0, y7) + m(x1,      y6) + m(x2,    y5) + m(x3,      y4) + m(x4,    y3) + m(x5,      y2) + m(x6,    y1) + m(x7,      y0) + m(x8, y9_19) + m(x9,   y8_19);
833        let z8 = m(x0, y8) + m(x1_2,    y7) + m(x2,    y6) + m(x3_2,    y5) + m(x4,    y4) + m(x5_2,    y3) + m(x6,    y2) + m(x7_2,    y1) + m(x8,    y0) + m(x9_2, y9_19);
834        let z9 = m(x0, y9) + m(x1,      y8) + m(x2,    y7) + m(x3,      y6) + m(x4,    y5) + m(x5,      y4) + m(x6,    y3) + m(x7,      y2) + m(x8,    y1) + m(x9,      y0);
835
836        // The bounds on z[i] are the same as in the serial 32-bit code
837        // and the comment below is copied from there:
838
839        // How big is the contribution to z[i+j] from x[i], y[j]?
840        //
841        // Using the bounds above, we get:
842        //
843        // i even, j even:   x[i]*y[j] <   2^(26+b)*2^(26+b) = 2*2^(51+2*b)
844        // i  odd, j even:   x[i]*y[j] <   2^(25+b)*2^(26+b) = 1*2^(51+2*b)
845        // i even, j  odd:   x[i]*y[j] <   2^(26+b)*2^(25+b) = 1*2^(51+2*b)
846        // i  odd, j  odd: 2*x[i]*y[j] < 2*2^(25+b)*2^(25+b) = 1*2^(51+2*b)
847        //
848        // We perform inline reduction mod p by replacing 2^255 by 19
849        // (since 2^255 - 19 = 0 mod p).  This adds a factor of 19, so
850        // we get the bounds (z0 is the biggest one, but calculated for
851        // posterity here in case finer estimation is needed later):
852        //
853        //  z0 < ( 2 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 249*2^(51 + 2*b)
854        //  z1 < ( 1 +  1   + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 154*2^(51 + 2*b)
855        //  z2 < ( 2 +  1   +  2   + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 195*2^(51 + 2*b)
856        //  z3 < ( 1 +  1   +  1   +  1   + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) = 118*2^(51 + 2*b)
857        //  z4 < ( 2 +  1   +  2   +  1   +  2   + 1*19 + 2*19 + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) = 141*2^(51 + 2*b)
858        //  z5 < ( 1 +  1   +  1   +  1   +  1   +  1   + 1*19 + 1*19 + 1*19 + 1*19 )*2^(51 + 2b) =  82*2^(51 + 2*b)
859        //  z6 < ( 2 +  1   +  2   +  1   +  2   +  1   +  2   + 1*19 + 2*19 + 1*19 )*2^(51 + 2b) =  87*2^(51 + 2*b)
860        //  z7 < ( 1 +  1   +  1   +  1   +  1   +  1   +  1   +  1   + 1*19 + 1*19 )*2^(51 + 2b) =  46*2^(51 + 2*b)
861        //  z8 < ( 2 +  1   +  2   +  1   +  2   +  1   +  2   +  1   +  2   + 1*19 )*2^(51 + 2b) =  33*2^(51 + 2*b)
862        //  z9 < ( 1 +  1   +  1   +  1   +  1   +  1   +  1   +  1   +  1   +  1   )*2^(51 + 2b) =  10*2^(51 + 2*b)
863        //
864        // So z[0] fits into a u64 if 51 + 2*b + lg(249) < 64
865        //                         if b < 2.5.
866
867        // In fact this bound is slightly sloppy, since it treats both
868        // inputs x and y as being bounded by the same parameter b,
869        // while they are in fact bounded by b_x and b_y, and we
870        // already require that b_y < 1.75 in order to fit the
871        // multiplications by 19 into a u32.  The tighter bound on b_y
872        // means we could get a tighter bound on the outputs, or a
873        // looser bound on b_x.
874        FieldElement2625x4::reduce64([z0, z1, z2, z3, z4, z5, z6, z7, z8, z9])
875    }
876}
877
878#[cfg(target_feature = "avx2")]
879#[cfg(test)]
880mod test {
881    use super::*;
882
883    #[test]
884    fn scale_by_curve_constants() {
885        let mut x = FieldElement2625x4::splat(&FieldElement51::ONE);
886
887        x = x * (121666, 121666, 2 * 121666, 2 * 121665);
888
889        let xs = x.split();
890        assert_eq!(xs[0], FieldElement51([121666, 0, 0, 0, 0]));
891        assert_eq!(xs[1], FieldElement51([121666, 0, 0, 0, 0]));
892        assert_eq!(xs[2], FieldElement51([2 * 121666, 0, 0, 0, 0]));
893        assert_eq!(xs[3], FieldElement51([2 * 121665, 0, 0, 0, 0]));
894    }
895
896    #[test]
897    fn diff_sum_vs_serial() {
898        let x0 = FieldElement51([10000, 10001, 10002, 10003, 10004]);
899        let x1 = FieldElement51([10100, 10101, 10102, 10103, 10104]);
900        let x2 = FieldElement51([10200, 10201, 10202, 10203, 10204]);
901        let x3 = FieldElement51([10300, 10301, 10302, 10303, 10304]);
902
903        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3).diff_sum();
904
905        let result = vec.split();
906
907        assert_eq!(result[0], &x1 - &x0);
908        assert_eq!(result[1], &x1 + &x0);
909        assert_eq!(result[2], &x3 - &x2);
910        assert_eq!(result[3], &x3 + &x2);
911    }
912
913    #[test]
914    fn square_vs_serial() {
915        let x0 = FieldElement51([10000, 10001, 10002, 10003, 10004]);
916        let x1 = FieldElement51([10100, 10101, 10102, 10103, 10104]);
917        let x2 = FieldElement51([10200, 10201, 10202, 10203, 10204]);
918        let x3 = FieldElement51([10300, 10301, 10302, 10303, 10304]);
919
920        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3);
921
922        let result = vec.square_and_negate_D().split();
923
924        assert_eq!(result[0], &x0 * &x0);
925        assert_eq!(result[1], &x1 * &x1);
926        assert_eq!(result[2], &x2 * &x2);
927        assert_eq!(result[3], -&(&x3 * &x3));
928    }
929
930    #[test]
931    fn multiply_vs_serial() {
932        let x0 = FieldElement51([10000, 10001, 10002, 10003, 10004]);
933        let x1 = FieldElement51([10100, 10101, 10102, 10103, 10104]);
934        let x2 = FieldElement51([10200, 10201, 10202, 10203, 10204]);
935        let x3 = FieldElement51([10300, 10301, 10302, 10303, 10304]);
936
937        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3);
938        let vecprime = vec.clone();
939
940        let result = (&vec * &vecprime).split();
941
942        assert_eq!(result[0], &x0 * &x0);
943        assert_eq!(result[1], &x1 * &x1);
944        assert_eq!(result[2], &x2 * &x2);
945        assert_eq!(result[3], &x3 * &x3);
946    }
947
948    #[test]
949    fn test_unpack_repack_pair() {
950        let x0 = FieldElement51([10000 + (10001 << 26), 0, 0, 0, 0]);
951        let x1 = FieldElement51([10100 + (10101 << 26), 0, 0, 0, 0]);
952        let x2 = FieldElement51([10200 + (10201 << 26), 0, 0, 0, 0]);
953        let x3 = FieldElement51([10300 + (10301 << 26), 0, 0, 0, 0]);
954
955        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3);
956
957        let src = vec.0[0];
958
959        let (a, b) = unpack_pair(src);
960
961        let expected_a = u32x8::new(10000, 0, 10100, 0, 10200, 0, 10300, 0);
962        let expected_b = u32x8::new(10001, 0, 10101, 0, 10201, 0, 10301, 0);
963
964        assert_eq!(a, expected_a);
965        assert_eq!(b, expected_b);
966
967        let expected_src = repack_pair(a, b);
968
969        assert_eq!(src, expected_src);
970    }
971
972    #[test]
973    fn new_split_roundtrips() {
974        let x0 = FieldElement51::from_bytes(&[0x10; 32]);
975        let x1 = FieldElement51::from_bytes(&[0x11; 32]);
976        let x2 = FieldElement51::from_bytes(&[0x12; 32]);
977        let x3 = FieldElement51::from_bytes(&[0x13; 32]);
978
979        let vec = FieldElement2625x4::new(&x0, &x1, &x2, &x3);
980
981        let splits = vec.split();
982
983        assert_eq!(x0, splits[0]);
984        assert_eq!(x1, splits[1]);
985        assert_eq!(x2, splits[2]);
986        assert_eq!(x3, splits[3]);
987    }
988}