MatrixProduct.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5 // Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
13 
14 #ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15 #define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
16 #endif
17 
18 #if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
19 #define EIGEN_ALTIVEC_DISABLE_MMA 0
20 #endif
21 
22 // Check for MMA builtin support.
23 #if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__has_builtin)
24 #if __has_builtin(__builtin_mma_assemble_acc)
25 #define EIGEN_ALTIVEC_MMA_SUPPORT
26 #endif
27 #endif
28 
29 // Check if and how we should actually use MMA if supported.
30 #if defined(EIGEN_ALTIVEC_MMA_SUPPORT)
31 
32 #if !defined(EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH)
33 #define EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH 0
34 #endif
35 
36 // Check if we want to enable dynamic dispatch. Not supported by LLVM.
37 #if EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH && !EIGEN_COMP_LLVM
38 #define EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH 1
39 // Otherwise, use MMA by default if available.
40 #elif defined(__MMA__)
41 #define EIGEN_ALTIVEC_MMA_ONLY 1
42 #endif
43 
44 #endif // EIGEN_ALTIVEC_MMA_SUPPORT
45 
46 #include "MatrixProductCommon.h"
47 
48 #if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
49 #include "MatrixProductMMA.h"
50 #endif
51 
52 // IWYU pragma: private
53 #include "../../InternalHeaderCheck.h"
54 
55 namespace Eigen {
56 
57 namespace internal {
58 
59 /**************************
60  * Constants and typedefs *
61  **************************/
62 template <typename Scalar>
63 struct quad_traits {
68 };
69 
70 template <>
76 };
77 
78 template <>
84 };
85 
86 // MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
87 // to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
88 // are responsible to extract from convert between Eigen's and MatrixProduct approach.
89 
90 const static Packet16uc p16uc_GETREAL32 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
91 
92 const static Packet16uc p16uc_GETIMAG32 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
93 
94 const static Packet16uc p16uc_GETREAL32b = {0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
95 
96 const static Packet16uc p16uc_GETIMAG32b = {4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31};
97 
98 /*********************************************
99  * Single precision real and complex packing *
100  * *******************************************/
101 
116 template <typename Scalar, int StorageOrder>
117 EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(
118  Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt) {
119  std::complex<Scalar> v;
120  if (i < j) {
121  v.real(dt(j, i).real());
122  v.imag(-dt(j, i).imag());
123  } else if (i > j) {
124  v.real(dt(i, j).real());
125  v.imag(dt(i, j).imag());
126  } else {
127  v.real(dt(i, j).real());
128  v.imag((Scalar)0.0);
129  }
130  return v;
131 }
132 
133 template <typename Scalar, int StorageOrder, int N>
134 EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs,
135  Index rhsStride, Index rows, Index cols, Index k2) {
136  const Index depth = k2 + rows;
137  const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
138  const Index vectorSize = N * quad_traits<Scalar>::vectorsize;
139  const Index vectorDelta = vectorSize * rows;
140  Scalar* blockBf = reinterpret_cast<Scalar*>(blockB);
141 
142  Index rir = 0, rii, j = 0;
143  for (; j + vectorSize <= cols; j += vectorSize) {
144  rii = rir + vectorDelta;
145 
146  for (Index i = k2; i < depth; i++) {
147  for (Index k = 0; k < vectorSize; k++) {
148  std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j + k, rhs);
149 
150  blockBf[rir + k] = v.real();
151  blockBf[rii + k] = v.imag();
152  }
153  rir += vectorSize;
154  rii += vectorSize;
155  }
156 
157  rir += vectorDelta;
158  }
159 
160  for (; j < cols; j++) {
161  rii = rir + rows;
162 
163  for (Index i = k2; i < depth; i++) {
164  std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j, rhs);
165 
166  blockBf[rir] = v.real();
167  blockBf[rii] = v.imag();
168 
169  rir += 1;
170  rii += 1;
171  }
172 
173  rir += rows;
174  }
175 }
176 
177 template <typename Scalar, int StorageOrder>
178 EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs,
179  Index lhsStride, Index cols, Index rows) {
180  const Index depth = cols;
181  const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
182  const Index vectorSize = quad_traits<Scalar>::vectorsize;
183  const Index vectorDelta = vectorSize * depth;
184  Scalar* blockAf = reinterpret_cast<Scalar*>(blockA);
185 
186  Index rir = 0, rii, j = 0;
187  for (; j + vectorSize <= rows; j += vectorSize) {
188  rii = rir + vectorDelta;
189 
190  for (Index i = 0; i < depth; i++) {
191  for (Index k = 0; k < vectorSize; k++) {
192  std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(j + k, i, lhs);
193 
194  blockAf[rir + k] = v.real();
195  blockAf[rii + k] = v.imag();
196  }
197  rir += vectorSize;
198  rii += vectorSize;
199  }
200 
201  rir += vectorDelta;
202  }
203 
204  if (j < rows) {
205  rii = rir + ((rows - j) * depth);
206 
207  for (Index i = 0; i < depth; i++) {
208  Index k = j;
209  for (; k < rows; k++) {
210  std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(k, i, lhs);
211 
212  blockAf[rir] = v.real();
213  blockAf[rii] = v.imag();
214 
215  rir += 1;
216  rii += 1;
217  }
218  }
219  }
220 }
221 
222 template <typename Scalar, int StorageOrder, int N>
223 EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows,
224  Index cols, Index k2) {
225  const Index depth = k2 + rows;
227  const Index vectorSize = quad_traits<Scalar>::vectorsize;
228 
229  Index ri = 0, j = 0;
230  for (; j + N * vectorSize <= cols; j += N * vectorSize) {
231  Index i = k2;
232  for (; i < depth; i++) {
233  for (Index k = 0; k < N * vectorSize; k++) {
234  if (i <= j + k)
235  blockB[ri + k] = rhs(j + k, i);
236  else
237  blockB[ri + k] = rhs(i, j + k);
238  }
239  ri += N * vectorSize;
240  }
241  }
242 
243  for (; j < cols; j++) {
244  for (Index i = k2; i < depth; i++) {
245  if (j <= i)
246  blockB[ri] = rhs(i, j);
247  else
248  blockB[ri] = rhs(j, i);
249  ri += 1;
250  }
251  }
252 }
253 
254 template <typename Scalar, int StorageOrder>
255 EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols,
256  Index rows) {
257  const Index depth = cols;
259  const Index vectorSize = quad_traits<Scalar>::vectorsize;
260 
261  Index ri = 0, j = 0;
262  for (; j + vectorSize <= rows; j += vectorSize) {
263  Index i = 0;
264 
265  for (; i < depth; i++) {
266  for (Index k = 0; k < vectorSize; k++) {
267  if (i <= j + k)
268  blockA[ri + k] = lhs(j + k, i);
269  else
270  blockA[ri + k] = lhs(i, j + k);
271  }
272  ri += vectorSize;
273  }
274  }
275 
276  if (j < rows) {
277  for (Index i = 0; i < depth; i++) {
278  Index k = j;
279  for (; k < rows; k++) {
280  if (i <= k)
281  blockA[ri] = lhs(k, i);
282  else
283  blockA[ri] = lhs(i, k);
284  ri += 1;
285  }
286  }
287  }
288 }
289 
290 template <typename Index, int nr, int StorageOrder>
291 struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder> {
292  void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols,
293  Index k2) {
294  symm_pack_complex_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
295  }
296 };
297 
298 template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
299 struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder> {
300  void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols,
301  Index rows) {
302  symm_pack_complex_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
303  }
304 };
305 
306 // *********** symm_pack std::complex<float64> ***********
307 
308 template <typename Index, int nr, int StorageOrder>
309 struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder> {
310  void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows,
311  Index cols, Index k2) {
312  symm_pack_complex_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
313  }
314 };
315 
316 template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
317 struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder> {
318  void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols,
319  Index rows) {
320  symm_pack_complex_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
321  }
322 };
323 
324 // *********** symm_pack float32 ***********
325 template <typename Index, int nr, int StorageOrder>
326 struct symm_pack_rhs<float, Index, nr, StorageOrder> {
327  void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2) {
328  symm_pack_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
329  }
330 };
331 
332 template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
333 struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder> {
334  void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows) {
335  symm_pack_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
336  }
337 };
338 
339 // *********** symm_pack float64 ***********
340 template <typename Index, int nr, int StorageOrder>
341 struct symm_pack_rhs<double, Index, nr, StorageOrder> {
342  void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2) {
343  symm_pack_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
344  }
345 };
346 
347 template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
348 struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder> {
349  void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows) {
350  symm_pack_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
351  }
352 };
353 
365 template <typename Scalar, typename Packet, int N>
367  const Index size = 16 / sizeof(Scalar);
368  pstore<Scalar>(to + (0 * size), block.packet[0]);
369  pstore<Scalar>(to + (1 * size), block.packet[1]);
370  if (N > 2) {
371  pstore<Scalar>(to + (2 * size), block.packet[2]);
372  }
373  if (N > 3) {
374  pstore<Scalar>(to + (3 * size), block.packet[3]);
375  }
376 }
377 
378 // General template for lhs & rhs complex packing.
379 template <typename Scalar, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate,
380  bool PanelMode, bool UseLhs>
381 struct dhs_cpack {
382  template <bool transpose>
384  Packet16uc permute) {
385  if (transpose) {
386  block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, permute);
387  block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, permute);
388  block.packet[2] = vec_perm(cblock.packet[4].v, cblock.packet[5].v, permute);
389  block.packet[3] = vec_perm(cblock.packet[6].v, cblock.packet[7].v, permute);
390 
391  Packet4f t0, t1, t2, t3;
392 #ifdef EIGEN_VECTORIZE_VSX
393  t0 = reinterpret_cast<Packet>(
394  vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
395  t1 = reinterpret_cast<Packet>(
396  vec_mergel(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
397  t2 = reinterpret_cast<Packet>(
398  vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
399  t3 = reinterpret_cast<Packet>(
400  vec_mergel(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
401 #else
402  t0 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_HI));
403  t1 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_LO));
404  t2 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_HI));
405  t3 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_LO));
406 #endif
407 
408  block.packet[0] = t0;
409  block.packet[1] = t1;
410  block.packet[2] = t2;
411  block.packet[3] = t3;
412  } else {
413  block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, permute);
414  block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, permute);
415  block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, permute);
416  block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, permute);
417  }
418  }
419 
420  EIGEN_ALWAYS_INLINE void dhs_ccopy(Scalar* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii,
421  Index depth, const Index vectorSize) {
422  PacketBlock<Packet, 4> blockr, blocki;
424 
425  for (; i + vectorSize <= depth; i += vectorSize) {
426  if (UseLhs) {
427  bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
428  } else {
429  bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
430  }
431 
432  if (((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs))) {
433  dhs_cblock<true>(cblock, blockr, p16uc_GETREAL32b);
434  dhs_cblock<true>(cblock, blocki, p16uc_GETIMAG32b);
435  } else {
436  dhs_cblock<false>(cblock, blockr, p16uc_GETREAL32);
437  dhs_cblock<false>(cblock, blocki, p16uc_GETIMAG32);
438  }
439 
440  if (Conjugate) {
441  blocki.packet[0] = -blocki.packet[0];
442  blocki.packet[1] = -blocki.packet[1];
443  blocki.packet[2] = -blocki.packet[2];
444  blocki.packet[3] = -blocki.packet[3];
445  }
446 
447  storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
448  storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
449 
450  rir += 4 * vectorSize;
451  rii += 4 * vectorSize;
452  }
453  }
454 
455  EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows,
456  Index stride, Index offset) {
457  const Index vectorSize = quad_traits<Scalar>::vectorsize;
458  const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
459  Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
460  Scalar* blockAt = reinterpret_cast<Scalar*>(blockA);
461  Index j = 0;
462 
463  for (; j + vectorSize <= rows; j += vectorSize) {
464  const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
465  Index i = 0;
466 
467  rii = rir + vectorDelta;
468 
469  dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
470 
471  for (; i < depth; i++) {
472  PacketBlock<Packet, 1> blockr, blocki;
474 
475  if (((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs))) {
476  if (UseLhs) {
477  cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
478  cblock.packet[1] = lhs2.template loadPacket<PacketC>(2, i);
479  } else {
480  cblock.packet[0] = lhs2.template loadPacket<PacketC>(i, 0);
481  cblock.packet[1] = lhs2.template loadPacket<PacketC>(i, 2);
482  }
483  } else {
484  if (UseLhs) {
485  cblock.packet[0] = pload2(lhs2(0, i), lhs2(1, i));
486  cblock.packet[1] = pload2(lhs2(2, i), lhs2(3, i));
487  } else {
488  cblock.packet[0] = pload2(lhs2(i, 0), lhs2(i, 1));
489  cblock.packet[1] = pload2(lhs2(i, 2), lhs2(i, 3));
490  }
491  }
492 
493  blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
494  blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
495 
496  if (Conjugate) {
497  blocki.packet[0] = -blocki.packet[0];
498  }
499 
500  pstore<Scalar>(blockAt + rir, blockr.packet[0]);
501  pstore<Scalar>(blockAt + rii, blocki.packet[0]);
502 
503  rir += vectorSize;
504  rii += vectorSize;
505  }
506 
507  rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
508  }
509 
510  if (!UseLhs) {
511  if (PanelMode) rir -= (offset * (vectorSize - 1));
512 
513  for (; j < rows; j++) {
514  const DataMapper lhs2 = lhs.getSubMapper(0, j);
515  rii = rir + ((PanelMode) ? stride : depth);
516 
517  for (Index i = 0; i < depth; i++) {
518  blockAt[rir] = lhs2(i, 0).real();
519 
520  if (Conjugate)
521  blockAt[rii] = -lhs2(i, 0).imag();
522  else
523  blockAt[rii] = lhs2(i, 0).imag();
524 
525  rir += 1;
526  rii += 1;
527  }
528 
529  rir += ((PanelMode) ? (2 * stride - depth) : depth);
530  }
531  } else {
532  if (j < rows) {
533  if (PanelMode) rir += (offset * (rows - j - vectorSize));
534  rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
535 
536  for (Index i = 0; i < depth; i++) {
537  Index k = j;
538  for (; k < rows; k++) {
539  blockAt[rir] = lhs(k, i).real();
540 
541  if (Conjugate)
542  blockAt[rii] = -lhs(k, i).imag();
543  else
544  blockAt[rii] = lhs(k, i).imag();
545 
546  rir += 1;
547  rii += 1;
548  }
549  }
550  }
551  }
552  }
553 };
554 
555 // General template for lhs & rhs packing.
556 template <typename Scalar, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
557 struct dhs_pack {
558  template <Index n>
559  EIGEN_ALWAYS_INLINE void dhs_copy(Scalar* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth,
560  const Index vectorSize) {
562 
563  for (; i + n * vectorSize <= depth; i += n * vectorSize) {
564  for (Index k = 0; k < n; k++) {
565  if (UseLhs) {
566  bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, 0, i + k * vectorSize);
567  } else {
568  bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, i + k * vectorSize, 0);
569  }
570  }
571 
572  if (((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) {
573  for (Index k = 0; k < n; k++) {
574  ptranspose(block[k]);
575  }
576  }
577 
578  for (Index k = 0; k < n; k++) {
579  storeBlock<Scalar, Packet, 4>(blockA + ri + k * 4 * vectorSize, block[k]);
580  }
581 
582  ri += n * 4 * vectorSize;
583  }
584  }
585 
586  EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
587  Index offset) {
588  const Index vectorSize = quad_traits<Scalar>::vectorsize;
589  Index ri = 0, j = 0;
590 
591  for (; j + vectorSize <= rows; j += vectorSize) {
592  const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
593  Index i = 0;
594 
595  if (PanelMode) ri += vectorSize * offset;
596 
597  dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
598  dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
599  dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
600 
601  for (; i < depth; i++) {
602  if (((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) {
603  if (UseLhs) {
604  blockA[ri + 0] = lhs2(0, i);
605  blockA[ri + 1] = lhs2(1, i);
606  blockA[ri + 2] = lhs2(2, i);
607  blockA[ri + 3] = lhs2(3, i);
608  } else {
609  blockA[ri + 0] = lhs2(i, 0);
610  blockA[ri + 1] = lhs2(i, 1);
611  blockA[ri + 2] = lhs2(i, 2);
612  blockA[ri + 3] = lhs2(i, 3);
613  }
614  } else {
615  Packet lhsV;
616  if (UseLhs) {
617  lhsV = lhs2.template loadPacket<Packet>(0, i);
618  } else {
619  lhsV = lhs2.template loadPacket<Packet>(i, 0);
620  }
621  pstore<Scalar>(blockA + ri, lhsV);
622  }
623 
624  ri += vectorSize;
625  }
626 
627  if (PanelMode) ri += vectorSize * (stride - offset - depth);
628  }
629 
630  if (!UseLhs) {
631  if (PanelMode) ri += offset;
632 
633  for (; j < rows; j++) {
634  const DataMapper lhs2 = lhs.getSubMapper(0, j);
635  for (Index i = 0; i < depth; i++) {
636  blockA[ri] = lhs2(i, 0);
637  ri += 1;
638  }
639 
640  if (PanelMode) ri += stride - depth;
641  }
642  } else {
643  if (j < rows) {
644  if (PanelMode) ri += offset * (rows - j);
645 
646  for (Index i = 0; i < depth; i++) {
647  Index k = j;
648  for (; k < rows; k++) {
649  blockA[ri] = lhs(k, i);
650  ri += 1;
651  }
652  }
653  }
654  }
655  }
656 };
657 
658 // General template for lhs packing, float64 specialization.
659 template <typename DataMapper, int StorageOrder, bool PanelMode>
660 struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true> {
661  template <Index n>
662  EIGEN_ALWAYS_INLINE void dhs_copy(double* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth,
663  const Index vectorSize) {
665 
666  for (; i + n * vectorSize <= depth; i += n * vectorSize) {
667  for (Index k = 0; k < n; k++) {
668  if (StorageOrder == RowMajor) {
669  block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize);
670  block[k].packet[1] = lhs2.template loadPacket<Packet2d>(1, i + k * vectorSize);
671  } else {
672  block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize + 0);
673  block[k].packet[1] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize + 1);
674  }
675  }
676 
677  if (StorageOrder == RowMajor) {
678  for (Index k = 0; k < n; k++) {
679  ptranspose(block[k]);
680  }
681  }
682 
683  for (Index k = 0; k < n; k++) {
684  storeBlock<double, Packet2d, 2>(blockA + ri + k * 2 * vectorSize, block[k]);
685  }
686 
687  ri += n * 2 * vectorSize;
688  }
689  }
690 
691  EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
692  Index offset) {
693  const Index vectorSize = quad_traits<double>::vectorsize;
694  Index ri = 0, j = 0;
695 
696  for (; j + vectorSize <= rows; j += vectorSize) {
697  const DataMapper lhs2 = lhs.getSubMapper(j, 0);
698  Index i = 0;
699 
700  if (PanelMode) ri += vectorSize * offset;
701 
702  dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
703  dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
704  dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
705 
706  for (; i < depth; i++) {
707  if (StorageOrder == RowMajor) {
708  blockA[ri + 0] = lhs2(0, i);
709  blockA[ri + 1] = lhs2(1, i);
710  } else {
711  Packet2d lhsV = lhs2.template loadPacket<Packet2d>(0, i);
712  pstore<double>(blockA + ri, lhsV);
713  }
714 
715  ri += vectorSize;
716  }
717 
718  if (PanelMode) ri += vectorSize * (stride - offset - depth);
719  }
720 
721  if (j < rows) {
722  if (PanelMode) ri += offset * (rows - j);
723 
724  for (Index i = 0; i < depth; i++) {
725  Index k = j;
726  for (; k < rows; k++) {
727  blockA[ri] = lhs(k, i);
728  ri += 1;
729  }
730  }
731  }
732  }
733 };
734 
735 // General template for rhs packing, float64 specialization.
736 template <typename DataMapper, int StorageOrder, bool PanelMode>
737 struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false> {
738  template <Index n>
739  EIGEN_ALWAYS_INLINE void dhs_copy(double* blockB, const DataMapper& rhs2, Index& i, Index& ri, Index depth,
740  const Index vectorSize) {
741  PacketBlock<Packet2d, 2> block1[n], block2[n];
742  PacketBlock<Packet2d, 4> block3[n];
743 
744  for (; i + n * vectorSize <= depth; i += n * vectorSize) {
745  for (Index k = 0; k < n; k++) {
746  if (StorageOrder == ColMajor) {
747  block1[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 0);
748  block1[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 1);
749  block2[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 2);
750  block2[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 3);
751  } else {
752  block3[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 0, 0); //[a1 a2]
753  block3[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 0, 2); //[a3 a4]
754  block3[k].packet[2] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 1, 0); //[b1 b2]
755  block3[k].packet[3] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 1, 2); //[b3 b4]
756  }
757  }
758 
759  if (StorageOrder == ColMajor) {
760  for (Index k = 0; k < n; k++) {
761  ptranspose(block1[k]);
762  ptranspose(block2[k]);
763  }
764  }
765 
766  for (Index k = 0; k < n; k++) {
767  if (StorageOrder == ColMajor) {
768  pstore<double>(blockB + ri + k * 4 * vectorSize, block1[k].packet[0]);
769  pstore<double>(blockB + ri + k * 4 * vectorSize + 2, block2[k].packet[0]);
770  pstore<double>(blockB + ri + k * 4 * vectorSize + 4, block1[k].packet[1]);
771  pstore<double>(blockB + ri + k * 4 * vectorSize + 6, block2[k].packet[1]);
772  } else {
773  storeBlock<double, Packet2d, 4>(blockB + ri + k * 4 * vectorSize, block3[k]);
774  }
775  }
776 
777  ri += n * 4 * vectorSize;
778  }
779  }
780 
781  EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride,
782  Index offset) {
783  const Index vectorSize = quad_traits<double>::vectorsize;
784  Index ri = 0, j = 0;
785 
786  for (; j + 2 * vectorSize <= cols; j += 2 * vectorSize) {
787  const DataMapper rhs2 = rhs.getSubMapper(0, j);
788  Index i = 0;
789 
790  if (PanelMode) ri += offset * (2 * vectorSize);
791 
792  dhs_copy<4>(blockB, rhs2, i, ri, depth, vectorSize);
793  dhs_copy<2>(blockB, rhs2, i, ri, depth, vectorSize);
794  dhs_copy<1>(blockB, rhs2, i, ri, depth, vectorSize);
795 
796  for (; i < depth; i++) {
797  if (StorageOrder == ColMajor) {
798  blockB[ri + 0] = rhs2(i, 0);
799  blockB[ri + 1] = rhs2(i, 1);
800 
801  ri += vectorSize;
802 
803  blockB[ri + 0] = rhs2(i, 2);
804  blockB[ri + 1] = rhs2(i, 3);
805  } else {
806  Packet2d rhsV = rhs2.template loadPacket<Packet2d>(i, 0);
807  pstore<double>(blockB + ri, rhsV);
808 
809  ri += vectorSize;
810 
811  rhsV = rhs2.template loadPacket<Packet2d>(i, 2);
812  pstore<double>(blockB + ri, rhsV);
813  }
814  ri += vectorSize;
815  }
816 
817  if (PanelMode) ri += (2 * vectorSize) * (stride - offset - depth);
818  }
819 
820  if (PanelMode) ri += offset;
821 
822  for (; j < cols; j++) {
823  const DataMapper rhs2 = rhs.getSubMapper(0, j);
824  for (Index i = 0; i < depth; i++) {
825  blockB[ri] = rhs2(i, 0);
826  ri += 1;
827  }
828 
829  if (PanelMode) ri += stride - depth;
830  }
831  }
832 };
833 
834 // General template for lhs packing, bfloat16 specialization.
835 template <typename DataMapper, int StorageOrder, bool PanelMode>
836 struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true> {
837  EIGEN_STRONG_INLINE void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
838  Index offset) {
839  const Index vectorSize = quad_traits<bfloat16>::vectorsize;
840  Index ri = 0, j = 0;
841 
842  for (; j + 2 * vectorSize <= rows; j += 2 * vectorSize) {
843  const DataMapper lhs2 = lhs.getSubMapper(j, 0);
844  Index i = 0;
845 
846  if (PanelMode) ri += 2 * vectorSize * offset;
847 
848  if (StorageOrder == ColMajor) {
849  for (; i + 2 <= depth; i += 2) {
851 
852  block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
853  block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
854  block.packet[2] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
855  block.packet[3] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 1);
856 
857  Packet8bf t0, t1;
858  t0 = vec_mergeh(block.packet[0].m_val, block.packet[2].m_val);
859  t1 = vec_mergel(block.packet[0].m_val, block.packet[2].m_val);
860  block.packet[2] = vec_mergeh(block.packet[1].m_val, block.packet[3].m_val);
861  block.packet[3] = vec_mergel(block.packet[1].m_val, block.packet[3].m_val);
862  block.packet[0] = t0;
863  block.packet[1] = t1;
864 
865  storeBlock<bfloat16, Packet8bf, 4>(blockA + ri, block);
866 
867  ri += 2 * 2 * vectorSize;
868  }
869  if (depth & 1) {
871 
872  block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
873  block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
874 
875  storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
876 
877  ri += 2 * vectorSize;
878  }
879  } else {
880  for (; i + vectorSize <= depth; i += vectorSize) {
881  PacketBlock<Packet8bf, 8> block1, block2;
882 
883  bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
884  bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block2, lhs2, 1 * vectorSize, i);
885 
886  Packet4ui v1[8], v2[8];
887 
888  v1[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
889  reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
890  v1[1] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
891  reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
892  v1[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
893  reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
894  v1[3] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
895  reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
896  v1[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
897  reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
898  v1[5] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
899  reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
900  v1[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
901  reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
902  v1[7] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
903  reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
904  v2[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[0].m_val),
905  reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
906  v2[1] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[0].m_val),
907  reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
908  v2[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[2].m_val),
909  reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
910  v2[3] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[2].m_val),
911  reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
912  v2[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[4].m_val),
913  reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
914  v2[5] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[4].m_val),
915  reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
916  v2[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[6].m_val),
917  reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
918  v2[7] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[6].m_val),
919  reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
920 
921 #ifdef EIGEN_VECTORIZE_VSX
922  block1.packet[0] = reinterpret_cast<Packet8us>(
923  vec_mergeh(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
924  block1.packet[2] = reinterpret_cast<Packet8us>(
925  vec_mergel(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
926  block1.packet[4] = reinterpret_cast<Packet8us>(
927  vec_mergeh(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
928  block1.packet[6] = reinterpret_cast<Packet8us>(
929  vec_mergel(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
930  block1.packet[1] = reinterpret_cast<Packet8us>(
931  vec_mergeh(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
932  block1.packet[3] = reinterpret_cast<Packet8us>(
933  vec_mergel(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
934  block1.packet[5] = reinterpret_cast<Packet8us>(
935  vec_mergeh(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
936  block1.packet[7] = reinterpret_cast<Packet8us>(
937  vec_mergel(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
938  block2.packet[0] = reinterpret_cast<Packet8us>(
939  vec_mergeh(reinterpret_cast<Packet2ul>(v2[0]), reinterpret_cast<Packet2ul>(v2[2])));
940  block2.packet[2] = reinterpret_cast<Packet8us>(
941  vec_mergel(reinterpret_cast<Packet2ul>(v2[0]), reinterpret_cast<Packet2ul>(v2[2])));
942  block2.packet[4] = reinterpret_cast<Packet8us>(
943  vec_mergeh(reinterpret_cast<Packet2ul>(v2[1]), reinterpret_cast<Packet2ul>(v2[3])));
944  block2.packet[6] = reinterpret_cast<Packet8us>(
945  vec_mergel(reinterpret_cast<Packet2ul>(v2[1]), reinterpret_cast<Packet2ul>(v2[3])));
946  block2.packet[1] = reinterpret_cast<Packet8us>(
947  vec_mergeh(reinterpret_cast<Packet2ul>(v2[4]), reinterpret_cast<Packet2ul>(v2[6])));
948  block2.packet[3] = reinterpret_cast<Packet8us>(
949  vec_mergel(reinterpret_cast<Packet2ul>(v2[4]), reinterpret_cast<Packet2ul>(v2[6])));
950  block2.packet[5] = reinterpret_cast<Packet8us>(
951  vec_mergeh(reinterpret_cast<Packet2ul>(v2[5]), reinterpret_cast<Packet2ul>(v2[7])));
952  block2.packet[7] = reinterpret_cast<Packet8us>(
953  vec_mergel(reinterpret_cast<Packet2ul>(v2[5]), reinterpret_cast<Packet2ul>(v2[7])));
954 #else
955  block1.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_HI));
956  block1.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_LO));
957  block1.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_HI));
958  block1.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_LO));
959  block1.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_HI));
960  block1.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_LO));
961  block1.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_HI));
962  block1.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_LO));
963  block2.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v2[0], v2[2], p16uc_TRANSPOSE64_HI));
964  block2.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v2[0], v2[2], p16uc_TRANSPOSE64_LO));
965  block2.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v2[1], v2[3], p16uc_TRANSPOSE64_HI));
966  block2.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v2[1], v2[3], p16uc_TRANSPOSE64_LO));
967  block2.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v2[4], v2[6], p16uc_TRANSPOSE64_HI));
968  block2.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v2[4], v2[6], p16uc_TRANSPOSE64_LO));
969  block2.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v2[5], v2[7], p16uc_TRANSPOSE64_HI));
970  block2.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v2[5], v2[7], p16uc_TRANSPOSE64_LO));
971 #endif
972 
973  for (Index M = 0; M < 8; M += 2) {
974  pstore<bfloat16>(blockA + ri + (0 * vectorSize) + (2 * vectorSize * M), block1.packet[M + 0]);
975  pstore<bfloat16>(blockA + ri + (1 * vectorSize) + (2 * vectorSize * M), block1.packet[M + 1]);
976  pstore<bfloat16>(blockA + ri + (2 * vectorSize) + (2 * vectorSize * M), block2.packet[M + 0]);
977  pstore<bfloat16>(blockA + ri + (3 * vectorSize) + (2 * vectorSize * M), block2.packet[M + 1]);
978  }
979 
980  ri += 2 * vectorSize * vectorSize;
981  }
982  for (; i + 2 <= depth; i += 2) {
983  for (Index M = 0; M < 2 * vectorSize; M++) {
984  blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
985  blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
986  }
987 
988  ri += 2 * 2 * vectorSize;
989  }
990  if (depth & 1) {
991  for (Index M = 0; M < 2 * vectorSize; M++) {
992  blockA[ri + M] = lhs2(M, i);
993  }
994  ri += 2 * vectorSize;
995  }
996  }
997 
998  if (PanelMode) ri += 2 * vectorSize * (stride - offset - depth);
999  }
1000  for (; j + vectorSize <= rows; j += vectorSize) {
1001  const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1002  Index i = 0;
1003 
1004  if (PanelMode) ri += vectorSize * offset;
1005 
1006  if (StorageOrder == ColMajor) {
1007  for (; i + 2 <= depth; i += 2) {
1009 
1010  block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
1011  block.packet[1] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
1012 
1013  Packet8bf t0;
1014  t0 = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1015  block.packet[1] = vec_mergel(block.packet[0].m_val, block.packet[1].m_val);
1016  block.packet[0] = t0;
1017 
1018  storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
1019 
1020  ri += 2 * vectorSize;
1021  }
1022  if (depth & 1) {
1023  Packet8bf lhsV = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
1024  pstore<bfloat16>(blockA + ri, lhsV);
1025 
1026  ri += vectorSize;
1027  }
1028  } else {
1029  for (; i + vectorSize <= depth; i += vectorSize) {
1031 
1032  bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
1033 
1034  Packet4ui v1[8];
1035 
1036  // This is transposing and interleaving data
1037  v1[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
1038  reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
1039  v1[1] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
1040  reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
1041  v1[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
1042  reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
1043  v1[3] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
1044  reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
1045  v1[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
1046  reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
1047  v1[5] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
1048  reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
1049  v1[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
1050  reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
1051  v1[7] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
1052  reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
1053 
1054 #ifdef EIGEN_VECTORIZE_VSX
1055  block1.packet[0] = reinterpret_cast<Packet8us>(
1056  vec_mergeh(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
1057  block1.packet[2] = reinterpret_cast<Packet8us>(
1058  vec_mergel(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
1059  block1.packet[4] = reinterpret_cast<Packet8us>(
1060  vec_mergeh(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
1061  block1.packet[6] = reinterpret_cast<Packet8us>(
1062  vec_mergel(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
1063  block1.packet[1] = reinterpret_cast<Packet8us>(
1064  vec_mergeh(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
1065  block1.packet[3] = reinterpret_cast<Packet8us>(
1066  vec_mergel(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
1067  block1.packet[5] = reinterpret_cast<Packet8us>(
1068  vec_mergeh(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
1069  block1.packet[7] = reinterpret_cast<Packet8us>(
1070  vec_mergel(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
1071 #else
1072  block1.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_HI));
1073  block1.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_LO));
1074  block1.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_HI));
1075  block1.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_LO));
1076  block1.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_HI));
1077  block1.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_LO));
1078  block1.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_HI));
1079  block1.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_LO));
1080 #endif
1081 
1082  for (Index M = 0; M < 8; M++) {
1083  pstore<bfloat16>(blockA + ri + (vectorSize * M), block1.packet[M]);
1084  }
1085 
1086  ri += vectorSize * vectorSize;
1087  }
1088  for (; i + 2 <= depth; i += 2) {
1089  for (Index M = 0; M < vectorSize; M++) {
1090  blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
1091  blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
1092  }
1093 
1094  ri += 2 * vectorSize;
1095  }
1096  if (depth & 1) {
1097  for (Index M = 0; M < vectorSize; M++) {
1098  blockA[ri + M] = lhs2(M, i);
1099  }
1100 
1101  ri += vectorSize;
1102  }
1103  }
1104 
1105  if (PanelMode) ri += vectorSize * (stride - offset - depth);
1106  }
1107  if (j + 4 <= rows) {
1108  const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1109  Index i = 0;
1110 
1111  if (PanelMode) ri += 4 * offset;
1112 
1113  for (; i + 2 <= depth; i += 2) {
1114  if (StorageOrder == ColMajor) {
1116 
1117  block.packet[0] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
1118  block.packet[1] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 1, 4);
1119 
1120  block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1121 
1122  pstore<bfloat16>(blockA + ri, block.packet[0]);
1123  } else {
1124  blockA[ri + 0] = lhs2(0, i + 0);
1125  blockA[ri + 1] = lhs2(0, i + 1);
1126  blockA[ri + 2] = lhs2(1, i + 0);
1127  blockA[ri + 3] = lhs2(1, i + 1);
1128  blockA[ri + 4] = lhs2(2, i + 0);
1129  blockA[ri + 5] = lhs2(2, i + 1);
1130  blockA[ri + 6] = lhs2(3, i + 0);
1131  blockA[ri + 7] = lhs2(3, i + 1);
1132  }
1133 
1134  ri += 2 * 4;
1135  }
1136  if (depth & 1) {
1137  if (StorageOrder == ColMajor) {
1138  Packet8bf lhsV = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
1139 
1140  pstore_partial<bfloat16>(blockA + ri, lhsV, 4);
1141  } else {
1142  blockA[ri + 0] = lhs2(0, i);
1143  blockA[ri + 1] = lhs2(1, i);
1144  blockA[ri + 2] = lhs2(2, i);
1145  blockA[ri + 3] = lhs2(3, i);
1146  }
1147 
1148  ri += 4;
1149  }
1150 
1151  if (PanelMode) ri += 4 * (stride - offset - depth);
1152  j += 4;
1153  }
1154 
1155  if (j < rows) {
1156  if (PanelMode) ri += offset * (rows - j);
1157 
1158  Index i = 0;
1159  for (; i + 2 <= depth; i += 2) {
1160  Index k = j;
1161  for (; k < rows; k++) {
1162  blockA[ri + 0] = lhs(k, i + 0);
1163  blockA[ri + 1] = lhs(k, i + 1);
1164  ri += 2;
1165  }
1166  }
1167  if (depth & 1) {
1168  for (; j < rows; j++) {
1169  blockA[ri] = lhs(j, i);
1170  ri += 1;
1171  }
1172  }
1173  }
1174  }
1175 };
1176 
1177 // General template for rhs packing, bfloat16 specialization.
1178 template <typename DataMapper, int StorageOrder, bool PanelMode>
1179 struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false> {
1180  EIGEN_STRONG_INLINE void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride,
1181  Index offset) {
1182  const Index vectorSize = quad_traits<bfloat16>::vectorsize;
1183  Index ri = 0, j = 0;
1184 
1185  for (; j + 4 <= cols; j += 4) {
1186  const DataMapper rhs2 = rhs.getSubMapper(0, j);
1187  Index i = 0;
1188 
1189  if (PanelMode) ri += 4 * offset;
1190 
1191  for (; i + vectorSize <= depth; i += vectorSize) {
1192  if (StorageOrder == ColMajor) {
1194 
1195  bload<DataMapper, Packet8bf, 4, StorageOrder, false, 4>(block, rhs2, i, 0);
1196 
1197  Packet4ui t0, t1, t2, t3;
1198 
1199  t0 = vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[0].m_val),
1200  reinterpret_cast<Packet4ui>(block.packet[1].m_val));
1201  t1 = vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val),
1202  reinterpret_cast<Packet4ui>(block.packet[1].m_val));
1203  t2 = vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val),
1204  reinterpret_cast<Packet4ui>(block.packet[3].m_val));
1205  t3 = vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val),
1206  reinterpret_cast<Packet4ui>(block.packet[3].m_val));
1207 
1208 #ifdef EIGEN_VECTORIZE_VSX
1209  block.packet[0] =
1210  reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(t0), reinterpret_cast<Packet2ul>(t2)));
1211  block.packet[1] =
1212  reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(t0), reinterpret_cast<Packet2ul>(t2)));
1213  block.packet[2] =
1214  reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(t1), reinterpret_cast<Packet2ul>(t3)));
1215  block.packet[3] =
1216  reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(t1), reinterpret_cast<Packet2ul>(t3)));
1217 #else
1218  block.packet[0] = reinterpret_cast<Packet8us>(vec_perm(t0, t2, p16uc_TRANSPOSE64_HI));
1219  block.packet[1] = reinterpret_cast<Packet8us>(vec_perm(t0, t2, p16uc_TRANSPOSE64_LO));
1220  block.packet[2] = reinterpret_cast<Packet8us>(vec_perm(t1, t3, p16uc_TRANSPOSE64_HI));
1221  block.packet[3] = reinterpret_cast<Packet8us>(vec_perm(t1, t3, p16uc_TRANSPOSE64_LO));
1222 #endif
1223 
1224  storeBlock<bfloat16, Packet8bf, 4>(blockB + ri, block);
1225  } else {
1227 
1228  for (int M = 0; M < 8; M++) {
1229  block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
1230  }
1231 
1232  block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1233  block.packet[1] = vec_mergeh(block.packet[2].m_val, block.packet[3].m_val);
1234  block.packet[2] = vec_mergeh(block.packet[4].m_val, block.packet[5].m_val);
1235  block.packet[3] = vec_mergeh(block.packet[6].m_val, block.packet[7].m_val);
1236 
1237  const Index size = 16 / sizeof(bfloat16);
1238 
1239  for (int M = 0; M < 4; M++) {
1240  pstore<bfloat16>(blockB + ri + (M * size), block.packet[M]);
1241  }
1242  }
1243 
1244  ri += 4 * vectorSize;
1245  }
1246  for (; i + 2 <= depth; i += 2) {
1247  if (StorageOrder == ColMajor) {
1248  blockB[ri + 0] = rhs2(i + 0, 0);
1249  blockB[ri + 1] = rhs2(i + 1, 0);
1250  blockB[ri + 2] = rhs2(i + 0, 1);
1251  blockB[ri + 3] = rhs2(i + 1, 1);
1252  blockB[ri + 4] = rhs2(i + 0, 2);
1253  blockB[ri + 5] = rhs2(i + 1, 2);
1254  blockB[ri + 6] = rhs2(i + 0, 3);
1255  blockB[ri + 7] = rhs2(i + 1, 3);
1256  } else {
1258 
1259  for (int M = 0; M < 2; M++) {
1260  block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
1261  }
1262 
1263  block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1264 
1265  pstore<bfloat16>(blockB + ri, block.packet[0]);
1266  }
1267 
1268  ri += 4 * 2;
1269  }
1270  if (depth & 1) {
1271  blockB[ri + 0] = rhs2(i, 0);
1272  blockB[ri + 1] = rhs2(i, 1);
1273  blockB[ri + 2] = rhs2(i, 2);
1274  blockB[ri + 3] = rhs2(i, 3);
1275 
1276  ri += 4;
1277  }
1278 
1279  if (PanelMode) ri += 4 * (stride - offset - depth);
1280  }
1281 
1282  if (j < cols) {
1283  if (PanelMode) ri += offset * (cols - j);
1284 
1285  Index i = 0;
1286  for (; i + 2 <= depth; i += 2) {
1287  Index k = j;
1288  for (; k < cols; k++) {
1289  blockB[ri + 0] = rhs(i + 0, k);
1290  blockB[ri + 1] = rhs(i + 1, k);
1291  ri += 2;
1292  }
1293  }
1294  if (depth & 1) {
1295  for (; j < cols; j++) {
1296  blockB[ri] = rhs(i, j);
1297  ri += 1;
1298  }
1299  }
1300  }
1301  }
1302 };
1303 
1304 // General template for lhs complex packing, float64 specialization.
1305 template <typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
1306 struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true> {
1307  EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii,
1308  Index depth, const Index vectorSize) {
1309  PacketBlock<Packet, 2> blockr, blocki;
1310  PacketBlock<PacketC, 4> cblock;
1311 
1312  for (; i + vectorSize <= depth; i += vectorSize) {
1313  if (StorageOrder == ColMajor) {
1314  cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0); //[a1 a1i]
1315  cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
1316 
1317  cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0); //[a2 a2i]
1318  cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i]
1319 
1320  blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2]
1321  blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2]
1322 
1323  blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
1324  blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
1325  } else {
1326  cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i); //[a1 a1i]
1327  cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i); //[a2 a2i]
1328 
1329  cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
1330  cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i
1331 
1332  blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2]
1333  blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2]
1334 
1335  blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1336  blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1337  }
1338 
1339  if (Conjugate) {
1340  blocki.packet[0] = -blocki.packet[0];
1341  blocki.packet[1] = -blocki.packet[1];
1342  }
1343 
1344  storeBlock<double, Packet, 2>(blockAt + rir, blockr);
1345  storeBlock<double, Packet, 2>(blockAt + rii, blocki);
1346 
1347  rir += 2 * vectorSize;
1348  rii += 2 * vectorSize;
1349  }
1350  }
1351 
1352  EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
1353  Index stride, Index offset) {
1354  const Index vectorSize = quad_traits<double>::vectorsize;
1355  const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
1356  Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
1357  double* blockAt = reinterpret_cast<double*>(blockA);
1358  Index j = 0;
1359 
1360  for (; j + vectorSize <= rows; j += vectorSize) {
1361  const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1362  Index i = 0;
1363 
1364  rii = rir + vectorDelta;
1365 
1366  dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
1367 
1368  for (; i < depth; i++) {
1369  PacketBlock<Packet, 1> blockr, blocki;
1370  PacketBlock<PacketC, 2> cblock;
1371 
1372  cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
1373  cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i);
1374 
1375  blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1376  blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1377 
1378  if (Conjugate) {
1379  blocki.packet[0] = -blocki.packet[0];
1380  }
1381 
1382  pstore<double>(blockAt + rir, blockr.packet[0]);
1383  pstore<double>(blockAt + rii, blocki.packet[0]);
1384 
1385  rir += vectorSize;
1386  rii += vectorSize;
1387  }
1388 
1389  rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
1390  }
1391 
1392  if (j < rows) {
1393  if (PanelMode) rir += (offset * (rows - j - vectorSize));
1394  rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
1395 
1396  for (Index i = 0; i < depth; i++) {
1397  Index k = j;
1398  for (; k < rows; k++) {
1399  blockAt[rir] = lhs(k, i).real();
1400 
1401  if (Conjugate)
1402  blockAt[rii] = -lhs(k, i).imag();
1403  else
1404  blockAt[rii] = lhs(k, i).imag();
1405 
1406  rir += 1;
1407  rii += 1;
1408  }
1409  }
1410  }
1411  }
1412 };
1413 
1414 // General template for rhs complex packing, float64 specialization.
1415 template <typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
1416 struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false> {
1417  EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockBt, const DataMapper& rhs2, Index& i, Index& rir, Index& rii,
1418  Index depth, const Index vectorSize) {
1419  for (; i < depth; i++) {
1420  PacketBlock<PacketC, 4> cblock;
1421  PacketBlock<Packet, 2> blockr, blocki;
1422 
1423  bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
1424 
1425  blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1426  blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
1427 
1428  blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1429  blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1430 
1431  if (Conjugate) {
1432  blocki.packet[0] = -blocki.packet[0];
1433  blocki.packet[1] = -blocki.packet[1];
1434  }
1435 
1436  storeBlock<double, Packet, 2>(blockBt + rir, blockr);
1437  storeBlock<double, Packet, 2>(blockBt + rii, blocki);
1438 
1439  rir += 2 * vectorSize;
1440  rii += 2 * vectorSize;
1441  }
1442  }
1443 
1444  EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols,
1445  Index stride, Index offset) {
1446  const Index vectorSize = quad_traits<double>::vectorsize;
1447  const Index vectorDelta = 2 * vectorSize * ((PanelMode) ? stride : depth);
1448  Index rir = ((PanelMode) ? (2 * vectorSize * offset) : 0), rii;
1449  double* blockBt = reinterpret_cast<double*>(blockB);
1450  Index j = 0;
1451 
1452  for (; j + 2 * vectorSize <= cols; j += 2 * vectorSize) {
1453  const DataMapper rhs2 = rhs.getSubMapper(0, j);
1454  Index i = 0;
1455 
1456  rii = rir + vectorDelta;
1457 
1458  dhs_ccopy(blockBt, rhs2, i, rir, rii, depth, vectorSize);
1459 
1460  rir += ((PanelMode) ? (2 * vectorSize * (2 * stride - depth)) : vectorDelta);
1461  }
1462 
1463  if (PanelMode) rir -= (offset * (2 * vectorSize - 1));
1464 
1465  for (; j < cols; j++) {
1466  const DataMapper rhs2 = rhs.getSubMapper(0, j);
1467  rii = rir + ((PanelMode) ? stride : depth);
1468 
1469  for (Index i = 0; i < depth; i++) {
1470  blockBt[rir] = rhs2(i, 0).real();
1471 
1472  if (Conjugate)
1473  blockBt[rii] = -rhs2(i, 0).imag();
1474  else
1475  blockBt[rii] = rhs2(i, 0).imag();
1476 
1477  rir += 1;
1478  rii += 1;
1479  }
1480 
1481  rir += ((PanelMode) ? (2 * stride - depth) : depth);
1482  }
1483  }
1484 };
1485 
1486 /**************
1487  * GEMM utils *
1488  **************/
1489 
1490 // 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
1491 template <typename Packet, bool NegativeAccumulate, int N>
1493  if (NegativeAccumulate) {
1494  for (int M = 0; M < N; M++) {
1495  acc->packet[M] = vec_nmsub(lhsV, rhsV[M], acc->packet[M]);
1496  }
1497  } else {
1498  for (int M = 0; M < N; M++) {
1499  acc->packet[M] = vec_madd(lhsV, rhsV[M], acc->packet[M]);
1500  }
1501  }
1502 }
1503 
1504 template <int N, typename Scalar, typename Packet, bool NegativeAccumulate>
1505 EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet, N>* acc, const Scalar* lhs, const Packet* rhsV) {
1506  Packet lhsV = pload<Packet>(lhs);
1507 
1508  pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1509 }
1510 
1511 // 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types
1512 // real * complex and complex * real.
1513 template <int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1515  const Packet& lhsV, Packet& lhsVi, const Packet* rhsV, const Packet* rhsVi) {
1516  pger_common<Packet, false, N>(accReal, lhsV, rhsV);
1517  if (LhsIsReal) {
1518  pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1519  EIGEN_UNUSED_VARIABLE(lhsVi);
1520  } else {
1521  if (!RhsIsReal) {
1522  pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
1523  pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1524  } else {
1525  EIGEN_UNUSED_VARIABLE(rhsVi);
1526  }
1527  pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
1528  }
1529 }
1530 
1531 template <int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1533  const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) {
1534  Packet lhsV = ploadLhs<Packet>(lhs_ptr);
1535  Packet lhsVi;
1536  if (!LhsIsReal)
1537  lhsVi = ploadLhs<Packet>(lhs_ptr_imag);
1538  else
1539  EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1540 
1541  pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1542 }
1543 
1544 template <typename Packet>
1546  return ploadu<Packet>(lhs);
1547 }
1548 
1549 // Zero the accumulator on PacketBlock.
1550 template <typename Packet, int N>
1552  for (int M = 0; M < N; M++) {
1553  acc.packet[M] = pset1<Packet>((__UNPACK_TYPE__(Packet))0);
1554  }
1555 }
1556 
1557 template <typename Packet, int N>
1559  const Packet& pAlpha) {
1560  for (int M = 0; M < N; M++) {
1561  acc.packet[M] = vec_mul(accZ.packet[M], pAlpha);
1562  }
1563 }
1564 
1565 template <typename Packet, int N>
1567  for (int M = 0; M < N; M++) {
1568  acc.packet[M] = pand<Packet>(acc.packet[M], pMask);
1569  }
1570 }
1571 
1572 // Complex version of PacketBlock scaling.
1573 template <typename Packet, int N, bool mask>
1575  const Packet& bImag, PacketBlock<Packet, N>& cReal, PacketBlock<Packet, N>& cImag,
1576  const Packet& pMask) {
1577  if (mask && (sizeof(__UNPACK_TYPE__(Packet)) == sizeof(float))) {
1578  band<Packet, N>(aReal, pMask);
1579  band<Packet, N>(aImag, pMask);
1580  } else {
1581  EIGEN_UNUSED_VARIABLE(pMask);
1582  }
1583 
1584  bscalec_common<Packet, N>(cReal, aReal, bReal);
1585 
1586  bscalec_common<Packet, N>(cImag, aImag, bReal);
1587 
1588  pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
1589 
1590  pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
1591 }
1592 
1593 // Load a PacketBlock, the N parameters make tuning gemm easier so we can add more accumulators as needed.
1594 //
1595 // full = operate (load) on the entire PacketBlock or only half
1596 template <typename DataMapper, typename Packet, const Index accCols, int StorageOrder, bool Complex, int N, bool full>
1597 EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
1598  Index col) {
1599  if (StorageOrder == RowMajor) {
1600  for (int M = 0; M < N; M++) {
1601  acc.packet[M] = res.template loadPacket<Packet>(row + M, col);
1602  }
1603  if (Complex) {
1604  for (int M = 0; M < N; M++) {
1605  acc.packet[M + N] = res.template loadPacket<Packet>(row + M, col + accCols);
1606  }
1607  }
1608  } else {
1609  for (int M = 0; M < N; M++) {
1610  acc.packet[M] = res.template loadPacket<Packet>(row, col + M);
1611  }
1612  if (Complex && full) {
1613  for (int M = 0; M < N; M++) {
1614  acc.packet[M + N] = res.template loadPacket<Packet>(row + accCols, col + M);
1615  }
1616  }
1617  }
1618 }
1619 
1620 template <typename DataMapper, typename Packet, int N>
1622  for (int M = 0; M < N; M++) {
1623  res.template storePacket<Packet>(row, M, acc.packet[M]);
1624  }
1625 }
1626 
1627 #ifdef USE_PARTIAL_PACKETS
1628 template <typename DataMapper, typename Packet, const Index accCols, bool Complex, Index N, bool full>
1629 EIGEN_ALWAYS_INLINE void bload_partial(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
1630  Index elements) {
1631  for (Index M = 0; M < N; M++) {
1632  acc.packet[M] = res.template loadPacketPartial<Packet>(row, M, elements);
1633  }
1634  if (Complex && full) {
1635  for (Index M = 0; M < N; M++) {
1636  acc.packet[M + N] = res.template loadPacketPartial<Packet>(row + accCols, M, elements);
1637  }
1638  }
1639 }
1640 
1641 template <typename DataMapper, typename Packet, Index N>
1642 EIGEN_ALWAYS_INLINE void bstore_partial(PacketBlock<Packet, N>& acc, const DataMapper& res, Index row, Index elements) {
1643  for (Index M = 0; M < N; M++) {
1644  res.template storePacketPartial<Packet>(row, M, acc.packet[M], elements);
1645  }
1646 }
1647 #endif
1648 
1649 #ifdef _ARCH_PWR10
1650 #define USE_P10_AND_PVIPR2_0 (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
1651 #else
1652 #define USE_P10_AND_PVIPR2_0 0
1653 #endif
1654 
1655 #if !USE_P10_AND_PVIPR2_0
1656 const static Packet4i mask4[4] = {{0, 0, 0, 0}, {-1, 0, 0, 0}, {-1, -1, 0, 0}, {-1, -1, -1, 0}};
1657 #endif
1658 
1659 template <typename Packet>
1660 EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows) {
1661 #if USE_P10_AND_PVIPR2_0
1662 #ifdef _BIG_ENDIAN
1663  return Packet(vec_reve(vec_genwm((1 << remaining_rows) - 1)));
1664 #else
1665  return Packet(vec_genwm((1 << remaining_rows) - 1));
1666 #endif
1667 #else
1668  return Packet(mask4[remaining_rows]);
1669 #endif
1670 }
1671 
1672 template <>
1674 #if USE_P10_AND_PVIPR2_0
1675  Packet2d mask2 = Packet2d(vec_gendm(remaining_rows));
1676 #ifdef _BIG_ENDIAN
1677  return preverse(mask2);
1678 #else
1679  return mask2;
1680 #endif
1681 #else
1682  Packet2l ret = {-remaining_rows, 0};
1683  return Packet2d(ret);
1684 #endif
1685 }
1686 
1687 template <typename Packet, int N>
1689  for (int M = 0; M < N; M++) {
1690  acc.packet[M] = pmadd<Packet>(pAlpha, accZ.packet[M], acc.packet[M]);
1691  }
1692 }
1693 
1694 // Scale the PacketBlock vectors by alpha.
1695 template <typename Packet, int N, bool mask>
1697  const Packet& pMask) {
1698  if (mask) {
1699  band<Packet, N>(accZ, pMask);
1700  } else {
1701  EIGEN_UNUSED_VARIABLE(pMask);
1702  }
1703 
1704  bscale<Packet, N>(acc, accZ, pAlpha);
1705 }
1706 
1707 template <typename Packet, int N, bool real>
1709  const __UNPACK_TYPE__(Packet) * ap2, Packet& a0, Packet& a1, Packet& a2,
1710  Packet& a3) {
1711  a0 = pset1<Packet>(ap0[0]);
1712  if (N == 4) {
1713  a1 = pset1<Packet>(ap0[1]);
1714  a2 = pset1<Packet>(ap0[2]);
1715  a3 = pset1<Packet>(ap0[3]);
1716  EIGEN_UNUSED_VARIABLE(ap1);
1717  EIGEN_UNUSED_VARIABLE(ap2);
1718  } else {
1719  if (N > 1) {
1720  a1 = pset1<Packet>(ap1[0]);
1721  } else {
1723  EIGEN_UNUSED_VARIABLE(ap1);
1724  }
1725  if (N > 2) {
1726  a2 = pset1<Packet>(ap2[0]);
1727  } else {
1729  EIGEN_UNUSED_VARIABLE(ap2);
1730  }
1731  }
1732 }
1733 
1734 template <>
1735 EIGEN_ALWAYS_INLINE void pbroadcastN<Packet4f, 4, true>(const float* ap0, const float*, const float*, Packet4f& a0,
1736  Packet4f& a1, Packet4f& a2, Packet4f& a3) {
1737  pbroadcast4<Packet4f>(ap0, a0, a1, a2, a3);
1738 }
1739 
1740 template <>
1741 EIGEN_ALWAYS_INLINE void pbroadcastN<Packet4f, 4, false>(const float* ap0, const float* ap1, const float* ap2,
1742  Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) {
1743  pbroadcastN<Packet4f, 4, true>(ap0, ap1, ap2, a0, a1, a2, a3);
1744 }
1745 
1746 template <>
1747 EIGEN_ALWAYS_INLINE void pbroadcastN<Packet2d, 4, false>(const double* ap0, const double*, const double*, Packet2d& a0,
1748  Packet2d& a1, Packet2d& a2, Packet2d& a3) {
1749  a1 = pload<Packet2d>(ap0);
1750  a3 = pload<Packet2d>(ap0 + 2);
1751  a0 = vec_splat(a1, 0);
1752  a1 = vec_splat(a1, 1);
1753  a2 = vec_splat(a3, 0);
1754  a3 = vec_splat(a3, 1);
1755 }
1756 
1757 // Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
1758 template <typename Packet, typename Packetc, int N, bool full>
1761  for (int M = 0; M < N; M++) {
1762  acc1.packet[M].v = vec_mergeh(taccReal.packet[M], taccImag.packet[M]);
1763  }
1764 
1765  if (full) {
1766  for (int M = 0; M < N; M++) {
1767  acc2.packet[M].v = vec_mergel(taccReal.packet[M], taccImag.packet[M]);
1768  }
1769  }
1770 }
1771 
1772 template <typename Packet, typename Packetc, int N, bool full>
1775  PacketBlock<Packetc, N>& acc2) {
1776  bcouple_common<Packet, Packetc, N, full>(taccReal, taccImag, acc1, acc2);
1777 
1778  for (int M = 0; M < N; M++) {
1779  acc1.packet[M] = padd<Packetc>(tRes.packet[M], acc1.packet[M]);
1780  }
1781 
1782  if (full) {
1783  for (int M = 0; M < N; M++) {
1784  acc2.packet[M] = padd<Packetc>(tRes.packet[M + N], acc2.packet[M]);
1785  }
1786  }
1787 }
1788 
1789 // PEEL loop factor.
1790 #define PEEL 7
1791 #define PEEL_ROW 7
1792 
1793 #define MICRO_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1794 
1795 #define MICRO_NORMAL_ROWS accRows == quad_traits<Scalar>::rows || accRows == 1
1796 
1797 #define MICRO_NEW_ROWS ((MICRO_NORMAL_ROWS) ? accRows : 1)
1798 
1799 #define MICRO_RHS(ptr, N) rhs_##ptr##N
1800 
1801 #define MICRO_ZERO_PEEL(peel) \
1802  if ((PEEL_ROW > peel) && (peel != 0)) { \
1803  bsetzero<Packet, accRows>(accZero##peel); \
1804  } else { \
1805  EIGEN_UNUSED_VARIABLE(accZero##peel); \
1806  }
1807 
1808 #define MICRO_ADD(ptr, N) \
1809  if (MICRO_NORMAL_ROWS) { \
1810  MICRO_RHS(ptr, 0) += (accRows * N); \
1811  } else { \
1812  MICRO_RHS(ptr, 0) += N; \
1813  MICRO_RHS(ptr, 1) += N; \
1814  if (accRows == 3) { \
1815  MICRO_RHS(ptr, 2) += N; \
1816  } \
1817  }
1818 
1819 #define MICRO_ADD_ROWS(N) MICRO_ADD(ptr, N)
1820 
1821 #define MICRO_BROADCAST1(peel, ptr, rhsV, real) \
1822  if (MICRO_NORMAL_ROWS) { \
1823  pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + (accRows * peel), MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 0), \
1824  rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1825  } else { \
1826  pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + peel, MICRO_RHS(ptr, 1) + peel, MICRO_RHS(ptr, 2) + peel, \
1827  rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1828  }
1829 
1830 #define MICRO_BROADCAST(peel) MICRO_BROADCAST1(peel, ptr, rhsV, true)
1831 
1832 #define MICRO_BROADCAST_EXTRA1(ptr, rhsV, real) \
1833  pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 1), MICRO_RHS(ptr, 2), rhsV[0], rhsV[1], \
1834  rhsV[2], rhsV[3]);
1835 
1836 #define MICRO_BROADCAST_EXTRA \
1837  Packet rhsV[4]; \
1838  MICRO_BROADCAST_EXTRA1(ptr, rhsV, true) \
1839  MICRO_ADD_ROWS(1)
1840 
1841 #define MICRO_SRC2(ptr, N, M) \
1842  if (MICRO_NORMAL_ROWS) { \
1843  EIGEN_UNUSED_VARIABLE(strideB); \
1844  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 1)); \
1845  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1846  } else { \
1847  MICRO_RHS(ptr, 1) = rhs_base + N + M; \
1848  if (accRows == 3) { \
1849  MICRO_RHS(ptr, 2) = rhs_base + N * 2 + M; \
1850  } else { \
1851  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1852  } \
1853  }
1854 
1855 #define MICRO_SRC2_PTR MICRO_SRC2(ptr, strideB, 0)
1856 
1857 #define MICRO_ZERO_PEEL_ROW MICRO_UNROLL(MICRO_ZERO_PEEL)
1858 
1859 #define MICRO_WORK_PEEL(peel) \
1860  if (PEEL_ROW > peel) { \
1861  MICRO_BROADCAST(peel) \
1862  pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
1863  } else { \
1864  EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1865  }
1866 
1867 #define MICRO_WORK_PEEL_ROW \
1868  Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
1869  MICRO_UNROLL(MICRO_WORK_PEEL) \
1870  lhs_ptr += (remaining_rows * PEEL_ROW); \
1871  MICRO_ADD_ROWS(PEEL_ROW)
1872 
1873 #define MICRO_ADD_PEEL(peel, sum) \
1874  if (PEEL_ROW > peel) { \
1875  for (Index i = 0; i < accRows; i++) { \
1876  accZero##sum.packet[i] += accZero##peel.packet[i]; \
1877  } \
1878  }
1879 
1880 #define MICRO_ADD_PEEL_ROW \
1881  MICRO_ADD_PEEL(4, 0) \
1882  MICRO_ADD_PEEL(5, 1) \
1883  MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
1884 
1885 #define MICRO_PREFETCHN1(ptr, N) \
1886  EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 0)); \
1887  if (N == 2 || N == 3) { \
1888  EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 1)); \
1889  if (N == 3) { \
1890  EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 2)); \
1891  } \
1892  }
1893 
1894 #define MICRO_PREFETCHN(N) MICRO_PREFETCHN1(ptr, N)
1895 
1896 #define MICRO_COMPLEX_PREFETCHN(N) \
1897  MICRO_PREFETCHN1(ptr_real, N); \
1898  if (!RhsIsReal) { \
1899  MICRO_PREFETCHN1(ptr_imag, N); \
1900  }
1901 
1902 template <typename Scalar, typename Packet, const Index accRows, const Index remaining_rows>
1903 EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(const Scalar*& lhs_ptr, const Scalar*& rhs_ptr0, const Scalar*& rhs_ptr1,
1904  const Scalar*& rhs_ptr2, PacketBlock<Packet, accRows>& accZero) {
1906  pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1907  lhs_ptr += remaining_rows;
1908 }
1909 
1910 template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols,
1911  const Index remaining_rows>
1912 EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(const DataMapper& res, const Scalar* lhs_base,
1913  const Scalar* rhs_base, Index depth, Index strideA, Index offsetA,
1914  Index strideB, Index row, Index rows, const Packet& pAlpha,
1915  const Packet& pMask) {
1916  const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
1917  const Scalar* lhs_ptr = lhs_base + row * strideA + remaining_rows * offsetA;
1918  PacketBlock<Packet, accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
1919 
1921  bsetzero<Packet, accRows>(accZero0);
1922 
1923  Index remaining_depth = depth & -quad_traits<Scalar>::rows;
1924  Index k = 0;
1925  if (remaining_depth >= PEEL_ROW) {
1927  do {
1928  MICRO_PREFETCHN(accRows)
1929  EIGEN_POWER_PREFETCH(lhs_ptr);
1931  } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
1933  }
1934  for (; k < depth; k++) {
1935  MICRO_EXTRA_ROW<Scalar, Packet, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
1936  }
1937 
1938 #ifdef USE_PARTIAL_PACKETS
1940  EIGEN_UNUSED_VARIABLE(pMask);
1941  bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row, remaining_rows);
1942  bscale<Packet, accRows>(acc, accZero0, pAlpha);
1943  bstore_partial<DataMapper, Packet, accRows>(acc, res, row, remaining_rows);
1944 #else
1945  bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row, 0);
1946  if ((accRows == 1) || (rows >= accCols)) {
1947  bscale<Packet, accRows, true>(acc, accZero0, pAlpha, pMask);
1948  bstore<DataMapper, Packet, accRows>(acc, res, row);
1949  } else {
1950  bscale<Packet, accRows, false>(acc, accZero0, pAlpha, pMask);
1951  for (Index j = 0; j < accRows; j++) {
1952  for (Index i = 0; i < remaining_rows; i++) {
1953  res(row + i, j) = acc.packet[j][i];
1954  }
1955  }
1956  }
1957 #endif
1958 }
1959 
1960 #define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
1961  switch (value) { \
1962  default: \
1963  MICRO_EXTRA_UNROLL(1) \
1964  break; \
1965  case 2: \
1966  if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1967  MICRO_EXTRA_UNROLL(2) \
1968  } \
1969  break; \
1970  case 3: \
1971  if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1972  MICRO_EXTRA_UNROLL(3) \
1973  } \
1974  break; \
1975  }
1976 
1977 #define MICRO_EXTRA_ROWS(N) \
1978  gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, accRows, accCols, N>( \
1979  res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
1980 
1981 template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
1982 EIGEN_ALWAYS_INLINE void gemm_extra_row(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
1983  Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows,
1984  Index remaining_rows, const Packet& pAlpha, const Packet& pMask) {
1985  MICRO_EXTRA(MICRO_EXTRA_ROWS, remaining_rows, false)
1986 }
1987 
1988 #define MICRO_UNROLL_WORK(func, func2, peel) \
1989  MICRO_UNROLL(func2); \
1990  func(0, peel) func(1, peel) func(2, peel) func(3, peel) func(4, peel) func(5, peel) func(6, peel) func(7, peel)
1991 
1992 #define MICRO_WORK_ONE(iter, peel) \
1993  if (unroll_factor > iter) { \
1994  pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
1995  }
1996 
1997 #define MICRO_TYPE_PEEL4(func, func2, peel) \
1998  if (PEEL > peel) { \
1999  Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
2000  MICRO_BROADCAST(peel) \
2001  MICRO_UNROLL_WORK(func, func2, peel) \
2002  } else { \
2003  EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2004  }
2005 
2006 #define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2007  Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
2008  func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3) func(func1, func2, 4) \
2009  func(func1, func2, 5) func(func1, func2, 6) func(func1, func2, 7)
2010 
2011 #define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
2012  Packet rhsV0[M]; \
2013  func(func1, func2, 0)
2014 
2015 #define MICRO_UNROLL_TYPE(MICRO_TYPE, size) \
2016  MICRO_TYPE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE) \
2017  MICRO_ADD_ROWS(size)
2018 
2019 #define MICRO_ONE_PEEL4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_PEEL, PEEL)
2020 
2021 #define MICRO_ONE4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_ONE, 1)
2022 
2023 #define MICRO_DST_PTR_ONE(iter) \
2024  if (unroll_factor > iter) { \
2025  bsetzero<Packet, accRows>(accZero##iter); \
2026  } else { \
2027  EIGEN_UNUSED_VARIABLE(accZero##iter); \
2028  }
2029 
2030 #define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
2031 
2032 #define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
2033 
2034 #define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
2035 
2036 #ifdef USE_PARTIAL_PACKETS
2037 #define MICRO_STORE_ONE(iter) \
2038  if (unroll_factor > iter) { \
2039  if (MICRO_NORMAL_PARTIAL(iter)) { \
2040  bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2041  bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2042  bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2043  } else { \
2044  bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row + iter * accCols, accCols2); \
2045  bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2046  bstore_partial<DataMapper, Packet, accRows>(acc, res, row + iter * accCols, accCols2); \
2047  } \
2048  }
2049 #else
2050 #define MICRO_STORE_ONE(iter) \
2051  if (unroll_factor > iter) { \
2052  bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2053  bscale<Packet, accRows, !(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
2054  bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2055  }
2056 #endif
2057 
2058 #define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
2059 
2060 #ifdef USE_PARTIAL_PACKETS
2061 template <int unroll_factor, typename Scalar, typename Packet, typename DataMapper, const Index accRows,
2062  const Index accCols, bool full>
2063 #else
2064 template <int unroll_factor, typename Scalar, typename Packet, typename DataMapper, const Index accRows,
2065  const Index accCols, const Index accCols2>
2066 #endif
2067 EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
2068  Index depth, Index strideA, Index offsetA, Index strideB, Index& row,
2069  const Packet& pAlpha,
2070 #ifdef USE_PARTIAL_PACKETS
2071  Index accCols2
2072 #else
2073  const Packet& pMask
2074 #endif
2075 ) {
2076  const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
2077  const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
2078  *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
2079  PacketBlock<Packet, accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
2081 
2085 
2086  Index k = 0;
2087  for (; k + PEEL <= depth; k += PEEL) {
2088  MICRO_PREFETCHN(accRows)
2091  }
2092  for (; k < depth; k++) {
2093  MICRO_ONE4
2094  }
2095  MICRO_STORE
2096 
2097  MICRO_UPDATE
2098 }
2099 
2100 #ifdef USE_PARTIAL_PACKETS
2101 #define MICRO_UNROLL_ITER2(N, M) \
2102  gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, !M>( \
2103  res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, M ? remaining_rows : accCols); \
2104  if (M) return;
2105 #else
2106 #define MICRO_UNROLL_ITER2(N, M) \
2107  gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, M ? M : accCols>( \
2108  res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
2109  if (M) return;
2110 #endif
2111 
2112 template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
2113 EIGEN_ALWAYS_INLINE void gemm_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
2114  Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows,
2115  Index remaining_rows, const Packet& pAlpha, const Packet& pMask) {
2116  const DataMapper res3 = res.getSubMapper(0, col);
2117 
2118  const Scalar* rhs_base = blockB + col * strideB + MICRO_NEW_ROWS * offsetB;
2119  const Scalar* lhs_base = blockA + accCols * offsetA;
2120  Index row = 0;
2121 
2122 #define MAX_UNROLL 7
2123  while (row + MAX_UNROLL * accCols <= rows) {
2125  }
2126  switch ((rows - row) / accCols) {
2127 #if MAX_UNROLL > 7
2128  case 7:
2130  break;
2131 #endif
2132 #if MAX_UNROLL > 6
2133  case 6:
2135  break;
2136 #endif
2137 #if MAX_UNROLL > 5
2138  case 5:
2140  break;
2141 #endif
2142 #if MAX_UNROLL > 4
2143  case 4:
2145  break;
2146 #endif
2147 #if MAX_UNROLL > 3
2148  case 3:
2150  break;
2151 #endif
2152 #if MAX_UNROLL > 2
2153  case 2:
2155  break;
2156 #endif
2157 #if MAX_UNROLL > 1
2158  case 1:
2160  break;
2161 #endif
2162  default:
2163  break;
2164  }
2165 #undef MAX_UNROLL
2166 
2167  if (remaining_rows > 0) {
2168  gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA,
2169  strideB, row, rows, remaining_rows, pAlpha, pMask);
2170  }
2171 }
2172 
2173 #define MICRO_EXTRA_COLS(N) \
2174  gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, \
2175  col, rows, remaining_rows, pAlpha, pMask);
2176 
2177 template <typename Scalar, typename Packet, typename DataMapper, const Index accCols>
2178 EIGEN_ALWAYS_INLINE void gemm_extra_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
2179  Index strideA, Index offsetA, Index strideB, Index offsetB, Index col,
2180  Index rows, Index cols, Index remaining_rows, const Packet& pAlpha,
2181  const Packet& pMask) {
2183 }
2184 
2185 /****************
2186  * GEMM kernels *
2187  * **************/
2188 template <typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows,
2189  const Index accCols>
2190 EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows,
2191  Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA,
2192  Index offsetB) {
2193  const Index remaining_rows = rows % accCols;
2194 
2195  if (strideA == -1) strideA = depth;
2196  if (strideB == -1) strideB = depth;
2197 
2198  const Packet pAlpha = pset1<Packet>(alpha);
2199  const Packet pMask = bmask<Packet>(remaining_rows);
2200 
2201  Index col = 0;
2202  for (; col + accRows <= cols; col += accRows) {
2203  gemm_cols<Scalar, Packet, DataMapper, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB,
2204  offsetB, col, rows, remaining_rows, pAlpha, pMask);
2205  }
2206 
2207  if (col != cols) {
2208  gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
2209  col, rows, cols, remaining_rows, pAlpha, pMask);
2210  }
2211 }
2212 
2213 #define accColsC (accCols / 2)
2214 #define advanceRows ((LhsIsReal) ? 1 : 2)
2215 #define advanceCols ((RhsIsReal) ? 1 : 2)
2216 
2217 // PEEL_COMPLEX loop factor.
2218 #define PEEL_COMPLEX 3
2219 #define PEEL_COMPLEX_ROW 3
2220 
2221 #define MICRO_COMPLEX_UNROLL(func) func(0) func(1) func(2) func(3)
2222 
2223 #define MICRO_COMPLEX_ZERO_PEEL(peel) \
2224  if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
2225  bsetzero<Packet, accRows>(accReal##peel); \
2226  bsetzero<Packet, accRows>(accImag##peel); \
2227  } else { \
2228  EIGEN_UNUSED_VARIABLE(accReal##peel); \
2229  EIGEN_UNUSED_VARIABLE(accImag##peel); \
2230  }
2231 
2232 #define MICRO_COMPLEX_ADD_ROWS(N, used) \
2233  MICRO_ADD(ptr_real, N) \
2234  if (!RhsIsReal) { \
2235  MICRO_ADD(ptr_imag, N) \
2236  } else if (used) { \
2237  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2238  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2239  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2240  }
2241 
2242 #define MICRO_COMPLEX_BROADCAST(peel) \
2243  MICRO_BROADCAST1(peel, ptr_real, rhsV, false) \
2244  if (!RhsIsReal) { \
2245  MICRO_BROADCAST1(peel, ptr_imag, rhsVi, false) \
2246  } else { \
2247  EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2248  }
2249 
2250 #define MICRO_COMPLEX_BROADCAST_EXTRA \
2251  Packet rhsV[4], rhsVi[4]; \
2252  MICRO_BROADCAST_EXTRA1(ptr_real, rhsV, false) \
2253  if (!RhsIsReal) { \
2254  MICRO_BROADCAST_EXTRA1(ptr_imag, rhsVi, false) \
2255  } else { \
2256  EIGEN_UNUSED_VARIABLE(rhsVi); \
2257  } \
2258  MICRO_COMPLEX_ADD_ROWS(1, true)
2259 
2260 #define MICRO_COMPLEX_SRC2_PTR \
2261  MICRO_SRC2(ptr_real, strideB* advanceCols, 0) \
2262  if (!RhsIsReal) { \
2263  MICRO_RHS(ptr_imag, 0) = rhs_base + MICRO_NEW_ROWS * strideB; \
2264  MICRO_SRC2(ptr_imag, strideB* advanceCols, strideB) \
2265  } else { \
2266  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2267  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2268  EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2269  }
2270 
2271 #define MICRO_COMPLEX_ZERO_PEEL_ROW MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_ZERO_PEEL)
2272 
2273 #define MICRO_COMPLEX_WORK_PEEL(peel) \
2274  if (PEEL_COMPLEX_ROW > peel) { \
2275  MICRO_COMPLEX_BROADCAST(peel) \
2276  pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2277  &accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), \
2278  lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
2279  } else { \
2280  EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2281  EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2282  }
2283 
2284 #define MICRO_COMPLEX_ADD_COLS(size) \
2285  lhs_ptr_real += (remaining_rows * size); \
2286  if (!LhsIsReal) \
2287  lhs_ptr_imag += (remaining_rows * size); \
2288  else \
2289  EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2290 
2291 #define MICRO_COMPLEX_WORK_PEEL_ROW \
2292  Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
2293  Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
2294  MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_WORK_PEEL) \
2295  MICRO_COMPLEX_ADD_COLS(PEEL_COMPLEX_ROW) \
2296  MICRO_COMPLEX_ADD_ROWS(PEEL_COMPLEX_ROW, false)
2297 
2298 #define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
2299  if (PEEL_COMPLEX_ROW > peel) { \
2300  for (Index i = 0; i < accRows; i++) { \
2301  accReal##sum.packet[i] += accReal##peel.packet[i]; \
2302  accImag##sum.packet[i] += accImag##peel.packet[i]; \
2303  } \
2304  }
2305 
2306 #define MICRO_COMPLEX_ADD_PEEL_ROW \
2307  MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) MICRO_COMPLEX_ADD_PEEL(1, 0)
2308 
2309 template <typename Scalar, typename Packet, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal,
2310  bool RhsIsReal, const Index remaining_rows>
2311 EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(const Scalar*& lhs_ptr_real, const Scalar*& lhs_ptr_imag,
2312  const Scalar*& rhs_ptr_real0, const Scalar*& rhs_ptr_real1,
2313  const Scalar*& rhs_ptr_real2, const Scalar*& rhs_ptr_imag0,
2314  const Scalar*& rhs_ptr_imag1, const Scalar*& rhs_ptr_imag2,
2316  PacketBlock<Packet, accRows>& accImag) {
2318  pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real,
2319  lhs_ptr_imag, rhsV, rhsVi);
2321 }
2322 
2323 template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2324  const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal,
2325  const Index remaining_rows>
2326 EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(const DataMapper& res, const Scalar* lhs_base,
2327  const Scalar* rhs_base, Index depth, Index strideA,
2328  Index offsetA, Index strideB, Index row, Index rows,
2329  const Packet& pAlphaReal, const Packet& pAlphaImag,
2330  const Packet& pMask) {
2331  const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2332  const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2333  const Scalar* lhs_ptr_real = lhs_base + advanceRows * row * strideA + remaining_rows * offsetA;
2334  const Scalar* lhs_ptr_imag = NULL;
2335  if (!LhsIsReal)
2336  lhs_ptr_imag = lhs_ptr_real + remaining_rows * strideA;
2337  else
2338  EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2339  PacketBlock<Packet, accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
2340  PacketBlock<Packet, accRows> taccReal, taccImag;
2341  PacketBlock<Packetc, accRows> acc0, acc1;
2343 
2345 
2346  bsetzero<Packet, accRows>(accReal0);
2347  bsetzero<Packet, accRows>(accImag0);
2348 
2349  Index remaining_depth = depth & -quad_traits<Scalar>::rows;
2350  Index k = 0;
2351  if (remaining_depth >= PEEL_COMPLEX_ROW) {
2353  do {
2354  MICRO_COMPLEX_PREFETCHN(accRows)
2355  EIGEN_POWER_PREFETCH(lhs_ptr_real);
2356  if (!LhsIsReal) {
2357  EIGEN_POWER_PREFETCH(lhs_ptr_imag);
2358  }
2360  } while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth);
2362  }
2363  for (; k < depth; k++) {
2364  MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(
2365  lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1,
2366  rhs_ptr_imag2, accReal0, accImag0);
2367  }
2368 
2369  constexpr bool full = (remaining_rows > accColsC);
2370  bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row, 0);
2371  if ((accRows == 1) || (rows >= accCols)) {
2372  bscalec<Packet, accRows, true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2373  bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2374  bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
2375  if (full) {
2376  bstore<DataMapper, Packetc, accRows>(acc1, res, row + accColsC);
2377  }
2378  } else {
2379  bscalec<Packet, accRows, false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2380  bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2381 
2382  if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) {
2383  for (Index j = 0; j < accRows; j++) {
2384  res(row + 0, j) = pfirst<Packetc>(acc0.packet[j]);
2385  }
2386  } else {
2387  bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
2388  if (full) {
2389  for (Index j = 0; j < accRows; j++) {
2390  res(row + accColsC, j) = pfirst<Packetc>(acc1.packet[j]);
2391  }
2392  }
2393  }
2394  }
2395 }
2396 
2397 #define MICRO_COMPLEX_EXTRA_ROWS(N) \
2398  gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, \
2399  ConjugateRhs, LhsIsReal, RhsIsReal, N>( \
2400  res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
2401 
2402 template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2403  const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2404 EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
2405  Index depth, Index strideA, Index offsetA, Index strideB, Index row,
2406  Index rows, Index remaining_rows, const Packet& pAlphaReal,
2407  const Packet& pAlphaImag, const Packet& pMask) {
2408  MICRO_EXTRA(MICRO_COMPLEX_EXTRA_ROWS, remaining_rows, false)
2409 }
2410 
2411 #define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2412  MICRO_COMPLEX_UNROLL(func2); \
2413  func(0, peel) func(1, peel) func(2, peel) func(3, peel)
2414 
2415 #define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
2416  if (unroll_factor > iter) { \
2417  pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2418  &accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2419  }
2420 
2421 #define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
2422  if (PEEL_COMPLEX > peel) { \
2423  Packet lhsV0, lhsV1, lhsV2, lhsV3; \
2424  Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
2425  MICRO_COMPLEX_BROADCAST(peel) \
2426  MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2427  } else { \
2428  EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2429  EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2430  }
2431 
2432 #define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2433  Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
2434  Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
2435  func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3)
2436 
2437 #define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2438  Packet rhsV0[M], rhsVi0[M]; \
2439  func(func1, func2, 0)
2440 
2441 #define MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_TYPE, size) \
2442  MICRO_COMPLEX_TYPE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE) \
2443  MICRO_COMPLEX_ADD_ROWS(size, false)
2444 
2445 #define MICRO_COMPLEX_ONE_PEEL4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_PEEL, PEEL_COMPLEX)
2446 
2447 #define MICRO_COMPLEX_ONE4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_ONE, 1)
2448 
2449 #define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2450  if (unroll_factor > iter) { \
2451  bsetzero<Packet, accRows>(accReal##iter); \
2452  bsetzero<Packet, accRows>(accImag##iter); \
2453  } else { \
2454  EIGEN_UNUSED_VARIABLE(accReal##iter); \
2455  EIGEN_UNUSED_VARIABLE(accImag##iter); \
2456  }
2457 
2458 #define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2459 
2460 #define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2461 
2462 #define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2463 
2464 #define MICRO_COMPLEX_STORE_ONE(iter) \
2465  if (unroll_factor > iter) { \
2466  constexpr bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
2467  bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter * accCols, 0); \
2468  bscalec<Packet, accRows, !(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, \
2469  taccImag, pMask); \
2470  bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
2471  bstore<DataMapper, Packetc, accRows>(acc0, res, row + iter * accCols + 0); \
2472  if (full) { \
2473  bstore<DataMapper, Packetc, accRows>(acc1, res, row + iter * accCols + accColsC); \
2474  } \
2475  }
2476 
2477 #define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2478 
2479 template <int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper,
2480  const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs,
2481  bool LhsIsReal, bool RhsIsReal>
2482 EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(const DataMapper& res, const Scalar* lhs_base,
2483  const Scalar* rhs_base, Index depth, Index strideA,
2484  Index offsetA, Index strideB, Index& row,
2485  const Packet& pAlphaReal, const Packet& pAlphaImag,
2486  const Packet& pMask) {
2487  const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2488  const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2489  const Index imag_delta = accCols * strideA;
2490  const Index imag_delta2 = accCols2 * strideA;
2491  const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
2492  const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
2493  PacketBlock<Packet, accRows> accReal0, accImag0, accReal1, accImag1;
2494  PacketBlock<Packet, accRows> accReal2, accImag2, accReal3, accImag3;
2495  PacketBlock<Packet, accRows> taccReal, taccImag;
2496  PacketBlock<Packetc, accRows> acc0, acc1;
2498 
2502 
2503  Index k = 0;
2504  for (; k + PEEL_COMPLEX <= depth; k += PEEL_COMPLEX) {
2505  MICRO_COMPLEX_PREFETCHN(accRows)
2508  }
2509  for (; k < depth; k++) {
2511  }
2513 
2515 }
2516 
2517 #define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
2518  gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, accRows, accCols, \
2519  M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2520  res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
2521  if (M) return;
2522 
2523 template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2524  const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2525 EIGEN_ALWAYS_INLINE void gemm_complex_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
2526  Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB,
2527  Index col, Index rows, Index remaining_rows, const Packet& pAlphaReal,
2528  const Packet& pAlphaImag, const Packet& pMask) {
2529  const DataMapper res3 = res.getSubMapper(0, col);
2530 
2531  const Scalar* rhs_base = blockB + advanceCols * col * strideB + MICRO_NEW_ROWS * offsetB;
2532  const Scalar* lhs_base = blockA + accCols * offsetA;
2533  Index row = 0;
2534 
2535 #define MAX_COMPLEX_UNROLL 4
2536  while (row + MAX_COMPLEX_UNROLL * accCols <= rows) {
2538  }
2539  switch ((rows - row) / accCols) {
2540 #if MAX_COMPLEX_UNROLL > 4
2541  case 4:
2543  break;
2544 #endif
2545 #if MAX_COMPLEX_UNROLL > 3
2546  case 3:
2548  break;
2549 #endif
2550 #if MAX_COMPLEX_UNROLL > 2
2551  case 2:
2553  break;
2554 #endif
2555 #if MAX_COMPLEX_UNROLL > 1
2556  case 1:
2558  break;
2559 #endif
2560  default:
2561  break;
2562  }
2563 #undef MAX_COMPLEX_UNROLL
2564 
2565  if (remaining_rows > 0) {
2566  gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2567  RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows,
2568  remaining_rows, pAlphaReal, pAlphaImag, pMask);
2569  }
2570 }
2571 
2572 #define MICRO_COMPLEX_EXTRA_COLS(N) \
2573  gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
2574  RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, \
2575  remaining_rows, pAlphaReal, pAlphaImag, pMask);
2576 
2577 template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols,
2578  bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2579 EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
2580  Index depth, Index strideA, Index offsetA, Index strideB,
2581  Index offsetB, Index col, Index rows, Index cols, Index remaining_rows,
2582  const Packet& pAlphaReal, const Packet& pAlphaImag,
2583  const Packet& pMask) {
2585 }
2586 
2587 template <typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc,
2588  typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs,
2589  bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2590 EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc,
2591  Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB,
2592  Index offsetA, Index offsetB) {
2593  const Index remaining_rows = rows % accCols;
2594 
2595  if (strideA == -1) strideA = depth;
2596  if (strideB == -1) strideB = depth;
2597 
2598  const Packet pAlphaReal = pset1<Packet>(alpha.real());
2599  const Packet pAlphaImag = pset1<Packet>(alpha.imag());
2600  const Packet pMask = bmask<Packet>(remaining_rows);
2601 
2602  const Scalar* blockA = (Scalar*)blockAc;
2603  const Scalar* blockB = (Scalar*)blockBc;
2604 
2605  Index col = 0;
2606  for (; col + accRows <= cols; col += accRows) {
2607  gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2608  RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows,
2609  remaining_rows, pAlphaReal, pAlphaImag, pMask);
2610  }
2611 
2612  if (col != cols) {
2613  gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2614  RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols,
2615  remaining_rows, pAlphaReal, pAlphaImag, pMask);
2616  }
2617 }
2618 
2619 #undef accColsC
2620 #undef advanceCols
2621 #undef advanceRows
2622 
2624 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2625  return true;
2626 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && defined(__BUILTIN_CPU_SUPPORTS__)
2627  return __builtin_cpu_supports("arch_3_1") && __builtin_cpu_supports("mma");
2628 #else
2629  return false; // No dynamic dispatch for LLVM or older GCC
2630 #endif
2631 }
2632 
2634  Packet4f result_block = ploadu<Packet4f>(result);
2635  return pmadd(acc, pAlpha, result_block);
2636 }
2637 
2638 template <bool lhsExtraRows>
2639 EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows) {
2640  if (lhsExtraRows) {
2641  pstoreu_partial(result, result_block, extra_rows);
2642  } else {
2643  pstoreu(result, result_block);
2644  }
2645  result += rows;
2646 }
2647 
2648 template <bool rhsExtraCols, bool lhsExtraRows>
2649 EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result,
2650  Index extra_cols, Index extra_rows) {
2651  Index x = 0;
2652  if (rhsExtraCols) {
2653  do {
2654  Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
2655  storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
2656  } while (++x < extra_cols);
2657  } else {
2658  Packet4f result_block[4];
2659  float* result2 = result;
2660  do {
2661  result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
2662  result += rows;
2663  } while (++x < 4);
2664  x = 0;
2665  do {
2666  storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
2667  } while (++x < 4);
2668  }
2669 }
2670 
2672  Packet8us z = pset1<Packet8us>(0);
2673 #ifdef _BIG_ENDIAN
2674  return reinterpret_cast<Packet4f>(vec_mergeh(data, z));
2675 #else
2676  return reinterpret_cast<Packet4f>(vec_mergeh(z, data));
2677 #endif
2678 }
2679 
2681  Packet8us z = pset1<Packet8us>(0);
2682 #ifdef _BIG_ENDIAN
2683  return reinterpret_cast<Packet4f>(vec_mergel(data, z));
2684 #else
2685  return reinterpret_cast<Packet4f>(vec_mergel(z, data));
2686 #endif
2687 }
2688 
2689 template <Index N, Index M>
2690 EIGEN_ALWAYS_INLINE void storeConvertTwoBF16(float* to, PacketBlock<Packet8bf, (N + 7) / 8>& block, Index extra = 0) {
2691  if (N < 4) {
2692  pstoreu_partial(to + 0, oneConvertBF16Hi(block.packet[0].m_val), extra);
2693  } else if (N >= (M * 8 + 4)) {
2694  pstoreu(to + 0, oneConvertBF16Hi(block.packet[M].m_val));
2695  if (N >= 8) {
2696  pstoreu(to + 4, oneConvertBF16Lo(block.packet[M].m_val));
2697  }
2698  }
2699 }
2700 
2701 template <Index N>
2703  storeConvertTwoBF16<N, 0>(to + 0, block, extra);
2704  if (N >= 16) {
2705  storeConvertTwoBF16<N, 1>(to + 8, block);
2706  }
2707  if (N >= 32) {
2708  storeConvertTwoBF16<N, 2>(to + 16, block);
2709  storeConvertTwoBF16<N, 3>(to + 24, block);
2710  }
2711 }
2712 
2713 template <bool non_unit_stride, Index delta>
2715  if (non_unit_stride) {
2716  return pgather<bfloat16, Packet8bf>(src + delta * resInc, resInc);
2717  } else {
2718  return ploadu<Packet8bf>(src + delta);
2719  }
2720 }
2721 
2722 static Packet16uc p16uc_MERGE16_32_1 = {0, 1, 16, 17, 2, 3, 18, 19, 0, 1, 16, 17, 2, 3, 18, 19};
2723 static Packet16uc p16uc_MERGE16_32_2 = {4, 5, 20, 21, 6, 7, 22, 23, 4, 5, 20, 21, 6, 7, 22, 23};
2724 static Packet16uc p16uc_MERGE16_32_3 = {8, 9, 24, 25, 10, 11, 26, 27, 8, 9, 24, 25, 10, 11, 26, 27};
2725 static Packet16uc p16uc_MERGE16_32_4 = {12, 13, 28, 29, 14, 15, 30, 31, 12, 13, 28, 29, 14, 15, 30, 31};
2726 
2727 static Packet16uc p16uc_MERGE16_32_5 = {0, 1, 16, 17, 16, 17, 16, 17, 0, 1, 16, 17, 16, 17, 16, 17};
2728 static Packet16uc p16uc_MERGE16_32_6 = {2, 3, 18, 19, 18, 19, 18, 19, 2, 3, 18, 19, 18, 19, 18, 19};
2729 static Packet16uc p16uc_MERGE16_32_7 = {4, 5, 20, 21, 20, 21, 20, 21, 4, 5, 20, 21, 20, 21, 20, 21};
2730 static Packet16uc p16uc_MERGE16_32_8 = {6, 7, 22, 23, 22, 23, 22, 23, 6, 7, 22, 23, 22, 23, 22, 23};
2731 
2733  Packet8us z = pset1<Packet8us>(0);
2734 #ifdef _BIG_ENDIAN
2735  return reinterpret_cast<Packet4f>(vec_perm(data, z, mask));
2736 #else
2737  return reinterpret_cast<Packet4f>(vec_perm(z, data, mask));
2738 #endif
2739 }
2740 
2741 template <bool lhsExtraRows, bool odd, Index size>
2743  Index extra_rows) {
2744  Packet4f dup[4 * 4];
2745  Packet8bf data[4];
2746 
2747  for (Index i = 0; i < size; i++) {
2748  data[i] = ploadu<Packet8bf>(src + rows * i);
2749  }
2750 
2751  for (Index i = 0, j = 0; i < size; i++, j += 4) {
2752  dup[j + 0] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_5 : p16uc_MERGE16_32_1);
2753  dup[j + 1] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_6 : p16uc_MERGE16_32_2);
2754  dup[j + 2] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_7 : p16uc_MERGE16_32_3);
2755  dup[j + 3] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_8 : p16uc_MERGE16_32_4);
2756  }
2757 
2758  for (Index j = 0; j < 4 * size; j += 4) {
2759  if (lhsExtraRows) {
2760  Packet4f z = pset1<Packet4f>(float(0));
2761  Index i = 0;
2762  do {
2763  pstoreu(result + (j + i) * 4, dup[j + i]);
2764  } while (++i < extra_rows);
2765  do {
2766  pstoreu(result + (j + i) * 4, z);
2767  } while (++i < 4);
2768  } else {
2769  for (Index i = 0; i < 4; i++) {
2770  pstoreu(result + (j + i) * 4, dup[j + i]);
2771  }
2772  }
2773  }
2774 }
2775 
2776 template <bool lhsExtraRows>
2778  Index delta, Index extra_rows) {
2779  Index col = 0;
2780  src += delta * 2;
2781  for (; col + 4 * 2 <= cols; col += 4 * 2, result += 4 * 4 * 4, src += 4 * rows) {
2782  convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 4>(result, rows, src, extra_rows);
2783  }
2784  for (; col + 2 <= cols; col += 2, result += 4 * 4, src += rows) {
2785  convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 1>(result, rows, src, extra_rows);
2786  }
2787  if (cols & 1) {
2788  convertArrayPointerBF16toF32DupOne<lhsExtraRows, true, 1>(result, rows, src - delta, extra_rows);
2789  }
2790 }
2791 
2792 template <const Index size, bool non_unit_stride>
2794  constexpr Index extra = ((size < 4) ? 4 : size);
2795  while (i + size <= rows) {
2796  PacketBlock<Packet8bf, (size + 7) / 8> r32;
2797  r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
2798  if (size >= 16) {
2799  r32.packet[1] = loadBF16fromResult<non_unit_stride, 8>(src, resInc);
2800  }
2801  if (size >= 32) {
2802  r32.packet[2] = loadBF16fromResult<non_unit_stride, 16>(src, resInc);
2803  r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
2804  }
2805  storeConvertBlockBF16<size>(result + i, r32, rows & 3);
2806  i += extra;
2807  src += extra * resInc;
2808  if (size != 32) break;
2809  }
2810 }
2811 
2812 template <bool non_unit_stride>
2814  Index resInc) {
2815  for (Index col = 0; col < cols; col++, src += (rows * resInc), result += rows) {
2816  Index i = 0;
2817  bfloat16* src2 = src;
2818  convertPointerBF16toF32<32, non_unit_stride>(i, result, rows, src2, resInc);
2819  convertPointerBF16toF32<16, non_unit_stride>(i, result, rows, src2, resInc);
2820  convertPointerBF16toF32<8, non_unit_stride>(i, result, rows, src2, resInc);
2821  convertPointerBF16toF32<4, non_unit_stride>(i, result, rows, src2, resInc);
2822  convertPointerBF16toF32<1, non_unit_stride>(i, result, rows, src2, resInc);
2823  }
2824 }
2825 
2826 template <Index num_acc, Index size = 4>
2828  Packet4f z = pset1<Packet4f>(float(0));
2829 
2830  for (Index k = 0; k < num_acc; k++) {
2831  for (Index j = 0; j < size; j++) {
2832  acc[k][j] = z;
2833  }
2834  }
2835 }
2836 
2837 template <Index num_acc>
2839  for (Index i = 0; i < num_acc; i++) {
2840  Packet4ui t0, t1, t2, t3;
2841  t0 = vec_mergeh(reinterpret_cast<Packet4ui>(acc[i][0]), reinterpret_cast<Packet4ui>(acc[i][2]));
2842  t1 = vec_mergel(reinterpret_cast<Packet4ui>(acc[i][0]), reinterpret_cast<Packet4ui>(acc[i][2]));
2843  t2 = vec_mergeh(reinterpret_cast<Packet4ui>(acc[i][1]), reinterpret_cast<Packet4ui>(acc[i][3]));
2844  t3 = vec_mergel(reinterpret_cast<Packet4ui>(acc[i][1]), reinterpret_cast<Packet4ui>(acc[i][3]));
2845  acc[i][0] = reinterpret_cast<Packet4f>(vec_mergeh(t0, t2));
2846  acc[i][1] = reinterpret_cast<Packet4f>(vec_mergel(t0, t2));
2847  acc[i][2] = reinterpret_cast<Packet4f>(vec_mergeh(t1, t3));
2848  acc[i][3] = reinterpret_cast<Packet4f>(vec_mergel(t1, t3));
2849  }
2850 }
2851 
2852 template <Index num_acc>
2853 EIGEN_ALWAYS_INLINE void addResults(Packet4f (&acc)[num_acc][4]) {
2854  for (Index i = 0, j = 0; j < num_acc; i++, j += 2) {
2855  for (Index x = 0, y = 0; x < 2; x++, y += 2) {
2856  for (Index w = 0, z = 0; w < 2; w++, z += 2) {
2857  acc[i][y + w] = acc[j + x][z + 0] + acc[j + x][z + 1];
2858  }
2859  }
2860  }
2861 }
2862 
2863 template <Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs>
2864 EIGEN_ALWAYS_INLINE void outputResultsVSX(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result,
2865  const Index extra_cols, Index extra_rows) {
2866  tranposeResults<num_acc>(acc);
2867  addResults<num_acc>(acc);
2868 
2869  constexpr Index real_rhs = ((num_rhs / 2) - (rhsExtraCols ? 1 : 0));
2870  Index k = 0;
2871  for (Index i = 0; i < real_rhs; i++, result += 4 * rows, k++) {
2872  storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
2873  }
2874  if (rhsExtraCols) {
2875  storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
2876  }
2877 }
2878 
2879 template <bool zero>
2880 EIGEN_ALWAYS_INLINE void loadTwoRhsFloat32(const float* block, Index strideB, Index i, Packet4f& dhs0, Packet4f& dhs1) {
2881  dhs0 = ploadu<Packet4f>(block + strideB * i + 0);
2882  if (zero) {
2883  Packet4f dhs2 = pset1<Packet4f>(float(0));
2884  dhs1 = vec_mergel(dhs0, dhs2);
2885  dhs0 = vec_mergeh(dhs0, dhs2);
2886  } else {
2887  dhs1 = ploadu<Packet4f>(block + strideB * i + 4);
2888  }
2889 }
2890 
2891 template <Index num_acc, bool zero, bool rhsExtraCols, Index num_rhs>
2892 EIGEN_ALWAYS_INLINE void KLoop(const float* indexA, const float* indexB, Packet4f (&acc)[num_acc][4], Index strideB,
2893  Index k, Index offsetB, Index extra_cols) {
2894  constexpr Index num_lhs = 4;
2895  Packet4f lhs[num_lhs], rhs[num_rhs];
2896 
2897  constexpr Index real_rhs = (num_rhs - (rhsExtraCols ? 2 : 0));
2898  for (Index i = 0; i < real_rhs; i += 2) {
2899  loadTwoRhsFloat32<zero>(indexB + k * 4, strideB, i, rhs[i + 0], rhs[i + 1]);
2900  }
2901  if (rhsExtraCols) {
2902  loadTwoRhsFloat32<zero>(indexB + k * extra_cols - offsetB, strideB, real_rhs, rhs[real_rhs + 0], rhs[real_rhs + 1]);
2903  }
2904 
2905  indexA += 2 * k * 4;
2906  for (Index j = 0; j < num_lhs; j++) {
2907  lhs[j] = ploadu<Packet4f>(indexA + j * 4);
2908  }
2909 
2910  for (Index j = 0; j < num_rhs; j++) {
2911  for (Index i = 0; i < num_lhs; i++) {
2912  acc[j][i] = pmadd(rhs[j], lhs[i], acc[j][i]);
2913  }
2914  }
2915 }
2916 
2917 template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2918 EIGEN_ALWAYS_INLINE void colVSXLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const float* indexA,
2919  const float* indexB, Index strideB, Index offsetB, float* result,
2920  const Index extra_cols, const Index extra_rows) {
2921  constexpr Index num_rhs = num_acc;
2922 
2923  Packet4f acc[num_acc][4];
2924 
2925  zeroAccumulators<num_acc>(acc);
2926 
2927  Index k;
2928  for (k = 0; k + 2 <= depth; k += 2) {
2929  KLoop<num_acc, false, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
2930  }
2931  if (depth & 1) {
2932  KLoop<num_acc, true, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
2933  }
2934 
2935  outputResultsVSX<num_acc, rhsExtraCols, lhsExtraRows, num_rhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
2936 }
2937 
2938 // No more than 4 (uses 2X the accumulators or 8X the number of VSX registers)
2939 #define MAX_BFLOAT16_ACC_VSX 4
2940 
2941 template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2942 void colVSXLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA,
2943  const float* indexB, Index strideB, Index offsetB, float* result) {
2944  constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
2945  const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
2946  const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
2947  constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC_VSX);
2948 
2949  do {
2950  colVSXLoopBodyIter<num_acc * 2, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB,
2951  result, extra_cols, extra_rows);
2952 
2953  indexB += strideB * (num_acc * 2);
2954  result += rows * step;
2955  } while (multiIters && (step <= cols - (col += step)));
2956 }
2957 
2958 template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2960  const float* indexA, const float* blockB, Index strideB, Index offsetB,
2961  float* result) {
2962  if (MAX_BFLOAT16_ACC_VSX > num_acc) {
2963  colVSXLoopBody<num_acc + (rhsExtraCols ? 1 : 0), rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA,
2964  blockB, strideB, offsetB, result);
2965  }
2966 }
2967 
2968 template <bool rhsExtraCols, bool lhsExtraRows>
2969 void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA,
2970  const float* blockB, Index strideB, Index offsetB, float* result) {
2971  switch ((cols - col) >> 2) {
2972  case 3:
2973  colVSXLoopBodyExtraN<3, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2974  offsetB, result);
2975  break;
2976  case 2:
2977  colVSXLoopBodyExtraN<2, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2978  offsetB, result);
2979  break;
2980  case 1:
2981  colVSXLoopBodyExtraN<1, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2982  offsetB, result);
2983  break;
2984  default:
2985  if (rhsExtraCols) {
2986  colVSXLoopBody<1, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
2987  }
2988  break;
2989  }
2990 }
2991 
2992 template <Index size, bool lhsExtraRows = false>
2993 EIGEN_ALWAYS_INLINE void colVSXLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
2994  const float* indexA2, const float* blockB2, Index strideA, Index strideB,
2995  Index offsetB, float* result2) {
2996  Index delta_rows = 2 * (lhsExtraRows ? (rows & 3) : size);
2997  for (Index row = 0; row < size; row += 4) {
2998  convertArrayPointerBF16toF32Dup<lhsExtraRows>(const_cast<float*>(indexA2), strideA, delta_rows, indexA, row,
2999  rows & 3);
3000 
3001  const float* blockB = blockB2;
3002  float* result = result2 + row;
3003 
3004  Index col = 0;
3005  if (cols >= (MAX_BFLOAT16_ACC_VSX * 4)) {
3006  colVSXLoopBody<MAX_BFLOAT16_ACC_VSX, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB,
3007  strideB, 0, result);
3008  blockB += (strideB >> 1) * col;
3009  result += rows * col;
3010  }
3011  if (cols & 3) {
3012  colVSXLoopBodyExtra<true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, offsetB,
3013  result);
3014  } else {
3015  colVSXLoopBodyExtra<false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, 0, result);
3016  }
3017  }
3018 }
3019 
3020 template <Index size>
3021 EIGEN_ALWAYS_INLINE void calcVSXColLoops(const bfloat16*& indexA, const float* indexA2, Index& row, Index depth,
3022  Index cols, Index rows, const Packet4f pAlpha, const float* indexB,
3023  Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix,
3024  float* result) {
3025  if ((size == 16) || (rows & size)) {
3026  indexA += size * offsetA;
3027  colVSXLoops<size>(depth, cols, rows, pAlpha, indexA, indexA2, indexB, strideA, strideB, offsetB, result + row);
3028  row += size;
3029  indexA += bigSuffix * size / 16;
3030  }
3031 }
3032 
3033 template <const Index size, typename DataMapper>
3034 EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float* result, Index rows, const DataMapper& src) {
3035  constexpr Index extra = ((size < 4) ? 4 : size);
3036  while (i + size <= rows) {
3037  PacketBlock<Packet8bf, (size + 7) / 8> r32;
3038  r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
3039  if (size >= 16) {
3040  r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
3041  }
3042  if (size >= 32) {
3043  r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
3044  r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
3045  }
3046  storeConvertBlockBF16<size>(result + i, r32, rows & 3);
3047  i += extra;
3048  if (size != 32) break;
3049  }
3050 }
3051 
3052 template <typename DataMapper>
3053 EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float* result, Index cols, Index rows, const DataMapper& src) {
3054  typedef typename DataMapper::LinearMapper LinearMapper;
3055  for (Index j = 0; j < cols; j++, result += rows) {
3056  const LinearMapper src2 = src.getLinearMapper(0, j);
3057  Index i = 0;
3058  convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
3059  convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
3060  convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
3061  convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
3062  convertBF16toF32<1, LinearMapper>(i, result, rows, src2);
3063  }
3064 }
3065 
3068 }
3069 
3070 template <typename DataMapper, const Index size>
3071 EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float* result, Index col, Index rows, const DataMapper& res) {
3072  const DataMapper res2 = res.getSubMapper(0, col);
3073  Index row;
3074  float* result2 = result + col * rows;
3075  for (row = 0; row + 8 <= rows; row += 8, result2 += 8) {
3076  // get and save block
3078  for (Index j = 0; j < size; j++) {
3079  block.packet[j] = convertF32toBF16VSX(result2 + j * rows);
3080  }
3081  res2.template storePacketBlock<Packet8bf, size>(row, 0, block);
3082  }
3083  // extra rows
3084  if (row < rows) {
3085  for (Index j = 0; j < size; j++) {
3086  Packet8bf fp16 = convertF32toBF16VSX(result2 + j * rows);
3087  res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
3088  }
3089  }
3090 }
3091 
3092 template <typename DataMapper>
3093 EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float* result, Index cols, Index rows, const DataMapper& res) {
3094  Index col;
3095  for (col = 0; col + 4 <= cols; col += 4) {
3096  convertArrayF32toBF16ColVSX<DataMapper, 4>(result, col, rows, res);
3097  }
3098  // extra cols
3099  switch (cols - col) {
3100  case 1:
3101  convertArrayF32toBF16ColVSX<DataMapper, 1>(result, col, rows, res);
3102  break;
3103  case 2:
3104  convertArrayF32toBF16ColVSX<DataMapper, 2>(result, col, rows, res);
3105  break;
3106  case 3:
3107  convertArrayF32toBF16ColVSX<DataMapper, 3>(result, col, rows, res);
3108  break;
3109  }
3110 }
3111 
3112 template <typename DataMapper>
3113 void gemmbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth,
3114  Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3116  const Packet4f pAlpha = pset1<Packet4f>(falpha);
3117 
3118  if (strideA == -1) strideA = depth;
3119  if (strideB == -1) strideB = depth;
3120 
3122  ei_declare_aligned_stack_constructed_variable(float, indexB2, strideB* cols, 0);
3123  ei_declare_aligned_stack_constructed_variable(float, indexA2, ((strideA + 1) & -2) * 4 * 2, 0);
3124 
3125  convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
3126  convertArrayPointerBF16toF32(indexB2, cols, strideB, const_cast<bfloat16*>(indexB));
3127 
3128  Index bigSuffix = 2 * 8 * (strideA - offsetA);
3129  float* indexBF32 = indexB2 + 4 * offsetB;
3130  offsetB *= 3;
3131  strideB *= 2;
3132 
3133  Index row = 0;
3134  // LHS (8x16) block
3135  while (row + 16 <= rows) {
3136  calcVSXColLoops<16>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3137  bigSuffix, result);
3138  }
3139  // LHS (8x8) block
3140  calcVSXColLoops<8>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3141  bigSuffix, result);
3142  // LHS (8x4) block
3143  calcVSXColLoops<4>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3144  bigSuffix, result);
3145  // extra rows
3146  if (rows & 3) {
3147  // This index is the beginning of remaining block.
3148  colVSXLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexA2, indexBF32, strideA, strideB, offsetB,
3149  result + row);
3150  }
3151 
3152  // Convert back to bfloat16
3153  convertArrayF32toBF16VSX<DataMapper>(result, cols, rows, res);
3154 }
3155 
3156 #undef MAX_BFLOAT16_ACC_VSX
3157 
3158 #include "MatrixVectorProduct.h"
3159 
3160 /************************************
3161  * ppc64le template specializations *
3162  * **********************************/
3163 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3164 struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3165  void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3166 };
3167 
3168 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3170  double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3172  pack(blockA, lhs, depth, rows, stride, offset);
3173 }
3174 
3175 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3176 struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3177  void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3178 };
3179 
3180 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3182  double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3184  pack(blockA, lhs, depth, rows, stride, offset);
3185 }
3186 
3187 #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3188 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3189 struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3190  void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3191 };
3192 
3193 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3195  double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3197  pack(blockB, rhs, depth, cols, stride, offset);
3198 }
3199 
3200 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3201 struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3202  void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3203 };
3204 
3205 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3207  double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3209  pack(blockB, rhs, depth, cols, stride, offset);
3210 }
3211 
3212 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3213 struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3214  void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3215 };
3216 
3217 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3219  bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3221  pack(blockB, rhs, depth, cols, stride, offset);
3222 }
3223 
3224 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3225 struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3226  void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3227 };
3228 
3229 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3231  bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3233  pack(blockB, rhs, depth, cols, stride, offset);
3234 }
3235 #endif
3236 
3237 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3238 struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3239  void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3240 };
3241 
3242 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3244  bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3246  pack(blockA, lhs, depth, rows, stride, offset);
3247 }
3248 
3249 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3250 struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3251  void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3252 };
3253 
3254 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3256  bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3258  pack(blockA, lhs, depth, rows, stride, offset);
3259 }
3260 
3261 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3262 struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3263  void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3264 };
3265 
3266 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3268  float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3270  pack(blockA, lhs, depth, rows, stride, offset);
3271 }
3272 
3273 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3274 struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3275  void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3276 };
3277 
3278 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3280  float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3282  pack(blockA, lhs, depth, rows, stride, offset);
3283 }
3284 
3285 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3286 struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3287  void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3288  Index offset = 0);
3289 };
3290 
3291 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3292 void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate,
3293  PanelMode>::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows,
3294  Index stride, Index offset) {
3296  pack(blockA, lhs, depth, rows, stride, offset);
3297 }
3298 
3299 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3300 struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3301  void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3302  Index offset = 0);
3303 };
3304 
3305 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3306 void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate,
3307  PanelMode>::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows,
3308  Index stride, Index offset) {
3310  pack(blockA, lhs, depth, rows, stride, offset);
3311 }
3312 
3313 #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3314 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3315 struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3316  void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3317 };
3318 
3319 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3321  float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3323  pack(blockB, rhs, depth, cols, stride, offset);
3324 }
3325 
3326 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3327 struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3328  void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3329 };
3330 
3331 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3333  float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3335  pack(blockB, rhs, depth, cols, stride, offset);
3336 }
3337 #endif
3338 
3339 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3340 struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3341  void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3342  Index offset = 0);
3343 };
3344 
3345 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3346 void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3347  std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3349  pack(blockB, rhs, depth, cols, stride, offset);
3350 }
3351 
3352 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3353 struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3354  void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3355  Index offset = 0);
3356 };
3357 
3358 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3359 void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3360  std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3362  pack(blockB, rhs, depth, cols, stride, offset);
3363 }
3364 
3365 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3366 struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3367  void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3368  Index offset = 0);
3369 };
3370 
3371 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3372 void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate,
3373  PanelMode>::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
3374  Index stride, Index offset) {
3376  pack(blockA, lhs, depth, rows, stride, offset);
3377 }
3378 
3379 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3380 struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3381  void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3382  Index offset = 0);
3383 };
3384 
3385 template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3386 void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate,
3387  PanelMode>::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
3388  Index stride, Index offset) {
3390  pack(blockA, lhs, depth, rows, stride, offset);
3391 }
3392 
3393 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3394 struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3395  void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3396  Index offset = 0);
3397 };
3398 
3399 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3400 void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3401  std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3403  pack(blockB, rhs, depth, cols, stride, offset);
3404 }
3405 
3406 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3407 struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3408  void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3409  Index offset = 0);
3410 };
3411 
3412 template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3413 void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3414  std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3416  pack(blockB, rhs, depth, cols, stride, offset);
3417 }
3418 
3419 // ********* gebp specializations *********
3420 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3421 struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3424 
3425  void operator()(const DataMapper& res, const float* blockA, const float* blockB, Index rows, Index depth, Index cols,
3426  float alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3427 };
3428 
3429 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3431  const DataMapper& res, const float* blockA, const float* blockB, Index rows, Index depth, Index cols, float alpha,
3432  Index strideA, Index strideB, Index offsetA, Index offsetB) {
3433  const Index accRows = quad_traits<float>::rows;
3434  const Index accCols = quad_traits<float>::size;
3435  static void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index,
3436  Index, Index) =
3437 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3438  (supportsMMA()) ? &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols> :
3439 #endif
3440  &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
3441  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3442 }
3443 
3444 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3445 struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3446  typedef Packet4f Packet;
3449 
3450  void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
3451  Index rows, Index depth, Index cols, std::complex<float> alpha, Index strideA = -1,
3452  Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3453 };
3454 
3455 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3456 void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs,
3457  ConjugateRhs>::operator()(const DataMapper& res, const std::complex<float>* blockA,
3458  const std::complex<float>* blockB, Index rows, Index depth, Index cols,
3459  std::complex<float> alpha, Index strideA, Index strideB, Index offsetA,
3460  Index offsetB) {
3461  const Index accRows = quad_traits<float>::rows;
3462  const Index accCols = quad_traits<float>::size;
3463  static void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*, Index, Index,
3464  Index, std::complex<float>, Index, Index, Index, Index) =
3465 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3466  (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>,
3467  float, Packet, Packetc, RhsPacket, DataMapper, accRows,
3468  accCols, ConjugateLhs, ConjugateRhs, false, false>
3469  :
3470 #endif
3471  &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>,
3472  float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3473  ConjugateLhs, ConjugateRhs, false, false>;
3474  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3475 }
3476 
3477 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3478 struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3479  typedef Packet4f Packet;
3482 
3483  void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB, Index rows,
3484  Index depth, Index cols, std::complex<float> alpha, Index strideA = -1, Index strideB = -1,
3485  Index offsetA = 0, Index offsetB = 0);
3486 };
3487 
3488 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3489 void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3490  const DataMapper& res, const float* blockA, const std::complex<float>* blockB, Index rows, Index depth, Index cols,
3491  std::complex<float> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3492  const Index accRows = quad_traits<float>::rows;
3493  const Index accCols = quad_traits<float>::size;
3494  static void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*, Index, Index, Index,
3495  std::complex<float>, Index, Index, Index, Index) =
3496 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3497  (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float,
3498  Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3499  ConjugateLhs, ConjugateRhs, true, false>
3500  :
3501 #endif
3502  &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet,
3503  Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3504  ConjugateRhs, true, false>;
3505  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3506 }
3507 
3508 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3509 struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3510  typedef Packet4f Packet;
3513 
3514  void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB, Index rows,
3515  Index depth, Index cols, std::complex<float> alpha, Index strideA = -1, Index strideB = -1,
3516  Index offsetA = 0, Index offsetB = 0);
3517 };
3518 
3519 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3520 void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3521  const DataMapper& res, const std::complex<float>* blockA, const float* blockB, Index rows, Index depth, Index cols,
3522  std::complex<float> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3523  const Index accRows = quad_traits<float>::rows;
3524  const Index accCols = quad_traits<float>::size;
3525  static void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*, Index, Index, Index,
3526  std::complex<float>, Index, Index, Index, Index) =
3527 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3528  (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float,
3529  Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3530  ConjugateLhs, ConjugateRhs, false, true>
3531  :
3532 #endif
3533  &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet,
3534  Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3535  ConjugateRhs, false, true>;
3536  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3537 }
3538 
3539 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3540 struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3543 
3544  void operator()(const DataMapper& res, const double* blockA, const double* blockB, Index rows, Index depth,
3545  Index cols, double alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
3546  Index offsetB = 0);
3547 };
3548 
3549 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3551  const DataMapper& res, const double* blockA, const double* blockB, Index rows, Index depth, Index cols,
3552  double alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3553  const Index accRows = quad_traits<double>::rows;
3554  const Index accCols = quad_traits<double>::size;
3555  static void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index,
3556  Index, Index, Index) =
3557 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3558  (supportsMMA()) ? &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols> :
3559 #endif
3560  &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
3561  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3562 }
3563 
3564 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3565 struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3569 
3570  void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
3571  Index rows, Index depth, Index cols, std::complex<double> alpha, Index strideA = -1,
3572  Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3573 };
3574 
3575 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3576 void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs,
3577  ConjugateRhs>::operator()(const DataMapper& res, const std::complex<double>* blockA,
3578  const std::complex<double>* blockB, Index rows, Index depth, Index cols,
3579  std::complex<double> alpha, Index strideA, Index strideB, Index offsetA,
3580  Index offsetB) {
3581  const Index accRows = quad_traits<double>::rows;
3582  const Index accCols = quad_traits<double>::size;
3583  static void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*, Index,
3584  Index, Index, std::complex<double>, Index, Index, Index, Index) =
3585 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3586  (supportsMMA())
3587  ? &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double,
3588  Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3589  ConjugateRhs, false, false>
3590  :
3591 #endif
3592  &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double,
3593  Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3594  ConjugateRhs, false, false>;
3595  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3596 }
3597 
3598 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3599 struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3603 
3604  void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB, Index rows,
3605  Index depth, Index cols, std::complex<double> alpha, Index strideA = -1, Index strideB = -1,
3606  Index offsetA = 0, Index offsetB = 0);
3607 };
3608 
3609 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3610 void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3611  const DataMapper& res, const std::complex<double>* blockA, const double* blockB, Index rows, Index depth,
3612  Index cols, std::complex<double> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3613  const Index accRows = quad_traits<double>::rows;
3614  const Index accCols = quad_traits<double>::size;
3615  static void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*, Index, Index, Index,
3616  std::complex<double>, Index, Index, Index, Index) =
3617 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3618  (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double,
3619  Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3620  ConjugateLhs, ConjugateRhs, false, true>
3621  :
3622 #endif
3623  &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet,
3624  Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3625  ConjugateRhs, false, true>;
3626  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3627 }
3628 
3629 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3630 struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3634 
3635  void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB, Index rows,
3636  Index depth, Index cols, std::complex<double> alpha, Index strideA = -1, Index strideB = -1,
3637  Index offsetA = 0, Index offsetB = 0);
3638 };
3639 
3640 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3641 void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3642  const DataMapper& res, const double* blockA, const std::complex<double>* blockB, Index rows, Index depth,
3643  Index cols, std::complex<double> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3644  const Index accRows = quad_traits<double>::rows;
3645  const Index accCols = quad_traits<double>::size;
3646  static void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*, Index, Index, Index,
3647  std::complex<double>, Index, Index, Index, Index) =
3648 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3649  (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double,
3650  Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3651  ConjugateLhs, ConjugateRhs, true, false>
3652  :
3653 #endif
3654  &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet,
3655  Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3656  ConjugateRhs, true, false>;
3657  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3658 }
3659 
3660 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3661 struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3664 
3665  void operator()(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth,
3666  Index cols, bfloat16 alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
3667  Index offsetB = 0);
3668 };
3669 
3670 template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3672  const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols,
3673  bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3674  static void (*gemm_function)(const DataMapper&, const bfloat16*, const bfloat16*, Index, Index, Index, bfloat16,
3675  Index, Index, Index, Index) =
3676 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3677  (supportsMMA()) ? &Eigen::internal::gemmMMAbfloat16<DataMapper> :
3678 #endif
3679  &Eigen::internal::gemmbfloat16<DataMapper>;
3680  gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3681 }
3682 } // end namespace internal
3683 
3684 } // end namespace Eigen
3685 
3686 #endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H
Array< int, Dynamic, 1 > v
Definition: Array_initializer_list_vector_cxx11.cpp:1
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
int data[]
Definition: Map_placement_new.cpp:1
m col(1)
m row(1)
#define MICRO_COMPLEX_UNROLL_ITER(func, N)
Definition: MatrixProductCommon.h:142
#define MICRO_UPDATE
Definition: MatrixProductCommon.h:191
#define MICRO_COMPLEX_UPDATE
Definition: MatrixProductCommon.h:198
#define EIGEN_POWER_PREFETCH(p)
Definition: MatrixProductCommon.h:5
#define MICRO_UNROLL_ITER(func, N)
Definition: MatrixProductCommon.h:139
#define MICRO_COMPLEX_EXTRA_COLS(N)
Definition: MatrixProduct.h:2572
#define MICRO_COMPLEX_DST_PTR
Definition: MatrixProduct.h:2458
#define MICRO_COMPLEX_PREFETCHN(N)
Definition: MatrixProduct.h:1896
#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col)
Definition: MatrixProduct.h:1960
#define advanceCols
Definition: MatrixProduct.h:2215
#define MICRO_COMPLEX_SRC2_PTR
Definition: MatrixProduct.h:2260
#define MICRO_PREFETCHN(N)
Definition: MatrixProduct.h:1894
#define MICRO_NEW_ROWS
Definition: MatrixProduct.h:1797
#define PEEL_COMPLEX_ROW
Definition: MatrixProduct.h:2219
#define MICRO_ONE_PEEL4
Definition: MatrixProduct.h:2019
#define PEEL_ROW
Definition: MatrixProduct.h:1791
#define MICRO_STORE
Definition: MatrixProduct.h:2058
#define MICRO_WORK_PEEL_ROW
Definition: MatrixProduct.h:1867
#define accColsC
Definition: MatrixProduct.h:2213
#define MICRO_DST_PTR
Definition: MatrixProduct.h:2030
#define MAX_UNROLL
#define advanceRows
Definition: MatrixProduct.h:2214
#define MICRO_EXTRA_ROWS(N)
Definition: MatrixProduct.h:1977
#define MICRO_EXTRA_COLS(N)
Definition: MatrixProduct.h:2173
#define MICRO_COMPLEX_ONE_PEEL4
Definition: MatrixProduct.h:2445
#define MICRO_COMPLEX_PREFETCH
Definition: MatrixProduct.h:2462
#define MICRO_COMPLEX_EXTRA_ROWS(N)
Definition: MatrixProduct.h:2397
#define PEEL_COMPLEX
Definition: MatrixProduct.h:2218
#define MICRO_COMPLEX_BROADCAST_EXTRA
Definition: MatrixProduct.h:2250
#define MICRO_COMPLEX_WORK_PEEL_ROW
Definition: MatrixProduct.h:2291
#define MICRO_ADD_PEEL_ROW
Definition: MatrixProduct.h:1880
#define MICRO_ZERO_PEEL_ROW
Definition: MatrixProduct.h:1857
#define MICRO_COMPLEX_ZERO_PEEL_ROW
Definition: MatrixProduct.h:2271
#define PEEL
Definition: MatrixProduct.h:1790
#define MICRO_SRC2_PTR
Definition: MatrixProduct.h:1855
#define MICRO_COMPLEX_ONE4
Definition: MatrixProduct.h:2447
#define MICRO_COMPLEX_ADD_PEEL_ROW
Definition: MatrixProduct.h:2306
#define MICRO_PREFETCH
Definition: MatrixProduct.h:2034
#define MICRO_ONE4
Definition: MatrixProduct.h:2021
#define MICRO_UNROLL_ITER2(N, M)
Definition: MatrixProduct.h:2106
#define MAX_BFLOAT16_ACC_VSX
Definition: MatrixProduct.h:2939
#define MICRO_COMPLEX_UNROLL_ITER2(N, M)
Definition: MatrixProduct.h:2517
#define MICRO_COMPLEX_STORE
Definition: MatrixProduct.h:2477
#define MICRO_SRC_PTR
Definition: MatrixProduct.h:2032
#define MICRO_COMPLEX_SRC_PTR
Definition: MatrixProduct.h:2460
#define MAX_COMPLEX_UNROLL
#define MICRO_BROADCAST_EXTRA
Definition: MatrixProduct.h:1836
#define MICRO_COMPLEX_ADD_COLS(size)
Definition: MatrixProduct.h:2284
RowVector3d w
Definition: Matrix_resize_int.cpp:3
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:806
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
Map< RowVectorXf > v2(M2.data(), M2.size())
M1<< 1, 2, 3, 4, 5, 6, 7, 8, 9;Map< RowVectorXf > v1(M1.data(), M1.size())
m m block(1, 0, 2, 2)<< 4
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:50
internal::packet_traits< Scalar >::type Packet
Definition: benchmark-blocking-sizes.cpp:54
Definition: ForwardDeclarations.h:102
The matrix class, also used for vectors and row-vectors.
Definition: Eigen/Eigen/src/Core/Matrix.h:186
EIGEN_STRONG_INLINE PacketScalar packet(Index rowId, Index colId) const
Definition: PlainObjectBase.h:247
Definition: BlasUtil.h:443
std::complex< RealScalar > Complex
Definition: common.h:71
@ N
Definition: constructor.cpp:22
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
Eigen::DenseIndex ret
Definition: level1_cplx_impl.h:43
RealScalar alpha
Definition: level1_cplx_impl.h:151
char char char int int * k
Definition: level2_impl.h:374
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h)
Definition: BFloat16.h:581
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data)
Definition: MatrixProduct.h:2680
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask)
Definition: MatrixProduct.h:2732
EIGEN_STRONG_INLINE void gemm_complex(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProduct.h:2590
EIGEN_ALWAYS_INLINE void colVSXLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const float *indexA2, const float *blockB2, Index strideA, Index strideB, Index offsetB, float *result2)
Definition: MatrixProduct.h:2993
__m128d Packet2d
Definition: LSX/PacketMath.h:36
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet8bf pgather< bfloat16, Packet8bf >(const bfloat16 *from, Index stride)
Definition: AltiVec/PacketMath.h:874
EIGEN_ALWAYS_INLINE void storeResults(Packet4f(&acc)[4], Index rows, const Packet4f pAlpha, float *result, Index extra_cols, Index extra_rows)
Definition: MatrixProduct.h:2649
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows)
Definition: MatrixProduct.h:1660
__vector int Packet4i
Definition: AltiVec/PacketMath.h:34
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2404
EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2326
EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock< Packet, N > *accReal, PacketBlock< Packet, N > *accImag, const Packet &lhsV, Packet &lhsVi, const Packet *rhsV, const Packet *rhsVi)
Definition: MatrixProduct.h:1514
EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex< Scalar > *blockB, const std::complex< Scalar > *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:134
EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16 *src, Index resInc)
Definition: MatrixProduct.h:2714
EIGEN_STRONG_INLINE void gemm(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProduct.h:2190
__vector unsigned char Packet16uc
Definition: AltiVec/PacketMath.h:41
void colVSXLoopBody(Index &col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *indexB, Index strideB, Index offsetB, float *result)
Definition: MatrixProduct.h:2942
EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex< float > &from0, const std::complex< float > &from1)
Definition: AltiVec/Complex.h:185
void gemmbfloat16(const DataMapper &res, const bfloat16 *indexA, const bfloat16 *indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProduct.h:3113
EIGEN_ALWAYS_INLINE void pgerc(PacketBlock< Packet, N > *accReal, PacketBlock< Packet, N > *accImag, const Scalar *lhs_ptr, const Scalar *lhs_ptr_imag, const Packet *rhsV, const Packet *rhsVi)
Definition: MatrixProduct.h:1532
EIGEN_ALWAYS_INLINE void storeF32(float *&result, Packet4f result_block, Index rows, Index extra_rows)
Definition: MatrixProduct.h:2639
EIGEN_ALWAYS_INLINE void outputResultsVSX(Packet4f(&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float *result, const Index extra_cols, Index extra_rows)
Definition: MatrixProduct.h:2864
EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock< Packet, N > &acc)
Definition: MatrixProduct.h:1551
EIGEN_STRONG_INLINE void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Definition: AltiVec/Complex.h:339
static const Packet16uc p16uc_GETIMAG32b
Definition: MatrixProduct.h:96
static Packet16uc p16uc_MERGE16_32_8
Definition: MatrixProduct.h:2730
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2482
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data)
Definition: MatrixProduct.h:2671
const Scalar & y
Definition: RandomImpl.h:36
EIGEN_ALWAYS_INLINE void pstore_partial< bfloat16 >(bfloat16 *to, const Packet8bf &from, const Index n, const Index offset)
Definition: AltiVec/PacketMath.h:737
EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock< Packet, N > &acc, PacketBlock< Packet, N > &accZ, const Packet &pAlpha)
Definition: MatrixProduct.h:1558
static Packet16uc p16uc_TRANSPOSE64_LO
Definition: AltiVec/PacketMath.h:145
EIGEN_ALWAYS_INLINE std::complex< Scalar > getAdjointVal(Index i, Index j, const_blas_data_mapper< std::complex< Scalar >, Index, StorageOrder > &dt)
Definition: MatrixProduct.h:117
EIGEN_ALWAYS_INLINE void tranposeResults(Packet4f(&acc)[num_acc][4])
Definition: MatrixProduct.h:2838
EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:223
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
Definition: MatrixProduct.h:2813
__vector unsigned short int Packet8us
Definition: AltiVec/PacketMath.h:38
EIGEN_ALWAYS_INLINE void colVSXLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProduct.h:2959
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(const Scalar *&lhs_ptr, const Scalar *&rhs_ptr0, const Scalar *&rhs_ptr1, const Scalar *&rhs_ptr2, PacketBlock< Packet, accRows > &accZero)
Definition: MatrixProduct.h:1903
EIGEN_STRONG_INLINE void pstore< bfloat16 >(bfloat16 *to, const Packet8bf &from)
Definition: AltiVec/PacketMath.h:662
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index &i, float *result, Index rows, const DataMapper &src)
Definition: MatrixProduct.h:3034
EIGEN_ALWAYS_INLINE void bstore(PacketBlock< Packet, N > &acc, const DataMapper &res, Index row)
Definition: MatrixProduct.h:1621
EIGEN_ALWAYS_INLINE void pbroadcastN(const __UNPACK_TYPE__(Packet) *ap0, const __UNPACK_TYPE__(Packet) *ap1, const __UNPACK_TYPE__(Packet) *ap2, Packet &a0, Packet &a1, Packet &a2, Packet &a3)
Definition: MatrixProduct.h:1708
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock< Packet, N > &aReal, PacketBlock< Packet, N > &aImag, const Packet &bReal, const Packet &bImag, PacketBlock< Packet, N > &cReal, PacketBlock< Packet, N > &cImag, const Packet &pMask)
Definition: MatrixProduct.h:1574
EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(const Scalar *&lhs_ptr_real, const Scalar *&lhs_ptr_imag, const Scalar *&rhs_ptr_real0, const Scalar *&rhs_ptr_real1, const Scalar *&rhs_ptr_real2, const Scalar *&rhs_ptr_imag0, const Scalar *&rhs_ptr_imag1, const Scalar *&rhs_ptr_imag2, PacketBlock< Packet, accRows > &accReal, PacketBlock< Packet, accRows > &accImag)
Definition: MatrixProduct.h:2311
EIGEN_ALWAYS_INLINE void gemm_extra_row(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:1982
EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index cols, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2579
void gemm_complexMMA(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProductMMA.h:859
EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar *blockA, const Scalar *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:255
__vector unsigned int Packet4ui
Definition: AltiVec/PacketMath.h:35
EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf &a)
Definition: AltiVec/Complex.h:303
EIGEN_ALWAYS_INLINE bool supportsMMA()
Definition: MatrixProduct.h:2623
EIGEN_ALWAYS_INLINE void calcVSXColLoops(const bfloat16 *&indexA, const float *indexA2, Index &row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexB, Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
Definition: MatrixProduct.h:3021
EIGEN_STRONG_INLINE void pstore< double >(double *to, const Packet4d &from)
Definition: AVX/PacketMath.h:1611
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index rows, const bfloat16 *src, Index extra_rows)
Definition: MatrixProduct.h:2742
static const Packet16uc p16uc_GETIMAG32
Definition: MatrixProduct.h:92
static const Packet16uc p16uc_GETREAL32
Definition: MatrixProduct.h:90
EIGEN_STRONG_INLINE Packet2d pload< Packet2d >(const double *from)
Definition: LSX/PacketMath.h:1407
static Packet16uc p16uc_MERGE16_32_2
Definition: MatrixProduct.h:2723
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Definition: AltiVec/PacketMath.h:42
EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float *result)
Definition: MatrixProduct.h:2633
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet4f, 4, false >(const float *ap0, const float *ap1, const float *ap2, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
Definition: MatrixProduct.h:1741
EIGEN_STRONG_INLINE Packet8bf F32ToBf16Both(Packet4f lo, Packet4f hi)
Definition: AltiVec/PacketMath.h:2237
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet4f, 4, true >(const float *ap0, const float *, const float *, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
Definition: MatrixProduct.h:1735
EIGEN_DEVICE_FUNC void pstoreu_partial(Scalar *to, const Packet &from, const Index n, const Index offset=0)
Definition: GenericPacketMath.h:917
static Packet16uc p16uc_MERGE16_32_3
Definition: MatrixProduct.h:2724
EIGEN_ALWAYS_INLINE void bscale(PacketBlock< Packet, N > &acc, PacketBlock< Packet, N > &accZ, const Packet &pAlpha)
Definition: MatrixProduct.h:1688
EIGEN_ALWAYS_INLINE void storeBlock(Scalar *to, PacketBlock< Packet, N > &block)
Definition: MatrixProduct.h:366
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16 *src, Index delta, Index extra_rows)
Definition: MatrixProduct.h:2777
EIGEN_STRONG_INLINE Packet4f pset1< Packet4f >(const float &from)
Definition: AltiVec/PacketMath.h:773
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet2d, 4, false >(const double *ap0, const double *, const double *, Packet2d &a0, Packet2d &a1, Packet2d &a2, Packet2d &a3)
Definition: MatrixProduct.h:1747
EIGEN_ALWAYS_INLINE void band(PacketBlock< Packet, N > &acc, const Packet &pMask)
Definition: MatrixProduct.h:1566
EIGEN_ALWAYS_INLINE void addResults(Packet4f(&acc)[num_acc][4])
Definition: MatrixProduct.h:2853
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f(&acc)[num_acc][size])
Definition: MatrixProduct.h:2827
EIGEN_ALWAYS_INLINE void gemm_complex_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2525
EIGEN_ALWAYS_INLINE Packet2d bmask< Packet2d >(const Index remaining_rows)
Definition: MatrixProduct.h:1673
EIGEN_ALWAYS_INLINE void bload(PacketBlock< Packet, N *(Complex ? 2 :1)> &acc, const DataMapper &res, Index row, Index col)
Definition: MatrixProduct.h:1597
EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex< Scalar > *blockA, const std::complex< Scalar > *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:178
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet) *lhs)
Definition: MatrixProduct.h:1545
void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProduct.h:2969
static Packet16uc p16uc_MERGE16_32_4
Definition: MatrixProduct.h:2725
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:911
EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet &a)
Definition: AltiVec/PacketMath.h:1876
EIGEN_STRONG_INLINE Packet4f ploadu< Packet4f >(const float *from)
Definition: AltiVec/PacketMath.h:1533
EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock< Packet, N > &taccReal, PacketBlock< Packet, N > &taccImag, PacketBlock< Packetc, N > &acc1, PacketBlock< Packetc, N > &acc2)
Definition: MatrixProduct.h:1759
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float *res)
Definition: MatrixProduct.h:3066
static Packet16uc p16uc_MERGE16_32_6
Definition: MatrixProduct.h:2728
EIGEN_ALWAYS_INLINE void storeConvertTwoBF16(float *to, PacketBlock< Packet8bf,(N+7)/8 > &block, Index extra=0)
Definition: MatrixProduct.h:2690
EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:2067
EIGEN_STRONG_INLINE Packet8us pset1< Packet8us >(const unsigned short int &from)
Definition: AltiVec/PacketMath.h:788
EIGEN_ALWAYS_INLINE void loadTwoRhsFloat32(const float *block, Index strideB, Index i, Packet4f &dhs0, Packet4f &dhs1)
Definition: MatrixProduct.h:2880
EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float *to, PacketBlock< Packet8bf,(N+7)/8 > &block, Index extra)
Definition: MatrixProduct.h:2702
static Packet16uc p16uc_MERGE16_32_7
Definition: MatrixProduct.h:2729
static Packet16uc p16uc_MERGE16_32_5
Definition: MatrixProduct.h:2727
EIGEN_STRONG_INLINE void pbroadcast4< Packet4f >(const float *a, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
Definition: AltiVec/PacketMath.h:823
EIGEN_ALWAYS_INLINE void pger_common(PacketBlock< Packet, N > *acc, const Packet &lhsV, const Packet *rhsV)
Definition: MatrixProduct.h:1492
EIGEN_ALWAYS_INLINE void gemm_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:2113
EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:1912
EIGEN_ALWAYS_INLINE void bcouple(PacketBlock< Packet, N > &taccReal, PacketBlock< Packet, N > &taccImag, PacketBlock< Packetc, N *2 > &tRes, PacketBlock< Packetc, N > &acc1, PacketBlock< Packetc, N > &acc2)
Definition: MatrixProduct.h:1773
__vector float Packet4f
Definition: AltiVec/PacketMath.h:33
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float *result, Index cols, Index rows, const DataMapper &res)
Definition: MatrixProduct.h:3093
EIGEN_ALWAYS_INLINE void KLoop(const float *indexA, const float *indexB, Packet4f(&acc)[num_acc][4], Index strideB, Index k, Index offsetB, Index extra_cols)
Definition: MatrixProduct.h:2892
static Packet16uc p16uc_MERGE16_32_1
Definition: MatrixProduct.h:2722
EIGEN_ALWAYS_INLINE void pger(PacketBlock< Packet, N > *acc, const Scalar *lhs, const Packet *rhsV)
Definition: MatrixProduct.h:1505
static const Packet16uc p16uc_GETREAL32b
Definition: MatrixProduct.h:94
EIGEN_STRONG_INLINE Packet8bf ploadu< Packet8bf >(const bfloat16 *from)
Definition: AltiVec/PacketMath.h:1549
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index &i, float *result, Index rows, bfloat16 *&src, Index resInc)
Definition: MatrixProduct.h:2793
static const Packet4i mask4[4]
Definition: MatrixProduct.h:1656
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float *result, Index col, Index rows, const DataMapper &res)
Definition: MatrixProduct.h:3071
EIGEN_ALWAYS_INLINE void gemm_extra_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index cols, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:2178
EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper &src)
Definition: MatrixProduct.h:3053
EIGEN_ALWAYS_INLINE void colVSXLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const float *indexA, const float *indexB, Index strideB, Index offsetB, float *result, const Index extra_cols, const Index extra_rows)
Definition: MatrixProduct.h:2918
static Packet16uc p16uc_TRANSPOSE64_HI
Definition: AltiVec/PacketMath.h:143
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
const AutoDiffScalar< DerType > & real(const AutoDiffScalar< DerType > &x)
Definition: AutoDiffScalar.h:486
DerType::Scalar imag(const AutoDiffScalar< DerType > &)
Definition: AutoDiffScalar.h:490
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
int delta
Definition: MultiOpt.py:96
void transpose()
Definition: skew_symmetric_matrix3.cpp:135
Definition: Eigen_Colamd.h:49
list x
Definition: plotDoE.py:28
Definition: BFloat16.h:101
Definition: LSX/Complex.h:260
Definition: AltiVec/Complex.h:38
Definition: GenericPacketMath.h:1407
Packet packet[N]
Definition: GenericPacketMath.h:1408
EIGEN_STRONG_INLINE void operator()(std::complex< double > *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset)
Definition: MatrixProduct.h:1444
EIGEN_ALWAYS_INLINE void dhs_ccopy(double *blockBt, const DataMapper &rhs2, Index &i, Index &rir, Index &rii, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:1417
EIGEN_STRONG_INLINE void operator()(std::complex< double > *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:1352
EIGEN_ALWAYS_INLINE void dhs_ccopy(double *blockAt, const DataMapper &lhs2, Index &i, Index &rir, Index &rii, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:1307
Definition: MatrixProduct.h:381
EIGEN_STRONG_INLINE void operator()(std::complex< Scalar > *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:455
EIGEN_ALWAYS_INLINE void dhs_ccopy(Scalar *blockAt, const DataMapper &lhs2, Index &i, Index &rir, Index &rii, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:420
EIGEN_ALWAYS_INLINE void dhs_cblock(PacketBlock< PacketC, 8 > &cblock, PacketBlock< Packet, 4 > &block, Packet16uc permute)
Definition: MatrixProduct.h:383
EIGEN_STRONG_INLINE void operator()(bfloat16 *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:837
EIGEN_STRONG_INLINE void operator()(bfloat16 *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset)
Definition: MatrixProduct.h:1180
EIGEN_STRONG_INLINE void operator()(double *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset)
Definition: MatrixProduct.h:781
EIGEN_ALWAYS_INLINE void dhs_copy(double *blockB, const DataMapper &rhs2, Index &i, Index &ri, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:739
EIGEN_ALWAYS_INLINE void dhs_copy(double *blockA, const DataMapper &lhs2, Index &i, Index &ri, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:662
EIGEN_STRONG_INLINE void operator()(double *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:691
Definition: MatrixProduct.h:557
EIGEN_STRONG_INLINE void operator()(Scalar *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:586
EIGEN_ALWAYS_INLINE void dhs_copy(Scalar *blockA, const DataMapper &lhs2, Index &i, Index &ri, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:559
Definition: GenericPacketMath.h:225
Definition: products/GeneralBlockPanelKernel.h:960
EIGEN_DONT_INLINE void operator()(const DataMapper &res, const LhsScalar *blockA, const RhsScalar *blockB, Index rows, Index depth, Index cols, ResScalar alpha, Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0)
Definition: products/GeneralBlockPanelKernel.h:1425
Definition: BlasUtil.h:34
Definition: BlasUtil.h:30
Definition: GenericPacketMath.h:108
vectortype rhstype
Definition: MatrixProduct.h:82
Packet8bf vectortype
Definition: MatrixProduct.h:80
PacketBlock< vectortype, 4 > type
Definition: MatrixProduct.h:81
Packet2d vectortype
Definition: MatrixProduct.h:72
PacketBlock< Packet2d, 2 > rhstype
Definition: MatrixProduct.h:74
PacketBlock< vectortype, 4 > type
Definition: MatrixProduct.h:73
Definition: MatrixProduct.h:63
PacketBlock< vectortype, 4 > type
Definition: MatrixProduct.h:65
packet_traits< Scalar >::type vectortype
Definition: MatrixProduct.h:64
vectortype rhstype
Definition: MatrixProduct.h:66
@ size
Definition: MatrixProduct.h:67
@ rows
Definition: MatrixProduct.h:67
@ vectorsize
Definition: MatrixProduct.h:67
void operator()(double *blockA, const double *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:349
void operator()(float *blockA, const float *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:334
void operator()(std::complex< double > *blockA, const std::complex< double > *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:318
void operator()(std::complex< float > *blockA, const std::complex< float > *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:300
Definition: SelfadjointMatrixMatrix.h:22
void operator()(double *blockB, const double *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:342
void operator()(float *blockB, const float *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:327
void operator()(std::complex< double > *blockB, const std::complex< double > *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:310
void operator()(std::complex< float > *blockB, const std::complex< float > *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:292
Definition: SelfadjointMatrixMatrix.h:100
Definition: datatypes.h:12
EIGEN_DONT_INLINE Scalar zero()
Definition: svd_common.h:232
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
Definition: ZVector/PacketMath.h:50