AVX512/TypeCasting.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TYPE_CASTING_AVX512_H
11 #define EIGEN_TYPE_CASTING_AVX512_H
12 
13 // IWYU pragma: private
14 #include "../../InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 template <>
21 struct type_casting_traits<float, bool> : vectorized_type_casting_traits<float, bool> {};
22 template <>
23 struct type_casting_traits<bool, float> : vectorized_type_casting_traits<bool, float> {};
24 
25 template <>
26 struct type_casting_traits<float, int> : vectorized_type_casting_traits<float, int> {};
27 template <>
28 struct type_casting_traits<int, float> : vectorized_type_casting_traits<int, float> {};
29 
30 template <>
31 struct type_casting_traits<float, double> : vectorized_type_casting_traits<float, double> {};
32 template <>
33 struct type_casting_traits<double, float> : vectorized_type_casting_traits<double, float> {};
34 
35 template <>
36 struct type_casting_traits<double, int> : vectorized_type_casting_traits<double, int> {};
37 template <>
38 struct type_casting_traits<int, double> : vectorized_type_casting_traits<int, double> {};
39 
40 template <>
42 template <>
44 
45 template <>
46 struct type_casting_traits<half, float> : vectorized_type_casting_traits<half, float> {};
47 template <>
48 struct type_casting_traits<float, half> : vectorized_type_casting_traits<float, half> {};
49 
50 template <>
51 struct type_casting_traits<bfloat16, float> : vectorized_type_casting_traits<bfloat16, float> {};
52 template <>
53 struct type_casting_traits<float, bfloat16> : vectorized_type_casting_traits<float, bfloat16> {};
54 
55 template <>
57  __mmask16 mask = _mm512_cmpneq_ps_mask(a, pzero(a));
58  return _mm512_maskz_cvtepi32_epi8(mask, _mm512_set1_epi32(1));
59 }
60 
61 template <>
63  return _mm512_cvtepi32_ps(_mm512_and_si512(_mm512_cvtepi8_epi32(a), _mm512_set1_epi32(1)));
64 }
65 
66 template <>
68  return _mm512_cvttps_epi32(a);
69 }
70 
71 template <>
73  return _mm512_cvtps_pd(_mm512_castps512_ps256(a));
74 }
75 
76 template <>
78  return _mm512_cvtps_pd(a);
79 }
80 
81 template <>
83 #if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
84  return _mm512_cvttpd_epi64(a);
85 #else
86  constexpr int kTotalBits = sizeof(double) * CHAR_BIT, kMantissaBits = std::numeric_limits<double>::digits - 1,
87  kExponentBits = kTotalBits - kMantissaBits - 1, kBias = (1 << (kExponentBits - 1)) - 1;
88 
89  const __m512i cst_one = _mm512_set1_epi64(1);
90  const __m512i cst_total_bits = _mm512_set1_epi64(kTotalBits);
91  const __m512i cst_bias = _mm512_set1_epi64(kBias);
92 
93  __m512i a_bits = _mm512_castpd_si512(a);
94  // shift left by 1 to clear the sign bit, and shift right by kMantissaBits + 1 to recover biased exponent
95  __m512i biased_e = _mm512_srli_epi64(_mm512_slli_epi64(a_bits, 1), kMantissaBits + 1);
96  __m512i e = _mm512_sub_epi64(biased_e, cst_bias);
97 
98  // shift to the left by kExponentBits + 1 to clear the sign and exponent bits
99  __m512i shifted_mantissa = _mm512_slli_epi64(a_bits, kExponentBits + 1);
100  // shift to the right by kTotalBits - e to convert the significand to an integer
101  __m512i result_significand = _mm512_srlv_epi64(shifted_mantissa, _mm512_sub_epi64(cst_total_bits, e));
102 
103  // add the implied bit
104  __m512i result_exponent = _mm512_sllv_epi64(cst_one, e);
105  // e <= 0 is interpreted as a large positive shift (2's complement), which also conveniently results in zero
106  __m512i result = _mm512_add_epi64(result_significand, result_exponent);
107  // handle negative arguments
108  __mmask8 sign_mask = _mm512_cmplt_epi64_mask(a_bits, _mm512_setzero_si512());
109  result = _mm512_mask_sub_epi64(result, sign_mask, _mm512_setzero_si512(), result);
110  return result;
111 #endif
112 }
113 
114 template <>
116  return _mm512_cvtepi32_ps(a);
117 }
118 
119 template <>
121  return _mm512_cvtepi32_pd(_mm512_castsi512_si256(a));
122 }
123 
124 template <>
126  return _mm512_cvtepi32_pd(a);
127 }
128 
129 template <>
131 #if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL)
132  return _mm512_cvtepi64_pd(a);
133 #else
134  EIGEN_ALIGN64 int64_t aux[8];
135  pstore(aux, a);
136  return _mm512_set_pd(static_cast<double>(aux[7]), static_cast<double>(aux[6]), static_cast<double>(aux[5]),
137  static_cast<double>(aux[4]), static_cast<double>(aux[3]), static_cast<double>(aux[2]),
138  static_cast<double>(aux[1]), static_cast<double>(aux[0]));
139 #endif
140 }
141 
142 template <>
144  return cat256(_mm512_cvtpd_ps(a), _mm512_cvtpd_ps(b));
145 }
146 
147 template <>
149  return cat256i(_mm512_cvttpd_epi32(a), _mm512_cvttpd_epi32(b));
150 }
151 
152 template <>
154  return _mm512_cvtpd_epi32(a);
155 }
156 template <>
158  return _mm512_cvtpd_ps(a);
159 }
160 
161 template <>
163  return _mm512_castps_si512(a);
164 }
165 
166 template <>
168  return _mm512_castsi512_ps(a);
169 }
170 
171 template <>
173  return _mm512_castps_pd(a);
174 }
175 
176 template <>
178  return _mm512_castsi512_pd(a);
179 }
180 
181 template <>
183  return _mm512_castpd_si512(a);
184 }
185 
186 template <>
188  return _mm512_castpd_ps(a);
189 }
190 
191 template <>
193  return _mm512_castps512_ps256(a);
194 }
195 
196 template <>
198  return _mm512_castps512_ps128(a);
199 }
200 
201 template <>
203  return _mm512_castpd512_pd256(a);
204 }
205 
206 template <>
208  return _mm512_castpd512_pd128(a);
209 }
210 
211 template <>
213  return _mm512_castps256_ps512(a);
214 }
215 
216 template <>
218  return _mm512_castps128_ps512(a);
219 }
220 
221 template <>
223  return _mm512_castpd256_pd512(a);
224 }
225 
226 template <>
228  return _mm512_castpd128_pd512(a);
229 }
230 
231 template <>
233  return _mm512_castsi512_si256(a);
234 }
235 template <>
237  return _mm512_castsi512_si128(a);
238 }
239 
240 template <>
242  return _mm256_castsi256_si128(a);
243 }
244 
245 template <>
247  return _mm256_castsi256_si128(a);
248 }
249 
250 template <>
252  return half2float(a);
253 }
254 
255 template <>
257  return float2half(a);
258 }
259 
260 template <>
262  return Bf16ToF32(a);
263 }
264 
265 template <>
267  return F32ToBf16(a);
268 }
269 
270 #ifdef EIGEN_VECTORIZE_AVX512FP16
271 
272 template <>
273 EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet32h>(const Packet32h& a) {
274  return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
275 }
276 template <>
277 EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet32h>(const Packet32h& a) {
278  return _mm256_castsi256_si128(preinterpret<Packet16h>(a));
279 }
280 
281 template <>
282 EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
283  // Discard second-half of input.
284  Packet16h low = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
285  return _mm512_cvtxph_ps(_mm256_castsi256_ph(low));
286 }
287 
288 template <>
289 EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
290  __m512d result = _mm512_undefined_pd();
291  result = _mm512_insertf64x4(
292  result, _mm256_castsi256_pd(_mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
293  result = _mm512_insertf64x4(
294  result, _mm256_castsi256_pd(_mm512_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
295  return _mm512_castpd_ph(result);
296 }
297 
298 template <>
299 EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
300  // Discard second-half of input.
301  Packet8h low = _mm_castps_si128(_mm256_extractf32x4_ps(_mm256_castsi256_ps(a), 0));
302  return _mm256_cvtxph_ps(_mm_castsi128_ph(low));
303 }
304 
305 template <>
306 EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
307  __m256d result = _mm256_undefined_pd();
308  result = _mm256_insertf64x2(result,
309  _mm_castsi128_pd(_mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
310  result = _mm256_insertf64x2(result,
311  _mm_castsi128_pd(_mm256_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
312  return _mm256_castpd_si256(result);
313 }
314 
315 template <>
316 EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
317  Packet8f full = _mm256_cvtxph_ps(_mm_castsi128_ph(a));
318  // Discard second-half of input.
319  return _mm256_extractf32x4_ps(full, 0);
320 }
321 
322 template <>
323 EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
324  __m256 result = _mm256_undefined_ps();
325  result = _mm256_insertf128_ps(result, a, 0);
326  result = _mm256_insertf128_ps(result, b, 1);
327  return _mm256_cvtps_ph(result, _MM_FROUND_TO_NEAREST_INT);
328 }
329 
330 #endif
331 
332 } // end namespace internal
333 
334 } // end namespace Eigen
335 
336 #endif // EIGEN_TYPE_CASTING_AVX512_H
#define EIGEN_ALIGN64
Definition: ConfigureVectorization.h:144
Array< double, 1, 3 > e(1./3., 0.5, 2.)
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
Scalar * b
Definition: benchVecAdd.cpp:17
return int(ret)+1
const Scalar * a
Definition: level2_cplx_impl.h:32
EIGEN_STRONG_INLINE Packet16f preinterpret< Packet16f, Packet4f >(const Packet4f &a)
Definition: AVX512/TypeCasting.h:217
__m128d Packet2d
Definition: LSX/PacketMath.h:36
EIGEN_STRONG_INLINE Packet4d preinterpret< Packet4d, Packet8d >(const Packet8d &a)
Definition: AVX512/TypeCasting.h:202
EIGEN_STRONG_INLINE Packet8d pcast< Packet16f, Packet8d >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:72
EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f &a)
Definition: AVX/PacketMath.h:2283
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf &a)
Definition: AVX/PacketMath.h:2558
EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f &)
Definition: AVX/PacketMath.h:774
EIGEN_STRONG_INLINE Packet8d pcast< Packet16i, Packet8d >(const Packet16i &a)
Definition: AVX512/TypeCasting.h:120
__vector int Packet4i
Definition: AltiVec/PacketMath.h:34
EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b)
Definition: AVX512/PacketMath.h:642
EIGEN_STRONG_INLINE Packet2d preinterpret< Packet2d, Packet8d >(const Packet8d &a)
Definition: AVX512/TypeCasting.h:207
EIGEN_STRONG_INLINE Packet8f pcast< Packet8d, Packet8f >(const Packet8d &a)
Definition: AVX512/TypeCasting.h:157
EIGEN_STRONG_INLINE Packet16f preinterpret< Packet16f, Packet8f >(const Packet8f &a)
Definition: AVX512/TypeCasting.h:212
EIGEN_STRONG_INLINE Packet8i preinterpret< Packet8i, Packet16i >(const Packet16i &a)
Definition: AVX512/TypeCasting.h:232
EIGEN_STRONG_INLINE Packet16f pcast< Packet8d, Packet16f >(const Packet8d &a, const Packet8d &b)
Definition: AVX512/TypeCasting.h:143
EIGEN_STRONG_INLINE Packet8d preinterpret< Packet8d, Packet2d >(const Packet2d &a)
Definition: AVX512/TypeCasting.h:227
__m512d Packet8d
Definition: AVX512/PacketMath.h:36
EIGEN_STRONG_INLINE Packet16i pcast< Packet16f, Packet16i >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:67
EIGEN_STRONG_INLINE Packet8l preinterpret< Packet8l, Packet8d >(const Packet8d &a)
Definition: AVX512/TypeCasting.h:182
EIGEN_STRONG_INLINE Packet16f preinterpret< Packet16f, Packet16i >(const Packet16i &a)
Definition: AVX512/TypeCasting.h:167
EIGEN_STRONG_INLINE Packet8f preinterpret< Packet8f, Packet16f >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:192
EIGEN_STRONG_INLINE Packet16i preinterpret< Packet16i, Packet16f >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:162
EIGEN_STRONG_INLINE Packet8d preinterpret< Packet8d, Packet16f >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:172
EIGEN_STRONG_INLINE Packet8d pcast< Packet8l, Packet8d >(const Packet8l &a)
Definition: AVX512/TypeCasting.h:130
EIGEN_STRONG_INLINE Packet8d preinterpret< Packet8d, Packet8l >(const Packet8l &a)
Definition: AVX512/TypeCasting.h:177
EIGEN_STRONG_INLINE Packet16f pcast< Packet16bf, Packet16f >(const Packet16bf &a)
Definition: AVX512/TypeCasting.h:261
EIGEN_STRONG_INLINE Packet16f pcast< Packet16h, Packet16f >(const Packet16h &a)
Definition: AVX512/TypeCasting.h:251
EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h &a)
Definition: AVX/PacketMath.h:2273
EIGEN_STRONG_INLINE Packet16i cat256i(Packet8i a, Packet8i b)
Definition: AVX512/PacketMath.h:646
EIGEN_STRONG_INLINE Packet16f pcast< Packet16b, Packet16f >(const Packet16b &a)
Definition: AVX512/TypeCasting.h:62
EIGEN_STRONG_INLINE Packet16i pcast< Packet8d, Packet16i >(const Packet8d &a, const Packet8d &b)
Definition: AVX512/TypeCasting.h:148
EIGEN_STRONG_INLINE Packet8d pcast< Packet8i, Packet8d >(const Packet8i &a)
Definition: AVX512/TypeCasting.h:125
EIGEN_STRONG_INLINE Packet8h preinterpret< Packet8h, Packet16h >(const Packet16h &a)
Definition: AVX512/TypeCasting.h:241
EIGEN_STRONG_INLINE Packet16b pcast< Packet16f, Packet16b >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:56
EIGEN_STRONG_INLINE Packet16bf pcast< Packet16f, Packet16bf >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:266
__m512i Packet16i
Definition: AVX512/PacketMath.h:35
EIGEN_STRONG_INLINE Packet16h pcast< Packet16f, Packet16h >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:256
EIGEN_STRONG_INLINE Packet16f preinterpret< Packet16f, Packet8d >(const Packet8d &a)
Definition: AVX512/TypeCasting.h:187
EIGEN_DEVICE_FUNC void pstore(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:891
EIGEN_STRONG_INLINE Packet4i preinterpret< Packet4i, Packet16i >(const Packet16i &a)
Definition: AVX512/TypeCasting.h:236
EIGEN_STRONG_INLINE Packet8l pcast< Packet8d, Packet8l >(const Packet8d &a)
Definition: AVX512/TypeCasting.h:82
EIGEN_STRONG_INLINE Packet8d preinterpret< Packet8d, Packet4d >(const Packet4d &a)
Definition: AVX512/TypeCasting.h:222
EIGEN_STRONG_INLINE Packet4f preinterpret< Packet4f, Packet16f >(const Packet16f &a)
Definition: AVX512/TypeCasting.h:197
EIGEN_STRONG_INLINE Packet8bf preinterpret< Packet8bf, Packet16bf >(const Packet16bf &a)
Definition: AVX512/TypeCasting.h:246
__vector float Packet4f
Definition: AltiVec/PacketMath.h:33
__m256 Packet8f
Definition: AVX/PacketMath.h:34
eigen_packet_wrapper< __m256i, 1 > Packet16h
Definition: AVX512/PacketMath.h:39
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f)
Definition: AltiVec/PacketMath.h:2059
EIGEN_STRONG_INLINE Packet8d pcast< Packet8f, Packet8d >(const Packet8f &a)
Definition: AVX512/TypeCasting.h:77
__m512h Packet32h
Definition: PacketMathFP16.h:20
__m256d Packet4d
Definition: AVX/PacketMath.h:36
eigen_packet_wrapper< __m128i, 2 > Packet8h
Definition: AVX/PacketMath.h:38
__m512 Packet16f
Definition: AVX512/PacketMath.h:34
EIGEN_STRONG_INLINE Packet8i pcast< Packet8d, Packet8i >(const Packet8d &a)
Definition: AVX512/TypeCasting.h:153
EIGEN_STRONG_INLINE Packet16f pcast< Packet16i, Packet16f >(const Packet16i &a)
Definition: AVX512/TypeCasting.h:115
std::int64_t int64_t
Definition: Meta.h:43
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
Definition: Eigen_Colamd.h:49
Definition: Half.h:139
Definition: GenericPacketMath.h:225
Definition: GenericPacketMath.h:201
Definition: GenericPacketMath.h:212