Skip to main content

zune_jpeg/upsampler/
avx2.rs

1/*
2 * Copyright (c) 2025.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9#[cfg(target_arch = "x86")]
10use core::arch::x86::*;
11#[cfg(target_arch = "x86_64")]
12use core::arch::x86_64::*;
13
14#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
15#[target_feature(enable = "avx2")]
16pub unsafe fn upsample_horizontal_avx2(
17    input: &[i16],
18    in_near: &[i16],
19    in_far: &[i16],
20    scratch: &mut [i16],
21    output: &mut [i16],
22) {
23    assert_eq!(input.len() * 2, output.len());
24    assert!(input.len() > 2);
25
26    let len = input.len();
27
28    if len < 18 {
29        return super::scalar::upsample_horizontal(input, in_near, in_far, scratch, output);
30    }
31
32    // First two pixels
33    output[0] = input[0];
34    output[1] = (input[0] * 3 + input[1] + 2) >> 2;
35
36    let v_three = _mm256_set1_epi16(3);
37    let v_two = _mm256_set1_epi16(2);
38
39    let upsample16 = |input: &[i16; 18], output: &mut [i16; 32]| {
40        let in_ptr = input.as_ptr();
41        let out_ptr = output.as_mut_ptr();
42
43        // SAFETY: The input is 18 * 16 bit long, so the loads are safe.
44        let (v_prev, v_curr, v_next) = unsafe {
45            (
46                _mm256_loadu_si256(in_ptr.add(0) as *const __m256i),
47                _mm256_loadu_si256(in_ptr.add(1) as *const __m256i),
48                _mm256_loadu_si256(in_ptr.add(2) as *const __m256i),
49            )
50        };
51
52        let v_common = _mm256_add_epi16(_mm256_mullo_epi16(v_curr, v_three), v_two);
53
54        let v_even = _mm256_srai_epi16(_mm256_add_epi16(v_common, v_prev), 2);
55        let v_odd = _mm256_srai_epi16(_mm256_add_epi16(v_common, v_next), 2);
56
57        let v_res_1 = _mm256_unpacklo_epi16(v_even, v_odd);
58        let v_res_2 = _mm256_unpackhi_epi16(v_even, v_odd);
59
60        let v_final_1 = _mm256_permute2x128_si256(v_res_1, v_res_2, 0x20);
61        let v_final_2 = _mm256_permute2x128_si256(v_res_1, v_res_2, 0x31);
62
63        // SAFETY: The output is 32 * 16 bit long, so the stores are safe.
64        unsafe {
65            _mm256_storeu_si256(out_ptr as *mut __m256i, v_final_1);
66            _mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v_final_2);
67        }
68    };
69
70    for (input, output) in input
71        .windows(18)
72        .step_by(16)
73        .zip(output[2..].chunks_exact_mut(32))
74    {
75        upsample16(input.try_into().unwrap(), output.try_into().unwrap());
76    }
77
78    // Upsample the remainder. This may have some overlap, but that's fine.
79    if let Some(rest_input) = input.last_chunk::<18>() {
80        let end = output.len() - 2;
81        if let Some(rest_output) = output[..end].last_chunk_mut::<32>() {
82            upsample16(rest_input, rest_output);
83        }
84    }
85
86    // Last two pixels.
87    output[output.len() - 2] = (3 * input[len - 1] + input[len - 2] + 2) >> 2;
88    output[output.len() - 1] = input[len - 1];
89}
90
91#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
92#[target_feature(enable = "avx2")]
93pub unsafe fn upsample_vertical_avx2(
94    input: &[i16],
95    in_near: &[i16],
96    in_far: &[i16],
97    scratch: &mut [i16],
98    output: &mut [i16],
99) {
100    assert_eq!(input.len() * 2, output.len());
101    assert_eq!(in_near.len(), input.len());
102    assert_eq!(in_far.len(), input.len());
103
104    let len = input.len();
105
106    if len < 16 {
107        return super::scalar::upsample_vertical(input, in_near, in_far, scratch, output);
108    }
109
110    let middle = output.len() / 2;
111    let (out_top, out_bottom) = output.split_at_mut(middle);
112
113    let v_three = _mm256_set1_epi16(3);
114    let v_two = _mm256_set1_epi16(2);
115
116    let upsample16 = |input: &[i16; 16],
117                      in_near: &[i16; 16],
118                      in_far: &[i16; 16],
119                      out_top: &mut [i16; 16],
120                      out_bottom: &mut [i16; 16]| {
121        // SAFETY: Inputs are all 16 * 16 bit long, so the loads are safe.
122        let (v_in, v_near, v_far) = unsafe {
123            (
124                _mm256_loadu_si256(input.as_ptr() as *const __m256i),
125                _mm256_loadu_si256(in_near.as_ptr() as *const __m256i),
126                _mm256_loadu_si256(in_far.as_ptr() as *const __m256i),
127            )
128        };
129
130        let v_common = _mm256_add_epi16(_mm256_mullo_epi16(v_in, v_three), v_two);
131
132        let v_out_top = _mm256_srai_epi16(_mm256_add_epi16(v_common, v_near), 2);
133        let v_out_bottom = _mm256_srai_epi16(_mm256_add_epi16(v_common, v_far), 2);
134
135        // SAFETY: Outputs are 16 * 16 bit long, so the stores are safe.
136        unsafe {
137            _mm256_storeu_si256(out_top.as_mut_ptr() as *mut __m256i, v_out_top);
138            _mm256_storeu_si256(out_bottom.as_mut_ptr() as *mut __m256i, v_out_bottom);
139        }
140    };
141
142    let chunks = input
143        .chunks_exact(16)
144        .zip(in_near.chunks_exact(16))
145        .zip(in_far.chunks_exact(16))
146        .zip(out_top.chunks_exact_mut(16))
147        .zip(out_bottom.chunks_exact_mut(16));
148
149    for ((((input, in_near), in_far), out_top), out_bottom) in chunks {
150        upsample16(
151            input.try_into().unwrap(),
152            in_near.try_into().unwrap(),
153            in_far.try_into().unwrap(),
154            out_top.try_into().unwrap(),
155            out_bottom.try_into().unwrap(),
156        );
157    }
158
159    // Upsample the remainder. This may have some overlap, but that's fine.
160    // Edition upgrade will fix this nested awfulness.
161    if let Some(rest) = input.last_chunk::<16>() {
162        if let Some(rest_near) = in_near.last_chunk::<16>() {
163            if let Some(rest_far) = in_far.last_chunk::<16>() {
164                if let Some(mut rest_top) = out_top.last_chunk_mut::<16>() {
165                    if let Some(mut rest_bottom) = out_bottom.last_chunk_mut::<16>() {
166                        upsample16(rest, rest_near, rest_far, &mut rest_top, &mut rest_bottom);
167                    }
168                }
169            }
170        }
171    }
172}
173
174#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
175#[target_feature(enable = "avx2")]
176pub unsafe fn upsample_hv_avx2(
177    input: &[i16],
178    in_near: &[i16],
179    in_far: &[i16],
180    scratch_space: &mut [i16],
181    output: &mut [i16],
182) {
183    assert_eq!(input.len() * 4, output.len());
184    assert!(input.len() * 2 <= scratch_space.len());
185    let scratch_space = &mut scratch_space[..input.len() * 2];
186
187
188    upsample_vertical_avx2(input, in_near, in_far, &mut [], scratch_space);
189
190    let scratch_half = scratch_space.len() / 2;
191    let output_half = output.len() / 2;
192
193    let (scratch_top, scratch_bottom) = scratch_space.split_at_mut(scratch_half);
194    let (out_top, out_bottom) = output.split_at_mut(output_half);
195
196    let mut t = [0];
197    upsample_horizontal_avx2(scratch_top, &[], &[], &mut t, out_top);
198    upsample_horizontal_avx2(scratch_bottom, &[], &[], &mut t, out_bottom);
199}