Skip to main content

zune_jpeg/idct/
avx2.rs

1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9#![cfg(any(target_arch = "x86", target_arch = "x86_64"))]
10//! AVX optimised IDCT.
11//!
12//! Okay not thaat optimised.
13//!
14//!
15//! # The implementation
16//! The implementation is neatly broken down into two operations.
17//!
18//! 1. Test for zeroes
19//! > There is a shortcut method for idct  where when all AC values are zero, we can get the answer really quickly.
20//!  by scaling the 1/8th of the DCT coefficient of the block to the whole block and level shifting.
21//!
22//! 2. If above fails, we proceed to carry out IDCT as a two pass one dimensional algorithm.
23//! IT does two whole scans where it carries out IDCT on all items
24//! After each successive scan, data is transposed in register(thank you x86 SIMD powers). and the second
25//! pass is carried out.
26//!
27//! The code is not super optimized, it produces bit identical results with scalar code hence it's
28//! `mm256_add_epi16`
29//! and it also has the advantage of making this implementation easy to maintain.
30
31#![cfg(feature = "x86")]
32#![allow(dead_code)]
33
34#[cfg(target_arch = "x86")]
35use core::arch::x86::*;
36#[cfg(target_arch = "x86_64")]
37use core::arch::x86_64::*;
38
39use crate::unsafe_utils::{transpose, YmmRegister};
40
41const SCALE_BITS: i32 = 512 + 65536 + (128 << 17);
42
43// Pack i32 to i16's,
44// clamp them to be between 0-255
45// Undo shuffling
46// Store back to array
47macro_rules! permute_store {
48    ($x:tt,$y:tt,$index:tt,$out:tt,$stride:tt) => {
49        let a = _mm256_packs_epi32($x, $y);
50
51        // Clamp the values after packing, we can clamp more values at once
52        let b = clamp_avx(a);
53        let mut tmp = [0;8];
54        // /Undo shuffling
55        let c = _mm256_permute4x64_epi64(b, shuffle(3, 1, 2, 0));
56
57        // store first vector
58        _mm_storeu_si128(
59            ($out)
60                .get_mut($index..$index + 8)
61                .unwrap_or(&mut tmp)
62                .as_mut_ptr()
63                .cast(),
64            _mm256_extractf128_si256::<0>(c),
65        );
66        $index += $stride;
67        // second vector
68        _mm_storeu_si128(
69            ($out)
70                .get_mut($index..$index + 8)
71                .unwrap()
72                .as_mut_ptr()
73                .cast(),
74            _mm256_extractf128_si256::<1>(c),
75        );
76        $index += $stride;
77    };
78}
79
80#[target_feature(enable = "avx2")]
81#[allow(
82    clippy::too_many_lines,
83    clippy::cast_possible_truncation,
84    clippy::similar_names,
85    clippy::op_ref,
86    unused_assignments,
87    clippy::zero_prefixed_literal
88)]
89pub unsafe fn idct_avx2(
90    in_vector: &mut [i32; 64], out_vector: &mut [i16], stride: usize,
91) {
92    let mut pos = 0;
93
94    // load into registers
95    //
96    // We sign extend i16's to i32's and calculate them with extended precision and
97    // later reduce them to i16's when we are done carrying out IDCT
98
99    let rw0 = _mm256_loadu_si256(in_vector[00..].as_ptr().cast());
100    let rw1 = _mm256_loadu_si256(in_vector[08..].as_ptr().cast());
101    let rw2 = _mm256_loadu_si256(in_vector[16..].as_ptr().cast());
102    let rw3 = _mm256_loadu_si256(in_vector[24..].as_ptr().cast());
103    let rw4 = _mm256_loadu_si256(in_vector[32..].as_ptr().cast());
104    let rw5 = _mm256_loadu_si256(in_vector[40..].as_ptr().cast());
105    let rw6 = _mm256_loadu_si256(in_vector[48..].as_ptr().cast());
106    let rw7 = _mm256_loadu_si256(in_vector[56..].as_ptr().cast());
107
108    // Forward DCT and quantization may cause all the AC terms to be zero, for such
109    // cases we can try to accelerate it
110
111    // Basically the poop is that whenever the array has 63 zeroes, its idct is
112    // (arr[0]>>3)or (arr[0]/8) propagated to all the elements.
113    // We first test to see if the array contains zero elements and if it does, we go the
114    // short way.
115    //
116    // This reduces IDCT overhead from about 39% to 18 %, almost half
117
118    // Do another load for the first row, we don't want to check DC value, because
119    // we only care about AC terms
120    let rw8 = _mm256_loadu_si256(in_vector[1..].as_ptr().cast());
121
122    let mut bitmap = _mm256_or_si256(rw1, rw2);
123    bitmap = _mm256_or_si256(bitmap, rw3);
124    bitmap = _mm256_or_si256(bitmap, rw4);
125    bitmap = _mm256_or_si256(bitmap, rw5);
126    bitmap = _mm256_or_si256(bitmap, rw6);
127    bitmap = _mm256_or_si256(bitmap, rw7);
128    bitmap = _mm256_or_si256(bitmap, rw8);
129
130    if _mm256_testz_si256(bitmap, bitmap) == 1 {
131        // AC terms all zero, idct of the block is ( coeff[0] * qt[0] )/8 + 128 (bias)
132        // (and clamped to 255)
133        // Round by adding 0.5 * (1 << 3) and offset by adding (128 << 3) before scaling
134        let coeff = ((in_vector[0] + 4 + 1024) >> 3).clamp(0, 255) as i16;
135        let idct_value = _mm_set1_epi16(coeff);
136
137        macro_rules! store {
138            ($pos:tt,$value:tt) => {
139                // store
140                _mm_storeu_si128(
141                    out_vector
142                        .get_mut($pos..$pos + 8)
143                        .unwrap()
144                        .as_mut_ptr()
145                        .cast(),
146                    $value,
147                );
148                $pos += stride;
149            };
150        }
151        store!(pos, idct_value);
152        store!(pos, idct_value);
153        store!(pos, idct_value);
154        store!(pos, idct_value);
155
156        store!(pos, idct_value);
157        store!(pos, idct_value);
158        store!(pos, idct_value);
159        store!(pos, idct_value);
160
161        return;
162    }
163
164    let mut row0 = YmmRegister { mm256: rw0 };
165    let mut row1 = YmmRegister { mm256: rw1 };
166    let mut row2 = YmmRegister { mm256: rw2 };
167    let mut row3 = YmmRegister { mm256: rw3 };
168
169    let mut row4 = YmmRegister { mm256: rw4 };
170    let mut row5 = YmmRegister { mm256: rw5 };
171    let mut row6 = YmmRegister { mm256: rw6 };
172    let mut row7 = YmmRegister { mm256: rw7 };
173
174    macro_rules! dct_pass {
175        ($SCALE_BITS:tt,$scale:tt) => {
176            // There are a lot of ways to do this
177            // but to keep it simple(and beautiful), ill make a direct translation of the
178            // scalar code to also make this code fully transparent(this version and the non
179            // avx one should produce identical code.)
180
181            // even part
182            let p1 = (row2 + row6) * 2217;
183
184            let mut t2 = p1 + row6 * -7567;
185            let mut t3 = p1 + row2 * 3135;
186
187            let mut t0 = YmmRegister {
188                mm256: _mm256_slli_epi32((row0 + row4).mm256, 12),
189            };
190            let mut t1 = YmmRegister {
191                mm256: _mm256_slli_epi32((row0 - row4).mm256, 12),
192            };
193
194            let x0 = t0 + t3 + $SCALE_BITS;
195            let x3 = t0 - t3 + $SCALE_BITS;
196            let x1 = t1 + t2 + $SCALE_BITS;
197            let x2 = t1 - t2 + $SCALE_BITS;
198
199            let p3 = row7 + row3;
200            let p4 = row5 + row1;
201            let p1 = row7 + row1;
202            let p2 = row5 + row3;
203            let p5 = (p3 + p4) * 4816;
204
205            t0 = row7 * 1223;
206            t1 = row5 * 8410;
207            t2 = row3 * 12586;
208            t3 = row1 * 6149;
209
210            let p1 = p5 + p1 * -3685;
211            let p2 = p5 + (p2 * -10497);
212            let p3 = p3 * -8034;
213            let p4 = p4 * -1597;
214
215            t3 += p1 + p4;
216            t2 += p2 + p3;
217            t1 += p2 + p4;
218            t0 += p1 + p3;
219
220            row0.mm256 = _mm256_srai_epi32((x0 + t3).mm256, $scale);
221            row1.mm256 = _mm256_srai_epi32((x1 + t2).mm256, $scale);
222            row2.mm256 = _mm256_srai_epi32((x2 + t1).mm256, $scale);
223            row3.mm256 = _mm256_srai_epi32((x3 + t0).mm256, $scale);
224
225            row4.mm256 = _mm256_srai_epi32((x3 - t0).mm256, $scale);
226            row5.mm256 = _mm256_srai_epi32((x2 - t1).mm256, $scale);
227            row6.mm256 = _mm256_srai_epi32((x1 - t2).mm256, $scale);
228            row7.mm256 = _mm256_srai_epi32((x0 - t3).mm256, $scale);
229        };
230    }
231
232    // Process rows
233    dct_pass!(512, 10);
234    transpose(
235        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7,
236    );
237
238    // process columns
239    dct_pass!(SCALE_BITS, 17);
240    transpose(
241        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7,
242    );
243    // Pack and write the values back to the array
244    permute_store!((row0.mm256), (row1.mm256), pos, out_vector, stride);
245    permute_store!((row2.mm256), (row3.mm256), pos, out_vector, stride);
246    permute_store!((row4.mm256), (row5.mm256), pos, out_vector, stride);
247    permute_store!((row6.mm256), (row7.mm256), pos, out_vector, stride);
248}
249
250
251#[target_feature(enable = "avx2")]
252#[allow(
253    clippy::too_many_lines,
254    clippy::cast_possible_truncation,
255    clippy::similar_names,
256    clippy::op_ref,
257    unused_assignments,
258    clippy::zero_prefixed_literal
259)]
260pub unsafe fn idct_avx2_4x4(
261    in_vector: &mut [i32; 64], out_vector: &mut [i16], stride: usize,
262) {
263    let rw0 = _mm256_loadu_si256(in_vector[00..].as_ptr().cast());
264    let rw1 = _mm256_loadu_si256(in_vector[08..].as_ptr().cast());
265    let rw2 = _mm256_loadu_si256(in_vector[16..].as_ptr().cast());
266    let rw3 = _mm256_loadu_si256(in_vector[24..].as_ptr().cast());
267
268    let mut row0 = YmmRegister { mm256: rw0 };
269    let mut row1 = YmmRegister { mm256: rw1 };
270    let mut row2 = YmmRegister { mm256: rw2 };
271    let mut row3 = YmmRegister { mm256: rw3 };
272
273    let mut row4 = YmmRegister { mm256: rw0 };
274    let mut row5 = YmmRegister { mm256: rw0 };
275    let mut row6 = YmmRegister { mm256: rw0 };
276    let mut row7 = YmmRegister { mm256: rw0 };
277
278    {
279        row0.mm256 = _mm256_slli_epi32(row0.mm256, 12);
280        row0 += 512;
281
282        let i2 = row2;
283
284        let p1 = i2 * 2217;
285        let p3 = i2 * 5352;
286
287        let x0 = row0 + p3;
288        let x1 = row0 + p1;
289        let x2 = row0 - p1;
290        let x3 = row0 - p3;
291
292        // odd part
293        let i4 = row3;
294        let i3 = row1;
295
296        let p5 = (i4 + i3) * 4816;
297
298        let p1 = p5 + i3 * -3685;
299        let p2 = p5 + i4 * -10497;
300
301        let t3 = p5 + i3 * 867;
302        let t2 = p5 + i4 * -5945;
303
304        let t1 = p2 + i3 * -1597;
305        let t0 = p1 + i4 * -8034;
306
307        row0.mm256 = _mm256_srai_epi32((x0 + t3).mm256, 10);
308        row1.mm256 = _mm256_srai_epi32((x1 + t2).mm256, 10);
309        row2.mm256 = _mm256_srai_epi32((x2 + t1).mm256, 10);
310        row3.mm256 = _mm256_srai_epi32((x3 + t0).mm256, 10);
311
312        row4.mm256 = _mm256_srai_epi32((x3 - t0).mm256, 10);
313        row5.mm256 = _mm256_srai_epi32((x2 - t1).mm256, 10);
314        row6.mm256 = _mm256_srai_epi32((x1 - t2).mm256, 10);
315        row7.mm256 = _mm256_srai_epi32((x0 - t3).mm256, 10);
316    }
317
318    transpose(
319        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7,
320    );
321
322    {
323        let i2 = row2;
324        let i0 = row0;
325
326        row0.mm256 = _mm256_slli_epi32(i0.mm256, 12);
327        let t0 = row0 + SCALE_BITS;
328
329        let t2 = i2 * 2217;
330        let t3 = i2 * 5352;
331
332        // constants scaled things up by 1<<12, plus we had 1<<2 from first
333        // loop, plus horizontal and vertical each scale by sqrt(8) so together
334        // we've got an extra 1<<3, so 1<<17 total we need to remove.
335        // so we want to round that, which means adding 0.5 * 1<<17,
336        // aka 65536. Also, we'll end up with -128 to 127 that we want
337        // to encode as 0..255 by adding 128, so we'll add that before the shift
338        // Rounding constant is already added into `t0`
339        let x0 = t0 + t3;
340        let x3 = t0 - t3;
341        let x1 = t0 + t2;
342        let x2 = t0 - t2;
343
344        // odd part
345        let i3 = row3;
346        let i1 = row1;
347
348        let p5 = (i3 + i1) * 4816;
349
350        let p1 = p5 + i1 * -3685;
351        let p2 = p5 + i3 * -10497;
352
353        let t3 = p5 + i1 * 867;
354        let t2 = p5 + i3 * -5945;
355
356        let t1 = p2 + i1 * -1597;
357        let t0 = p1 + i3 * -8034;
358
359        row0.mm256 = _mm256_srai_epi32((x0 + t3).mm256, 17);
360        row1.mm256 = _mm256_srai_epi32((x1 + t2).mm256, 17);
361        row2.mm256 = _mm256_srai_epi32((x2 + t1).mm256, 17);
362        row3.mm256 = _mm256_srai_epi32((x3 + t0).mm256, 17);
363        row4.mm256 = _mm256_srai_epi32((x3 - t0).mm256, 17);
364        row5.mm256 = _mm256_srai_epi32((x2 - t1).mm256, 17);
365        row6.mm256 = _mm256_srai_epi32((x1 - t2).mm256, 17);
366        row7.mm256 = _mm256_srai_epi32((x0 - t3).mm256, 17);
367    }
368
369    transpose(
370        &mut row0, &mut row1, &mut row2, &mut row3, &mut row4, &mut row5, &mut row6, &mut row7,
371    );
372
373    let mut pos = 0;
374
375    // Pack and write the values back to the array
376    permute_store!((row0.mm256), (row1.mm256), pos, out_vector, stride);
377    permute_store!((row2.mm256), (row3.mm256), pos, out_vector, stride);
378    permute_store!((row4.mm256), (row5.mm256), pos, out_vector, stride);
379    permute_store!((row6.mm256), (row7.mm256), pos, out_vector, stride);
380}
381
382#[inline]
383#[target_feature(enable = "avx2")]
384unsafe fn clamp_avx(reg: __m256i) -> __m256i {
385    let min_s = _mm256_set1_epi16(0);
386    let max_s = _mm256_set1_epi16(255);
387
388    let max_v = _mm256_max_epi16(reg, min_s); //max(a,0)
389    let min_v = _mm256_min_epi16(max_v, max_s); //min(max(a,0),255)
390    return min_v;
391}
392
393/// A copy of `_MM_SHUFFLE()` that doesn't require
394/// a nightly compiler
395#[inline]
396const fn shuffle(z: i32, y: i32, x: i32, w: i32) -> i32 {
397    ((z << 6) | (y << 4) | (x << 2) | w)
398}