1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use core::{mem::MaybeUninit, ptr};

#[cfg(target_arch = "x86")]
use core::arch::x86::{
    __m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps,
    _MM_FROUND_TO_NEAREST_INT,
};
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::{
    __m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps, _mm_cvtps_ph,
    _MM_FROUND_TO_NEAREST_INT,
};

#[cfg(target_arch = "x86")]
use core::arch::x86::_mm_cvtps_ph;

use super::convert_chunked_slice_8;

/////////////// x86/x86_64 f16c ////////////////

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 {
    let mut vec = MaybeUninit::<__m128i>::zeroed();
    vec.as_mut_ptr().cast::<u16>().write(i);
    let retval = _mm_cvtph_ps(vec.assume_init());
    *(&retval as *const __m128).cast()
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f32_to_f16_x86_f16c(f: f32) -> u16 {
    let mut vec = MaybeUninit::<__m128>::zeroed();
    vec.as_mut_ptr().cast::<f32>().write(f);
    let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
    *(&retval as *const __m128i).cast()
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16; 4]) -> [f32; 4] {
    let mut vec = MaybeUninit::<__m128i>::zeroed();
    ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
    let retval = _mm_cvtph_ps(vec.assume_init());
    *(&retval as *const __m128).cast()
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32; 4]) -> [u16; 4] {
    let mut vec = MaybeUninit::<__m128>::uninit();
    ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
    let retval = _mm_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
    *(&retval as *const __m128i).cast()
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16; 4]) -> [f64; 4] {
    let array = f16x4_to_f32x4_x86_f16c(v);
    // Let compiler vectorize this regular cast for now.
    // TODO: investigate auto-detecting sse2/avx convert features
    [
        array[0] as f64,
        array[1] as f64,
        array[2] as f64,
        array[3] as f64,
    ]
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64; 4]) -> [u16; 4] {
    // Let compiler vectorize this regular cast for now.
    // TODO: investigate auto-detecting sse2/avx convert features
    let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32];
    f32x4_to_f16x4_x86_f16c(&v)
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16x8_to_f32x8_x86_f16c(v: &[u16; 8]) -> [f32; 8] {
    let mut vec = MaybeUninit::<__m128i>::zeroed();
    ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 8);
    let retval = _mm256_cvtph_ps(vec.assume_init());
    *(&retval as *const __m256).cast()
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f32x8_to_f16x8_x86_f16c(v: &[f32; 8]) -> [u16; 8] {
    let mut vec = MaybeUninit::<__m256>::uninit();
    ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 8);
    let retval = _mm256_cvtps_ph(vec.assume_init(), _MM_FROUND_TO_NEAREST_INT);
    *(&retval as *const __m128i).cast()
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16x8_to_f64x8_x86_f16c(v: &[u16; 8]) -> [f64; 8] {
    let array = f16x8_to_f32x8_x86_f16c(v);
    // Let compiler vectorize this regular cast for now.
    // TODO: investigate auto-detecting sse2/avx convert features
    [
        array[0] as f64,
        array[1] as f64,
        array[2] as f64,
        array[3] as f64,
        array[4] as f64,
        array[5] as f64,
        array[6] as f64,
        array[7] as f64,
    ]
}

#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f64x8_to_f16x8_x86_f16c(v: &[f64; 8]) -> [u16; 8] {
    // Let compiler vectorize this regular cast for now.
    // TODO: investigate auto-detecting sse2/avx convert features
    let v = [
        v[0] as f32,
        v[1] as f32,
        v[2] as f32,
        v[3] as f32,
        v[4] as f32,
        v[5] as f32,
        v[6] as f32,
        v[7] as f32,
    ];
    f32x8_to_f16x8_x86_f16c(&v)
}