Skip to main content

poly1305/backend/
avx2.rs

1//! AVX2 implementation of the Poly1305 state machine.
2
3// The State struct and its logic was originally derived from Goll and Gueron's AVX2 C
4// code:
5//     [Vectorization of Poly1305 message authentication code](https://ieeexplore.ieee.org/document/7113463)
6//
7// which was sourced from Bhattacharyya and Sarkar's modified variant:
8//     [Improved SIMD Implementation of Poly1305](https://eprint.iacr.org/2019/842)
9//     https://github.com/Sreyosi/Improved-SIMD-Implementation-of-Poly1305
10//
11// The logic has been extensively rewritten and documented, and several bugs in the
12// original C code were fixed.
13//
14// Note that State only implements the original Goll-Gueron algorithm, not the
15// optimisations provided by Bhattacharyya and Sarkar. The latter require the message
16// length to be known, which is incompatible with the streaming API of UniversalHash.
17#![allow(unsafe_op_in_unsafe_fn)]
18
19use universal_hash::{
20    UhfBackend,
21    array::Array,
22    common::{BlockSizeUser, ParBlocksSizeUser},
23    consts::{U4, U16},
24};
25
26use crate::{Block, Key, Tag};
27
28mod helpers;
29use self::helpers::*;
30
31/// Four Poly1305 blocks (64-bytes)
32type ParBlocks = universal_hash::ParBlocks<State>;
33
34#[derive(Copy, Clone)]
35struct Initialized {
36    p: Aligned4x130,
37    m: SpacedMultiplier4x130,
38    r4: PrecomputedMultiplier,
39}
40
41#[derive(Clone)]
42pub(crate) struct State {
43    k: AdditionKey,
44    r1: PrecomputedMultiplier,
45    r2: PrecomputedMultiplier,
46    initialized: Option<Initialized>,
47    cached_blocks: [Block; 4],
48    num_cached_blocks: usize,
49    partial_block: Option<Block>,
50}
51
52impl State {
53    /// Initialize Poly1305 [`State`] with the given key
54    pub(crate) fn new(key: &Key) -> Self {
55        // Prepare addition key and polynomial key.
56        let (k, r1) = unsafe { prepare_keys(key) };
57
58        // Precompute R^2.
59        let r2 = (r1 * r1).reduce();
60
61        State {
62            k,
63            r1,
64            r2: r2.into(),
65            initialized: None,
66            cached_blocks: [Block::default(); 4],
67            num_cached_blocks: 0,
68            partial_block: None,
69        }
70    }
71
72    /// Process four Poly1305 blocks at once.
73    #[target_feature(enable = "avx2")]
74    pub(crate) unsafe fn compute_par_blocks(&mut self, blocks: &ParBlocks) {
75        assert!(self.partial_block.is_none());
76        assert_eq!(self.num_cached_blocks, 0);
77
78        self.process_blocks(Aligned4x130::from_par_blocks(blocks));
79    }
80
81    /// Compute a Poly1305 block
82    #[target_feature(enable = "avx2")]
83    pub(crate) unsafe fn compute_block(&mut self, block: &Block, partial: bool) {
84        // We can cache a single partial block.
85        if partial {
86            assert!(self.partial_block.is_none());
87            self.partial_block = Some(*block);
88            return;
89        }
90
91        self.cached_blocks[self.num_cached_blocks].copy_from_slice(block);
92        if self.num_cached_blocks < 3 {
93            self.num_cached_blocks += 1;
94            return;
95        } else {
96            self.num_cached_blocks = 0;
97        }
98
99        self.process_blocks(Aligned4x130::from_blocks(&self.cached_blocks));
100    }
101
102    /// Compute a Poly1305 block
103    #[target_feature(enable = "avx2")]
104    unsafe fn process_blocks(&mut self, blocks: Aligned4x130) {
105        if let Some(inner) = &mut self.initialized {
106            // P <-- R^4 * P + blocks
107            inner.p = (&inner.p * inner.r4).reduce() + blocks;
108        } else {
109            // Initialize the polynomial.
110            let p = blocks;
111
112            // Initialize the multiplier (used to merge down the polynomial during
113            // finalization).
114            let (m, r4) = SpacedMultiplier4x130::new(self.r1, self.r2);
115
116            self.initialized = Some(Initialized { p, m, r4 });
117        }
118    }
119
120    /// Finalize output producing a [`Tag`]
121    #[target_feature(enable = "avx2")]
122    pub(crate) unsafe fn finalize(&mut self) -> Tag {
123        assert!(self.num_cached_blocks < 4);
124        let mut data = &self.cached_blocks[..];
125
126        // T ← R◦T
127        // P = T_0 + T_1 + T_2 + T_3
128        let mut p = self
129            .initialized
130            .take()
131            .map(|inner| (inner.p * inner.m).sum().reduce());
132
133        if self.num_cached_blocks >= 2 {
134            // Compute 32 byte block (remaining data < 64 bytes)
135            let mut c = Aligned2x130::from_blocks(data[..2].try_into().unwrap());
136            if let Some(p) = p {
137                c = c + p;
138            }
139            p = Some(c.mul_and_sum(self.r1, self.r2).reduce());
140            data = &data[2..];
141            self.num_cached_blocks -= 2;
142        }
143
144        if self.num_cached_blocks == 1 {
145            // Compute 16 byte block (remaining data < 32 bytes)
146            let mut c = Aligned130::from_block(&data[0]);
147            if let Some(p) = p {
148                c = c + p;
149            }
150            p = Some((c * self.r1).reduce());
151            self.num_cached_blocks -= 1;
152        }
153
154        if let Some(block) = &self.partial_block {
155            // Compute last block (remaining data < 16 bytes)
156            let mut c = Aligned130::from_partial_block(block);
157            if let Some(p) = p {
158                c = c + p;
159            }
160            p = Some((c * self.r1).reduce());
161        }
162
163        // Compute tag: p + k mod 2^128
164        let mut tag = Array::<u8, _>::default();
165        let tag_int = if let Some(p) = p {
166            self.k + p
167        } else {
168            self.k.into()
169        };
170        tag_int.write(tag.as_mut_slice());
171
172        tag
173    }
174}
175
176impl BlockSizeUser for State {
177    type BlockSize = U16;
178}
179
180impl ParBlocksSizeUser for State {
181    type ParBlocksSize = U4;
182}
183
184impl UhfBackend for State {
185    fn proc_block(&mut self, block: &Block) {
186        unsafe { self.compute_block(block, false) };
187    }
188
189    fn proc_par_blocks(&mut self, blocks: &ParBlocks) {
190        if self.num_cached_blocks == 0 {
191            // Fast path.
192            unsafe { self.compute_par_blocks(blocks) };
193        } else {
194            // We are unaligned; use the slow fallback.
195            for block in blocks {
196                self.proc_block(block);
197            }
198        }
199    }
200
201    fn blocks_needed_to_align(&self) -> usize {
202        if self.num_cached_blocks == 0 {
203            // There are no cached blocks; fast path is available.
204            0
205        } else {
206            // There are cached blocks; report how many more we need.
207            self.cached_blocks.len() - self.num_cached_blocks
208        }
209    }
210}