half/binary16/arch/
x86.rs

1use core::{mem::MaybeUninit, ptr};
2use zerocopy::transmute;
3
4#[cfg(target_arch = "x86")]
5use core::arch::x86::{
6    __m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps,
7    _MM_FROUND_TO_NEAREST_INT,
8};
9#[cfg(target_arch = "x86_64")]
10use core::arch::x86_64::{
11    __m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps, _mm_cvtps_ph,
12    _MM_FROUND_TO_NEAREST_INT,
13};
14
15#[cfg(target_arch = "x86")]
16use core::arch::x86::_mm_cvtps_ph;
17
18use super::convert_chunked_slice_8;
19
20/////////////// x86/x86_64 f16c ////////////////
21
22#[target_feature(enable = "f16c")]
23#[inline]
24pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 {
25    let vec: __m128i = transmute!([i, 0, 0, 0, 0, 0, 0, 0]);
26    let retval: [f32; 4] = transmute!(_mm_cvtph_ps(vec));
27    retval[0]
28}
29
30#[target_feature(enable = "f16c")]
31#[inline]
32pub(super) unsafe fn f32_to_f16_x86_f16c(f: f32) -> u16 {
33    let vec: __m128 = transmute!([f, 0.0, 0.0, 0.0]);
34    let retval = _mm_cvtps_ph(vec, _MM_FROUND_TO_NEAREST_INT);
35    let retval: [u16; 8] = transmute!(retval);
36    retval[0]
37}
38
39#[target_feature(enable = "f16c")]
40#[inline]
41pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16; 4]) -> [f32; 4] {
42    let vec: __m128i = transmute!([*v, [0, 0, 0, 0]]);
43    transmute!(_mm_cvtph_ps(vec))
44}
45
46#[target_feature(enable = "f16c")]
47#[inline]
48pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32; 4]) -> [u16; 4] {
49    let vec: __m128 = zerocopy::transmute!(*v);
50    let retval = _mm_cvtps_ph(vec, _MM_FROUND_TO_NEAREST_INT);
51    let retval: [[u16; 4]; 2] = transmute!(retval);
52    retval[0]
53}
54
55#[target_feature(enable = "f16c")]
56#[inline]
57pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16; 4]) -> [f64; 4] {
58    let array = f16x4_to_f32x4_x86_f16c(v);
59    // Let compiler vectorize this regular cast for now.
60    // TODO: investigate auto-detecting sse2/avx convert features
61    [
62        array[0] as f64,
63        array[1] as f64,
64        array[2] as f64,
65        array[3] as f64,
66    ]
67}
68
69#[target_feature(enable = "f16c")]
70#[inline]
71pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64; 4]) -> [u16; 4] {
72    // Let compiler vectorize this regular cast for now.
73    // TODO: investigate auto-detecting sse2/avx convert features
74    let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32];
75    f32x4_to_f16x4_x86_f16c(&v)
76}
77
78#[target_feature(enable = "f16c")]
79#[inline]
80pub(super) unsafe fn f16x8_to_f32x8_x86_f16c(v: &[u16; 8]) -> [f32; 8] {
81    let vec: __m128i = transmute!(*v);
82    transmute!(_mm256_cvtph_ps(vec))
83}
84
85#[target_feature(enable = "f16c")]
86#[inline]
87pub(super) unsafe fn f32x8_to_f16x8_x86_f16c(v: &[f32; 8]) -> [u16; 8] {
88    let vec: __m256 = transmute!(*v);
89    let retval = _mm256_cvtps_ph(vec, _MM_FROUND_TO_NEAREST_INT);
90    transmute!(retval)
91}
92
93#[target_feature(enable = "f16c")]
94#[inline]
95pub(super) unsafe fn f16x8_to_f64x8_x86_f16c(v: &[u16; 8]) -> [f64; 8] {
96    let array = f16x8_to_f32x8_x86_f16c(v);
97    // Let compiler vectorize this regular cast for now.
98    // TODO: investigate auto-detecting sse2/avx convert features
99    [
100        array[0] as f64,
101        array[1] as f64,
102        array[2] as f64,
103        array[3] as f64,
104        array[4] as f64,
105        array[5] as f64,
106        array[6] as f64,
107        array[7] as f64,
108    ]
109}
110
111#[target_feature(enable = "f16c")]
112#[inline]
113pub(super) unsafe fn f64x8_to_f16x8_x86_f16c(v: &[f64; 8]) -> [u16; 8] {
114    // Let compiler vectorize this regular cast for now.
115    // TODO: investigate auto-detecting sse2/avx convert features
116    let v = [
117        v[0] as f32,
118        v[1] as f32,
119        v[2] as f32,
120        v[3] as f32,
121        v[4] as f32,
122        v[5] as f32,
123        v[6] as f32,
124        v[7] as f32,
125    ];
126    f32x8_to_f16x8_x86_f16c(&v)
127}