AVX/TypeCasting.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TYPE_CASTING_AVX_H
11 #define EIGEN_TYPE_CASTING_AVX_H
12 
13 // IWYU pragma: private
14 #include "../../InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 #ifndef EIGEN_VECTORIZE_AVX512
21 template <>
22 struct type_casting_traits<float, bool> : vectorized_type_casting_traits<float, bool> {};
23 template <>
24 struct type_casting_traits<bool, float> : vectorized_type_casting_traits<bool, float> {};
25 
26 template <>
27 struct type_casting_traits<float, int> : vectorized_type_casting_traits<float, int> {};
28 template <>
29 struct type_casting_traits<int, float> : vectorized_type_casting_traits<int, float> {};
30 
31 template <>
32 struct type_casting_traits<float, double> : vectorized_type_casting_traits<float, double> {};
33 template <>
34 struct type_casting_traits<double, float> : vectorized_type_casting_traits<double, float> {};
35 
36 template <>
38 template <>
40 
41 template <>
42 struct type_casting_traits<half, float> : vectorized_type_casting_traits<half, float> {};
43 template <>
44 struct type_casting_traits<float, half> : vectorized_type_casting_traits<float, half> {};
45 
46 template <>
47 struct type_casting_traits<bfloat16, float> : vectorized_type_casting_traits<bfloat16, float> {};
48 template <>
49 struct type_casting_traits<float, bfloat16> : vectorized_type_casting_traits<float, bfloat16> {};
50 
51 #ifdef EIGEN_VECTORIZE_AVX2
52 template <>
53 struct type_casting_traits<double, int64_t> : vectorized_type_casting_traits<double, int64_t> {};
54 template <>
55 struct type_casting_traits<int64_t, double> : vectorized_type_casting_traits<int64_t, double> {};
56 #endif
57 #endif
58 
59 template <>
61  __m256 nonzero_a = _mm256_cmp_ps(a, pzero(a), _CMP_NEQ_UQ);
62  __m256 nonzero_b = _mm256_cmp_ps(b, pzero(b), _CMP_NEQ_UQ);
63  constexpr char kFF = '\255';
64 #ifndef EIGEN_VECTORIZE_AVX2
65  __m128i shuffle_mask128_a_lo = _mm_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0);
66  __m128i shuffle_mask128_a_hi = _mm_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF);
67  __m128i shuffle_mask128_b_lo = _mm_set_epi8(kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
68  __m128i shuffle_mask128_b_hi = _mm_set_epi8(12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
69  __m128i a_hi = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_a), 1), shuffle_mask128_a_hi);
70  __m128i a_lo = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_a), 0), shuffle_mask128_a_lo);
71  __m128i b_hi = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_b), 1), shuffle_mask128_b_hi);
72  __m128i b_lo = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_b), 0), shuffle_mask128_b_lo);
73  __m128i merged = _mm_or_si128(_mm_or_si128(b_lo, b_hi), _mm_or_si128(a_lo, a_hi));
74  return _mm_and_si128(merged, _mm_set1_epi8(1));
75 #else
76  __m256i a_shuffle_mask = _mm256_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF,
77  kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0);
78  __m256i b_shuffle_mask = _mm256_set_epi8(12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF,
79  kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
80  __m256i a_shuff = _mm256_shuffle_epi8(_mm256_castps_si256(nonzero_a), a_shuffle_mask);
81  __m256i b_shuff = _mm256_shuffle_epi8(_mm256_castps_si256(nonzero_b), b_shuffle_mask);
82  __m256i a_or_b = _mm256_or_si256(a_shuff, b_shuff);
83  __m256i merged = _mm256_or_si256(a_or_b, _mm256_castsi128_si256(_mm256_extractf128_si256(a_or_b, 1)));
84  return _mm256_castsi256_si128(_mm256_and_si256(merged, _mm256_set1_epi8(1)));
85 #endif
86 }
87 
88 template <>
90  const __m256 cst_one = _mm256_set1_ps(1.0f);
91 #ifdef EIGEN_VECTORIZE_AVX2
92  __m256i a_extended = _mm256_cvtepi8_epi32(a);
93  __m256i abcd_efgh = _mm256_cmpeq_epi32(a_extended, _mm256_setzero_si256());
94 #else
95  __m128i abcd_efhg_ijkl_mnop = _mm_cmpeq_epi8(a, _mm_setzero_si128());
96  __m128i aabb_ccdd_eeff_gghh = _mm_unpacklo_epi8(abcd_efhg_ijkl_mnop, abcd_efhg_ijkl_mnop);
97  __m128i aaaa_bbbb_cccc_dddd = _mm_unpacklo_epi8(aabb_ccdd_eeff_gghh, aabb_ccdd_eeff_gghh);
98  __m128i eeee_ffff_gggg_hhhh = _mm_unpackhi_epi8(aabb_ccdd_eeff_gghh, aabb_ccdd_eeff_gghh);
99  __m256i abcd_efgh = _mm256_setr_m128i(aaaa_bbbb_cccc_dddd, eeee_ffff_gggg_hhhh);
100 #endif
101  __m256 result = _mm256_andnot_ps(_mm256_castsi256_ps(abcd_efgh), cst_one);
102  return result;
103 }
104 
105 template <>
107  return _mm256_cvttps_epi32(a);
108 }
109 
110 template <>
112  return _mm256_set_m128i(_mm256_cvttpd_epi32(b), _mm256_cvttpd_epi32(a));
113 }
114 
115 template <>
117  return _mm256_cvttpd_epi32(a);
118 }
119 
120 template <>
122  return _mm256_cvtepi32_ps(a);
123 }
124 
125 template <>
127  return _mm256_set_m128(_mm256_cvtpd_ps(b), _mm256_cvtpd_ps(a));
128 }
129 
130 template <>
132  return _mm256_cvtpd_ps(a);
133 }
134 
135 template <>
137  return _mm256_cvtepi32_pd(_mm256_castsi256_si128(a));
138 }
139 
140 template <>
142  return _mm256_cvtepi32_pd(a);
143 }
144 
145 template <>
147  return _mm256_cvtps_pd(_mm256_castps256_ps128(a));
148 }
149 
150 template <>
152  return _mm256_cvtps_pd(a);
153 }
154 
155 template <>
157  return _mm256_castps_si256(a);
158 }
159 
160 template <>
162  return _mm256_castsi256_ps(a);
163 }
164 
165 template <>
167  return Packet8ui(a);
168 }
169 
170 template <>
172  return Packet8i(a);
173 }
174 
175 // truncation operations
176 
177 template <>
179  return _mm256_castps256_ps128(a);
180 }
181 
182 template <>
184  return _mm256_castpd256_pd128(a);
185 }
186 
187 template <>
189  return _mm256_castsi256_si128(a);
190 }
191 
192 template <>
194  return _mm256_castsi256_si128(a);
195 }
196 
197 #ifdef EIGEN_VECTORIZE_AVX2
198 template <>
199 EIGEN_STRONG_INLINE Packet4l pcast<Packet4d, Packet4l>(const Packet4d& a) {
200 #if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL)
201  return _mm256_cvttpd_epi64(a);
202 #else
203 
204  // if 'a' exceeds the numerical limits of int64_t, the behavior is undefined
205 
206  // e <= 0 corresponds to |a| < 1, which should result in zero. incidentally, intel intrinsics with shift arguments
207  // greater than or equal to 64 produce zero. furthermore, negative shifts appear to be interpreted as large positive
208  // shifts (two's complement), which also result in zero. therefore, e does not need to be clamped to [0, 64)
209 
210  constexpr int kTotalBits = sizeof(double) * CHAR_BIT, kMantissaBits = std::numeric_limits<double>::digits - 1,
211  kExponentBits = kTotalBits - kMantissaBits - 1, kBias = (1 << (kExponentBits - 1)) - 1;
212 
213  const __m256i cst_one = _mm256_set1_epi64x(1);
214  const __m256i cst_total_bits = _mm256_set1_epi64x(kTotalBits);
215  const __m256i cst_bias = _mm256_set1_epi64x(kBias);
216 
217  __m256i a_bits = _mm256_castpd_si256(a);
218  // shift left by 1 to clear the sign bit, and shift right by kMantissaBits + 1 to recover biased exponent
219  __m256i biased_e = _mm256_srli_epi64(_mm256_slli_epi64(a_bits, 1), kMantissaBits + 1);
220  __m256i e = _mm256_sub_epi64(biased_e, cst_bias);
221 
222  // shift to the left by kExponentBits + 1 to clear the sign and exponent bits
223  __m256i shifted_mantissa = _mm256_slli_epi64(a_bits, kExponentBits + 1);
224  // shift to the right by kTotalBits - e to convert the significand to an integer
225  __m256i result_significand = _mm256_srlv_epi64(shifted_mantissa, _mm256_sub_epi64(cst_total_bits, e));
226 
227  // add the implied bit
228  __m256i result_exponent = _mm256_sllv_epi64(cst_one, e);
229  // e <= 0 is interpreted as a large positive shift (2's complement), which also conveniently results in zero
230  __m256i result = _mm256_add_epi64(result_significand, result_exponent);
231  // handle negative arguments
232  __m256i sign_mask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), a_bits);
233  result = _mm256_sub_epi64(_mm256_xor_si256(result, sign_mask), sign_mask);
234  return result;
235 #endif
236 }
237 
238 template <>
239 EIGEN_STRONG_INLINE Packet4d pcast<Packet4l, Packet4d>(const Packet4l& a) {
240 #if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL)
241  return _mm256_cvtepi64_pd(a);
242 #else
243  EIGEN_ALIGN16 int64_t aux[4];
244  pstore(aux, a);
245  return _mm256_set_pd(static_cast<double>(aux[3]), static_cast<double>(aux[2]), static_cast<double>(aux[1]),
246  static_cast<double>(aux[0]));
247 #endif
248 }
249 
250 template <>
251 EIGEN_STRONG_INLINE Packet4d pcast<Packet2l, Packet4d>(const Packet2l& a, const Packet2l& b) {
252  return _mm256_set_m128d((pcast<Packet2l, Packet2d>(b)), (pcast<Packet2l, Packet2d>(a)));
253 }
254 
255 template <>
256 EIGEN_STRONG_INLINE Packet4ul preinterpret<Packet4ul, Packet4l>(const Packet4l& a) {
257  return Packet4ul(a);
258 }
259 
260 template <>
261 EIGEN_STRONG_INLINE Packet4l preinterpret<Packet4l, Packet4ul>(const Packet4ul& a) {
262  return Packet4l(a);
263 }
264 
265 template <>
266 EIGEN_STRONG_INLINE Packet4l preinterpret<Packet4l, Packet4d>(const Packet4d& a) {
267  return _mm256_castpd_si256(a);
268 }
269 
270 template <>
271 EIGEN_STRONG_INLINE Packet4d preinterpret<Packet4d, Packet4l>(const Packet4l& a) {
272  return _mm256_castsi256_pd(a);
273 }
274 
275 // truncation operations
276 template <>
277 EIGEN_STRONG_INLINE Packet2l preinterpret<Packet2l, Packet4l>(const Packet4l& a) {
278  return _mm256_castsi256_si128(a);
279 }
280 #endif
281 
282 template <>
284  return half2float(a);
285 }
286 
287 template <>
289  return Bf16ToF32(a);
290 }
291 
292 template <>
294  return float2half(a);
295 }
296 
297 template <>
299  return F32ToBf16(a);
300 }
301 
302 } // end namespace internal
303 
304 } // end namespace Eigen
305 
306 #endif // EIGEN_TYPE_CASTING_AVX_H
#define EIGEN_ALIGN16
Definition: ConfigureVectorization.h:142
Array< double, 1, 3 > e(1./3., 0.5, 2.)
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
Scalar * b
Definition: benchVecAdd.cpp:17
return int(ret)+1
const Scalar * a
Definition: level2_cplx_impl.h:32
EIGEN_STRONG_INLINE Packet8f pcast< Packet4d, Packet8f >(const Packet4d &a, const Packet4d &b)
Definition: AVX/TypeCasting.h:126
__m128d Packet2d
Definition: LSX/PacketMath.h:36
eigen_packet_wrapper< __m128i, 3 > Packet2l
Definition: LSX/PacketMath.h:41
EIGEN_STRONG_INLINE Packet4d pcast< Packet4f, Packet4d >(const Packet4f &a)
Definition: AVX/TypeCasting.h:151
EIGEN_STRONG_INLINE Packet8f pcast< Packet8i, Packet8f >(const Packet8i &a)
Definition: AVX/TypeCasting.h:121
EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f &a)
Definition: AVX/PacketMath.h:2283
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf &a)
Definition: AVX/PacketMath.h:2558
EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f &)
Definition: AVX/PacketMath.h:774
__vector int Packet4i
Definition: AltiVec/PacketMath.h:34
EIGEN_STRONG_INLINE Packet8bf pcast< Packet8f, Packet8bf >(const Packet8f &a)
Definition: AVX/TypeCasting.h:298
EIGEN_STRONG_INLINE Packet4f pcast< Packet4d, Packet4f >(const Packet4d &a)
Definition: AVX/TypeCasting.h:131
EIGEN_STRONG_INLINE Packet16b pcast< Packet8f, Packet16b >(const Packet8f &a, const Packet8f &b)
Definition: AVX/TypeCasting.h:60
EIGEN_STRONG_INLINE Packet4d pcast< Packet8f, Packet4d >(const Packet8f &a)
Definition: AVX/TypeCasting.h:146
EIGEN_STRONG_INLINE Packet4i preinterpret< Packet4i, Packet8i >(const Packet8i &a)
Definition: AVX/TypeCasting.h:188
__vector unsigned int Packet4ui
Definition: AltiVec/PacketMath.h:35
EIGEN_STRONG_INLINE Packet8f preinterpret< Packet8f, Packet8i >(const Packet8i &a)
Definition: AVX/TypeCasting.h:161
EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h &a)
Definition: AVX/PacketMath.h:2273
EIGEN_STRONG_INLINE Packet4f preinterpret< Packet4f, Packet8f >(const Packet8f &a)
Definition: AVX/TypeCasting.h:178
EIGEN_STRONG_INLINE Packet4d pcast< Packet8i, Packet4d >(const Packet8i &a)
Definition: AVX/TypeCasting.h:136
EIGEN_STRONG_INLINE Packet4d pcast< Packet4i, Packet4d >(const Packet4i &a)
Definition: AVX/TypeCasting.h:141
EIGEN_STRONG_INLINE Packet4i pcast< Packet4d, Packet4i >(const Packet4d &a)
Definition: AVX/TypeCasting.h:116
EIGEN_STRONG_INLINE Packet8f pcast< Packet8bf, Packet8f >(const Packet8bf &a)
Definition: AVX/TypeCasting.h:288
EIGEN_STRONG_INLINE Packet8f pcast< Packet8h, Packet8f >(const Packet8h &a)
Definition: AVX/TypeCasting.h:283
EIGEN_STRONG_INLINE Packet8f pcast< Packet16b, Packet8f >(const Packet16b &a)
Definition: AVX/TypeCasting.h:89
EIGEN_STRONG_INLINE Packet2d pcast< Packet2l, Packet2d >(const Packet2l &a)
Definition: LSX/TypeCasting.h:514
EIGEN_STRONG_INLINE Packet8i preinterpret< Packet8i, Packet8ui >(const Packet8ui &a)
Definition: AVX/TypeCasting.h:171
EIGEN_STRONG_INLINE Packet2d preinterpret< Packet2d, Packet4d >(const Packet4d &a)
Definition: AVX/TypeCasting.h:183
eigen_packet_wrapper< __m256i, 0 > Packet8i
Definition: AVX/PacketMath.h:35
EIGEN_STRONG_INLINE Packet8i pcast< Packet4d, Packet8i >(const Packet4d &a, const Packet4d &b)
Definition: AVX/TypeCasting.h:111
EIGEN_DEVICE_FUNC void pstore(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:891
eigen_packet_wrapper< __m256i, 4 > Packet8ui
Definition: AVX/PacketMath.h:41
EIGEN_STRONG_INLINE Packet8i preinterpret< Packet8i, Packet8f >(const Packet8f &a)
Definition: AVX/TypeCasting.h:156
EIGEN_STRONG_INLINE Packet8i pcast< Packet8f, Packet8i >(const Packet8f &a)
Definition: AVX/TypeCasting.h:106
EIGEN_STRONG_INLINE Packet8h pcast< Packet8f, Packet8h >(const Packet8f &a)
Definition: AVX/TypeCasting.h:293
__vector float Packet4f
Definition: AltiVec/PacketMath.h:33
__m256 Packet8f
Definition: AVX/PacketMath.h:34
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f)
Definition: AltiVec/PacketMath.h:2059
EIGEN_STRONG_INLINE Packet4ui preinterpret< Packet4ui, Packet8ui >(const Packet8ui &a)
Definition: AVX/TypeCasting.h:193
__m256d Packet4d
Definition: AVX/PacketMath.h:36
EIGEN_STRONG_INLINE Packet8ui preinterpret< Packet8ui, Packet8i >(const Packet8i &a)
Definition: AVX/TypeCasting.h:166
std::int64_t int64_t
Definition: Meta.h:43
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
Definition: Eigen_Colamd.h:49
Definition: BFloat16.h:101
Definition: Half.h:139
Definition: GenericPacketMath.h:225
Definition: GenericPacketMath.h:201
Definition: GenericPacketMath.h:212