21 #ifndef EIGEN_PACKET_MATH_SYCL_H
22 #define EIGEN_PACKET_MATH_SYCL_H
23 #include <type_traits>
26 #include "../../InternalHeaderCheck.h"
31 #ifdef SYCL_DEVICE_ONLY
32 #define SYCL_PLOAD(packet_type, AlignedType) \
34 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType<packet_type>( \
35 const typename unpacket_traits<packet_type>::type* from) { \
37 cl::sycl::address_space_cast<cl::sycl::access::address_space::generic_space, cl::sycl::access::decorated::no>( \
44 SYCL_PLOAD(cl::sycl::cl_float4, u)
45 SYCL_PLOAD(cl::sycl::cl_float4, )
46 SYCL_PLOAD(cl::sycl::cl_double2, u)
47 SYCL_PLOAD(cl::sycl::cl_double2, )
54 cl::sycl::address_space_cast<cl::sycl::access::address_space::generic_space, cl::sycl::access::decorated::no>(
55 reinterpret_cast<const cl::sycl::cl_half*
>(from));
56 cl::sycl::cl_half8
res{};
65 cl::sycl::address_space_cast<cl::sycl::access::address_space::generic_space, cl::sycl::access::decorated::no>(
66 reinterpret_cast<const cl::sycl::cl_half*
>(from));
67 cl::sycl::cl_half8
res{};
72 #define SYCL_PSTORE(scalar, packet_type, alignment) \
74 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment(scalar* to, const packet_type& from) { \
76 cl::sycl::address_space_cast<cl::sycl::access::address_space::generic_space, cl::sycl::access::decorated::no>( \
81 SYCL_PSTORE(
float, cl::sycl::cl_float4, )
82 SYCL_PSTORE(
float, cl::sycl::cl_float4, u)
83 SYCL_PSTORE(
double, cl::sycl::cl_double2, )
84 SYCL_PSTORE(
double, cl::sycl::cl_double2, u)
90 cl::sycl::address_space_cast<cl::sycl::access::address_space::generic_space, cl::sycl::access::decorated::no>(
91 reinterpret_cast<cl::sycl::cl_half*
>(to));
98 cl::sycl::address_space_cast<cl::sycl::access::address_space::generic_space, cl::sycl::access::decorated::no>(
99 reinterpret_cast<cl::sycl::cl_half*
>(to));
103 #define SYCL_PSET1(packet_type) \
105 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pset1<packet_type>( \
106 const typename unpacket_traits<packet_type>::type& from) { \
107 return packet_type(from); \
111 SYCL_PSET1(cl::sycl::cl_half8)
112 SYCL_PSET1(cl::sycl::cl_float4)
113 SYCL_PSET1(cl::sycl::cl_double2)
117 template <
typename packet_type>
118 struct get_base_packet {
119 template <
typename sycl_multi_po
inter>
122 template <
typename sycl_multi_po
inter>
127 struct get_base_packet<cl::sycl::cl_half8> {
128 template <
typename sycl_multi_po
inter>
130 return cl::sycl::cl_half8(
static_cast<cl::sycl::half
>(from[0]),
static_cast<cl::sycl::half
>(from[0]),
131 static_cast<cl::sycl::half
>(from[1]),
static_cast<cl::sycl::half
>(from[1]),
132 static_cast<cl::sycl::half
>(from[2]),
static_cast<cl::sycl::half
>(from[2]),
133 static_cast<cl::sycl::half
>(from[3]),
static_cast<cl::sycl::half
>(from[3]));
135 template <
typename sycl_multi_po
inter>
137 return cl::sycl::cl_half8(
138 static_cast<cl::sycl::half
>(from[0 * stride]),
static_cast<cl::sycl::half
>(from[1 * stride]),
139 static_cast<cl::sycl::half
>(from[2 * stride]),
static_cast<cl::sycl::half
>(from[3 * stride]),
140 static_cast<cl::sycl::half
>(from[4 * stride]),
static_cast<cl::sycl::half
>(from[5 * stride]),
141 static_cast<cl::sycl::half
>(from[6 * stride]),
static_cast<cl::sycl::half
>(from[7 * stride]));
144 template <
typename sycl_multi_po
inter>
158 return cl::sycl::cl_half8(
static_cast<cl::sycl::half
>(
a),
static_cast<cl::sycl::half
>(
a + 1),
159 static_cast<cl::sycl::half
>(
a + 2),
static_cast<cl::sycl::half
>(
a + 3),
160 static_cast<cl::sycl::half
>(
a + 4),
static_cast<cl::sycl::half
>(
a + 5),
161 static_cast<cl::sycl::half
>(
a + 6),
static_cast<cl::sycl::half
>(
a + 7));
166 struct get_base_packet<cl::sycl::cl_float4> {
167 template <
typename sycl_multi_po
inter>
169 return cl::sycl::cl_float4(from[0], from[0], from[1], from[1]);
171 template <
typename sycl_multi_po
inter>
173 return cl::sycl::cl_float4(from[0 * stride], from[1 * stride], from[2 * stride], from[3 * stride]);
176 template <
typename sycl_multi_po
inter>
182 to[
tmp += stride] = from.z();
183 to[
tmp += stride] = from.w();
186 return cl::sycl::cl_float4(
static_cast<float>(
a),
static_cast<float>(
a + 1),
static_cast<float>(
a + 2),
187 static_cast<float>(
a + 3));
192 struct get_base_packet<cl::sycl::cl_double2> {
193 template <
typename sycl_multi_po
inter>
195 return cl::sycl::cl_double2(from[0], from[0]);
198 template <
typename sycl_multi_po
inter,
typename Index>
201 return cl::sycl::cl_double2(from[0 * stride], from[1 * stride]);
204 template <
typename sycl_multi_po
inter>
206 const cl::sycl::cl_double2& from,
Index stride) {
208 to[stride] = from.y();
212 return cl::sycl::cl_double2(
static_cast<double>(
a),
static_cast<double>(
a + 1));
216 #define SYCL_PLOAD_DUP_SPECILIZE(packet_type) \
218 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup<packet_type>( \
219 const typename unpacket_traits<packet_type>::type* from) { \
220 return get_base_packet<packet_type>::get_ploaddup(from); \
223 SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_half8)
224 SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_float4)
225 SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_double2)
227 #undef SYCL_PLOAD_DUP_SPECILIZE
229 #define SYCL_PLSET(packet_type) \
231 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type plset<packet_type>( \
232 const typename unpacket_traits<packet_type>::type& a) { \
233 return get_base_packet<packet_type>::set_plset(a); \
235 SYCL_PLSET(cl::sycl::cl_float4)
236 SYCL_PLSET(cl::sycl::cl_double2)
242 return get_base_packet<cl::sycl::cl_half8>::set_plset((
const cl::sycl::half&)
a);
245 #define SYCL_PGATHER_SPECILIZE(scalar, packet_type) \
247 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pgather<scalar, packet_type>( \
248 const typename unpacket_traits<packet_type>::type* from, Index stride) { \
249 return get_base_packet<packet_type>::get_pgather(from, stride); \
252 SYCL_PGATHER_SPECILIZE(
Eigen::half, cl::sycl::cl_half8)
253 SYCL_PGATHER_SPECILIZE(
float, cl::sycl::cl_float4)
254 SYCL_PGATHER_SPECILIZE(
double, cl::sycl::cl_double2)
255 #undef SYCL_PGATHER_SPECILIZE
257 #define SYCL_PSCATTER_SPECILIZE(scalar, packet_type) \
259 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter<scalar, packet_type>( \
260 typename unpacket_traits<packet_type>::type * to, const packet_type& from, Index stride) { \
261 get_base_packet<packet_type>::set_pscatter(to, from, stride); \
264 SYCL_PSCATTER_SPECILIZE(
Eigen::half, cl::sycl::cl_half8)
265 SYCL_PSCATTER_SPECILIZE(
float, cl::sycl::cl_float4)
266 SYCL_PSCATTER_SPECILIZE(
double, cl::sycl::cl_double2)
268 #undef SYCL_PSCATTER_SPECILIZE
270 #define SYCL_PMAD(packet_type) \
272 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pmadd(const packet_type& a, const packet_type& b, \
273 const packet_type& c) { \
274 return cl::sycl::mad(a, b, c); \
277 SYCL_PMAD(cl::sycl::cl_half8)
278 SYCL_PMAD(cl::sycl::cl_float4)
279 SYCL_PMAD(cl::sycl::cl_double2)
302 return a.x() +
a.y() +
a.z() +
a.w();
307 return a.x() +
a.y();
344 return a.x() *
a.y() *
a.z() *
a.w();
348 return a.x() *
a.y();
367 template <
typename Packet>
369 return (
a <=
b).template as<Packet>();
372 template <
typename Packet>
374 return (
a <
b).template as<Packet>();
377 template <
typename Packet>
379 return (
a ==
b).template as<Packet>();
382 #define SYCL_PCMP(OP, TYPE) \
384 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE TYPE pcmp_##OP<TYPE>(const TYPE& a, const TYPE& b) { \
385 return sycl_pcmp_##OP<TYPE>(a, b); \
388 SYCL_PCMP(le, cl::sycl::cl_half8)
389 SYCL_PCMP(lt, cl::sycl::cl_half8)
390 SYCL_PCMP(eq, cl::sycl::cl_half8)
391 SYCL_PCMP(le, cl::sycl::cl_float4)
392 SYCL_PCMP(lt, cl::sycl::cl_float4)
393 SYCL_PCMP(eq, cl::sycl::cl_float4)
394 SYCL_PCMP(le, cl::sycl::cl_double2)
395 SYCL_PCMP(lt, cl::sycl::cl_double2)
396 SYCL_PCMP(eq, cl::sycl::cl_double2)
400 cl::sycl::cl_half
tmp = kernel.
packet[0].s1();
401 kernel.packet[0].s1() = kernel.packet[1].s0();
402 kernel.packet[1].s0() =
tmp;
405 kernel.packet[0].s2() = kernel.packet[2].s0();
406 kernel.packet[2].s0() =
tmp;
409 kernel.packet[0].s3() = kernel.packet[3].s0();
410 kernel.packet[3].s0() =
tmp;
413 kernel.packet[0].s4() = kernel.packet[4].s0();
414 kernel.packet[4].s0() =
tmp;
417 kernel.packet[0].s5() = kernel.packet[5].s0();
418 kernel.packet[5].s0() =
tmp;
421 kernel.packet[0].s6() = kernel.packet[6].s0();
422 kernel.packet[6].s0() =
tmp;
425 kernel.packet[0].s7() = kernel.packet[7].s0();
426 kernel.packet[7].s0() =
tmp;
429 kernel.packet[1].s2() = kernel.packet[2].s1();
430 kernel.packet[2].s1() =
tmp;
433 kernel.packet[1].s3() = kernel.packet[3].s1();
434 kernel.packet[3].s1() =
tmp;
437 kernel.packet[1].s4() = kernel.packet[4].s1();
438 kernel.packet[4].s1() =
tmp;
441 kernel.packet[1].s5() = kernel.packet[5].s1();
442 kernel.packet[5].s1() =
tmp;
445 kernel.packet[1].s6() = kernel.packet[6].s1();
446 kernel.packet[6].s1() =
tmp;
449 kernel.packet[1].s7() = kernel.packet[7].s1();
450 kernel.packet[7].s1() =
tmp;
453 kernel.packet[2].s3() = kernel.packet[3].s2();
454 kernel.packet[3].s2() =
tmp;
457 kernel.packet[2].s4() = kernel.packet[4].s2();
458 kernel.packet[4].s2() =
tmp;
461 kernel.packet[2].s5() = kernel.packet[5].s2();
462 kernel.packet[5].s2() =
tmp;
465 kernel.packet[2].s6() = kernel.packet[6].s2();
466 kernel.packet[6].s2() =
tmp;
469 kernel.packet[2].s7() = kernel.packet[7].s2();
470 kernel.packet[7].s2() =
tmp;
473 kernel.packet[3].s4() = kernel.packet[4].s3();
474 kernel.packet[4].s3() =
tmp;
477 kernel.packet[3].s5() = kernel.packet[5].s3();
478 kernel.packet[5].s3() =
tmp;
481 kernel.packet[3].s6() = kernel.packet[6].s3();
482 kernel.packet[6].s3() =
tmp;
485 kernel.packet[3].s7() = kernel.packet[7].s3();
486 kernel.packet[7].s3() =
tmp;
489 kernel.packet[4].s5() = kernel.packet[5].s4();
490 kernel.packet[5].s4() =
tmp;
493 kernel.packet[4].s6() = kernel.packet[6].s4();
494 kernel.packet[6].s4() =
tmp;
497 kernel.packet[4].s7() = kernel.packet[7].s4();
498 kernel.packet[7].s4() =
tmp;
501 kernel.packet[5].s6() = kernel.packet[6].s5();
502 kernel.packet[6].s5() =
tmp;
505 kernel.packet[5].s7() = kernel.packet[7].s5();
506 kernel.packet[7].s5() =
tmp;
509 kernel.packet[6].s7() = kernel.packet[7].s6();
510 kernel.packet[7].s6() =
tmp;
515 kernel.packet[0].y() = kernel.packet[1].x();
516 kernel.packet[1].x() =
tmp;
519 kernel.packet[0].z() = kernel.packet[2].x();
520 kernel.packet[2].x() =
tmp;
523 kernel.packet[0].w() = kernel.packet[3].x();
524 kernel.packet[3].x() =
tmp;
527 kernel.packet[1].z() = kernel.packet[2].y();
528 kernel.packet[2].y() =
tmp;
531 kernel.packet[1].w() = kernel.packet[3].y();
532 kernel.packet[3].y() =
tmp;
535 kernel.packet[2].w() = kernel.packet[3].z();
536 kernel.packet[3].z() =
tmp;
541 kernel.packet[0].y() = kernel.packet[1].x();
542 kernel.packet[1].x() =
tmp;
548 const cl::sycl::cl_half8& elsePacket) {
549 cl::sycl::cl_short8 condition(ifPacket.select[0] ? 0 : -1, ifPacket.select[1] ? 0 : -1, ifPacket.select[2] ? 0 : -1,
550 ifPacket.select[3] ? 0 : -1, ifPacket.select[4] ? 0 : -1, ifPacket.select[5] ? 0 : -1,
551 ifPacket.select[6] ? 0 : -1, ifPacket.select[7] ? 0 : -1);
552 return cl::sycl::select(thenPacket, elsePacket, condition);
558 const cl::sycl::cl_float4& elsePacket) {
559 cl::sycl::cl_int4 condition(ifPacket.select[0] ? 0 : -1, ifPacket.select[1] ? 0 : -1, ifPacket.select[2] ? 0 : -1,
560 ifPacket.select[3] ? 0 : -1);
561 return cl::sycl::select(thenPacket, elsePacket, condition);
566 const cl::sycl::cl_double2& thenPacket,
const cl::sycl::cl_double2& elsePacket) {
567 cl::sycl::cl_long2 condition(ifPacket.select[0] ? 0 : -1, ifPacket.select[1] ? 0 : -1);
568 return cl::sycl::select(thenPacket, elsePacket, condition);
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
Scalar * b
Definition: benchVecAdd.cpp:17
EIGEN_STRONG_INLINE PacketScalar packet(Index rowId, Index colId) const
Definition: PlainObjectBase.h:247
const Scalar * a
Definition: level2_cplx_impl.h:32
Eigen::Matrix< Scalar, Dynamic, Dynamic, ColMajor > tmp
Definition: level3_impl.h:365
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:664
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:670
EIGEN_STRONG_INLINE void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Definition: AltiVec/Complex.h:339
EIGEN_STRONG_INLINE Packet4i pblend(const Selector< 4 > &ifPacket, const Packet4i &thenPacket, const Packet4i &elsePacket)
Definition: AltiVec/PacketMath.h:3075
EIGEN_DEVICE_FUNC void pstore(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:891
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:911
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Real fabs(const Real &a)
Definition: boostmultiprec.cpp:117
Definition: Eigen_Colamd.h:49
T type
Definition: GenericPacketMath.h:135
@ size
Definition: GenericPacketMath.h:139