10 #ifndef EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
11 #define EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
14 #include "../../InternalHeaderCheck.h"
16 #if defined(__MMA__) && !EIGEN_ALTIVEC_DISABLE_MMA
17 #if EIGEN_COMP_LLVM || (__GNUC__ > 10 || __GNUC_MINOR__ >= 3)
21 #if !EIGEN_COMP_LLVM && (__GNUC__ < 11)
23 #define GCC_ONE_VECTORPAIR_BUG
31 #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
32 #define EIGEN_POWER_GEMV_PREFETCH(p) prefetch(p)
34 #define EIGEN_POWER_GEMV_PREFETCH(p)
38 #if !__has_builtin(__builtin_vsx_assemble_pair)
39 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
41 #if !__has_builtin(__builtin_vsx_disassemble_pair)
42 #define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
47 #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
48 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
51 #if (__GNUC_MINOR__ > 3)
52 #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
53 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
55 #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
56 __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
59 #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
60 __builtin_vsx_build_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
64 #define GEMV_IS_COMPLEX_COMPLEX ((sizeof(LhsPacket) == 16) && (sizeof(RhsPacket) == 16))
65 #define GEMV_IS_FLOAT (ResPacketSize == (16 / sizeof(float)))
66 #define GEMV_IS_SCALAR (sizeof(ResPacket) != 16)
67 #define GEMV_IS_COMPLEX_FLOAT (ResPacketSize == (16 / sizeof(std::complex<float>)))
70 template <
typename ResPacket,
typename ResScalar>
75 template <
typename ResScalar>
80 #define GEMV_UNROLL(func, N) func(0, N) func(1, N) func(2, N) func(3, N) func(4, N) func(5, N) func(6, N) func(7, N)
82 #define GEMV_UNROLL_HALF(func, N) func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
84 #define GEMV_GETN(N) (((N) * ResPacketSize) >> 2)
86 #define GEMV_LOADPACKET_COL(iter) lhs.template load<LhsPacket, LhsAlignment>(i + ((iter) * LhsPacketSize), j)
89 #define GEMV_UNROLL3(func, N, which) \
90 func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) func(4, N, which) func(5, N, which) \
91 func(6, N, which) func(7, N, which)
93 #define GEMV_UNUSED_VAR(iter, N, which) \
94 if (GEMV_GETN(N) <= iter) { \
95 EIGEN_UNUSED_VARIABLE(which##iter); \
98 #define GEMV_UNUSED_EXTRA_VAR(iter, N, which) \
100 EIGEN_UNUSED_VARIABLE(which##iter); \
103 #define GEMV_UNUSED_EXTRA(N, which) GEMV_UNROLL3(GEMV_UNUSED_EXTRA_VAR, N, which)
105 #define GEMV_UNUSED(N, which) GEMV_UNROLL3(GEMV_UNUSED_VAR, N, which)
107 #define GEMV_INIT_MMA(iter, N) \
108 if (GEMV_GETN(N) > iter) { \
109 __builtin_mma_xxsetaccz(&e##iter); \
113 #define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
114 GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_COL(iter2), GEMV_LOADPACKET_COL((iter2) + 1));
116 #define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
117 const LhsScalar& src##iter1 = lhs(i + ((iter1 * 32) / sizeof(LhsScalar)), j); \
118 b##iter1 = *reinterpret_cast<__vector_pair*>(const_cast<LhsScalar*>(&src##iter1));
121 #define GEMV_LOAD1A_COL_MMA(iter, N) \
122 if (GEMV_GETN(N) > iter) { \
123 if (GEMV_IS_FLOAT) { \
124 g##iter = GEMV_LOADPACKET_COL(iter); \
125 EIGEN_UNUSED_VARIABLE(b##iter); \
127 GEMV_LOADPAIR_COL_MMA(iter, iter << 1) \
128 EIGEN_UNUSED_VARIABLE(g##iter); \
131 EIGEN_UNUSED_VARIABLE(b##iter); \
132 EIGEN_UNUSED_VARIABLE(g##iter); \
135 #define GEMV_WORK1A_COL_MMA(iter, N) \
136 if (GEMV_GETN(N) > iter) { \
137 if (GEMV_IS_FLOAT) { \
138 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, a0, g##iter); \
140 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, b##iter, a0); \
144 #define GEMV_LOAD1B_COL_MMA(iter1, iter2, iter3, N) \
145 if (GEMV_GETN(N) > iter1) { \
146 if (GEMV_IS_FLOAT) { \
147 GEMV_LOADPAIR_COL_MMA(iter2, iter2) \
148 EIGEN_UNUSED_VARIABLE(b##iter3); \
150 GEMV_LOADPAIR_COL_MMA(iter2, iter2 << 1) \
151 GEMV_LOADPAIR_COL_MMA(iter3, iter3 << 1) \
154 EIGEN_UNUSED_VARIABLE(b##iter2); \
155 EIGEN_UNUSED_VARIABLE(b##iter3); \
157 EIGEN_UNUSED_VARIABLE(g##iter2); \
158 EIGEN_UNUSED_VARIABLE(g##iter3);
160 #define GEMV_WORK1B_COL_MMA(iter1, iter2, iter3, N) \
161 if (GEMV_GETN(N) > iter1) { \
162 if (GEMV_IS_FLOAT) { \
164 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(h), &b##iter2); \
165 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, a0, h[0]); \
166 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, a0, h[1]); \
168 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, b##iter2, a0); \
169 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, b##iter3, a0); \
174 #define GEMV_LOAD_COL_MMA(N) \
175 if (GEMV_GETN(N) > 1) { \
176 GEMV_UNROLL_HALF(GEMV_LOAD1B_COL_MMA, (N >> 1)) \
178 GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N) \
181 #define GEMV_WORK_COL_MMA(N) \
182 if (GEMV_GETN(N) > 1) { \
183 GEMV_UNROLL_HALF(GEMV_WORK1B_COL_MMA, (N >> 1)) \
185 GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N) \
188 #define GEMV_LOAD_COL_MMA(N) GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N)
190 #define GEMV_WORK_COL_MMA(N) GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N)
193 #define GEMV_DISASSEMBLE_MMA(iter, N) \
194 if (GEMV_GETN(N) > iter) { \
195 __builtin_mma_disassemble_acc(&result##iter.packet, &e##iter); \
196 if (!GEMV_IS_FLOAT) { \
197 result##iter.packet[0][1] = result##iter.packet[1][0]; \
198 result##iter.packet[2][1] = result##iter.packet[3][0]; \
202 #define GEMV_LOADPAIR2_COL_MMA(iter1, iter2) \
203 b##iter1 = *reinterpret_cast<__vector_pair*>(res + i + ((iter2) * ResPacketSize));
205 #define GEMV_LOAD2_COL_MMA(iter1, iter2, iter3, N) \
206 if (GEMV_GETN(N) > iter1) { \
207 if (GEMV_IS_FLOAT) { \
208 GEMV_LOADPAIR2_COL_MMA(iter2, iter2); \
209 EIGEN_UNUSED_VARIABLE(b##iter3); \
211 GEMV_LOADPAIR2_COL_MMA(iter2, iter2 << 1); \
212 GEMV_LOADPAIR2_COL_MMA(iter3, iter3 << 1); \
215 EIGEN_UNUSED_VARIABLE(b##iter2); \
216 EIGEN_UNUSED_VARIABLE(b##iter3); \
220 #define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
221 ResPacket f##iter2[2]; \
222 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(f##iter2), &b##iter2); \
223 f##iter2[0] = pmadd(result##iter2.packet[0], palpha, f##iter2[0]); \
224 f##iter2[1] = pmadd(result##iter3.packet[(iter2 == iter3) ? 2 : 0], palpha, f##iter2[1]); \
225 GEMV_BUILDPAIR_MMA(b##iter2, f##iter2[0], f##iter2[1]);
227 #define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
228 if (GEMV_IS_FLOAT) { \
229 __asm__("xvmaddasp %0,%x1,%x3\n\txvmaddasp %L0,%x2,%x3" \
231 : "wa"(result##iter3.packet[0]), "wa"(result##iter2.packet[0]), "wa"(palpha)); \
233 __asm__("xvmaddadp %0,%x1,%x3\n\txvmaddadp %L0,%x2,%x3" \
235 : "wa"(result##iter2.packet[2]), "wa"(result##iter2.packet[0]), "wa"(palpha)); \
239 #define GEMV_WORK2_COL_MMA(iter1, iter2, iter3, N) \
240 if (GEMV_GETN(N) > iter1) { \
241 if (GEMV_IS_FLOAT) { \
242 GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter2); \
244 GEMV_WORKPAIR2_COL_MMA(iter2, iter2, iter2 << 1); \
245 GEMV_WORKPAIR2_COL_MMA(iter3, iter3, iter3 << 1); \
249 #define GEMV_STOREPAIR2_COL_MMA(iter1, iter2) \
250 *reinterpret_cast<__vector_pair*>(res + i + ((iter2) * ResPacketSize)) = b##iter1;
252 #define GEMV_STORE_COL_MMA(iter, N) \
253 if (GEMV_GETN(N) > iter) { \
254 if (GEMV_IS_FLOAT) { \
255 storeMaddData<ResPacket, ResScalar>(res + i + (iter * ResPacketSize), palpha, result##iter.packet[0]); \
257 GEMV_LOADPAIR2_COL_MMA(iter, iter << 1) \
258 GEMV_WORKPAIR2_COL_MMA(iter, iter, iter << 1) \
259 GEMV_STOREPAIR2_COL_MMA(iter, iter << 1) \
263 #define GEMV_STORE2_COL_MMA(iter1, iter2, iter3, N) \
264 if (GEMV_GETN(N) > iter1) { \
265 if (GEMV_IS_FLOAT) { \
266 GEMV_STOREPAIR2_COL_MMA(iter2, iter2); \
268 GEMV_STOREPAIR2_COL_MMA(iter2, iter2 << 1) \
269 GEMV_STOREPAIR2_COL_MMA(iter3, iter3 << 1) \
273 #define GEMV_PROCESS_COL_ONE_MMA(N) \
274 GEMV_UNROLL(GEMV_INIT_MMA, N) \
276 __vector_pair b0, b1, b2, b3, b4, b5, b6, b7; \
278 LhsPacket g0, g1, g2, g3, g4, g5, g6, g7; \
279 RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
280 GEMV_UNROLL(GEMV_PREFETCH, N) \
281 GEMV_LOAD_COL_MMA(N) \
282 GEMV_WORK_COL_MMA(N) \
283 } while (++j < jend); \
284 GEMV_UNROLL(GEMV_DISASSEMBLE_MMA, N) \
285 if (GEMV_GETN(N) <= 1) { \
286 GEMV_UNROLL(GEMV_STORE_COL_MMA, N) \
288 GEMV_UNROLL_HALF(GEMV_LOAD2_COL_MMA, (N >> 1)) \
289 GEMV_UNROLL_HALF(GEMV_WORK2_COL_MMA, (N >> 1)) \
290 GEMV_UNROLL_HALF(GEMV_STORE2_COL_MMA, (N >> 1)) \
292 i += (ResPacketSize * N);
295 #define GEMV_INIT(iter, N) \
297 c##iter = pset1<ResPacket>(ResScalar(0)); \
299 EIGEN_UNUSED_VARIABLE(c##iter); \
302 #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
303 #define GEMV_PREFETCH(iter, N) \
304 if (GEMV_GETN(N) > ((iter >> 1) + ((N >> 1) * (iter & 1)))) { \
305 lhs.prefetch(i + (iter * LhsPacketSize) + prefetch_dist, j); \
308 #define GEMV_PREFETCH(iter, N)
311 #define GEMV_WORK_COL(iter, N) \
313 c##iter = pcj.pmadd(GEMV_LOADPACKET_COL(iter), a0, c##iter); \
316 #define GEMV_STORE_COL(iter, N) \
318 pstoreu(res + i + (iter * ResPacketSize), \
319 pmadd(c##iter, palpha, ploadu<ResPacket>(res + i + (iter * ResPacketSize)))); \
323 #define GEMV_PROCESS_COL_ONE(N) \
324 GEMV_UNROLL(GEMV_INIT, N) \
327 RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
328 GEMV_UNROLL(GEMV_PREFETCH, N) \
329 GEMV_UNROLL(GEMV_WORK_COL, N) \
330 } while (++j < jend); \
331 GEMV_UNROLL(GEMV_STORE_COL, N) \
332 i += (ResPacketSize * N);
335 #define GEMV_PROCESS_COL(N) GEMV_PROCESS_COL_ONE_MMA(N)
337 #define GEMV_PROCESS_COL(N) GEMV_PROCESS_COL_ONE(N)
342 template <
typename LhsPacket,
typename RhsPacket,
bool accumulate>
343 EIGEN_ALWAYS_INLINE void pger_vecMMA_acc(__vector_quad* acc,
const RhsPacket&
a,
const LhsPacket&
b) {
345 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
347 __builtin_mma_xvf32ger(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
352 template <
typename LhsPacket,
typename RhsPacket,
bool accumulate>
355 __builtin_mma_xvf64gerpp(acc,
a, (__vector
unsigned char)
b);
357 __builtin_mma_xvf64ger(acc,
a, (__vector
unsigned char)
b);
362 template <
typename LhsScalar,
typename LhsMapper,
typename RhsScalar,
typename RhsMapper,
typename ResScalar>
365 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
367 typedef typename Traits::LhsPacket LhsPacket;
368 typedef typename Traits::RhsPacket RhsPacket;
369 typedef typename Traits::ResPacket ResPacket;
379 conj_helper<LhsScalar, RhsScalar, false, false> cj;
380 conj_helper<LhsPacket, RhsPacket, false, false> pcj;
382 const Index lhsStride = lhs.stride();
386 ResPacketSize = Traits::ResPacketSize,
387 LhsPacketSize = Traits::LhsPacketSize,
388 RhsPacketSize = Traits::RhsPacketSize,
391 #ifndef GCC_ONE_VECTORPAIR_BUG
392 const Index n8 =
rows - 8 * ResPacketSize + 1;
393 const Index n4 =
rows - 4 * ResPacketSize + 1;
394 const Index n2 =
rows - 2 * ResPacketSize + 1;
396 const Index n1 =
rows - 1 * ResPacketSize + 1;
397 #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
398 const Index prefetch_dist = 64 * LhsPacketSize;
402 const Index block_cols =
cols < 128 ?
cols : (lhsStride *
sizeof(LhsScalar) < 16000 ? 16 : 8);
405 for (
Index j2 = 0; j2 <
cols; j2 += block_cols) {
408 ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
410 __vector_quad e0, e1, e2, e3, e4, e5, e6, e7;
411 PacketBlock<ResPacket, 4> result0, result1, result2, result3, result4, result5, result6, result7;
413 GEMV_UNUSED(8, result)
414 GEMV_UNUSED_EXTRA(1,
c)
416 #ifndef GCC_ONE_VECTORPAIR_BUG
437 d0 += cj.pmul(lhs(
i,
j), rhs2(
j, 0));
438 }
while (++
j < jend);
444 template <
bool extraRows>
447 d0 =
pmadd(acc, pAlpha, d0);
455 template <Index num_acc,
bool extraRows, Index size>
458 constexpr
Index real_acc = (num_acc - (extraRows ? 1 : 0));
459 for (
Index k = 0;
k < real_acc;
k++) {
460 outputVecCol<false>(acc[
k][0], result +
k * 4, pAlpha, extra_rows);
463 outputVecCol<true>(acc[real_acc][0], result + real_acc * 4, pAlpha, extra_rows);
467 static Packet16uc p16uc_MERGE16_32_V1 = {0, 1, 16, 17, 0, 1, 16, 17, 0, 1, 16, 17, 0, 1, 16, 17};
468 static Packet16uc p16uc_MERGE16_32_V2 = {2, 3, 18, 19, 2, 3, 18, 19, 2, 3, 18, 19, 2, 3, 18, 19};
470 template <Index num_acc,
typename LhsMapper,
bool zero>
472 Packet8bf c0 = lhs.template loadPacket<Packet8bf>(
k * 4, 0);
475 b1 = lhs.template loadPacket<Packet8bf>(
k * 4, 1);
481 if (num_acc > (
k + 1)) {
489 template <Index num_acc,
bool zero>
491 for (
Index k = 0;
k < num_acc;
k++) {
498 template <
typename RhsMapper,
bool linear>
506 to[
i] = rhs(
j +
i, 0);
512 template <
typename RhsMapper>
516 return rhs.template loadPacket<Packet8bf>(
j + 0, 0);
520 template <
typename RhsMapper,
bool linear>
525 template <Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool zero,
bool linear>
528 Packet8bf b2 = loadColData<RhsMapper, linear>(rhs,
j);
535 using LhsSubMapper =
typename LhsMapper::SubMapper;
537 LhsSubMapper lhs2 = lhs.getSubMapper(0,
j);
538 for (
Index k = 0;
k < num_acc;
k += 2) {
539 loadVecLoopVSX<num_acc, LhsSubMapper, zero>(
k, lhs2, a0);
542 multVecVSX<num_acc, zero>(acc, a0, b0);
545 template <Index num_acc>
547 for (
Index i = 0;
i < num_acc;
i++) {
548 acc[
i][0] = acc[
i][0] + acc[
i][1];
553 #define MAX_BFLOAT16_VEC_ACC_VSX 8
555 template <const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
558 constexpr
Index step = (num_acc * 4);
559 const Index extra_rows = (extraRows) ? (
rows & 3) : 0;
565 zeroAccumulators<num_acc, 2>(acc);
567 using LhsSubMapper =
typename LhsMapper::SubMapper;
569 LhsSubMapper lhs2 = lhs.getSubMapper(
row, 0);
570 for (
Index j = 0;
j + 2 <= cend;
j += 2) {
571 vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, false, linear>(
j, lhs2, rhs, acc);
574 vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
577 addResultsVSX<num_acc>(acc);
579 outputVecColResults<num_acc, extraRows, 2>(acc, result, pAlpha, extra_rows);
582 }
while (multiIters && (step <=
rows - (
row += step)));
585 template <const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
587 const Packet4f pAlpha,
float* result) {
589 colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs,
590 rhs, pAlpha, result);
594 template <
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
596 const Packet4f pAlpha,
float* result) {
599 colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
602 colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
605 colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
608 colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
611 colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
614 colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
617 colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
621 colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
627 template <
typename LhsMapper,
typename RhsMapper,
bool linear>
629 const Packet4f pAlpha,
float* result) {
632 colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false, linear>(
row, cend,
rows, lhs, rhs,
637 colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
639 colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
643 template <const Index size,
bool inc, Index delta>
660 template <const Index size,
bool inc = false>
674 storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc,
rows & 7);
676 storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
679 storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
680 storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
683 dst += extra * resInc;
684 if (
size != 32)
break;
688 template <
bool inc = false>
691 convertPointerF32toBF16VSX<32, inc>(
i, result,
rows, dst, resInc);
692 convertPointerF32toBF16VSX<16, inc>(
i, result,
rows, dst, resInc);
693 convertPointerF32toBF16VSX<8, inc>(
i, result,
rows, dst, resInc);
694 convertPointerF32toBF16VSX<1, inc>(
i, result,
rows, dst, resInc);
697 template <
typename RhsMapper,
typename LhsMapper,
typename =
void>
701 using RhsSubMapper =
typename RhsMapper::SubMapper;
703 RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
704 calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2,
rows, lhs, rhs2, pAlpha, result);
708 template <
typename RhsMapper,
typename LhsMapper>
710 std::enable_if_t<std::is_member_function_pointer<decltype(&RhsMapper::stride)>::value>>
714 using RhsSubMapper =
typename RhsMapper::SubMapper;
716 RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
717 if (rhs.stride() == 1) {
718 calcVSXVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2,
rows, lhs, rhs2, pAlpha, result);
720 calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2,
rows, lhs, rhs2, pAlpha, result);
725 template <
typename LhsMapper,
typename RhsMapper>
736 const Index lhsStride = lhs.stride();
739 const Index block_cols =
cols < 128 ?
cols : (lhsStride *
sizeof(bfloat16) < 16000 ? 16 : 8);
747 for (
Index j2 = 0; j2 <
cols; j2 += block_cols) {
750 using LhsSubMapper =
typename LhsMapper::SubMapper;
752 LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
759 template <Index num_acc, Index size>
761 constexpr
Index extra = num_acc & 3;
763 for (
Index k = 0;
k < num_acc;
k += 4) {
765 d0 =
pmadd(acc[
k + 0][0], pAlpha, d0);
767 if (num_acc > (
k + 3)) {
773 memcpy((
void*)(result +
k), (
void*)(&d0),
sizeof(
float) * extra);
779 template <Index num_acc>
781 if (num_acc > (
k + 1)) {
782 acc[
k][1] = vec_mergel(acc[
k + 0][0], acc[
k + 1][0]);
783 acc[
k][0] = vec_mergeh(acc[
k + 0][0], acc[
k + 1][0]);
784 acc[
k][0] = acc[
k][0] + acc[
k][1];
785 acc[
k][0] += vec_sld(acc[
k][0], acc[
k][0], 8);
787 acc[
k][0] += vec_sld(acc[
k][0], acc[
k][0], 8);
789 acc[
k][0] += vec_sld(acc[
k][0], acc[
k][0], 12);
791 acc[
k][0] += vec_sld(acc[
k][0], acc[
k][0], 4);
796 template <Index num_acc>
798 for (
Index k = 0;
k < num_acc;
k += 4) {
799 preduxVecResults2VSX<num_acc>(acc,
k + 0);
800 if (num_acc > (
k + 2)) {
801 preduxVecResults2VSX<num_acc>(acc,
k + 2);
802 #ifdef EIGEN_VECTORIZE_VSX
803 acc[
k + 0][0] =
reinterpret_cast<Packet4f>(
804 vec_mergeh(
reinterpret_cast<Packet2ul>(acc[
k + 0][0]),
reinterpret_cast<Packet2ul>(acc[
k + 2][0])));
823 template <Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extra>
830 b1 = rhs.template loadPacketPartial<Packet8bf>(
j, extra_cols);
835 b1 = rhs.template loadPacket<Packet8bf>(
j);
840 const LhsMapper lhs2 = lhs.getSubMapper(0,
j);
841 for (
Index k = 0;
k < num_acc;
k++) {
843 a1 = lhs2.template loadPacketPartial<Packet8bf>(
k, 0, extra_cols);
848 a1 = lhs2.template loadPacket<Packet8bf>(
k, 0);
854 multVecVSX<num_acc, false>(acc, a0, b0);
857 template <Index num_acc,
typename LhsMapper,
typename RhsMapper>
861 for (;
j + 8 <=
cols;
j += 8) {
862 multVSXVecLoop<num_acc, LhsMapper, RhsMapper, false>(acc, lhs, rhs,
j, extra_cols);
866 multVSXVecLoop<num_acc, LhsMapper, RhsMapper, true>(acc, lhs, rhs,
j, extra_cols);
870 template <const Index num_acc,
typename LhsMapper,
typename RhsMapper>
879 zeroAccumulators<num_acc, 2>(acc);
881 const LhsMapper lhs2 = lhs.getSubMapper(
row, 0);
882 vecVSXLoop<num_acc, LhsMapper, RhsMapper>(
cols, lhs2, rhs, acc, extra_cols);
884 addResultsVSX<num_acc>(acc);
886 preduxVecResultsVSX<num_acc>(acc);
888 outputVecResults<num_acc, 2>(acc, result, pAlpha);
891 }
while (multiIters && (num_acc <=
rows - (
row += num_acc)));
894 template <const Index num_acc,
typename LhsMapper,
typename RhsMapper>
896 const Packet4f pAlpha,
float* result) {
898 colVSXVecLoopBody<num_acc, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
902 template <
typename LhsMapper,
typename RhsMapper>
904 const Packet4f pAlpha,
float* result) {
907 colVSXVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
910 colVSXVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
913 colVSXVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
916 colVSXVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
919 colVSXVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
922 colVSXVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
925 colVSXVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
930 template <
typename LhsMapper,
typename RhsMapper>
935 colVSXVecLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
938 colVSXVecLoopBodyExtra<LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
941 template <
typename LhsMapper,
typename RhsMapper>
944 typedef typename RhsMapper::LinearMapper LinearMapper;
949 LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
960 convertArrayPointerBF16toF32<true>(result, 1,
rows,
res, resIncr);
962 calcVSXVecLoops<LhsMapper, LinearMapper>(
cols,
rows, lhs, rhs2, pAlpha, result);
966 convertArrayPointerF32toBF16VSX<true>(result,
rows,
res, resIncr);
970 #undef MAX_BFLOAT16_VEC_ACC_VSX
973 0xcc, 0xdd, 0xee, 0xff, 0x88, 0x99, 0xaa, 0xbb};
975 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77};
979 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00};
981 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
983 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
985 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
987 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00};
989 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
992 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
994 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
996 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00};
998 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
1000 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x80};
1002 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
1006 #define COMPLEX_DELTA 0
1008 #define COMPLEX_DELTA 2
1022 #ifdef __POWER8_VECTOR__
1033 #if defined(_ARCH_PWR8) && (!EIGEN_COMP_LLVM || __clang_major__ >= 12)
1034 #define PERMXOR_GOOD
1073 #ifdef __POWER8_VECTOR__
1074 return Packet2cf(vec_neg(
a.v));
1081 #ifdef __POWER8_VECTOR__
1082 return Packet1cd(vec_neg(
a.v));
1111 #ifdef EIGEN_VECTORIZE_VSX
1112 return Packet1cd(__builtin_vsx_xxpermdi(
a.v,
a.v, 2));
1121 #ifdef EIGEN_VECTORIZE_VSX
1123 __asm__(
"lxsdx %x0,%y1" :
"=wa"(
t) :
"Z"(*src));
1125 *
reinterpret_cast<std::complex<float>*
>(
reinterpret_cast<float*
>(&
t) +
COMPLEX_DELTA) = *src;
1131 template <
typename RhsScalar>
1134 __asm__(
"lxvwsx %x0,%y1" :
"=wa"(
r) :
"Z"(*(
reinterpret_cast<float*
>(src) + 0)));
1135 __asm__(
"lxvwsx %x0,%y1" :
"=wa"(
i) :
"Z"(*(
reinterpret_cast<float*
>(src) + 1)));
1143 template <
typename RhsScalar>
1145 #ifdef EIGEN_VECTORIZE_VSX
1146 __asm__(
"lxvdsx %x0,%y1" :
"=wa"(
r) :
"Z"(*(
reinterpret_cast<double*
>(src) + 0)));
1147 __asm__(
"lxvdsx %x0,%y1" :
"=wa"(
i) :
"Z"(*(
reinterpret_cast<double*
>(src) + 1)));
1150 r = vec_splat(
t, 0);
1151 i = vec_splat(
t, 1);
1155 #ifndef __POWER8_VECTOR__
1157 0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B};
1160 0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F};
1164 template <
typename RhsScalar>
1167 #ifdef __POWER8_VECTOR__
1168 r = vec_mergee(
t,
t);
1169 i = vec_mergeo(
t,
t);
1176 template <
typename RhsScalar>
1183 #ifdef EIGEN_VECTORIZE_VSX
1185 __asm__(
"lxvdsx %x0,%y1" :
"=wa"(
ret) :
"Z"(*(
reinterpret_cast<double*
>(src) + 0)));
1200 template <
typename ResPacket>
1209 template <
typename ResPacket>
1215 template <
typename ResPacket>
1220 template <
typename ResPacket>
1249 return vec_mergeh(
ret,
ret);
1263 template <
typename ResPacket>
1272 template <
typename ResPacket>
1288 template <
typename Scalar,
typename ResScalar>
1294 template <
typename Scalar,
typename ResScalar,
typename ResPacket,
int which>
1297 ret.v[
COMPLEX_DELTA + 0] = pset1_realimag<Scalar, ResScalar>(
alpha, (which & 0x01), (which & 0x04));
1298 ret.v[
COMPLEX_DELTA + 1] = pset1_realimag<Scalar, ResScalar>(
alpha, (which & 0x02), (which & 0x08));
1304 template <
typename Scalar,
typename ResScalar,
typename ResPacket,
int which>
1307 ret.v[0] = pset1_realimag<Scalar, ResScalar>(
alpha, (which & 0x01), (which & 0x04));
1308 ret.v[1] = pset1_realimag<Scalar, ResScalar>(
alpha, (which & 0x02), (which & 0x08));
1313 template <
typename Packet>
1329 template <
typename Packet,
typename LhsPacket,
typename RhsPacket>
1333 return pset_zero<Packet>();
1339 template <
typename PResPacket,
typename ResPacket,
typename ResScalar,
typename Scalar>
1342 separate.
r = pset1_complex<Scalar, ResScalar, ResPacket, 0x3>(
alpha);
1343 separate.
i = pset1_complex<Scalar, ResScalar, ResPacket, 0x0>(
alpha);
1352 template <
typename ScalarPacket,
typename AlphaData>
1354 return pmadd(c2, b0.separate.i.v,
pmadd(c0, b0.separate.r.v, c4));
1358 template <
typename Scalar,
typename ScalarPacket,
typename PResPacket,
typename ResPacket,
typename ResScalar,
1363 ScalarPacket c4 = ploadu<ScalarPacket>(
reinterpret_cast<Scalar*
>(
res));
1364 ScalarPacket c3 = pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0);
1367 ScalarPacket c4 = pload_complex<ResPacket>(
res);
1368 PResPacket c3 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
1373 template <
typename ScalarPacket,
typename PResPacket,
typename ResPacket,
typename ResScalar,
typename AlphaData,
1378 #if !defined(_ARCH_PWR10)
1379 ScalarPacket c4 = pload_complex<ResPacket>(
res + (iter2 * ResPacketSize));
1380 ScalarPacket c5 = pload_complex<ResPacket>(
res + ((iter2 + 1) * ResPacketSize));
1381 PResPacket c6 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
1382 PResPacket c7 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c5, b0));
1384 pstoreu(
res + ((iter2 + 1) * ResPacketSize), c7);
1386 __vector_pair
a = *
reinterpret_cast<__vector_pair*
>(
res + (iter2 * ResPacketSize));
1389 __builtin_vsx_disassemble_pair(
reinterpret_cast<void*
>(c6), &
a);
1390 c6[0] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c6[0].v, b0));
1391 c6[1] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c6[1].v, b0));
1395 __asm__(
"xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" :
"+&d"(
a) :
"wa"(b0.separate.r.v),
"wa"(c0.v),
"wa"(c1.v));
1396 __asm__(
"xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" :
"+&d"(
a) :
"wa"(b0.separate.i.v),
"wa"(c2.v),
"wa"(c3.v));
1398 __asm__(
"xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" :
"+&d"(
a) :
"wa"(b0.separate.r.v),
"wa"(c0.v),
"wa"(c1.v));
1399 __asm__(
"xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" :
"+&d"(
a) :
"wa"(b0.separate.i.v),
"wa"(c2.v),
"wa"(c3.v));
1402 *
reinterpret_cast<__vector_pair*
>(
res + (iter2 * ResPacketSize)) =
a;
1407 template <
typename Scalar,
typename LhsScalar,
typename LhsMapper,
typename LhsPacket>
1409 if (
sizeof(
Scalar) ==
sizeof(LhsScalar)) {
1410 const LhsScalar& src = lhs(
i + 0,
j);
1413 return lhs.template load<LhsPacket, Unaligned>(
i + 0,
j);
1417 template <
typename ComplexPacket,
typename RealPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1419 if (ConjugateLhs && ConjugateRhs) {
1420 return vec_madd(
a,
pconj2(ComplexPacket(
b)).
v,
c);
1421 }
else if (Negate && !ConjugateLhs && ConjugateRhs) {
1422 return vec_nmsub(
a,
b,
c);
1424 return vec_madd(
a,
b,
c);
1429 template <
typename ComplexPacket,
typename RealPacket,
bool Conjugate>
1432 return vec_madd(
a,
pconj2(ComplexPacket(
b)).
v,
c);
1434 return vec_madd(
a,
b,
c);
1438 template <
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
bool ConjugateLhs,
1439 bool ConjugateRhs,
int StorageOrder>
1441 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
1444 b0 = pset1<RhsPacket>(*
b);
1446 b0 = ploadu<RhsPacket>(
b);
1448 c0 = pcj.pmadd(a0, b0, c0);
1452 template <
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
1453 typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1455 ScalarPacket br, bi;
1457 pload_realimag<RhsScalar>(
b, br, bi);
1459 pload_realimag_row<RhsScalar>(
b, br, bi);
1461 if (ConjugateLhs && !ConjugateRhs) a0 =
pconj2(a0);
1463 ScalarPacket cr = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, false>(a0.v, br, c0.v);
1464 ScalarPacket ci = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, true>(a1.v, bi, c1.v);
1466 c0 = PResPacket(cr);
1470 template <
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
1471 typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1479 ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateRhs>(a0, b0, c0.v);
1480 c0 = PResPacket(cri);
1484 template <
typename ScalarPacket,
typename LhsPacket,
typename RhsScalar,
typename RhsPacket,
typename PResPacket,
1485 typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1487 ScalarPacket a1 = pload_complex<ResPacket>(&a0);
1492 b0 = pload_real_row<ResPacket>(
b);
1494 ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateLhs>(a1, b0, c0.v);
1495 c0 = PResPacket(cri);
1498 #define GEMV_MULT_COMPLEX_COMPLEX(LhsType, RhsType, ResType) \
1499 template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1500 typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1501 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, ResType& c1) { \
1502 gemv_mult_complex_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1503 ConjugateRhs, StorageOrder>(a0, b, c0, c1); \
1509 #define GEMV_MULT_REAL_COMPLEX(LhsType, RhsType, ResType) \
1510 template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1511 typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1512 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, RhsType&) { \
1513 gemv_mult_real_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1514 ConjugateRhs, StorageOrder>(a0, b, c0); \
1522 #define GEMV_MULT_COMPLEX_REAL(LhsType, RhsType, ResType1, ResType2) \
1523 template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1524 typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1525 EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType1& c0, ResType2&) { \
1526 gemv_mult_complex_real<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1527 ConjugateRhs, StorageOrder>(a0, b, c0); \
1537 template <
typename T>
1547 template <
typename T>
1557 template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename ResPacket>
1559 a = SLhsPacket(pload_complex<ResPacket>(&
a));
1562 template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename ResPacket>
1568 template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1570 if (NegativeAccumulate) {
1571 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
1573 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
1578 template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1580 if (NegativeAccumulate) {
1581 __builtin_mma_xvf64gernp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
1583 __builtin_mma_xvf64gerpp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
1587 template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
1593 template <
typename RealPacket,
typename LhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1595 if (ConjugateLhs && ConjugateRhs) {
1596 RealPacket b2 =
pconj2(convertComplex(
b)).v;
1597 return pger_vecMMA<RealPacket, RealPacket, false>(
c, b2,
a.v);
1598 }
else if (Negate && !ConjugateLhs && ConjugateRhs) {
1599 return pger_vecMMA<RealPacket, RealPacket, true>(
c,
b,
a.v);
1601 return pger_vecMMA<RealPacket, RealPacket, false>(
c,
b,
a.v);
1605 template <
typename RealPacket,
typename LhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool Negate>
1607 if (ConjugateLhs && ConjugateRhs) {
1608 RealPacket b2 =
pconj2(convertComplex(
b)).v;
1609 return pger_vecMMA<RealPacket, __vector_pair, false>(
c,
a, b2);
1610 }
else if (Negate && !ConjugateLhs && ConjugateRhs) {
1611 return pger_vecMMA<RealPacket, __vector_pair, true>(
c,
a,
b);
1613 return pger_vecMMA<RealPacket, __vector_pair, false>(
c,
a,
b);
1618 template <
typename RealPacket,
typename LhsPacket,
bool Conjugate,
int StorageOrder>
1620 RealPacket a2 = convertReal(
a);
1622 RealPacket b2 =
pconj2(convertComplex(
b)).v;
1624 return pger_vecMMA<RealPacket, RealPacket, false>(
c, b2, a2);
1626 return pger_vecMMA<RealPacket, RealPacket, false>(
c, a2, b2);
1630 return pger_vecMMA<RealPacket, RealPacket, false>(
c,
b, a2);
1632 return pger_vecMMA<RealPacket, RealPacket, false>(
c, a2,
b);
1638 template <
typename RealPacket,
typename LhsPacket,
bool Conjugate,
int StorageOrder>
1641 RealPacket b2 =
pconj2(convertComplex(
b)).v;
1642 return pger_vecMMA<RealPacket, __vector_pair, false>(
c,
a, b2);
1644 return pger_vecMMA<RealPacket, __vector_pair, false>(
c,
a,
b);
1649 template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
1650 bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1651 EIGEN_ALWAYS_INLINE void gemv_mult_complex_complex_MMA(SLhsPacket& a0, RhsScalar*
b, __vector_quad* c0) {
1658 pmadd_complex_complex_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ConjugateRhs, false>(a0, b0, c0);
1662 template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
1663 bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1664 EIGEN_ALWAYS_INLINE void gemv_mult_complex_real_MMA(SLhsPacket& a0, RhsScalar*
b, __vector_quad* c0) {
1665 pload_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, ResPacket>(a0);
1670 b0 = pload_real_row<ResPacket>(
b);
1672 pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ColMajor>(a0, b0, c0);
1676 template <
typename ScalarPacket,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
typename ResPacket,
1677 bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1678 EIGEN_ALWAYS_INLINE void gemv_mult_real_complex_MMA(SLhsPacket& a0, RhsScalar*
b, __vector_quad* c0) {
1685 pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateRhs,
1686 (
sizeof(RhsScalar) ==
sizeof(std::complex<float>)) ? StorageOrder :
ColMajor>(a0, b0, c0);
1689 #define GEMV_MULT_COMPLEX_COMPLEX_MMA(LhsType, RhsType) \
1690 template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1691 typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1692 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1693 gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, \
1694 ConjugateRhs, StorageOrder>(a0, b, c0); \
1697 GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet2cf, std::complex<float>)
1698 GEMV_MULT_COMPLEX_COMPLEX_MMA(__vector_pair, std::complex<float>)
1699 GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet1cd, std::complex<double>)
1702 template <
typename ScalarPacket,
typename LhsScalar,
typename LhsPacket,
typename SLhsPacket,
typename RhsScalar,
1703 typename RhsPacket,
typename ResPacket,
bool ConjugateLhs,
bool ConjugateRhs,
int StorageOrder>
1704 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(__vector_pair& a0, std::complex<double>*
b, __vector_quad* c0) {
1705 if (
sizeof(LhsScalar) == 16) {
1706 gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs,
1707 StorageOrder>(a0,
b, c0);
1709 gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs,
1710 StorageOrder>(a0,
b, c0);
1714 #define GEMV_MULT_REAL_COMPLEX_MMA(LhsType, RhsType) \
1715 template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1716 typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1717 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1718 gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, \
1719 StorageOrder>(a0, b, c0); \
1722 GEMV_MULT_REAL_COMPLEX_MMA(
Packet4f, std::complex<float>)
1723 GEMV_MULT_REAL_COMPLEX_MMA(
Packet2d, std::complex<double>)
1725 #define GEMV_MULT_COMPLEX_REAL_MMA(LhsType, RhsType) \
1726 template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1727 typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1728 EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1729 gemv_mult_complex_real_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, \
1730 StorageOrder>(a0, b, c0); \
1733 GEMV_MULT_COMPLEX_REAL_MMA(Packet2cf,
float)
1734 GEMV_MULT_COMPLEX_REAL_MMA(Packet1cd,
double)
1735 GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair,
float)
1736 GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair,
double)
1739 template <
typename Scalar,
typename ScalarPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
1741 EIGEN_ALWAYS_INLINE void disassembleResults2(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1742 __builtin_mma_disassemble_acc(&result0.packet, c0);
1743 if (
sizeof(LhsPacket) == 16) {
1744 if (
sizeof(RhsPacket) == 16) {
1745 ScalarPacket tmp0, tmp2;
1746 tmp2 = vec_mergeh(result0.packet[2], result0.packet[3]);
1747 tmp0 = vec_mergeh(result0.packet[0], result0.packet[1]);
1748 result0.packet[3] = vec_mergel(result0.packet[3], result0.packet[2]);
1749 result0.packet[1] = vec_mergel(result0.packet[1], result0.packet[0]);
1750 result0.packet[2] = tmp2;
1751 result0.packet[0] = tmp0;
1754 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
1755 result0.packet[2] =
pconj2(convertComplex(result0.packet[2])).v;
1756 }
else if (ConjugateRhs) {
1757 result0.packet[1] =
pconj2(convertComplex(result0.packet[1])).v;
1758 result0.packet[3] =
pconj2(convertComplex(result0.packet[3])).v;
1760 result0.packet[1] =
pconjinv(convertComplex(result0.packet[1])).v;
1761 result0.packet[3] =
pconjinv(convertComplex(result0.packet[3])).v;
1763 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1764 result0.packet[2] = vec_add(result0.packet[2], result0.packet[3]);
1766 result0.packet[0][1] = result0.packet[1][1];
1767 result0.packet[2][1] = result0.packet[3][1];
1772 template <
typename Scalar,
typename ScalarPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
1774 EIGEN_ALWAYS_INLINE void disassembleResults4(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1775 __builtin_mma_disassemble_acc(&result0.packet, c0);
1778 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
1779 result0.packet[1] =
pcplxflip2(convertComplex(result0.packet[1])).v;
1782 result0.packet[1] =
pcplxconjflip(convertComplex(result0.packet[1])).v;
1784 result0.packet[1] =
pcplxflipconj(convertComplex(result0.packet[1])).v;
1787 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1788 }
else if (
sizeof(LhsPacket) ==
sizeof(std::complex<float>)) {
1790 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
1793 result0.packet[0] = vec_mergee(result0.packet[0], result0.packet[1]);
1797 template <
typename Scalar,
typename ScalarPacket,
int ResPacketSize,
typename LhsPacket,
typename RhsPacket,
1798 bool ConjugateLhs,
bool ConjugateRhs>
1799 EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1801 disassembleResults2<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1803 disassembleResults4<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1808 #define GEMV_GETN_COMPLEX(N) (((N) * ResPacketSize) >> 1)
1810 #define GEMV_LOADPACKET_COL_COMPLEX(iter) \
1811 loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + ((iter) * ResPacketSize), j)
1813 #define GEMV_LOADPACKET_COL_COMPLEX_DATA(iter) convertReal(GEMV_LOADPACKET_COL_COMPLEX(iter))
1816 #define GEMV_INIT_COL_COMPLEX_MMA(iter, N) \
1817 if (GEMV_GETN_COMPLEX(N) > iter) { \
1818 __builtin_mma_xxsetaccz(&e0##iter); \
1822 #define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1823 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), \
1824 GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1)); \
1825 EIGEN_UNUSED_VARIABLE(f##iter1);
1827 #define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1828 if (sizeof(LhsPacket) == 16) { \
1829 const LhsScalar& src = lhs(i + ((32 * iter1) / sizeof(LhsScalar)), j); \
1830 a##iter1 = *reinterpret_cast<__vector_pair*>(const_cast<LhsScalar*>(&src)); \
1831 EIGEN_UNUSED_VARIABLE(f##iter1); \
1833 f##iter1 = lhs.template load<PLhsPacket, Unaligned>(i + ((iter2) * ResPacketSize), j); \
1834 GEMV_BUILDPAIR_MMA(a##iter1, vec_splat(convertReal(f##iter1), 0), vec_splat(convertReal(f##iter1), 1)); \
1838 #define GEMV_LOAD1_COL_COMPLEX_MMA(iter, N) \
1839 if (GEMV_GETN_COMPLEX(N) > iter) { \
1840 if (GEMV_IS_COMPLEX_FLOAT) { \
1841 f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1842 EIGEN_UNUSED_VARIABLE(a##iter); \
1844 GEMV_LOADPAIR_COL_COMPLEX_MMA(iter, iter << 1) \
1847 EIGEN_UNUSED_VARIABLE(a##iter); \
1848 EIGEN_UNUSED_VARIABLE(f##iter); \
1851 #define GEMV_WORK1_COL_COMPLEX_MMA(iter, N) \
1852 if (GEMV_GETN_COMPLEX(N) > iter) { \
1853 if (GEMV_IS_COMPLEX_FLOAT) { \
1854 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1855 ConjugateLhs, ConjugateRhs, ColMajor>(f##iter, b, &e0##iter); \
1857 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1858 ConjugateLhs, ConjugateRhs, ColMajor>(a##iter, b, &e0##iter); \
1862 #define GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter1, iter2) \
1863 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1));
1865 #define GEMV_LOAD2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1866 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1867 if (GEMV_IS_COMPLEX_FLOAT) { \
1868 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2); \
1869 EIGEN_UNUSED_VARIABLE(a##iter3) \
1871 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2 << 1); \
1872 GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter3, iter3 << 1); \
1875 EIGEN_UNUSED_VARIABLE(a##iter2); \
1876 EIGEN_UNUSED_VARIABLE(a##iter3); \
1878 EIGEN_UNUSED_VARIABLE(f##iter2); \
1879 EIGEN_UNUSED_VARIABLE(f##iter3);
1881 #define GEMV_WORK2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1882 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1883 if (GEMV_IS_COMPLEX_FLOAT) { \
1885 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(g), &a##iter2); \
1886 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1887 ConjugateLhs, ConjugateRhs, ColMajor>(g[0], b, &e0##iter2); \
1888 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1889 ConjugateLhs, ConjugateRhs, ColMajor>(g[1], b, &e0##iter3); \
1891 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1892 ConjugateLhs, ConjugateRhs, ColMajor>(a##iter2, b, &e0##iter2); \
1893 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1894 ConjugateLhs, ConjugateRhs, ColMajor>(a##iter3, b, &e0##iter3); \
1899 #define GEMV_LOAD_COL_COMPLEX_MMA(N) \
1900 if (GEMV_GETN_COMPLEX(N) > 1) { \
1901 GEMV_UNROLL_HALF(GEMV_LOAD2_COL_COMPLEX_MMA, (N >> 1)) \
1903 GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N) \
1906 #define GEMV_WORK_COL_COMPLEX_MMA(N) \
1907 if (GEMV_GETN_COMPLEX(N) > 1) { \
1908 GEMV_UNROLL_HALF(GEMV_WORK2_COL_COMPLEX_MMA, (N >> 1)) \
1910 GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N) \
1913 #define GEMV_LOAD_COL_COMPLEX_MMA(N) GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N)
1915 #define GEMV_WORK_COL_COMPLEX_MMA(N) GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N)
1918 #define GEMV_DISASSEMBLE_COMPLEX_MMA(iter) \
1919 disassembleResults<Scalar, ScalarPacket, ResPacketSize, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>( \
1920 &e0##iter, result0##iter);
1922 #define GEMV_STORE_COL_COMPLEX_MMA(iter, N) \
1923 if (GEMV_GETN_COMPLEX(N) > iter) { \
1924 GEMV_DISASSEMBLE_COMPLEX_MMA(iter); \
1925 c0##iter = PResPacket(result0##iter.packet[0]); \
1926 if (GEMV_IS_COMPLEX_FLOAT) { \
1927 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1928 c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
1930 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1931 c0##iter, alpha_data, res + i + ((iter << 1) * ResPacketSize)); \
1932 c0##iter = PResPacket(result0##iter.packet[2]); \
1933 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1934 c0##iter, alpha_data, res + i + (((iter << 1) + 1) * ResPacketSize)); \
1938 #define GEMV_STORE2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1939 if (GEMV_GETN_COMPLEX(N) > iter1) { \
1940 GEMV_DISASSEMBLE_COMPLEX_MMA(iter2); \
1941 GEMV_DISASSEMBLE_COMPLEX_MMA(iter3); \
1942 c0##iter2 = PResPacket(result0##iter2.packet[0]); \
1943 if (GEMV_IS_COMPLEX_FLOAT) { \
1944 c0##iter3 = PResPacket(result0##iter3.packet[0]); \
1945 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2>( \
1946 c0##iter2, c0##iter3, alpha_data, res + i); \
1948 c0##iter3 = PResPacket(result0##iter2.packet[2]); \
1949 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2 << 1>( \
1950 c0##iter2, c0##iter3, alpha_data, res + i); \
1951 c0##iter2 = PResPacket(result0##iter3.packet[0]); \
1952 c0##iter3 = PResPacket(result0##iter3.packet[2]); \
1953 pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter3 << 1>( \
1954 c0##iter2, c0##iter3, alpha_data, res + i); \
1958 #define GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
1959 GEMV_UNROLL(GEMV_INIT_COL_COMPLEX_MMA, N) \
1962 const RhsScalar& b1 = rhs2(j, 0); \
1963 RhsScalar* b = const_cast<RhsScalar*>(&b1); \
1964 GEMV_UNROLL(GEMV_PREFETCH, N) \
1965 GEMV_LOAD_COL_COMPLEX_MMA(N) \
1966 GEMV_WORK_COL_COMPLEX_MMA(N) \
1967 } while (++j < jend); \
1968 if (GEMV_GETN(N) <= 2) { \
1969 GEMV_UNROLL(GEMV_STORE_COL_COMPLEX_MMA, N) \
1971 GEMV_UNROLL_HALF(GEMV_STORE2_COL_COMPLEX_MMA, (N >> 1)) \
1973 i += (ResPacketSize * N);
1976 #define GEMV_INIT_COMPLEX(iter, N) \
1978 c0##iter = pset_zero<PResPacket>(); \
1979 c1##iter = pset_init<ResPacket, LhsPacket, RhsPacket>(c1##iter); \
1981 EIGEN_UNUSED_VARIABLE(c0##iter); \
1982 EIGEN_UNUSED_VARIABLE(c1##iter); \
1985 #define GEMV_WORK_COL_COMPLEX(iter, N) \
1987 f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1988 gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1989 ConjugateRhs, ColMajor>(f##iter, b, c0##iter, c1##iter); \
1991 EIGEN_UNUSED_VARIABLE(f##iter); \
1994 #define GEMV_STORE_COL_COMPLEX(iter, N) \
1996 if (GEMV_IS_COMPLEX_COMPLEX) { \
1997 c0##iter = padd(c0##iter, c1##iter); \
1999 pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
2000 c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
2004 #define GEMV_PROCESS_COL_COMPLEX_ONE(N) \
2005 GEMV_UNROLL(GEMV_INIT_COMPLEX, N) \
2008 const RhsScalar& b1 = rhs2(j, 0); \
2009 RhsScalar* b = const_cast<RhsScalar*>(&b1); \
2010 GEMV_UNROLL(GEMV_PREFETCH, N) \
2011 GEMV_UNROLL(GEMV_WORK_COL_COMPLEX, N) \
2012 } while (++j < jend); \
2013 GEMV_UNROLL(GEMV_STORE_COL_COMPLEX, N) \
2014 i += (ResPacketSize * N);
2016 #if defined(USE_GEMV_MMA) && (EIGEN_COMP_LLVM || defined(USE_SLOWER_GEMV_MMA))
2017 #define USE_GEMV_COL_COMPLEX_MMA
2020 #ifdef USE_GEMV_COL_COMPLEX_MMA
2021 #define GEMV_PROCESS_COL_COMPLEX(N) GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N)
2023 #if defined(USE_GEMV_MMA) && (__GNUC__ > 10)
2024 #define GEMV_PROCESS_COL_COMPLEX(N) \
2025 if (sizeof(Scalar) != sizeof(LhsPacket)) { \
2026 GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
2028 GEMV_PROCESS_COL_COMPLEX_ONE(N) \
2031 #define GEMV_PROCESS_COL_COMPLEX(N) GEMV_PROCESS_COL_COMPLEX_ONE(N)
2035 template <
typename Scalar,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
bool LhsIsReal,
2036 typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
bool RhsIsReal,
typename ResScalar>
2039 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2041 typedef typename Traits::LhsPacket LhsPacket;
2042 typedef typename Traits::RhsPacket RhsPacket;
2043 typedef typename Traits::ResPacket ResPacket;
2048 typedef gemv_traits<ResPacket, ResPacket> PTraits;
2055 LhsMapper lhs(alhs);
2056 RhsMapper rhs2(rhs);
2058 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2060 const Index lhsStride = lhs.stride();
2064 ResPacketSize = PTraits::ResPacketSize,
2065 LhsPacketSize = PTraits::LhsPacketSize,
2066 RhsPacketSize = PTraits::RhsPacketSize,
2068 #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
2069 const Index prefetch_dist = 64 * LhsPacketSize;
2072 #ifndef GCC_ONE_VECTORPAIR_BUG
2073 const Index n8 =
rows - 8 * ResPacketSize + 1;
2074 const Index n4 =
rows - 4 * ResPacketSize + 1;
2075 const Index n2 =
rows - 2 * ResPacketSize + 1;
2077 const Index n1 =
rows - 1 * ResPacketSize + 1;
2080 const Index block_cols =
cols < 128 ?
cols : (lhsStride *
sizeof(LhsScalar) < 16000 ? 16 : 8);
2083 AlphaData alpha_data(
alpha);
2085 for (
Index j2 = 0; j2 <
cols; j2 += block_cols) {
2088 PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2089 ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2090 PLhsPacket f0,
f1,
f2, f3, f4, f5, f6, f7;
2092 __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2093 __vector_pair a0, a1, a2, a3, a4, a5, a6, a7;
2094 PacketBlock<ScalarPacket, 4> result00, result01, result02, result03, result04, result05, result06, result07;
2096 GEMV_UNUSED(8, result0)
2099 #if !defined(GCC_ONE_VECTORPAIR_BUG) && defined(USE_GEMV_COL_COMPLEX_MMA)
2103 #ifndef GCC_ONE_VECTORPAIR_BUG
2126 d0 += cj.pmul(lhs(
i,
j), rhs2(
j, 0));
2127 }
while (++
j < jend);
2133 template <
typename Scalar,
int N>
2139 static Packet16uc p16uc_ELEMENT_3 = {0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f,
2140 0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f};
2143 template <
typename ResScalar,
typename ResPacket>
2145 PacketBlock<ResPacket, 4> result0, result1;
2146 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2147 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2148 result0.packet[0] = vec_mergeh(result0.packet[0], result1.packet[0]);
2149 result0.packet[1] = vec_mergeo(result0.packet[1], result1.packet[1]);
2150 result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
2151 result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
2153 vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
2159 PacketBlock<Packet2d, 4> result0, result1;
2160 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2161 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2163 vec_add(vec_mergeh(result0.packet[0], result1.packet[0]), vec_mergel(result0.packet[1], result1.packet[1]));
2168 template <
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
2170 PacketBlock<Packet4f, 4>& result1) {
2172 result0.packet[0] =
reinterpret_cast<Packet4f>(
2173 vec_mergeh(
reinterpret_cast<Packet2d>(result0.packet[0]),
reinterpret_cast<Packet2d>(result1.packet[0])));
2174 result0.packet[2] =
reinterpret_cast<Packet4f>(
2175 vec_mergel(
reinterpret_cast<Packet2d>(result0.packet[2]),
reinterpret_cast<Packet2d>(result1.packet[2])));
2176 result0.packet[0] = vec_add(result0.packet[0], result0.packet[2]);
2178 result0.packet[1] =
reinterpret_cast<Packet4f>(
2179 vec_mergeh(
reinterpret_cast<Packet2d>(result0.packet[1]),
reinterpret_cast<Packet2d>(result1.packet[1])));
2180 result0.packet[3] =
reinterpret_cast<Packet4f>(
2181 vec_mergel(
reinterpret_cast<Packet2d>(result0.packet[3]),
reinterpret_cast<Packet2d>(result1.packet[3])));
2182 result0.packet[1] = vec_add(result0.packet[1], result0.packet[3]);
2184 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
2185 result0.packet[1] =
pcplxflip2(convertComplex(result0.packet[1])).v;
2186 }
else if (ConjugateRhs) {
2187 result0.packet[1] =
pcplxconjflip(convertComplex(result0.packet[1])).v;
2189 result0.packet[1] =
pcplxflipconj(convertComplex(result0.packet[1])).v;
2191 result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
2193 if (ConjugateLhs && (
sizeof(LhsPacket) ==
sizeof(std::complex<float>))) {
2194 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
2197 cc0.
scalar[0].real(result0.packet[0][0]);
2198 cc0.
scalar[0].imag(result0.packet[0][1]);
2199 cc0.
scalar[1].real(result0.packet[0][2]);
2200 cc0.
scalar[1].imag(result0.packet[0][3]);
2204 template <
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs>
2206 PacketBlock<Packet2d, 4>&) {
2213 template <
typename ResScalar,
typename ResPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
2216 PacketBlock<ResPacket, 4> result0, result1;
2217 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2218 __builtin_mma_disassemble_acc(&result1.packet, acc1);
2219 return addComplexResults<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(result0, result1);
2222 template <
typename ResScalar,
typename ResPacket>
2224 PacketBlock<ResPacket, 4> result0;
2225 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2227 vec_add(vec_mergeh(result0.packet[0], result0.packet[2]), vec_mergel(result0.packet[1], result0.packet[3]));
2231 template <
typename ResScalar,
typename ResPacket,
typename LhsPacket,
typename RhsPacket,
bool ConjugateLhs,
2235 PacketBlock<ResPacket, 4> result0;
2236 __builtin_mma_disassemble_acc(&result0.packet, acc0);
2239 result0.packet[1] =
pconjinv(convertComplex(result0.packet[1])).v;
2240 result0.packet[3] =
pconjinv(convertComplex(result0.packet[3])).v;
2241 }
else if (ConjugateRhs) {
2242 result0.packet[0] =
pconj2(convertComplex(result0.packet[0])).v;
2243 result0.packet[2] =
pconj2(convertComplex(result0.packet[2])).v;
2245 result0.packet[1] =
pconj2(convertComplex(result0.packet[1])).v;
2246 result0.packet[3] =
pconj2(convertComplex(result0.packet[3])).v;
2248 result0.packet[0] = vec_add(result0.packet[0], __builtin_vsx_xxpermdi(result0.packet[1], result0.packet[1], 2));
2249 result0.packet[2] = vec_add(result0.packet[2], __builtin_vsx_xxpermdi(result0.packet[3], result0.packet[3], 2));
2251 result0.packet[0] = __builtin_vsx_xxpermdi(result0.packet[0], result0.packet[1], 1);
2252 result0.packet[2] = __builtin_vsx_xxpermdi(result0.packet[2], result0.packet[3], 1);
2254 cc0.
scalar[0].real(result0.packet[0][0]);
2255 cc0.
scalar[0].imag(result0.packet[0][1]);
2256 cc0.
scalar[1].real(result0.packet[2][0]);
2257 cc0.
scalar[1].imag(result0.packet[2][1]);
2262 template <
typename ResScalar,
typename ResPacket>
2270 template <
typename ResScalar,
typename ResPacket>
2272 return predux_real<ResScalar, ResPacket>(
a,
b);
2275 #define GEMV_UNROLL_ROW(func, N) func(0, N) func(1, N) func(2, N) func(3, N) func(4, N) func(5, N) func(6, N) func(7, N)
2277 #define GEMV_UNROLL_ROW_HALF(func, N) func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
2279 #define GEMV_LOADPACKET_ROW(iter) lhs.template load<LhsPacket, Unaligned>(i + (iter), j)
2282 #define GEMV_UNROLL3_ROW(func, N, which) \
2283 func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) func(4, N, which) func(5, N, which) \
2284 func(6, N, which) func(7, N, which)
2286 #define GEMV_UNUSED_ROW(N, which) GEMV_UNROLL3_ROW(GEMV_UNUSED_VAR, N, which)
2288 #define GEMV_INIT_ROW(iter, N) \
2289 if (GEMV_GETN(N) > iter) { \
2290 __builtin_mma_xxsetaccz(&c##iter); \
2293 #define GEMV_LOADPAIR_ROW(iter1, iter2) \
2294 GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_ROW(iter2), GEMV_LOADPACKET_ROW((iter2) + 1));
2296 #define GEMV_WORK_ROW(iter, N) \
2297 if (GEMV_GETN(N) > iter) { \
2298 if (GEMV_IS_FLOAT) { \
2299 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, a0, GEMV_LOADPACKET_ROW(iter)); \
2301 __vector_pair b##iter; \
2302 GEMV_LOADPAIR_ROW(iter, iter << 1) \
2303 pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, b##iter, a0); \
2307 #define GEMV_PREDUX2(iter1, iter2, iter3, N) \
2309 if (GEMV_IS_FLOAT) { \
2310 cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter2, &c##iter3); \
2312 cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter1); \
2315 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2318 #define GEMV_INIT_ROW(iter, N) \
2320 c##iter = pset1<ResPacket>(ResScalar(0)); \
2322 EIGEN_UNUSED_VARIABLE(c##iter); \
2325 #define GEMV_WORK_ROW(iter, N) \
2327 c##iter = pcj.pmadd(GEMV_LOADPACKET_ROW(iter), a0, c##iter); \
2330 #define GEMV_PREDUX2(iter1, iter2, iter3, N) \
2332 cc##iter1 = predux_real<ResScalar, ResPacket>(c##iter2, c##iter3); \
2334 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2338 #define GEMV_MULT(iter1, iter2, iter3, N) \
2340 cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), a0); \
2341 cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), a0); \
2344 #define GEMV_STORE_ROW(iter1, iter2, iter3, N) \
2346 storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2347 storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2351 #define GEMV_PROCESS_ROW(N) \
2352 for (; i < n##N; i += N) { \
2353 GEMV_UNROLL_ROW(GEMV_INIT_ROW, N) \
2355 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2356 RhsPacket a0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2357 GEMV_UNROLL_ROW(GEMV_WORK_ROW, N) \
2359 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX2, (N >> 1)) \
2360 for (; j < cols; ++j) { \
2361 RhsScalar a0 = rhs2(j); \
2362 GEMV_UNROLL_ROW_HALF(GEMV_MULT, (N >> 1)) \
2364 GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \
2367 template <
typename LhsScalar,
typename LhsMapper,
typename RhsScalar,
typename RhsMapper,
typename ResScalar>
2370 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2372 typedef typename Traits::LhsPacket LhsPacket;
2373 typedef typename Traits::RhsPacket RhsPacket;
2374 typedef typename Traits::ResPacket ResPacket;
2378 LhsMapper lhs(alhs);
2379 typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2382 conj_helper<LhsScalar, RhsScalar, false, false> cj;
2383 conj_helper<LhsPacket, RhsPacket, false, false> pcj;
2387 #ifndef GCC_ONE_VECTORPAIR_BUG
2388 const Index n8 = lhs.stride() *
sizeof(LhsScalar) > 32000 ? (
rows - 7) : (
rows - 7);
2396 ResPacketSize = Traits::ResPacketSize,
2397 LhsPacketSize = Traits::LhsPacketSize,
2398 RhsPacketSize = Traits::RhsPacketSize,
2403 __vector_quad c0, c1, c2, c3, c4, c5, c6, c7;
2404 GEMV_UNUSED_ROW(8,
c)
2406 ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
2408 #ifndef GCC_ONE_VECTORPAIR_BUG
2415 ResPacket d0 = pset1<ResPacket>(ResScalar(0));
2417 for (;
j + LhsPacketSize <=
cols;
j += LhsPacketSize) {
2418 RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(
j);
2420 d0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 0,
j), b0, d0);
2422 ResScalar dd0 =
predux(d0);
2424 dd0 += cj.pmul(lhs(
i,
j), rhs2(
j));
2430 #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar) \
2431 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2432 struct general_matrix_vector_product<Index, Scalar, LhsMapper, ColMajor, ConjugateLhs, Scalar, RhsMapper, \
2433 ConjugateRhs, Version> { \
2434 typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2436 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2437 const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2438 ResScalar alpha) { \
2439 gemv_col<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2443 #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar) \
2444 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2445 struct general_matrix_vector_product<Index, Scalar, LhsMapper, RowMajor, ConjugateLhs, Scalar, RhsMapper, \
2446 ConjugateRhs, Version> { \
2447 typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2449 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2450 const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2451 ResScalar alpha) { \
2452 gemv_row<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2462 #define gemv_bf16_col gemvMMA_bfloat16_col
2463 #define gemv_bf16_row gemvMMA_bfloat16_row
2465 #define gemv_bf16_col gemv_bfloat16_col
2466 #define gemv_bf16_row gemv_bfloat16_row
2469 #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \
2470 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2471 struct general_matrix_vector_product<Index, bfloat16, LhsMapper, ColMajor, ConjugateLhs, bfloat16, RhsMapper, \
2472 ConjugateRhs, Version> { \
2473 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2474 const RhsMapper& rhs, bfloat16* res, Index resIncr, \
2476 gemv_bf16_col<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2480 #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() \
2481 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2482 struct general_matrix_vector_product<Index, bfloat16, LhsMapper, RowMajor, ConjugateLhs, bfloat16, RhsMapper, \
2483 ConjugateRhs, Version> { \
2484 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2485 const RhsMapper& rhs, bfloat16* res, Index resIncr, \
2487 gemv_bf16_row<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2494 template <typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
2501 return predux_complex<ResScalar, PResPacket>(a0, b0);
2504 #define GEMV_LOADPACKET_ROW_COMPLEX(iter) loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + (iter), j)
2506 #define GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter) convertReal(GEMV_LOADPACKET_ROW_COMPLEX(iter))
2508 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(which, N) \
2510 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2511 const RhsScalar& b1 = rhs2(j); \
2512 RhsScalar* b = const_cast<RhsScalar*>(&b1); \
2513 GEMV_UNROLL_ROW(which, N) \
2516 #define GEMV_PROCESS_END_ROW_COMPLEX(N) \
2517 for (; j < cols; ++j) { \
2518 RhsScalar b0 = rhs2(j); \
2519 GEMV_UNROLL_ROW_HALF(GEMV_MULT_COMPLEX, (N >> 1)) \
2521 GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW_COMPLEX, (N >> 1))
2524 #define GEMV_INIT_ROW_COMPLEX_MMA(iter, N) \
2525 if (GEMV_GETN_COMPLEX(N) > iter) { \
2526 __builtin_mma_xxsetaccz(&e0##iter); \
2529 #define GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter1, iter2) \
2530 GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter2), GEMV_LOADPACKET_ROW_COMPLEX_DATA((iter2) + 1));
2532 #define GEMV_WORK_ROW_COMPLEX_MMA(iter, N) \
2533 if (GEMV_GETN_COMPLEX(N) > iter) { \
2534 if (GEMV_IS_COMPLEX_FLOAT) { \
2535 PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2536 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
2537 ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2539 __vector_pair a##iter; \
2540 GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter, iter << 1) \
2541 gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
2542 ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2546 #define GEMV_PREDUX4_COMPLEX_MMA(iter1, iter2, iter3, N) \
2548 if (GEMV_IS_COMPLEX_FLOAT) { \
2549 cc##iter1 = predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>( \
2550 &e0##iter2, &e0##iter3); \
2553 predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter1); \
2556 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2559 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2560 GEMV_UNROLL_ROW(GEMV_INIT_ROW_COMPLEX_MMA, N) \
2561 GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX_MMA, N)
2563 #define GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N) \
2564 for (; i < n##N; i += N) { \
2565 GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2566 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_MMA, (N >> 1)) \
2567 GEMV_PROCESS_END_ROW_COMPLEX(N); \
2571 #define GEMV_WORK_ROW_COMPLEX(iter, N) \
2573 PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2574 gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
2575 ConjugateRhs, RowMajor>(a##iter, b, c0##iter, c1##iter); \
2578 #define GEMV_PREDUX4_COMPLEX(iter1, iter2, iter3, N) \
2580 cc##iter1 = predux_complex<ResScalar, PResPacket, ResPacket, LhsPacket, RhsPacket>(c0##iter2, c0##iter3, \
2581 c1##iter2, c1##iter3); \
2583 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2586 #define GEMV_MULT_COMPLEX(iter1, iter2, iter3, N) \
2588 cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), b0); \
2589 cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), b0); \
2592 #define GEMV_STORE_ROW_COMPLEX(iter1, iter2, iter3, N) \
2594 storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2595 storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2598 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2599 GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX, N) \
2600 GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX, N)
2604 #define GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2605 for (; i < n##N; i += N) { \
2606 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2607 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX, (N >> 1)) \
2608 GEMV_PROCESS_END_ROW_COMPLEX(N); \
2611 #define GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2612 if (GEMV_IS_COMPLEX_COMPLEX) { \
2613 c0##iter = padd(c0##iter, c1##iter); \
2615 dd0 = predux(c0##iter);
2618 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N)
2620 #define GEMV_PROCESS_ROW_COMPLEX_ONE(N) GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N)
2622 #define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter)
2627 #define GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter) lhs.template load<LhsPacket, LhsAlignment>(i + (iter), j)
2629 #define GEMV_INIT_COMPLEX_OLD(iter, N) \
2630 EIGEN_UNUSED_VARIABLE(c0##iter); \
2632 c1##iter = pset_zero<ResPacket>(); \
2634 EIGEN_UNUSED_VARIABLE(c1##iter); \
2637 #define GEMV_WORK_ROW_COMPLEX_OLD(iter, N) \
2639 LhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter); \
2640 c1##iter = pcj.pmadd(a##iter, b0, c1##iter); \
2643 #define GEMV_PREDUX4_COMPLEX_OLD(iter1, iter2, iter3, N) \
2645 cc##iter1.scalar[0] = predux(c1##iter2); \
2646 cc##iter1.scalar[1] = predux(c1##iter3); \
2648 EIGEN_UNUSED_VARIABLE(cc##iter1); \
2651 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2652 GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX_OLD, N) \
2654 for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2655 RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2656 GEMV_UNROLL_ROW(GEMV_WORK_ROW_COMPLEX_OLD, N) \
2659 #define GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2660 for (; i < n##N; i += N) { \
2661 GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2662 GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_OLD, (N >> 1)) \
2663 GEMV_PROCESS_END_ROW_COMPLEX(N) \
2666 #define GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) dd0 = predux(c1##iter);
2669 #define GEMV_PROCESS_ROW_COMPLEX_IS_NEW 1
2671 #define GEMV_PROCESS_ROW_COMPLEX_IS_NEW (sizeof(Scalar) == sizeof(float)) || GEMV_IS_COMPLEX_COMPLEX
2674 #define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) \
2675 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2676 GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2678 GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2681 #define GEMV_PROCESS_ROW_COMPLEX_ONE(N) \
2682 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2683 GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2685 GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2688 #define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) \
2689 if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2690 GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2692 GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) \
2697 #define GEMV_PROCESS_ROW_COMPLEX(N) GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N)
2699 #define GEMV_PROCESS_ROW_COMPLEX(N) GEMV_PROCESS_ROW_COMPLEX_ONE(N)
2702 template <
typename Scalar,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
bool LhsIsReal,
2703 typename RhsScalar,
typename RhsMapper,
bool ConjugateRhs,
bool RhsIsReal,
typename ResScalar>
2706 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2708 typedef typename Traits::LhsPacket LhsPacket;
2709 typedef typename Traits::RhsPacket RhsPacket;
2710 typedef typename Traits::ResPacket ResPacket;
2715 typedef gemv_traits<ResPacket, ResPacket> PTraits;
2719 LhsMapper lhs(alhs);
2720 typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2723 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2724 #if !EIGEN_COMP_LLVM
2725 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
2730 #ifndef GCC_ONE_VECTORPAIR_BUG
2731 const Index n8 = lhs.stride() *
sizeof(LhsScalar) > 32000 ? (
rows - 7) : (
rows - 7);
2739 ResPacketSize = PTraits::ResPacketSize,
2740 LhsPacketSize = PTraits::LhsPacketSize,
2741 RhsPacketSize = PTraits::RhsPacketSize,
2745 PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2746 ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2748 __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2749 GEMV_UNUSED_ROW(8, e0)
2750 GEMV_UNUSED_EXTRA(1, c0)
2751 GEMV_UNUSED_EXTRA(1, c1)
2754 #ifndef GCC_ONE_VECTORPAIR_BUG
2769 dd0 += cj.pmul(lhs(
i,
j), rhs2(
j));
2775 #define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar) \
2776 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2777 struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, \
2778 ConjugateRhs, Version> { \
2779 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2781 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2782 const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2783 ResScalar alpha) { \
2784 gemv_complex_col<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, \
2785 RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, \
2786 res, resIncr, alpha); \
2790 #define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar) \
2791 template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2792 struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, \
2793 ConjugateRhs, Version> { \
2794 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2796 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2797 const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2798 ResScalar alpha) { \
2799 gemv_complex_row<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, \
2800 RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, \
2801 res, resIncr, alpha); \
#define LOAD_STORE_UNROLL_16
Definition: AltiVec/PacketMath.h:160
#define __UNPACK_TYPE__(PACKETNAME)
Definition: AltiVec/PacketMath.h:70
AnnoyingScalar conj(const AnnoyingScalar &x)
Definition: AnnoyingScalar.h:133
Array< int, Dynamic, 1 > v
Definition: Array_initializer_list_vector_cxx11.cpp:1
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
Array< double, 1, 3 > e(1./3., 0.5, 2.)
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define eigen_internal_assert(x)
Definition: Macros.h:916
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
int data[]
Definition: Map_placement_new.cpp:1
static Packet16uc p16uc_MERGE16_32_V1
Definition: MatrixVectorProduct.h:467
EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_real(RealPacket &a, RealPacket &b, RealPacket &c)
Definition: MatrixVectorProduct.h:1430
EIGEN_ALWAYS_INLINE LhsPacket loadLhsPacket(LhsMapper &lhs, Index i, Index j)
Definition: MatrixVectorProduct.h:1408
EIGEN_ALWAYS_INLINE Packet2cf pcplxflipnegate(Packet2cf a)
Definition: MatrixVectorProduct.h:1089
const Packet16uc p16uc_COMPLEX64_XORFLIP
Definition: MatrixVectorProduct.h:974
EIGEN_ALWAYS_INLINE Packet4f pload_real(float *src)
Definition: MatrixVectorProduct.h:1238
EIGEN_ALWAYS_INLINE Packet2cf pcplxconjflip(Packet2cf a)
Definition: MatrixVectorProduct.h:1055
#define MAX_BFLOAT16_VEC_ACC_VSX
Definition: MatrixVectorProduct.h:553
EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtra(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:903
const Packet16uc p16uc_COMPLEX32_CONJ_XOR
Definition: MatrixVectorProduct.h:991
EIGEN_ALWAYS_INLINE Packet2cf pnegate2(Packet2cf a)
Definition: MatrixVectorProduct.h:1072
EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_complex(RealPacket &a, RealPacket &b, RealPacket &c)
Definition: MatrixVectorProduct.h:1418
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16()
Definition: MatrixVectorProduct.h:2469
EIGEN_ALWAYS_INLINE Packet2cf padd(Packet2cf &a, std::complex< float > &b)
Definition: MatrixVectorProduct.h:1277
EIGEN_ALWAYS_INLINE void gemv_mult_complex_complex(LhsPacket &a0, RhsScalar *b, PResPacket &c0, ResPacket &c1)
Definition: MatrixVectorProduct.h:1454
EIGEN_ALWAYS_INLINE Scalar pset1_realimag(ResScalar &alpha, int which, int conj)
Definition: MatrixVectorProduct.h:1289
EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16 *dst, Packet8bf data, Index resInc, Index extra)
Definition: MatrixVectorProduct.h:644
EIGEN_ALWAYS_INLINE Packet2cf pcplxflip2(Packet2cf a)
Definition: MatrixVectorProduct.h:1106
EIGEN_ALWAYS_INLINE Packet pset_init(Packet &c1)
Definition: MatrixVectorProduct.h:1330
const Packet16uc p16uc_MERGEE
Definition: MatrixVectorProduct.h:1156
EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f(&acc)[num_acc][size], float *result, Packet4f pAlpha, Index extra_rows)
Definition: MatrixVectorProduct.h:456
#define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter)
Definition: MatrixVectorProduct.h:2688
EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtraN(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:895
#define GEMV_PROCESS_ROW_COMPLEX(N)
Definition: MatrixVectorProduct.h:2699
#define GEMV_PROCESS_COL_COMPLEX(N)
Definition: MatrixVectorProduct.h:2031
EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f(&acc)[num_acc][2], Packet4f(&a0)[num_acc][2], Packet4f(&b0)[2])
Definition: MatrixVectorProduct.h:490
EIGEN_ALWAYS_INLINE Packet4f pload_complex_half(std::complex< float > *src)
Definition: MatrixVectorProduct.h:1119
#define GEMV_PROCESS_ROW(N)
Definition: MatrixVectorProduct.h:2351
EIGEN_ALWAYS_INLINE ScalarPacket pmadd_complex(ScalarPacket &c0, ScalarPacket &c2, ScalarPacket &c4, AlphaData &b0)
Definition: MatrixVectorProduct.h:1353
#define GEMV_IS_COMPLEX_FLOAT
Definition: MatrixVectorProduct.h:67
EIGEN_ALWAYS_INLINE void multVSXVecLoop(Packet4f(&acc)[num_acc][2], const LhsMapper &lhs, RhsMapper &rhs, Index j, Index extra_cols)
Definition: MatrixVectorProduct.h:824
const Packet16uc p16uc_COMPLEX64_CONJ_XOR2
Definition: MatrixVectorProduct.h:997
EIGEN_ALWAYS_INLINE Packet4f pload_real_full(float *src)
Definition: MatrixVectorProduct.h:1247
EIGEN_ALWAYS_INLINE void addResultsVSX(Packet4f(&acc)[num_acc][2])
Definition: MatrixVectorProduct.h:546
EIGEN_ALWAYS_INLINE Packet2cf pconj2(const Packet2cf &a)
Definition: MatrixVectorProduct.h:1012
#define GEMV_IS_SCALAR
Definition: MatrixVectorProduct.h:66
EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine_row(std::complex< float > *src)
Definition: MatrixVectorProduct.h:1195
EIGEN_STRONG_INLINE void gemv_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, ResScalar *res, Index resIncr, ResScalar alpha)
Definition: MatrixVectorProduct.h:2368
#define GEMV_PROCESS_COL(N)
Definition: MatrixVectorProduct.h:337
const Packet16uc p16uc_COMPLEX32_CONJ_XOR2
Definition: MatrixVectorProduct.h:995
EIGEN_ALWAYS_INLINE Packet2cf pconjinv(const Packet2cf &a)
Definition: MatrixVectorProduct.h:1021
EIGEN_ALWAYS_INLINE Packet4f pload_complex_full(std::complex< float > *src)
Definition: MatrixVectorProduct.h:1226
static Packet16uc p16uc_MERGE16_32_V2
Definition: MatrixVectorProduct.h:468
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar)
Definition: MatrixVectorProduct.h:2443
EIGEN_ALWAYS_INLINE Packet pset_zero()
Definition: MatrixVectorProduct.h:1314
EIGEN_ALWAYS_INLINE void loadVecLoopVSX(Index k, LhsMapper &lhs, Packet4f(&a0)[num_acc][2])
Definition: MatrixVectorProduct.h:471
EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:628
void colVSXVecColLoopBody(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:556
EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper &lhs, RhsMapper &rhs, Packet4f(&acc)[num_acc][2])
Definition: MatrixVectorProduct.h:526
EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket &c0, AlphaData &b0, ResScalar *res)
Definition: MatrixVectorProduct.h:1360
EIGEN_ALWAYS_INLINE Packet1cd pset_zero< Packet1cd >()
Definition: MatrixVectorProduct.h:1324
EIGEN_ALWAYS_INLINE void calcVSXVecLoops(Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:931
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper &rhs, Index j)
Definition: MatrixVectorProduct.h:521
EIGEN_ALWAYS_INLINE void pload_realimag_row(RhsScalar *src, Packet4f &r, Packet4f &i)
Definition: MatrixVectorProduct.h:1165
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f(&acc)[num_acc][size], float *result, Packet4f pAlpha)
Definition: MatrixVectorProduct.h:760
EIGEN_ALWAYS_INLINE void pload_realimag(RhsScalar *src, Packet4f &r, Packet4f &i)
Definition: MatrixVectorProduct.h:1132
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index rows, bfloat16 *dst, Index resInc=1)
Definition: MatrixVectorProduct.h:689
EIGEN_ALWAYS_INLINE Packet4f pload_complex(std::complex< float > *src)
Definition: MatrixVectorProduct.h:1201
EIGEN_ALWAYS_INLINE void preduxVecResults2VSX(Packet4f(&acc)[num_acc][2], Index k)
Definition: MatrixVectorProduct.h:780
EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine(std::complex< float > *src)
Definition: MatrixVectorProduct.h:1182
EIGEN_ALWAYS_INLINE Packet2cf pset_zero< Packet2cf >()
Definition: MatrixVectorProduct.h:1319
void colVSXVecLoopBody(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:871
#define GEMV_PROCESS_COL_COMPLEX_ONE(N)
Definition: MatrixVectorProduct.h:2004
EIGEN_ALWAYS_INLINE Packet2cf pcplxflipconj(Packet2cf a)
Definition: MatrixVectorProduct.h:1038
#define GEMV_PROCESS_COL_ONE(N)
Definition: MatrixVectorProduct.h:323
EIGEN_ALWAYS_INLINE void gemv_mult_complex_real(LhsPacket &a0, RhsScalar *b, PResPacket &c0)
Definition: MatrixVectorProduct.h:1486
EIGEN_STRONG_INLINE void gemv_complex_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, ResScalar *res, Index resIncr, ResScalar alpha)
Definition: MatrixVectorProduct.h:2037
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:595
EIGEN_ALWAYS_INLINE void preduxVecResultsVSX(Packet4f(&acc)[num_acc][2])
Definition: MatrixVectorProduct.h:797
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index &i, float *result, Index rows, bfloat16 *&dst, Index resInc=1)
Definition: MatrixVectorProduct.h:661
EIGEN_ALWAYS_INLINE void vecVSXLoop(Index cols, const LhsMapper &lhs, RhsMapper &rhs, Packet4f(&acc)[num_acc][2], Index extra_cols)
Definition: MatrixVectorProduct.h:858
#define GEMV_MULT_COMPLEX_REAL(LhsType, RhsType, ResType1, ResType2)
Definition: MatrixVectorProduct.h:1522
EIGEN_ALWAYS_INLINE void gemv_mult_generic(LhsPacket &a0, RhsScalar *b, PResPacket &c0)
Definition: MatrixVectorProduct.h:1440
EIGEN_STRONG_INLINE void gemv_bfloat16_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
Definition: MatrixVectorProduct.h:942
EIGEN_ALWAYS_INLINE Packet4f pload_real_row(float *src)
Definition: MatrixVectorProduct.h:1264
#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar)
Definition: MatrixVectorProduct.h:2790
EIGEN_STRONG_INLINE void gemv_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, ResScalar *res, Index resIncr, ResScalar alpha)
Definition: MatrixVectorProduct.h:363
const Packet16uc p16uc_COMPLEX32_NEGATE
Definition: MatrixVectorProduct.h:999
EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar *res, ResPacket &palpha, ResPacket &data)
Definition: MatrixVectorProduct.h:71
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar)
Definition: MatrixVectorProduct.h:2430
EIGEN_ALWAYS_INLINE Packet2cf pset1_complex(std::complex< float > &alpha)
Definition: MatrixVectorProduct.h:1295
EIGEN_ALWAYS_INLINE Packet4f pload_complex_full_row(std::complex< float > *src)
Definition: MatrixVectorProduct.h:1233
const Packet16uc p16uc_COMPLEX64_CONJ_XOR
Definition: MatrixVectorProduct.h:993
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:586
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16()
Definition: MatrixVectorProduct.h:2480
EIGEN_ALWAYS_INLINE ScalarBlock< ResScalar, 2 > predux_complex(ResPacket &a, ResPacket &b)
Definition: MatrixVectorProduct.h:2271
EIGEN_ALWAYS_INLINE Packet8us loadPacketPartialZero(Packet8us data, Index extra_cols)
Definition: MatrixVectorProduct.h:813
const Packet16uc p16uc_MERGEO
Definition: MatrixVectorProduct.h:1159
EIGEN_ALWAYS_INLINE ScalarBlock< ResScalar, 2 > predux_real(ResPacket &a, ResPacket &b)
Definition: MatrixVectorProduct.h:2263
const Packet16uc p16uc_COMPLEX64_NEGATE
Definition: MatrixVectorProduct.h:1001
#define GEMV_BUILDPAIR_MMA(dst, src1, src2)
Definition: MatrixVectorProduct.h:55
const Packet16uc p16uc_COMPLEX32_XORFLIP
Definition: MatrixVectorProduct.h:972
#define GEMV_MULT_REAL_COMPLEX(LhsType, RhsType, ResType)
Definition: MatrixVectorProduct.h:1509
void gemv_bfloat16_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
Definition: MatrixVectorProduct.h:726
#define GEMV_IS_COMPLEX_COMPLEX
Definition: MatrixVectorProduct.h:64
#define COMPLEX_DELTA
Definition: MatrixVectorProduct.h:1008
EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows)
Definition: MatrixVectorProduct.h:445
EIGEN_STRONG_INLINE void gemv_complex_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, ResScalar *res, Index resIncr, ResScalar alpha)
Definition: MatrixVectorProduct.h:2704
#define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar)
Definition: MatrixVectorProduct.h:2775
EIGEN_ALWAYS_INLINE void gemv_mult_real_complex(LhsPacket &a0, RhsScalar *b, PResPacket &c0)
Definition: MatrixVectorProduct.h:1472
#define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N)
Definition: MatrixVectorProduct.h:2674
#define GEMV_MULT_COMPLEX_COMPLEX(LhsType, RhsType, ResType)
Definition: MatrixVectorProduct.h:1498
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:806
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
@ N
Definition: constructor.cpp:22
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
Definition: cxx11_tensor_map.cpp:237
@ Unaligned
Definition: Constants.h:235
@ ColMajor
Definition: Constants.h:318
RealScalar * palpha
Definition: level1_cplx_impl.h:147
Eigen::DenseIndex ret
Definition: level1_cplx_impl.h:43
RealScalar alpha
Definition: level1_cplx_impl.h:151
const Scalar * a
Definition: level2_cplx_impl.h:32
char char char int int * k
Definition: level2_impl.h:374
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h)
Definition: BFloat16.h:581
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data)
Definition: MatrixProduct.h:2680
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask)
Definition: MatrixProduct.h:2732
__m128d Packet2d
Definition: LSX/PacketMath.h:36
__vector unsigned char Packet16uc
Definition: AltiVec/PacketMath.h:41
EIGEN_STRONG_INLINE Packet16uc pset1< Packet16uc >(const unsigned char &from)
Definition: AltiVec/PacketMath.h:798
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data)
Definition: MatrixProduct.h:2671
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
Definition: MatrixProduct.h:2813
EIGEN_STRONG_INLINE Packet2d ploaddup< Packet2d >(const double *from)
Definition: LSX/PacketMath.h:1490
__vector unsigned short int Packet8us
Definition: AltiVec/PacketMath.h:38
EIGEN_STRONG_INLINE Packet2d pset1< Packet2d >(const double &from)
Definition: LSX/PacketMath.h:503
EIGEN_STRONG_INLINE Packet1cd ploadu< Packet1cd >(const std::complex< double > *from)
Definition: LSX/Complex.h:373
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218
EIGEN_STRONG_INLINE Packet2cf pcplxflip(const Packet2cf &x)
Definition: LSX/Complex.h:218
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Definition: AltiVec/PacketMath.h:42
EIGEN_DEVICE_FUNC void pstoreu_partial(Scalar *to, const Packet &from, const Index n, const Index offset=0)
Definition: GenericPacketMath.h:917
EIGEN_STRONG_INLINE Packet4f pset1< Packet4f >(const float &from)
Definition: AltiVec/PacketMath.h:773
EIGEN_STRONG_INLINE Packet2cf ploadu< Packet2cf >(const std::complex< float > *from)
Definition: AltiVec/Complex.h:148
EIGEN_DEVICE_FUNC void pscatter(Scalar *to, const Packet &from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
EIGEN_STRONG_INLINE Packet2d ploadu< Packet2d >(const double *from)
Definition: LSX/PacketMath.h:1448
EIGEN_DEVICE_FUNC unpacket_traits< Packet >::type predux(const Packet &a)
Definition: GenericPacketMath.h:1232
EIGEN_STRONG_INLINE Packet8bf pload< Packet8bf >(const bfloat16 *from)
Definition: AltiVec/PacketMath.h:522
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:911
EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h &a, const Packet8h &b)
Definition: AVX/PacketMath.h:2315
EIGEN_STRONG_INLINE Packet4f ploadu< Packet4f >(const float *from)
Definition: AltiVec/PacketMath.h:1533
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float *res)
Definition: MatrixProduct.h:3066
EIGEN_DEVICE_FUNC void pscatter_partial(Scalar *to, const Packet &from, Index stride, const Index n)
Definition: GenericPacketMath.h:956
__vector float Packet4f
Definition: AltiVec/PacketMath.h:33
eigen_packet_wrapper< __m128i, 7 > Packet2ul
Definition: LSX/PacketMath.h:45
static Packet16uc p16uc_TRANSPOSE64_HI
Definition: AltiVec/PacketMath.h:143
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
double f2(const Vector< double > &coord)
f2 function, in front of the C2 unknown
Definition: poisson/poisson_with_singularity/two_d_poisson.cc:233
double f1(const Vector< double > &coord)
f1 function, in front of the C1 unknown
Definition: poisson/poisson_with_singularity/two_d_poisson.cc:147
int delta
Definition: MultiOpt.py:96
int c
Definition: calibrate.py:100
type
Definition: compute_granudrum_aor.py:141
t
Definition: plotPSD.py:36
Packet2d v
Definition: LSX/Complex.h:263
Packet4f v
Definition: AltiVec/Complex.h:78
Scalar type
Definition: GenericPacketMath.h:109
Definition: MatrixVectorProduct.h:2134
Scalar scalar[N]
Definition: MatrixVectorProduct.h:2135
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper &lhs, RhsMapper &rhs, Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:712
Definition: MatrixVectorProduct.h:698
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper &lhs, RhsMapper &rhs, Packet4f pAlpha, float *result)
Definition: MatrixVectorProduct.h:699
Definition: MatrixVectorProduct.h:1345
PResPacket i
Definition: MatrixVectorProduct.h:1347
PResPacket r
Definition: MatrixVectorProduct.h:1346
Definition: MatrixVectorProduct.h:1340
struct alpha_store::ri separate
alpha_store(ResScalar &alpha)
Definition: MatrixVectorProduct.h:1341
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper &rhs, Index j)
Definition: MatrixVectorProduct.h:515
Definition: MatrixVectorProduct.h:499
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper &rhs, Index j)
Definition: MatrixVectorProduct.h:501
EIGEN_DONT_INLINE Scalar zero()
Definition: svd_common.h:232
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2