11 #ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
15 #if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
16 #pragma GCC push_options
17 #pragma GCC target("cpu=power10,htm")
21 #if !__has_builtin(__builtin_vsx_assemble_pair)
22 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
24 #if !__has_builtin(__builtin_vsx_disassemble_pair)
25 #define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
30 #include "../../InternalHeaderCheck.h"
38 #define accColsC (accCols / 2)
42 template <
typename DataMapper,
typename Packet,
bool full>
46 __builtin_mma_disassemble_acc(&result.
packet, acc);
51 bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes,
data,
i, 0);
52 bscale<Packet, 4>(tRes, result,
alpha);
53 bstore<DataMapper, Packet, 4>(tRes,
data,
i);
55 bload_partial<DataMapper, Packet, 0, false, 4>(tRes,
data,
i, elements);
56 bscale<Packet, 4>(tRes, result,
alpha);
57 bstore_partial<DataMapper, Packet, 4>(tRes,
data,
i, elements);
61 template <
typename DataMapper,
typename Packet,
typename Packetc, const Index accCols, const Index accCols2>
63 const Packet& alphaImag,
const Packet& pMask, __vector_quad* accReal,
64 __vector_quad* accImag) {
65 constexpr
bool full = (accCols2 >
accColsC);
67 __builtin_mma_disassemble_acc(&resultReal.
packet, accReal);
68 __builtin_mma_disassemble_acc(&resultImag.
packet, accImag);
71 bload<DataMapper, Packetc, accColsC, ColMajor, true, 4, full>(tRes,
data,
i, 0);
74 bscalec<Packet, 4, (accCols != accCols2)>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask);
77 bcouple<Packet, Packetc, 4, full>(taccReal, taccImag, tRes, acc1, acc2);
79 bstore<DataMapper, Packetc, 4>(acc1,
data,
i);
86 template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
88 if (NegativeAccumulate) {
89 __builtin_mma_xvf32gernp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
91 __builtin_mma_xvf32gerpp(acc, (__vector
unsigned char)
a, (__vector
unsigned char)
b);
95 template <
typename LhsPacket,
typename RhsPacket,
bool NegativeAccumulate>
97 if (NegativeAccumulate) {
98 __builtin_mma_xvf64gernp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
100 __builtin_mma_xvf64gerpp(acc, (__vector_pair)
a, (__vector
unsigned char)
b);
104 template <
typename Packet,
typename RhsPacket,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
106 const RhsPacket& rhsV, RhsPacket& rhsVi) {
107 pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
109 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
113 pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
114 pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
118 pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
123 template <
typename Packet>
125 return ploadu<Packet>(rhs);
128 template <
typename Scalar,
typename Packet>
130 rhsV = ploadRhs<Packet>(rhs);
136 __builtin_vsx_assemble_pair(
137 &rhsV,
reinterpret_cast<__vector
unsigned char>(ploadRhs<Packet2d>(rhs + (
sizeof(
Packet2d) /
sizeof(
double)))),
138 reinterpret_cast<__vector
unsigned char>(ploadRhs<Packet2d>(rhs)));
140 rhsV = *
reinterpret_cast<__vector_pair*
>(
const_cast<double*
>(rhs));
146 #define GEMM_MULTIPLE_COLS
151 #define VECTOR_PAIR_LOADS_LHS
155 #ifdef GEMM_MULTIPLE_COLS
159 #if EIGEN_COMP_LLVM || (__GNUC__ < 12) || defined(VECTOR_PAIR_LOADS_LHS)
166 #define MICRO_MMA_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
168 #define MICRO_MMA_WORK(func, type, peel) \
170 func(0, type, peel, 0, 0) func(1, type, peel, 1, 0) func(2, type, peel, 2, 0) func(3, type, peel, 3, 0) \
171 func(4, type, peel, 4, 0) func(5, type, peel, 5, 0) func(6, type, peel, 6, 0) func(7, type, peel, 7, 0) \
172 } else if (accItr == 2) { \
173 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 1, 0) func(3, type, peel, 1, 1) \
174 func(4, type, peel, 2, 0) func(5, type, peel, 2, 1) func(6, type, peel, 3, 0) func(7, type, peel, 3, 1) \
176 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 0, 2) func(3, type, peel, 0, 3) \
177 func(4, type, peel, 1, 0) func(5, type, peel, 1, 1) func(6, type, peel, 1, 2) func(7, type, peel, 1, 3) \
180 #define MICRO_MMA_WORK_ONE(iter, type, peel, left, right) \
181 if (unroll_factor > left) { \
182 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV##left); \
185 #ifdef VECTOR_PAIR_LOADS_LHS
186 #define MICRO_MMA_WORK_TWO(iter, type, peel, left, right) \
187 if (unroll_factor > left) { \
188 pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV2##left.packet[peel & 1]); \
191 #define MICRO_MMA_LOAD1_TWO(lhs_ptr, left) \
192 if (unroll_factor > left) { \
193 if (MICRO_NORMAL(left)) { \
194 ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##left), plhsV##left); \
195 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##left.packet), &plhsV##left); \
196 lhs_ptr##left += accCols * 2; \
198 lhsV2##left.packet[0] = ploadLhs<Packet>(lhs_ptr##left); \
199 lhsV2##left.packet[1] = ploadLhs<Packet>(lhs_ptr##left + accCols2); \
200 lhs_ptr##left += accCols2 * 2; \
201 EIGEN_UNUSED_VARIABLE(plhsV##left); \
204 EIGEN_UNUSED_VARIABLE(lhsV2##left); \
205 EIGEN_UNUSED_VARIABLE(plhsV##left); \
208 #define MICRO_MMA_LOAD_TWO(left) MICRO_MMA_LOAD1_TWO(lhs_ptr, left)
211 #define MICRO_MMA_UNROLL_ITER(func, val) \
212 func(val, 0) if (accItr > 1) { \
213 func(val, 1) if (accItr > 2) { func(val, 2) func(val, 3) } \
216 #define MICRO_MMA_LOAD_ONE_RHS1(peel, right) ploadRhsMMA(rhs_ptr##right + (accRows * peel), rhsV##right[peel]);
218 #define MICRO_MMA_LOAD_ONE_RHS(peel) MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_ONE_RHS1, peel)
220 #define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
221 if (PEEL_MMA > peel) { \
222 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
223 MICRO_MMA_LOAD_ONE_RHS(peel) \
224 MICRO_MMA_UNROLL(funcl) \
225 MICRO_MMA_WORK(funcw, type, peel) \
228 #ifndef VECTOR_PAIR_LOADS_LHS
229 #define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
230 type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
231 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 0) \
232 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 1) \
233 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 2) \
234 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 3) \
235 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 4) \
236 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 5) \
237 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 6) MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 7)
239 #define MICRO_MMA_LOAD_TWO_RHS(peel1, right) \
240 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr##right + (accRows * peel1)), prhsV##peel1); \
241 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1);
243 #define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
244 if (PEEL_MMA > peel2) { \
245 PacketBlock<Packet, 2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
246 __vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
247 if (sizeof(type) == 16) { \
248 MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_TWO_RHS, peel1) \
250 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
251 MICRO_MMA_LOAD_ONE_RHS(peel1) \
252 MICRO_MMA_LOAD_ONE_RHS(peel2) \
254 MICRO_MMA_UNROLL(funcl2) \
255 MICRO_MMA_WORK(funcw2, type, peel1) \
256 MICRO_MMA_WORK(funcw2, type, peel2) \
258 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
259 MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
262 #define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
263 type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
264 __vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \
265 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 0, 1) \
266 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 2, 3) \
267 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 4, 5) \
268 MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 6, 7)
271 #define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
272 type rhsV0[1], rhsV1[1], rhsV2[1], rhsV3[1]; \
273 MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 0)
275 #define MICRO_MMA_UPDATE_RHS1(size, right) rhs_ptr##right += (accRows * size);
277 #define MICRO_MMA_UPDATE_RHS(size) MICRO_MMA_UNROLL_ITER(MICRO_MMA_UPDATE_RHS1, size)
279 #define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
280 MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \
281 MICRO_MMA_UPDATE_RHS(size)
283 #ifndef VECTOR_PAIR_LOADS_LHS
284 #define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_PEEL, PEEL_MMA)
286 #define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \
287 MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \
288 MICRO_MMA_UPDATE_RHS(size)
290 #define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA)
293 #define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1)
295 #define MICRO_MMA_DST_PTR_ONE(iter) \
296 if (unroll_factor * accItr > iter) { \
297 bsetzeroMMA(&accZero##iter); \
299 EIGEN_UNUSED_VARIABLE(accZero##iter); \
302 #define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
304 #define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_SRC_PTR_ONE)
306 #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE)
308 #define MICRO_MMA_STORE_ONE(iter, left, right) \
309 if (unroll_factor > left) { \
310 storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(left)>(row + left * accCols, res##right, pAlpha, \
311 accCols2, &accZero##iter); \
314 #define MICRO_MMA_ITER_UNROLL(func) \
316 func(0, 0, 0) func(1, 1, 0) func(2, 2, 0) func(3, 3, 0) func(4, 4, 0) func(5, 5, 0) func(6, 6, 0) func(7, 7, 0) \
317 } else if (accItr == 2) { \
318 func(0, 0, 0) func(1, 0, 1) func(2, 1, 0) func(3, 1, 1) func(4, 2, 0) func(5, 2, 1) func(6, 3, 0) func(7, 3, 1) \
320 func(0, 0, 0) func(1, 0, 1) func(2, 0, 2) func(3, 0, 3) func(4, 1, 0) func(5, 1, 1) func(6, 1, 2) func(7, 1, 3) \
323 #define MICRO_MMA_STORE MICRO_MMA_ITER_UNROLL(MICRO_MMA_STORE_ONE)
325 #define MICRO_MMA_EXTRA_ROWS(right) \
326 gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>( \
327 res3##right, blockA, rhs_base + right * accRows * strideB, depth, strideA, offsetA, strideB, row, rows, \
328 remaining_rows, pAlpha, pMask);
330 #define MICRO_MMA_EXTRA_ROWS1(val, right) MICRO_MMA_EXTRA_ROWS(right);
332 template <
int unroll_factor,
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
333 const Index accRows,
const Index accCols,
bool full,
const Index accItr>
335 const DataMapper& res2,
const DataMapper& res3,
339 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL, *rhs_ptr3 = NULL;
340 const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
341 *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
342 __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
345 rhs_ptr1 = rhs_base + (accRows * strideB);
352 rhs_ptr2 = rhs_base + (2 * accRows * strideB);
353 rhs_ptr3 = rhs_base + (3 * accRows * strideB);
370 for (;
k < depth;
k++) {
378 #define MICRO_MMA_UNROLL_ITER2(N, M) \
379 gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M, accItr>( \
380 res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, strideB, offsetA, row, pAlpha, \
381 M ? remaining_rows : accCols); \
384 #define MICRO_MMA_ROWS(n) \
385 while (row + n * accCols <= rows) { \
386 MICRO_MMA_UNROLL_ITER2(n, 0); \
389 template <
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
const Index accRows,
394 const DataMapper res30 =
res.getSubMapper(0,
col);
395 const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows * 1) : res30;
396 const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows * 2) : res30;
397 const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows * 3) : res30;
399 const Scalar* rhs_base = blockB +
col * strideB + accRows * offsetB;
400 const Scalar* lhs_base = blockA + accCols * offsetA;
403 #define MAX_MMA_UNROLL 7
405 #if MAX_MMA_UNROLL < 2
407 #elif MAX_MMA_UNROLL < 4
413 }
else if (accItr == 2) {
418 switch ((
rows -
row) / accCols) {
419 #if MAX_MMA_UNROLL > 7
426 #if MAX_MMA_UNROLL > 6
433 #if MAX_MMA_UNROLL > 5
440 #if MAX_MMA_UNROLL > 4
447 #if MAX_MMA_UNROLL > 3
454 #if MAX_MMA_UNROLL > 2
461 #if MAX_MMA_UNROLL > 1
469 #undef MAX_MMA_UNROLL
471 if (remaining_rows > 0) {
476 #define MICRO_MMA_COLS(n) \
477 for (; col + n * accRows <= cols; col += n * accRows) { \
478 gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols, n>( \
479 res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); \
482 template <
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
const Index accRows,
486 const Index remaining_rows =
rows % accCols;
488 if (strideA == -1) strideA = depth;
489 if (strideB == -1) strideB = depth;
492 const Packet pMask = bmask<Packet>(remaining_rows);
494 typedef typename std::conditional_t<(
sizeof(
Scalar) ==
sizeof(
float)), RhsPacket, __vector_pair> RhsPacket2;
497 #ifdef GEMM_MULTIPLE_COLS
504 gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
509 #define advanceRows ((LhsIsReal) ? 1 : 2)
510 #define advanceCols ((RhsIsReal) ? 1 : 2)
513 #ifdef GEMM_MULTIPLE_COLS
514 #define PEEL_COMPLEX_MMA 4
516 #define PEEL_COMPLEX_MMA 3
519 #define MICRO_COMPLEX_MMA_UNROLL(func) func(0) func(1) func(2) func(3)
521 #define MICRO_COMPLEX_MMA_WORK(func, type, peel) \
523 func(0, type, peel, 0, 0) func(1, type, peel, 1, 0) func(2, type, peel, 2, 0) func(3, type, peel, 3, 0) \
524 } else if (accItr == 2) { \
525 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 1, 0) func(3, type, peel, 1, 1) \
527 func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 0, 2) func(3, type, peel, 0, 3) \
530 #define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel, left, right) \
531 if (unroll_factor > left) { \
532 pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
533 &accReal##iter, &accImag##iter, lhsV##left, lhsVi##left, rhsV##right[peel], rhsVi##right[peel]); \
536 #ifdef VECTOR_PAIR_LOADS_LHS
537 #define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel, left, right) \
538 if (unroll_factor > left) { \
539 pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
540 &accReal##iter, &accImag##iter, lhsV2##left.packet[peel & 1], lhsVi2##left.packet[peel & 1], \
541 rhsV##right[peel], rhsVi##right[peel]); \
544 #define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left) \
545 if (!LhsIsReal && (unroll_factor > left)) { \
546 if (MICRO_NORMAL(left)) { \
547 ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##left + imag_delta), plhsVi##left); \
548 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##left.packet), &plhsVi##left); \
550 lhsVi2##left.packet[0] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2); \
551 lhsVi2##left.packet[1] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2 + accCols2); \
552 EIGEN_UNUSED_VARIABLE(plhsVi##left); \
555 EIGEN_UNUSED_VARIABLE(lhsVi2##left); \
556 EIGEN_UNUSED_VARIABLE(plhsVi##left); \
558 MICRO_MMA_LOAD1_TWO(lhs_ptr_real, left)
560 #define MICRO_COMPLEX_MMA_LOAD_TWO(left) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left)
563 #define MICRO_COMPLEX_MMA_LOAD_RHS1(peel, right) \
564 ploadRhsMMA(rhs_ptr_real##right + (accRows * peel), rhsV##right[peel]); \
566 ploadRhsMMA(rhs_ptr_imag##right + (accRows * peel), rhsVi##right[peel]); \
569 #define MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_RHS1, peel)
571 #define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
572 if (PEEL_COMPLEX_MMA > peel) { \
573 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
574 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
575 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \
576 MICRO_COMPLEX_MMA_UNROLL(funcl) \
577 MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
580 #ifndef VECTOR_PAIR_LOADS_LHS
581 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
582 type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], \
583 rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
584 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 0) \
585 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 1) \
586 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 3)
588 #define MICRO_COMPLEX_MMA_LOAD_TWO_RHS(peel1, right) \
589 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real##right + (accRows * peel1)), prhsV##peel1); \
590 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1); \
592 ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag##right + (accRows * peel1)), prhsVi##peel1); \
593 __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi##right[peel1]), &prhsVi##peel1); \
595 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
598 #define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
599 if (PEEL_COMPLEX_MMA > peel2) { \
600 PacketBlock<Packet, 2> lhsV20, lhsV21, lhsV22, lhsV23; \
601 PacketBlock<Packet, 2> lhsVi20, lhsVi21, lhsVi22, lhsVi23; \
602 __vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
603 __vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
604 if (sizeof(type) == 16) { \
605 MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_TWO_RHS, peel1) \
607 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
608 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
609 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel1); \
610 MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel2); \
612 MICRO_COMPLEX_MMA_UNROLL(funcl2) \
613 MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
614 MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \
616 EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
617 EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
618 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
621 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
622 type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], \
623 rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
624 __vector_pair prhsV0, prhsV2; \
625 __vector_pair prhsVi0, prhsVi2; \
626 MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 0, 1) \
627 MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 2, 3)
630 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
631 type rhsV0[1], rhsVi0[1], rhsV1[1], rhsVi1[1], rhsV2[1], rhsVi2[1], rhsV3[1], rhsVi3[1]; \
632 MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 0)
634 #define MICRO_COMPLEX_MMA_UPDATE_RHS1(size, right) \
635 rhs_ptr_real##right += (accRows * size); \
636 if (!RhsIsReal) rhs_ptr_imag##right += (accRows * size);
638 #define MICRO_COMPLEX_MMA_UPDATE_RHS(size) MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_UPDATE_RHS1, size)
640 #define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
641 MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \
642 MICRO_COMPLEX_MMA_UPDATE_RHS(size);
644 #ifndef VECTOR_PAIR_LOADS_LHS
645 #define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL, PEEL_COMPLEX_MMA)
647 #define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \
648 MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, \
649 MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \
650 MICRO_COMPLEX_MMA_UPDATE_RHS(size);
652 #define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA)
655 #define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1)
657 #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
658 if (unroll_factor * accItr > iter) { \
659 bsetzeroMMA(&accReal##iter); \
660 bsetzeroMMA(&accImag##iter); \
662 EIGEN_UNUSED_VARIABLE(accReal##iter); \
663 EIGEN_UNUSED_VARIABLE(accImag##iter); \
666 #define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
668 #define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
670 #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
672 #define MICRO_COMPLEX_MMA_STORE_ONE(iter, left, right) \
673 if (unroll_factor > left) { \
674 storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (left + 1)) ? accCols : accCols2>( \
675 row + left * accCols, res##right, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
678 #define MICRO_COMPLEX_MMA_ITER_UNROLL(func) \
680 func(0, 0, 0) func(1, 1, 0) func(2, 2, 0) func(3, 3, 0) \
681 } else if (accItr == 2) { \
682 func(0, 0, 0) func(1, 0, 1) func(2, 1, 0) func(3, 1, 1) \
684 func(0, 0, 0) func(1, 0, 1) func(2, 0, 2) func(3, 0, 3) \
687 #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_ITER_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
689 #define MICRO_COMPLEX_MMA_EXTRA_ROWS(right) \
690 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
691 RhsIsReal>(res3##right, blockA, rhs_base + right * accRows * (RhsIsReal ? 1 : 2) * strideB, \
692 depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, \
695 #define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right) MICRO_COMPLEX_MMA_EXTRA_ROWS(right);
697 template <
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
698 typename DataMapper,
const Index accRows,
const Index accCols,
const Index accCols2,
bool ConjugateLhs,
699 bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal,
const Index accItr>
701 const DataMapper& res2,
const DataMapper& res3,
706 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL, *rhs_ptr_real3 = NULL;
707 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL, *rhs_ptr_imag3 = NULL;
708 const Index imag_delta = accCols * strideA;
709 const Index imag_delta2 = accCols2 * strideA;
712 rhs_ptr_imag0 = rhs_base + accRows * strideB;
718 rhs_ptr_real1 = rhs_base + (2 * accRows * strideB);
719 rhs_ptr_imag1 = rhs_base + (3 * accRows * strideB);
721 rhs_ptr_real1 = rhs_base + accRows * strideB;
731 rhs_ptr_real2 = rhs_base + (4 * accRows * strideB);
732 rhs_ptr_imag2 = rhs_base + (5 * accRows * strideB);
733 rhs_ptr_real3 = rhs_base + (6 * accRows * strideB);
734 rhs_ptr_imag3 = rhs_base + (7 * accRows * strideB);
736 rhs_ptr_real2 = rhs_base + (2 * accRows * strideB);
737 rhs_ptr_real3 = rhs_base + (3 * accRows * strideB);
749 const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
750 const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
751 __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
765 for (;
k < depth;
k++) {
773 #define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
774 gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, \
775 accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, \
776 accItr>(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, offsetA, \
777 strideB, row, pAlphaReal, pAlphaImag, pMask); \
780 #define MICRO_COMPLEX_MMA_ROWS(n) \
781 while (row + n * accCols <= rows) { \
782 MICRO_COMPLEX_MMA_UNROLL_ITER2(n, 0); \
785 template <
typename Scalar,
typename Packet,
typename Packetc,
typename RhsPacket,
typename DataMapper,
786 const Index accRows,
const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
787 bool RhsIsReal,
const Index accItr>
792 const DataMapper res30 =
res.getSubMapper(0,
col);
793 const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows * 1) : res30;
794 const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows * 2) : res30;
795 const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows * 3) : res30;
798 const Scalar* lhs_base = blockA + accCols * offsetA;
801 #define MAX_COMPLEX_MMA_UNROLL 4
803 #if MAX_COMPLEX_MMA_UNROLL < 2
805 #elif MAX_COMPLEX_MMA_UNROLL < 4
811 }
else if (accItr == 2) {
816 switch ((
rows -
row) / accCols) {
817 #if MAX_COMPLEX_MMA_UNROLL > 3
824 #if MAX_COMPLEX_MMA_UNROLL > 2
831 #if MAX_COMPLEX_MMA_UNROLL > 1
841 #undef MAX_COMPLEX_MMA_UNROLL
843 if (remaining_rows > 0) {
848 #define MICRO_COMPLEX_MMA_COLS(n) \
849 for (; col + n * accRows <= cols; col += n * accRows) { \
850 gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, \
851 ConjugateRhs, LhsIsReal, RhsIsReal, n>(res, blockA, blockB, depth, strideA, offsetA, strideB, \
852 offsetB, col, rows, remaining_rows, pAlphaReal, \
853 pAlphaImag, pMask); \
856 template <
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Packet,
typename Packetc,
857 typename RhsPacket,
typename DataMapper,
const Index accRows,
const Index accCols,
bool ConjugateLhs,
858 bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
861 const Index remaining_rows =
rows % accCols;
863 if (strideA == -1) strideA = depth;
864 if (strideB == -1) strideB = depth;
866 const Packet pAlphaReal = pset1<Packet>(
alpha.real());
867 const Packet pAlphaImag = pset1<Packet>(
alpha.imag());
868 const Packet pMask = bmask<Packet>(remaining_rows);
873 typedef typename std::conditional_t<(
sizeof(
Scalar) ==
sizeof(
float)), RhsPacket, __vector_pair> RhsPacket2;
876 #ifdef GEMM_MULTIPLE_COLS
884 RhsIsReal>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows,
cols,
885 remaining_rows, pAlphaReal, pAlphaImag, pMask);
897 #if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
898 #pragma GCC pop_options
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
int data[]
Definition: Map_placement_new.cpp:1
#define MICRO_COMPLEX_UNROLL_ITER(func, N)
Definition: MatrixProductCommon.h:142
#define MICRO_UPDATE
Definition: MatrixProductCommon.h:191
#define MICRO_COMPLEX_UPDATE
Definition: MatrixProductCommon.h:198
#define EIGEN_POWER_PREFETCH(p)
Definition: MatrixProductCommon.h:5
#define MICRO_UNROLL_ITER(func, N)
Definition: MatrixProductCommon.h:139
#define MICRO_MMA_DST_PTR
Definition: MatrixProductMMA.h:302
#define advanceCols
Definition: MatrixProductMMA.h:510
#define MICRO_COMPLEX_MMA_ONE
Definition: MatrixProductMMA.h:655
#define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right)
Definition: MatrixProductMMA.h:695
#define MICRO_COMPLEX_MMA_ROWS(n)
Definition: MatrixProductMMA.h:780
#define accColsC
Definition: MatrixProductMMA.h:38
#define MICRO_MMA_STORE
Definition: MatrixProductMMA.h:323
#define MICRO_COMPLEX_MMA_DST_PTR
Definition: MatrixProductMMA.h:666
#define MICRO_COMPLEX_MMA_SRC_PTR
Definition: MatrixProductMMA.h:668
#define MICRO_MMA_PREFETCH
Definition: MatrixProductMMA.h:306
#define MICRO_COMPLEX_MMA_STORE
Definition: MatrixProductMMA.h:687
#define MICRO_MMA_ROWS(n)
Definition: MatrixProductMMA.h:384
#define MICRO_MMA_ONE
Definition: MatrixProductMMA.h:293
#define MAX_COMPLEX_MMA_UNROLL
#define MICRO_MMA_SRC_PTR
Definition: MatrixProductMMA.h:304
#define MICRO_COMPLEX_MMA_ONE_PEEL
Definition: MatrixProductMMA.h:645
#define MICRO_MMA_ONE_PEEL
Definition: MatrixProductMMA.h:284
#define PEEL_MMA
Definition: MatrixProductMMA.h:156
#define MICRO_COMPLEX_MMA_COLS(n)
Definition: MatrixProductMMA.h:848
#define MICRO_MMA_UNROLL_ITER2(N, M)
Definition: MatrixProductMMA.h:378
#define PEEL_COMPLEX_MMA
Definition: MatrixProductMMA.h:514
#define MICRO_MMA_UNROLL_ITER(func, val)
Definition: MatrixProductMMA.h:211
#define MICRO_MMA_COLS(n)
Definition: MatrixProductMMA.h:476
#define MICRO_MMA_EXTRA_ROWS1(val, right)
Definition: MatrixProductMMA.h:330
#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M)
Definition: MatrixProductMMA.h:773
#define MICRO_COMPLEX_MMA_PREFETCH
Definition: MatrixProductMMA.h:670
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
internal::packet_traits< Scalar >::type Packet
Definition: benchmark-blocking-sizes.cpp:54
RealScalar alpha
Definition: level1_cplx_impl.h:151
const Scalar * a
Definition: level2_cplx_impl.h:32
char char char int int * k
Definition: level2_impl.h:374
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(const DataMapper &res0, const DataMapper &res1, const DataMapper &res2, const DataMapper &res3, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index strideB, Index offsetA, Index &row, const Packet &pAlpha, Index accCols2)
Definition: MatrixProductMMA.h:334
__m128d Packet2d
Definition: LSX/PacketMath.h:36
EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad *acc)
Definition: MatrixProductMMA.h:40
EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper &data, const Packet &alphaReal, const Packet &alphaImag, const Packet &pMask, __vector_quad *accReal, __vector_quad *accImag)
Definition: MatrixProductMMA.h:62
EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad *accReal, __vector_quad *accImag, const Packet &lhsV, Packet &lhsVi, const RhsPacket &rhsV, RhsPacket &rhsVi)
Definition: MatrixProductMMA.h:105
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper &data, const Packet &alpha, const Index elements, __vector_quad *acc)
Definition: MatrixProductMMA.h:43
void gemmMMA(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProductMMA.h:484
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(const DataMapper &res0, const DataMapper &res1, const DataMapper &res2, const DataMapper &res3, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProductMMA.h:700
EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index cols, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2579
void gemm_complexMMA(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProductMMA.h:859
EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet &a)
Definition: AltiVec/PacketMath.h:1876
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProductMMA.h:788
EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double *lhs, __vector_pair &lhsV)
Definition: MatrixProductMMA.h:144
EIGEN_ALWAYS_INLINE void gemmMMA_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProductMMA.h:391
EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar *rhs, Packet &rhsV)
Definition: MatrixProductMMA.h:129
EIGEN_ALWAYS_INLINE Packet ploadRhs(const __UNPACK_TYPE__(Packet) *rhs)
Definition: MatrixProductMMA.h:124
EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad *acc, const RhsPacket &a, const LhsPacket &b)
Definition: MatrixProductMMA.h:87
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Definition: Eigen_Colamd.h:49
Definition: GenericPacketMath.h:1407
Packet packet[N]
Definition: GenericPacketMath.h:1408
Definition: ZVector/PacketMath.h:50