MatrixProductMMA.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5 // Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
13 
14 // If using dynamic dispatch, set the CPU target.
15 #if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
16 #pragma GCC push_options
17 #pragma GCC target("cpu=power10,htm")
18 #endif
19 
20 #ifdef __has_builtin
21 #if !__has_builtin(__builtin_vsx_assemble_pair)
22 #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
23 #endif
24 #if !__has_builtin(__builtin_vsx_disassemble_pair)
25 #define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
26 #endif
27 #endif
28 
29 // IWYU pragma: private
30 #include "../../InternalHeaderCheck.h"
31 
33 
34 namespace Eigen {
35 
36 namespace internal {
37 
38 #define accColsC (accCols / 2)
39 
40 EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc) { __builtin_mma_xxsetaccz(acc); }
41 
42 template <typename DataMapper, typename Packet, bool full>
43 EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Index elements,
44  __vector_quad* acc) {
46  __builtin_mma_disassemble_acc(&result.packet, acc);
47 
49  if (full) {
50  EIGEN_UNUSED_VARIABLE(elements);
51  bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes, data, i, 0);
52  bscale<Packet, 4>(tRes, result, alpha);
53  bstore<DataMapper, Packet, 4>(tRes, data, i);
54  } else {
55  bload_partial<DataMapper, Packet, 0, false, 4>(tRes, data, i, elements);
56  bscale<Packet, 4>(tRes, result, alpha);
57  bstore_partial<DataMapper, Packet, 4>(tRes, data, i, elements);
58  }
59 }
60 
61 template <typename DataMapper, typename Packet, typename Packetc, const Index accCols, const Index accCols2>
62 EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal,
63  const Packet& alphaImag, const Packet& pMask, __vector_quad* accReal,
64  __vector_quad* accImag) {
65  constexpr bool full = (accCols2 > accColsC);
66  PacketBlock<Packet, 4> resultReal, resultImag;
67  __builtin_mma_disassemble_acc(&resultReal.packet, accReal);
68  __builtin_mma_disassemble_acc(&resultImag.packet, accImag);
69 
71  bload<DataMapper, Packetc, accColsC, ColMajor, true, 4, full>(tRes, data, i, 0);
72 
73  PacketBlock<Packet, 4> taccReal, taccImag;
74  bscalec<Packet, 4, (accCols != accCols2)>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag, pMask);
75 
76  PacketBlock<Packetc, 4> acc1, acc2;
77  bcouple<Packet, Packetc, 4, full>(taccReal, taccImag, tRes, acc1, acc2);
78 
79  bstore<DataMapper, Packetc, 4>(acc1, data, i);
80  if (full) {
81  bstore<DataMapper, Packetc, 4>(acc2, data, i + accColsC);
82  }
83 }
84 
85 // Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
86 template <typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
87 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b) {
88  if (NegativeAccumulate) {
89  __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
90  } else {
91  __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
92  }
93 }
94 
95 template <typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
96 EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b) {
97  if (NegativeAccumulate) {
98  __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
99  } else {
100  __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
101  }
102 }
103 
104 template <typename Packet, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
105 EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, Packet& lhsVi,
106  const RhsPacket& rhsV, RhsPacket& rhsVi) {
107  pgerMMA<Packet, RhsPacket, false>(accReal, rhsV, lhsV);
108  if (LhsIsReal) {
109  pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
110  EIGEN_UNUSED_VARIABLE(lhsVi);
111  } else {
112  if (!RhsIsReal) {
113  pgerMMA<Packet, RhsPacket, ConjugateLhs == ConjugateRhs>(accReal, rhsVi, lhsVi);
114  pgerMMA<Packet, RhsPacket, ConjugateRhs>(accImag, rhsVi, lhsV);
115  } else {
116  EIGEN_UNUSED_VARIABLE(rhsVi);
117  }
118  pgerMMA<Packet, RhsPacket, ConjugateLhs>(accImag, rhsV, lhsVi);
119  }
120 }
121 
122 // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
123 template <typename Packet>
125  return ploadu<Packet>(rhs);
126 }
127 
128 template <typename Scalar, typename Packet>
129 EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV) {
130  rhsV = ploadRhs<Packet>(rhs);
131 }
132 
133 template <>
134 EIGEN_ALWAYS_INLINE void ploadRhsMMA(const double* rhs, __vector_pair& rhsV) {
135 #if EIGEN_COMP_LLVM
136  __builtin_vsx_assemble_pair(
137  &rhsV, reinterpret_cast<__vector unsigned char>(ploadRhs<Packet2d>(rhs + (sizeof(Packet2d) / sizeof(double)))),
138  reinterpret_cast<__vector unsigned char>(ploadRhs<Packet2d>(rhs)));
139 #else
140  rhsV = *reinterpret_cast<__vector_pair*>(const_cast<double*>(rhs));
141 #endif
142 }
143 
144 EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV) { ploadRhsMMA(lhs, lhsV); }
145 
146 #define GEMM_MULTIPLE_COLS
147 
148 // Disable in GCC until unnecessary register moves are fixed
149 // #if (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
150 #if EIGEN_COMP_LLVM
151 #define VECTOR_PAIR_LOADS_LHS
152 #endif
153 
154 // PEEL_MMA loop factor.
155 #ifdef GEMM_MULTIPLE_COLS
156 #define PEEL_MMA 8
157 #else
158 // Register spillage with GCC12+
159 #if EIGEN_COMP_LLVM || (__GNUC__ < 12) || defined(VECTOR_PAIR_LOADS_LHS)
160 #define PEEL_MMA 7
161 #else
162 #define PEEL_MMA 6
163 #endif
164 #endif
165 
166 #define MICRO_MMA_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
167 
168 #define MICRO_MMA_WORK(func, type, peel) \
169  if (accItr == 1) { \
170  func(0, type, peel, 0, 0) func(1, type, peel, 1, 0) func(2, type, peel, 2, 0) func(3, type, peel, 3, 0) \
171  func(4, type, peel, 4, 0) func(5, type, peel, 5, 0) func(6, type, peel, 6, 0) func(7, type, peel, 7, 0) \
172  } else if (accItr == 2) { \
173  func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 1, 0) func(3, type, peel, 1, 1) \
174  func(4, type, peel, 2, 0) func(5, type, peel, 2, 1) func(6, type, peel, 3, 0) func(7, type, peel, 3, 1) \
175  } else { \
176  func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 0, 2) func(3, type, peel, 0, 3) \
177  func(4, type, peel, 1, 0) func(5, type, peel, 1, 1) func(6, type, peel, 1, 2) func(7, type, peel, 1, 3) \
178  }
179 
180 #define MICRO_MMA_WORK_ONE(iter, type, peel, left, right) \
181  if (unroll_factor > left) { \
182  pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV##left); \
183  }
184 
185 #ifdef VECTOR_PAIR_LOADS_LHS
186 #define MICRO_MMA_WORK_TWO(iter, type, peel, left, right) \
187  if (unroll_factor > left) { \
188  pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV2##left.packet[peel & 1]); \
189  }
190 
191 #define MICRO_MMA_LOAD1_TWO(lhs_ptr, left) \
192  if (unroll_factor > left) { \
193  if (MICRO_NORMAL(left)) { \
194  ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##left), plhsV##left); \
195  __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##left.packet), &plhsV##left); \
196  lhs_ptr##left += accCols * 2; \
197  } else { \
198  lhsV2##left.packet[0] = ploadLhs<Packet>(lhs_ptr##left); \
199  lhsV2##left.packet[1] = ploadLhs<Packet>(lhs_ptr##left + accCols2); \
200  lhs_ptr##left += accCols2 * 2; \
201  EIGEN_UNUSED_VARIABLE(plhsV##left); \
202  } \
203  } else { \
204  EIGEN_UNUSED_VARIABLE(lhsV2##left); \
205  EIGEN_UNUSED_VARIABLE(plhsV##left); \
206  }
207 
208 #define MICRO_MMA_LOAD_TWO(left) MICRO_MMA_LOAD1_TWO(lhs_ptr, left)
209 #endif
210 
211 #define MICRO_MMA_UNROLL_ITER(func, val) \
212  func(val, 0) if (accItr > 1) { \
213  func(val, 1) if (accItr > 2) { func(val, 2) func(val, 3) } \
214  }
215 
216 #define MICRO_MMA_LOAD_ONE_RHS1(peel, right) ploadRhsMMA(rhs_ptr##right + (accRows * peel), rhsV##right[peel]);
217 
218 #define MICRO_MMA_LOAD_ONE_RHS(peel) MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_ONE_RHS1, peel)
219 
220 #define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
221  if (PEEL_MMA > peel) { \
222  Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
223  MICRO_MMA_LOAD_ONE_RHS(peel) \
224  MICRO_MMA_UNROLL(funcl) \
225  MICRO_MMA_WORK(funcw, type, peel) \
226  }
227 
228 #ifndef VECTOR_PAIR_LOADS_LHS
229 #define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
230  type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
231  MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 0) \
232  MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 1) \
233  MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 2) \
234  MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 3) \
235  MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 4) \
236  MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 5) \
237  MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 6) MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 7)
238 #else
239 #define MICRO_MMA_LOAD_TWO_RHS(peel1, right) \
240  ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr##right + (accRows * peel1)), prhsV##peel1); \
241  __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1);
242 
243 #define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
244  if (PEEL_MMA > peel2) { \
245  PacketBlock<Packet, 2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
246  __vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
247  if (sizeof(type) == 16) { \
248  MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_TWO_RHS, peel1) \
249  } else { \
250  EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
251  MICRO_MMA_LOAD_ONE_RHS(peel1) \
252  MICRO_MMA_LOAD_ONE_RHS(peel2) \
253  } \
254  MICRO_MMA_UNROLL(funcl2) \
255  MICRO_MMA_WORK(funcw2, type, peel1) \
256  MICRO_MMA_WORK(funcw2, type, peel2) \
257  } else { \
258  EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
259  MICRO_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
260  }
261 
262 #define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
263  type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
264  __vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \
265  MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 0, 1) \
266  MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 2, 3) \
267  MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 4, 5) \
268  MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 6, 7)
269 #endif
270 
271 #define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
272  type rhsV0[1], rhsV1[1], rhsV2[1], rhsV3[1]; \
273  MICRO_MMA_TYPE_PEEL(funcw, funcl, type, 0)
274 
275 #define MICRO_MMA_UPDATE_RHS1(size, right) rhs_ptr##right += (accRows * size);
276 
277 #define MICRO_MMA_UPDATE_RHS(size) MICRO_MMA_UNROLL_ITER(MICRO_MMA_UPDATE_RHS1, size)
278 
279 #define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
280  MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \
281  MICRO_MMA_UPDATE_RHS(size)
282 
283 #ifndef VECTOR_PAIR_LOADS_LHS
284 #define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_PEEL, PEEL_MMA)
285 #else
286 #define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \
287  MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \
288  MICRO_MMA_UPDATE_RHS(size)
289 
290 #define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA)
291 #endif
292 
293 #define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1)
294 
295 #define MICRO_MMA_DST_PTR_ONE(iter) \
296  if (unroll_factor * accItr > iter) { \
297  bsetzeroMMA(&accZero##iter); \
298  } else { \
299  EIGEN_UNUSED_VARIABLE(accZero##iter); \
300  }
301 
302 #define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE)
303 
304 #define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_SRC_PTR_ONE)
305 
306 #define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE)
307 
308 #define MICRO_MMA_STORE_ONE(iter, left, right) \
309  if (unroll_factor > left) { \
310  storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(left)>(row + left * accCols, res##right, pAlpha, \
311  accCols2, &accZero##iter); \
312  }
313 
314 #define MICRO_MMA_ITER_UNROLL(func) \
315  if (accItr == 1) { \
316  func(0, 0, 0) func(1, 1, 0) func(2, 2, 0) func(3, 3, 0) func(4, 4, 0) func(5, 5, 0) func(6, 6, 0) func(7, 7, 0) \
317  } else if (accItr == 2) { \
318  func(0, 0, 0) func(1, 0, 1) func(2, 1, 0) func(3, 1, 1) func(4, 2, 0) func(5, 2, 1) func(6, 3, 0) func(7, 3, 1) \
319  } else { \
320  func(0, 0, 0) func(1, 0, 1) func(2, 0, 2) func(3, 0, 3) func(4, 1, 0) func(5, 1, 1) func(6, 1, 2) func(7, 1, 3) \
321  }
322 
323 #define MICRO_MMA_STORE MICRO_MMA_ITER_UNROLL(MICRO_MMA_STORE_ONE)
324 
325 #define MICRO_MMA_EXTRA_ROWS(right) \
326  gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>( \
327  res3##right, blockA, rhs_base + right * accRows * strideB, depth, strideA, offsetA, strideB, row, rows, \
328  remaining_rows, pAlpha, pMask);
329 
330 #define MICRO_MMA_EXTRA_ROWS1(val, right) MICRO_MMA_EXTRA_ROWS(right);
331 
332 template <int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper,
333  const Index accRows, const Index accCols, bool full, const Index accItr>
334 EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(const DataMapper& res0, const DataMapper& res1,
335  const DataMapper& res2, const DataMapper& res3,
336  const Scalar* lhs_base, const Scalar* rhs_base, Index depth,
337  Index strideA, Index strideB, Index offsetA, Index& row,
338  const Packet& pAlpha, Index accCols2) {
339  const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL, *rhs_ptr3 = NULL;
340  const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
341  *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
342  __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
343 
344  if (accItr > 1) {
345  rhs_ptr1 = rhs_base + (accRows * strideB);
346  } else {
347  EIGEN_UNUSED_VARIABLE(strideB);
348  EIGEN_UNUSED_VARIABLE(rhs_ptr1);
349  EIGEN_UNUSED_VARIABLE(res1);
350  }
351  if (accItr > 2) {
352  rhs_ptr2 = rhs_base + (2 * accRows * strideB);
353  rhs_ptr3 = rhs_base + (3 * accRows * strideB);
354  } else {
355  EIGEN_UNUSED_VARIABLE(rhs_ptr2);
356  EIGEN_UNUSED_VARIABLE(rhs_ptr3);
357  EIGEN_UNUSED_VARIABLE(res2);
358  EIGEN_UNUSED_VARIABLE(res3);
359  }
360 
363 
364  Index k = 0, depth2 = depth - PEEL_MMA;
365  for (; k <= depth2; k += PEEL_MMA) {
366  EIGEN_POWER_PREFETCH(rhs_ptr);
369  }
370  for (; k < depth; k++) {
372  }
374 
376 }
377 
378 #define MICRO_MMA_UNROLL_ITER2(N, M) \
379  gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M, accItr>( \
380  res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, strideB, offsetA, row, pAlpha, \
381  M ? remaining_rows : accCols); \
382  if (M) return;
383 
384 #define MICRO_MMA_ROWS(n) \
385  while (row + n * accCols <= rows) { \
386  MICRO_MMA_UNROLL_ITER2(n, 0); \
387  }
388 
389 template <typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows,
390  const Index accCols, const Index accItr>
391 EIGEN_ALWAYS_INLINE void gemmMMA_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
392  Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows,
393  Index remaining_rows, const Packet& pAlpha, const Packet& pMask) {
394  const DataMapper res30 = res.getSubMapper(0, col);
395  const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows * 1) : res30;
396  const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows * 2) : res30;
397  const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows * 3) : res30;
398 
399  const Scalar* rhs_base = blockB + col * strideB + accRows * offsetB;
400  const Scalar* lhs_base = blockA + accCols * offsetA;
401  Index row = 0;
402 
403 #define MAX_MMA_UNROLL 7
404 
405 #if MAX_MMA_UNROLL < 2
406  if (1) {
407 #elif MAX_MMA_UNROLL < 4
408  if (accItr <= 2) {
409 #else
410  if (accItr == 1) {
411 #endif
413  } else if (accItr == 2) {
414  MICRO_MMA_ROWS(4);
415  } else {
416  MICRO_MMA_ROWS(2);
417  }
418  switch ((rows - row) / accCols) {
419 #if MAX_MMA_UNROLL > 7
420  case 7:
421  if (accItr == 1) {
423  }
424  break;
425 #endif
426 #if MAX_MMA_UNROLL > 6
427  case 6:
428  if (accItr == 1) {
430  }
431  break;
432 #endif
433 #if MAX_MMA_UNROLL > 5
434  case 5:
435  if (accItr == 1) {
437  }
438  break;
439 #endif
440 #if MAX_MMA_UNROLL > 4
441  case 4:
442  if (accItr == 1) {
444  }
445  break;
446 #endif
447 #if MAX_MMA_UNROLL > 3
448  case 3:
449  if (accItr <= 2) {
451  }
452  break;
453 #endif
454 #if MAX_MMA_UNROLL > 2
455  case 2:
456  if (accItr <= 2) {
458  }
459  break;
460 #endif
461 #if MAX_MMA_UNROLL > 1
462  case 1:
464  break;
465 #endif
466  default:
467  break;
468  }
469 #undef MAX_MMA_UNROLL
470 
471  if (remaining_rows > 0) {
473  }
474 }
475 
476 #define MICRO_MMA_COLS(n) \
477  for (; col + n * accRows <= cols; col += n * accRows) { \
478  gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols, n>( \
479  res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); \
480  }
481 
482 template <typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows,
483  const Index accCols>
484 void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols,
485  Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
486  const Index remaining_rows = rows % accCols;
487 
488  if (strideA == -1) strideA = depth;
489  if (strideB == -1) strideB = depth;
490 
491  const Packet pAlpha = pset1<Packet>(alpha);
492  const Packet pMask = bmask<Packet>(remaining_rows);
493 
494  typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
495 
496  Index col = 0;
497 #ifdef GEMM_MULTIPLE_COLS
498  MICRO_MMA_COLS(4);
499  MICRO_MMA_COLS(2);
500 #endif
501  MICRO_MMA_COLS(1);
502 
503  if (col != cols) {
504  gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
505  col, rows, cols, remaining_rows, pAlpha, pMask);
506  }
507 }
508 
509 #define advanceRows ((LhsIsReal) ? 1 : 2)
510 #define advanceCols ((RhsIsReal) ? 1 : 2)
511 
512 // PEEL_COMPLEX_MMA loop factor.
513 #ifdef GEMM_MULTIPLE_COLS
514 #define PEEL_COMPLEX_MMA 4
515 #else
516 #define PEEL_COMPLEX_MMA 3
517 #endif
518 
519 #define MICRO_COMPLEX_MMA_UNROLL(func) func(0) func(1) func(2) func(3)
520 
521 #define MICRO_COMPLEX_MMA_WORK(func, type, peel) \
522  if (accItr == 1) { \
523  func(0, type, peel, 0, 0) func(1, type, peel, 1, 0) func(2, type, peel, 2, 0) func(3, type, peel, 3, 0) \
524  } else if (accItr == 2) { \
525  func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 1, 0) func(3, type, peel, 1, 1) \
526  } else { \
527  func(0, type, peel, 0, 0) func(1, type, peel, 0, 1) func(2, type, peel, 0, 2) func(3, type, peel, 0, 3) \
528  }
529 
530 #define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel, left, right) \
531  if (unroll_factor > left) { \
532  pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
533  &accReal##iter, &accImag##iter, lhsV##left, lhsVi##left, rhsV##right[peel], rhsVi##right[peel]); \
534  }
535 
536 #ifdef VECTOR_PAIR_LOADS_LHS
537 #define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel, left, right) \
538  if (unroll_factor > left) { \
539  pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
540  &accReal##iter, &accImag##iter, lhsV2##left.packet[peel & 1], lhsVi2##left.packet[peel & 1], \
541  rhsV##right[peel], rhsVi##right[peel]); \
542  }
543 
544 #define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left) \
545  if (!LhsIsReal && (unroll_factor > left)) { \
546  if (MICRO_NORMAL(left)) { \
547  ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##left + imag_delta), plhsVi##left); \
548  __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##left.packet), &plhsVi##left); \
549  } else { \
550  lhsVi2##left.packet[0] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2); \
551  lhsVi2##left.packet[1] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2 + accCols2); \
552  EIGEN_UNUSED_VARIABLE(plhsVi##left); \
553  } \
554  } else { \
555  EIGEN_UNUSED_VARIABLE(lhsVi2##left); \
556  EIGEN_UNUSED_VARIABLE(plhsVi##left); \
557  } \
558  MICRO_MMA_LOAD1_TWO(lhs_ptr_real, left)
559 
560 #define MICRO_COMPLEX_MMA_LOAD_TWO(left) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left)
561 #endif
562 
563 #define MICRO_COMPLEX_MMA_LOAD_RHS1(peel, right) \
564  ploadRhsMMA(rhs_ptr_real##right + (accRows * peel), rhsV##right[peel]); \
565  if (!RhsIsReal) { \
566  ploadRhsMMA(rhs_ptr_imag##right + (accRows * peel), rhsVi##right[peel]); \
567  }
568 
569 #define MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_RHS1, peel)
570 
571 #define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
572  if (PEEL_COMPLEX_MMA > peel) { \
573  Packet lhsV0, lhsV1, lhsV2, lhsV3; \
574  Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
575  MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \
576  MICRO_COMPLEX_MMA_UNROLL(funcl) \
577  MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
578  }
579 
580 #ifndef VECTOR_PAIR_LOADS_LHS
581 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
582  type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], \
583  rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
584  MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 0) \
585  MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 1) \
586  MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 3)
587 #else
588 #define MICRO_COMPLEX_MMA_LOAD_TWO_RHS(peel1, right) \
589  ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real##right + (accRows * peel1)), prhsV##peel1); \
590  __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1); \
591  if (!RhsIsReal) { \
592  ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag##right + (accRows * peel1)), prhsVi##peel1); \
593  __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi##right[peel1]), &prhsVi##peel1); \
594  } else { \
595  EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
596  }
597 
598 #define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
599  if (PEEL_COMPLEX_MMA > peel2) { \
600  PacketBlock<Packet, 2> lhsV20, lhsV21, lhsV22, lhsV23; \
601  PacketBlock<Packet, 2> lhsVi20, lhsVi21, lhsVi22, lhsVi23; \
602  __vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
603  __vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
604  if (sizeof(type) == 16) { \
605  MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_TWO_RHS, peel1) \
606  } else { \
607  EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
608  EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
609  MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel1); \
610  MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel2); \
611  } \
612  MICRO_COMPLEX_MMA_UNROLL(funcl2) \
613  MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
614  MICRO_COMPLEX_MMA_WORK(funcw2, type, peel2) \
615  } else { \
616  EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
617  EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
618  MICRO_COMPLEX_MMA_TYPE_PEEL(funcw1, funcl1, type, peel1) \
619  }
620 
621 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
622  type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], \
623  rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
624  __vector_pair prhsV0, prhsV2; \
625  __vector_pair prhsVi0, prhsVi2; \
626  MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 0, 1) \
627  MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, 2, 3)
628 #endif
629 
630 #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
631  type rhsV0[1], rhsVi0[1], rhsV1[1], rhsVi1[1], rhsV2[1], rhsVi2[1], rhsV3[1], rhsVi3[1]; \
632  MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, 0)
633 
634 #define MICRO_COMPLEX_MMA_UPDATE_RHS1(size, right) \
635  rhs_ptr_real##right += (accRows * size); \
636  if (!RhsIsReal) rhs_ptr_imag##right += (accRows * size);
637 
638 #define MICRO_COMPLEX_MMA_UPDATE_RHS(size) MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_UPDATE_RHS1, size)
639 
640 #define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
641  MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \
642  MICRO_COMPLEX_MMA_UPDATE_RHS(size);
643 
644 #ifndef VECTOR_PAIR_LOADS_LHS
645 #define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL, PEEL_COMPLEX_MMA)
646 #else
647 #define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \
648  MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, \
649  MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \
650  MICRO_COMPLEX_MMA_UPDATE_RHS(size);
651 
652 #define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA)
653 #endif
654 
655 #define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1)
656 
657 #define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
658  if (unroll_factor * accItr > iter) { \
659  bsetzeroMMA(&accReal##iter); \
660  bsetzeroMMA(&accImag##iter); \
661  } else { \
662  EIGEN_UNUSED_VARIABLE(accReal##iter); \
663  EIGEN_UNUSED_VARIABLE(accImag##iter); \
664  }
665 
666 #define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE)
667 
668 #define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
669 
670 #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
671 
672 #define MICRO_COMPLEX_MMA_STORE_ONE(iter, left, right) \
673  if (unroll_factor > left) { \
674  storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (left + 1)) ? accCols : accCols2>( \
675  row + left * accCols, res##right, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
676  }
677 
678 #define MICRO_COMPLEX_MMA_ITER_UNROLL(func) \
679  if (accItr == 1) { \
680  func(0, 0, 0) func(1, 1, 0) func(2, 2, 0) func(3, 3, 0) \
681  } else if (accItr == 2) { \
682  func(0, 0, 0) func(1, 0, 1) func(2, 1, 0) func(3, 1, 1) \
683  } else { \
684  func(0, 0, 0) func(1, 0, 1) func(2, 0, 2) func(3, 0, 3) \
685  }
686 
687 #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_ITER_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
688 
689 #define MICRO_COMPLEX_MMA_EXTRA_ROWS(right) \
690  gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
691  RhsIsReal>(res3##right, blockA, rhs_base + right * accRows * (RhsIsReal ? 1 : 2) * strideB, \
692  depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, \
693  pAlphaImag, pMask);
694 
695 #define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right) MICRO_COMPLEX_MMA_EXTRA_ROWS(right);
696 
697 template <int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket,
698  typename DataMapper, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs,
699  bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index accItr>
700 EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(const DataMapper& res0, const DataMapper& res1,
701  const DataMapper& res2, const DataMapper& res3,
702  const Scalar* lhs_base, const Scalar* rhs_base,
703  Index depth, Index strideA, Index offsetA, Index strideB,
704  Index& row, const Packet& pAlphaReal,
705  const Packet& pAlphaImag, const Packet& pMask) {
706  const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL, *rhs_ptr_real3 = NULL;
707  const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL, *rhs_ptr_imag3 = NULL;
708  const Index imag_delta = accCols * strideA;
709  const Index imag_delta2 = accCols2 * strideA;
710 
711  if (!RhsIsReal) {
712  rhs_ptr_imag0 = rhs_base + accRows * strideB;
713  } else {
714  EIGEN_UNUSED_VARIABLE(rhs_ptr_imag0);
715  }
716  if (accItr > 1) {
717  if (!RhsIsReal) {
718  rhs_ptr_real1 = rhs_base + (2 * accRows * strideB);
719  rhs_ptr_imag1 = rhs_base + (3 * accRows * strideB);
720  } else {
721  rhs_ptr_real1 = rhs_base + accRows * strideB;
722  EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
723  }
724  } else {
725  EIGEN_UNUSED_VARIABLE(rhs_ptr_real1);
726  EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
727  EIGEN_UNUSED_VARIABLE(res1);
728  }
729  if (accItr > 2) {
730  if (!RhsIsReal) {
731  rhs_ptr_real2 = rhs_base + (4 * accRows * strideB);
732  rhs_ptr_imag2 = rhs_base + (5 * accRows * strideB);
733  rhs_ptr_real3 = rhs_base + (6 * accRows * strideB);
734  rhs_ptr_imag3 = rhs_base + (7 * accRows * strideB);
735  } else {
736  rhs_ptr_real2 = rhs_base + (2 * accRows * strideB);
737  rhs_ptr_real3 = rhs_base + (3 * accRows * strideB);
738  EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
739  EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
740  }
741  } else {
742  EIGEN_UNUSED_VARIABLE(rhs_ptr_real2);
743  EIGEN_UNUSED_VARIABLE(rhs_ptr_real3);
744  EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
745  EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
746  EIGEN_UNUSED_VARIABLE(res2);
747  EIGEN_UNUSED_VARIABLE(res3);
748  }
749  const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
750  const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
751  __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
752 
755 
756  Index k = 0, depth2 = depth - PEEL_COMPLEX_MMA;
757  for (; k <= depth2; k += PEEL_COMPLEX_MMA) {
758  EIGEN_POWER_PREFETCH(rhs_ptr_real);
759  if (!RhsIsReal) {
760  EIGEN_POWER_PREFETCH(rhs_ptr_imag);
761  }
764  }
765  for (; k < depth; k++) {
767  }
769 
771 }
772 
773 #define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
774  gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, \
775  accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, \
776  accItr>(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, offsetA, \
777  strideB, row, pAlphaReal, pAlphaImag, pMask); \
778  if (M) return;
779 
780 #define MICRO_COMPLEX_MMA_ROWS(n) \
781  while (row + n * accCols <= rows) { \
782  MICRO_COMPLEX_MMA_UNROLL_ITER2(n, 0); \
783  }
784 
785 template <typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper,
786  const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal,
787  bool RhsIsReal, const Index accItr>
788 EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
789  Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB,
790  Index col, Index rows, Index remaining_rows, const Packet& pAlphaReal,
791  const Packet& pAlphaImag, const Packet& pMask) {
792  const DataMapper res30 = res.getSubMapper(0, col);
793  const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows * 1) : res30;
794  const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows * 2) : res30;
795  const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows * 3) : res30;
796 
797  const Scalar* rhs_base = blockB + advanceCols * col * strideB + accRows * offsetB;
798  const Scalar* lhs_base = blockA + accCols * offsetA;
799  Index row = 0;
800 
801 #define MAX_COMPLEX_MMA_UNROLL 4
802 
803 #if MAX_COMPLEX_MMA_UNROLL < 2
804  if (1) {
805 #elif MAX_COMPLEX_MMA_UNROLL < 4
806  if (accItr <= 2) {
807 #else
808  if (accItr == 1) {
809 #endif
811  } else if (accItr == 2) {
813  } else {
815  }
816  switch ((rows - row) / accCols) {
817 #if MAX_COMPLEX_MMA_UNROLL > 3
818  case 3:
819  if (accItr == 1) {
821  }
822  break;
823 #endif
824 #if MAX_COMPLEX_MMA_UNROLL > 2
825  case 2:
826  if (accItr == 1) {
828  }
829  break;
830 #endif
831 #if MAX_COMPLEX_MMA_UNROLL > 1
832  case 1:
833  if (accItr <= 2) {
835  }
836  break;
837 #endif
838  default:
839  break;
840  }
841 #undef MAX_COMPLEX_MMA_UNROLL
842 
843  if (remaining_rows > 0) {
845  }
846 }
847 
848 #define MICRO_COMPLEX_MMA_COLS(n) \
849  for (; col + n * accRows <= cols; col += n * accRows) { \
850  gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, \
851  ConjugateRhs, LhsIsReal, RhsIsReal, n>(res, blockA, blockB, depth, strideA, offsetA, strideB, \
852  offsetB, col, rows, remaining_rows, pAlphaReal, \
853  pAlphaImag, pMask); \
854  }
855 
856 template <typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc,
857  typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs,
858  bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
859 void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth,
860  Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
861  const Index remaining_rows = rows % accCols;
862 
863  if (strideA == -1) strideA = depth;
864  if (strideB == -1) strideB = depth;
865 
866  const Packet pAlphaReal = pset1<Packet>(alpha.real());
867  const Packet pAlphaImag = pset1<Packet>(alpha.imag());
868  const Packet pMask = bmask<Packet>(remaining_rows);
869 
870  const Scalar* blockA = (Scalar*)blockAc;
871  const Scalar* blockB = (Scalar*)blockBc;
872 
873  typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
874 
875  Index col = 0;
876 #ifdef GEMM_MULTIPLE_COLS
879 #endif
881 
882  if (col != cols) {
883  gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
884  RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols,
885  remaining_rows, pAlphaReal, pAlphaImag, pMask);
886  }
887 }
888 
889 #undef accColsC
890 #undef advanceRows
891 #undef advanceCols
892 
893 } // end namespace internal
894 
895 } // end namespace Eigen
896 
897 #if defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
898 #pragma GCC pop_options
899 #endif
900 
901 #endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
int data[]
Definition: Map_placement_new.cpp:1
m col(1)
m row(1)
#define MICRO_COMPLEX_UNROLL_ITER(func, N)
Definition: MatrixProductCommon.h:142
#define MICRO_UPDATE
Definition: MatrixProductCommon.h:191
#define MICRO_COMPLEX_UPDATE
Definition: MatrixProductCommon.h:198
#define EIGEN_POWER_PREFETCH(p)
Definition: MatrixProductCommon.h:5
#define MICRO_UNROLL_ITER(func, N)
Definition: MatrixProductCommon.h:139
#define MICRO_MMA_DST_PTR
Definition: MatrixProductMMA.h:302
#define advanceCols
Definition: MatrixProductMMA.h:510
#define MICRO_COMPLEX_MMA_ONE
Definition: MatrixProductMMA.h:655
#define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right)
Definition: MatrixProductMMA.h:695
#define MICRO_COMPLEX_MMA_ROWS(n)
Definition: MatrixProductMMA.h:780
#define accColsC
Definition: MatrixProductMMA.h:38
#define MICRO_MMA_STORE
Definition: MatrixProductMMA.h:323
#define MICRO_COMPLEX_MMA_DST_PTR
Definition: MatrixProductMMA.h:666
#define MICRO_COMPLEX_MMA_SRC_PTR
Definition: MatrixProductMMA.h:668
#define MICRO_MMA_PREFETCH
Definition: MatrixProductMMA.h:306
#define MICRO_COMPLEX_MMA_STORE
Definition: MatrixProductMMA.h:687
#define MICRO_MMA_ROWS(n)
Definition: MatrixProductMMA.h:384
#define MICRO_MMA_ONE
Definition: MatrixProductMMA.h:293
#define MAX_COMPLEX_MMA_UNROLL
#define MICRO_MMA_SRC_PTR
Definition: MatrixProductMMA.h:304
#define MICRO_COMPLEX_MMA_ONE_PEEL
Definition: MatrixProductMMA.h:645
#define MICRO_MMA_ONE_PEEL
Definition: MatrixProductMMA.h:284
#define MAX_MMA_UNROLL
#define PEEL_MMA
Definition: MatrixProductMMA.h:156
#define MICRO_COMPLEX_MMA_COLS(n)
Definition: MatrixProductMMA.h:848
#define MICRO_MMA_UNROLL_ITER2(N, M)
Definition: MatrixProductMMA.h:378
#define PEEL_COMPLEX_MMA
Definition: MatrixProductMMA.h:514
#define MICRO_MMA_UNROLL_ITER(func, val)
Definition: MatrixProductMMA.h:211
#define MICRO_MMA_COLS(n)
Definition: MatrixProductMMA.h:476
#define MICRO_MMA_EXTRA_ROWS1(val, right)
Definition: MatrixProductMMA.h:330
#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M)
Definition: MatrixProductMMA.h:773
#define MICRO_COMPLEX_MMA_PREFETCH
Definition: MatrixProductMMA.h:670
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
internal::packet_traits< Scalar >::type Packet
Definition: benchmark-blocking-sizes.cpp:54
RealScalar alpha
Definition: level1_cplx_impl.h:151
const Scalar * a
Definition: level2_cplx_impl.h:32
char char char int int * k
Definition: level2_impl.h:374
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(const DataMapper &res0, const DataMapper &res1, const DataMapper &res2, const DataMapper &res3, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index strideB, Index offsetA, Index &row, const Packet &pAlpha, Index accCols2)
Definition: MatrixProductMMA.h:334
__m128d Packet2d
Definition: LSX/PacketMath.h:36
EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad *acc)
Definition: MatrixProductMMA.h:40
EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper &data, const Packet &alphaReal, const Packet &alphaImag, const Packet &pMask, __vector_quad *accReal, __vector_quad *accImag)
Definition: MatrixProductMMA.h:62
EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad *accReal, __vector_quad *accImag, const Packet &lhsV, Packet &lhsVi, const RhsPacket &rhsV, RhsPacket &rhsVi)
Definition: MatrixProductMMA.h:105
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper &data, const Packet &alpha, const Index elements, __vector_quad *acc)
Definition: MatrixProductMMA.h:43
void gemmMMA(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProductMMA.h:484
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(const DataMapper &res0, const DataMapper &res1, const DataMapper &res2, const DataMapper &res3, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProductMMA.h:700
EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index cols, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2579
void gemm_complexMMA(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProductMMA.h:859
EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet &a)
Definition: AltiVec/PacketMath.h:1876
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProductMMA.h:788
EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double *lhs, __vector_pair &lhsV)
Definition: MatrixProductMMA.h:144
EIGEN_ALWAYS_INLINE void gemmMMA_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProductMMA.h:391
EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar *rhs, Packet &rhsV)
Definition: MatrixProductMMA.h:129
EIGEN_ALWAYS_INLINE Packet ploadRhs(const __UNPACK_TYPE__(Packet) *rhs)
Definition: MatrixProductMMA.h:124
EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad *acc, const RhsPacket &a, const LhsPacket &b)
Definition: MatrixProductMMA.h:87
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Definition: Eigen_Colamd.h:49
Definition: GenericPacketMath.h:1407
Packet packet[N]
Definition: GenericPacketMath.h:1408
Definition: ZVector/PacketMath.h:50