11 #ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12 #define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
14 #ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15 #define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
18 #if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
19 #define EIGEN_ALTIVEC_DISABLE_MMA 0
23 #if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__has_builtin)
24 #if __has_builtin(__builtin_mma_assemble_acc)
25 #define EIGEN_ALTIVEC_MMA_SUPPORT
30 #if defined(EIGEN_ALTIVEC_MMA_SUPPORT)
32 #if !defined(EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH)
33 #define EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH 0
37 #if EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH && !EIGEN_COMP_LLVM
38 #define EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH 1
40 #elif defined(__MMA__)
41 #define EIGEN_ALTIVEC_MMA_ONLY 1
48 #if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
53 #include "../../InternalHeaderCheck.h"
62 template <
typename Scalar>
90 const static Packet16uc p16uc_GETREAL32 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
92 const static Packet16uc p16uc_GETIMAG32 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
94 const static Packet16uc p16uc_GETREAL32b = {0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
96 const static Packet16uc p16uc_GETIMAG32b = {4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31};
116 template <
typename Scalar,
int StorageOrder>
119 std::complex<Scalar>
v;
133 template <
typename Scalar,
int StorageOrder,
int N>
139 const Index vectorDelta = vectorSize *
rows;
142 Index rir = 0, rii,
j = 0;
143 for (;
j + vectorSize <=
cols;
j += vectorSize) {
144 rii = rir + vectorDelta;
146 for (
Index i = k2;
i < depth;
i++) {
147 for (
Index k = 0;
k < vectorSize;
k++) {
148 std::complex<Scalar>
v = getAdjointVal<Scalar, StorageOrder>(
i,
j +
k, rhs);
150 blockBf[rir +
k] =
v.real();
151 blockBf[rii +
k] =
v.imag();
163 for (
Index i = k2;
i < depth;
i++) {
164 std::complex<Scalar>
v = getAdjointVal<Scalar, StorageOrder>(
i,
j, rhs);
166 blockBf[rir] =
v.real();
167 blockBf[rii] =
v.imag();
177 template <
typename Scalar,
int StorageOrder>
183 const Index vectorDelta = vectorSize * depth;
186 Index rir = 0, rii,
j = 0;
187 for (;
j + vectorSize <=
rows;
j += vectorSize) {
188 rii = rir + vectorDelta;
191 for (
Index k = 0;
k < vectorSize;
k++) {
192 std::complex<Scalar>
v = getAdjointVal<Scalar, StorageOrder>(
j +
k,
i, lhs);
194 blockAf[rir +
k] =
v.real();
195 blockAf[rii +
k] =
v.imag();
205 rii = rir + ((
rows -
j) * depth);
210 std::complex<Scalar>
v = getAdjointVal<Scalar, StorageOrder>(
k,
i, lhs);
212 blockAf[rir] =
v.real();
213 blockAf[rii] =
v.imag();
222 template <
typename Scalar,
int StorageOrder,
int N>
230 for (;
j +
N * vectorSize <=
cols;
j +=
N * vectorSize) {
232 for (;
i < depth;
i++) {
233 for (
Index k = 0;
k <
N * vectorSize;
k++) {
235 blockB[ri +
k] = rhs(
j +
k,
i);
237 blockB[ri +
k] = rhs(
i,
j +
k);
239 ri +=
N * vectorSize;
244 for (
Index i = k2;
i < depth;
i++) {
246 blockB[ri] = rhs(
i,
j);
248 blockB[ri] = rhs(
j,
i);
254 template <
typename Scalar,
int StorageOrder>
262 for (;
j + vectorSize <=
rows;
j += vectorSize) {
265 for (;
i < depth;
i++) {
266 for (
Index k = 0;
k < vectorSize;
k++) {
268 blockA[ri +
k] = lhs(
j +
k,
i);
270 blockA[ri +
k] = lhs(
i,
j +
k);
281 blockA[ri] = lhs(
k,
i);
283 blockA[ri] = lhs(
i,
k);
290 template <
typename Index,
int nr,
int StorageOrder>
294 symm_pack_complex_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride,
rows,
cols, k2);
298 template <
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
302 symm_pack_complex_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride,
cols,
rows);
308 template <
typename Index,
int nr,
int StorageOrder>
312 symm_pack_complex_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride,
rows,
cols, k2);
316 template <
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
320 symm_pack_complex_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride,
cols,
rows);
325 template <
typename Index,
int nr,
int StorageOrder>
328 symm_pack_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride,
rows,
cols, k2);
332 template <
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
335 symm_pack_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride,
cols,
rows);
340 template <
typename Index,
int nr,
int StorageOrder>
343 symm_pack_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride,
rows,
cols, k2);
347 template <
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
350 symm_pack_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride,
cols,
rows);
365 template <
typename Scalar,
typename Packet,
int N>
368 pstore<Scalar>(to + (0 *
size),
block.packet[0]);
369 pstore<Scalar>(to + (1 *
size),
block.packet[1]);
371 pstore<Scalar>(to + (2 *
size),
block.packet[2]);
374 pstore<Scalar>(to + (3 *
size),
block.packet[3]);
379 template <
typename Scalar,
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
380 bool PanelMode,
bool UseLhs>
382 template <
bool transpose>
392 #ifdef EIGEN_VECTORIZE_VSX
393 t0 =
reinterpret_cast<Packet>(
395 t1 =
reinterpret_cast<Packet>(
397 t2 =
reinterpret_cast<Packet>(
399 t3 =
reinterpret_cast<Packet>(
408 block.packet[0] = t0;
409 block.packet[1] = t1;
410 block.packet[2] = t2;
411 block.packet[3] = t3;
425 for (;
i + vectorSize <= depth;
i += vectorSize) {
427 bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0,
i);
429 bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2,
i, 0);
432 if (((StorageOrder ==
RowMajor) && UseLhs) || (((StorageOrder ==
ColMajor) && !UseLhs))) {
447 storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
448 storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
450 rir += 4 * vectorSize;
451 rii += 4 * vectorSize;
458 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
459 Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
463 for (;
j + vectorSize <=
rows;
j += vectorSize) {
464 const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(
j, 0) : lhs.getSubMapper(0,
j);
467 rii = rir + vectorDelta;
469 dhs_ccopy(blockAt, lhs2,
i, rir, rii, depth, vectorSize);
471 for (;
i < depth;
i++) {
475 if (((StorageOrder ==
ColMajor) && UseLhs) || (((StorageOrder ==
RowMajor) && !UseLhs))) {
477 cblock.
packet[0] = lhs2.template loadPacket<PacketC>(0,
i);
478 cblock.
packet[1] = lhs2.template loadPacket<PacketC>(2,
i);
480 cblock.
packet[0] = lhs2.template loadPacket<PacketC>(
i, 0);
481 cblock.
packet[1] = lhs2.template loadPacket<PacketC>(
i, 2);
500 pstore<Scalar>(blockAt + rir, blockr.
packet[0]);
501 pstore<Scalar>(blockAt + rii, blocki.
packet[0]);
507 rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
511 if (PanelMode) rir -= (offset * (vectorSize - 1));
514 const DataMapper lhs2 = lhs.getSubMapper(0,
j);
515 rii = rir + ((PanelMode) ? stride : depth);
518 blockAt[rir] = lhs2(
i, 0).real();
521 blockAt[rii] = -lhs2(
i, 0).imag();
523 blockAt[rii] = lhs2(
i, 0).imag();
529 rir += ((PanelMode) ? (2 * stride - depth) : depth);
533 if (PanelMode) rir += (offset * (
rows -
j - vectorSize));
534 rii = rir + (((PanelMode) ? stride : depth) * (
rows -
j));
539 blockAt[rir] = lhs(
k,
i).real();
542 blockAt[rii] = -lhs(
k,
i).imag();
544 blockAt[rii] = lhs(
k,
i).imag();
556 template <
typename Scalar,
typename DataMapper,
typename Packet,
int StorageOrder,
bool PanelMode,
bool UseLhs>
560 const Index vectorSize) {
563 for (;
i +
n * vectorSize <= depth;
i +=
n * vectorSize) {
566 bload<DataMapper, Packet, 4, StorageOrder, false, 4>(
block[
k], lhs2, 0,
i +
k * vectorSize);
568 bload<DataMapper, Packet, 4, StorageOrder, false, 4>(
block[
k], lhs2,
i +
k * vectorSize, 0);
572 if (((StorageOrder ==
RowMajor) && UseLhs) || ((StorageOrder ==
ColMajor) && !UseLhs)) {
579 storeBlock<Scalar, Packet, 4>(blockA + ri +
k * 4 * vectorSize,
block[
k]);
582 ri +=
n * 4 * vectorSize;
591 for (;
j + vectorSize <=
rows;
j += vectorSize) {
592 const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(
j, 0) : lhs.getSubMapper(0,
j);
595 if (PanelMode) ri += vectorSize * offset;
597 dhs_copy<4>(blockA, lhs2,
i, ri, depth, vectorSize);
598 dhs_copy<2>(blockA, lhs2,
i, ri, depth, vectorSize);
599 dhs_copy<1>(blockA, lhs2,
i, ri, depth, vectorSize);
601 for (;
i < depth;
i++) {
602 if (((StorageOrder ==
RowMajor) && UseLhs) || ((StorageOrder ==
ColMajor) && !UseLhs)) {
604 blockA[ri + 0] = lhs2(0,
i);
605 blockA[ri + 1] = lhs2(1,
i);
606 blockA[ri + 2] = lhs2(2,
i);
607 blockA[ri + 3] = lhs2(3,
i);
609 blockA[ri + 0] = lhs2(
i, 0);
610 blockA[ri + 1] = lhs2(
i, 1);
611 blockA[ri + 2] = lhs2(
i, 2);
612 blockA[ri + 3] = lhs2(
i, 3);
617 lhsV = lhs2.template loadPacket<Packet>(0,
i);
619 lhsV = lhs2.template loadPacket<Packet>(
i, 0);
621 pstore<Scalar>(blockA + ri, lhsV);
627 if (PanelMode) ri += vectorSize * (stride - offset - depth);
631 if (PanelMode) ri += offset;
634 const DataMapper lhs2 = lhs.getSubMapper(0,
j);
636 blockA[ri] = lhs2(
i, 0);
640 if (PanelMode) ri += stride - depth;
644 if (PanelMode) ri += offset * (
rows -
j);
649 blockA[ri] = lhs(
k,
i);
659 template <
typename DataMapper,
int StorageOrder,
bool PanelMode>
663 const Index vectorSize) {
666 for (;
i +
n * vectorSize <= depth;
i +=
n * vectorSize) {
669 block[
k].packet[0] = lhs2.template loadPacket<Packet2d>(0,
i +
k * vectorSize);
670 block[
k].packet[1] = lhs2.template loadPacket<Packet2d>(1,
i +
k * vectorSize);
672 block[
k].packet[0] = lhs2.template loadPacket<Packet2d>(0,
i +
k * vectorSize + 0);
673 block[
k].packet[1] = lhs2.template loadPacket<Packet2d>(0,
i +
k * vectorSize + 1);
684 storeBlock<double, Packet2d, 2>(blockA + ri +
k * 2 * vectorSize,
block[
k]);
687 ri +=
n * 2 * vectorSize;
696 for (;
j + vectorSize <=
rows;
j += vectorSize) {
697 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
700 if (PanelMode) ri += vectorSize * offset;
702 dhs_copy<4>(blockA, lhs2,
i, ri, depth, vectorSize);
703 dhs_copy<2>(blockA, lhs2,
i, ri, depth, vectorSize);
704 dhs_copy<1>(blockA, lhs2,
i, ri, depth, vectorSize);
706 for (;
i < depth;
i++) {
708 blockA[ri + 0] = lhs2(0,
i);
709 blockA[ri + 1] = lhs2(1,
i);
711 Packet2d lhsV = lhs2.template loadPacket<Packet2d>(0,
i);
718 if (PanelMode) ri += vectorSize * (stride - offset - depth);
722 if (PanelMode) ri += offset * (
rows -
j);
727 blockA[ri] = lhs(
k,
i);
736 template <
typename DataMapper,
int StorageOrder,
bool PanelMode>
740 const Index vectorSize) {
744 for (;
i +
n * vectorSize <= depth;
i +=
n * vectorSize) {
747 block1[
k].
packet[0] = rhs2.template loadPacket<Packet2d>(
i +
k * vectorSize, 0);
748 block1[
k].
packet[1] = rhs2.template loadPacket<Packet2d>(
i +
k * vectorSize, 1);
749 block2[
k].
packet[0] = rhs2.template loadPacket<Packet2d>(
i +
k * vectorSize, 2);
750 block2[
k].
packet[1] = rhs2.template loadPacket<Packet2d>(
i +
k * vectorSize, 3);
752 block3[
k].
packet[0] = rhs2.template loadPacket<Packet2d>(
i +
k * vectorSize + 0, 0);
753 block3[
k].
packet[1] = rhs2.template loadPacket<Packet2d>(
i +
k * vectorSize + 0, 2);
754 block3[
k].
packet[2] = rhs2.template loadPacket<Packet2d>(
i +
k * vectorSize + 1, 0);
755 block3[
k].
packet[3] = rhs2.template loadPacket<Packet2d>(
i +
k * vectorSize + 1, 2);
773 storeBlock<double, Packet2d, 4>(blockB + ri +
k * 4 * vectorSize, block3[
k]);
777 ri +=
n * 4 * vectorSize;
786 for (;
j + 2 * vectorSize <=
cols;
j += 2 * vectorSize) {
787 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
790 if (PanelMode) ri += offset * (2 * vectorSize);
792 dhs_copy<4>(blockB, rhs2,
i, ri, depth, vectorSize);
793 dhs_copy<2>(blockB, rhs2,
i, ri, depth, vectorSize);
794 dhs_copy<1>(blockB, rhs2,
i, ri, depth, vectorSize);
796 for (;
i < depth;
i++) {
798 blockB[ri + 0] = rhs2(
i, 0);
799 blockB[ri + 1] = rhs2(
i, 1);
803 blockB[ri + 0] = rhs2(
i, 2);
804 blockB[ri + 1] = rhs2(
i, 3);
806 Packet2d rhsV = rhs2.template loadPacket<Packet2d>(
i, 0);
811 rhsV = rhs2.template loadPacket<Packet2d>(
i, 2);
817 if (PanelMode) ri += (2 * vectorSize) * (stride - offset - depth);
820 if (PanelMode) ri += offset;
823 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
825 blockB[ri] = rhs2(
i, 0);
829 if (PanelMode) ri += stride - depth;
835 template <
typename DataMapper,
int StorageOrder,
bool PanelMode>
842 for (;
j + 2 * vectorSize <=
rows;
j += 2 * vectorSize) {
843 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
846 if (PanelMode) ri += 2 * vectorSize * offset;
849 for (;
i + 2 <= depth;
i += 2) {
852 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 0);
853 block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize,
i + 0);
854 block.packet[2] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 1);
855 block.packet[3] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize,
i + 1);
858 t0 = vec_mergeh(
block.packet[0].m_val,
block.packet[2].m_val);
859 t1 = vec_mergel(
block.packet[0].m_val,
block.packet[2].m_val);
860 block.packet[2] = vec_mergeh(
block.packet[1].m_val,
block.packet[3].m_val);
861 block.packet[3] = vec_mergel(
block.packet[1].m_val,
block.packet[3].m_val);
862 block.packet[0] = t0;
863 block.packet[1] = t1;
865 storeBlock<bfloat16, Packet8bf, 4>(blockA + ri,
block);
867 ri += 2 * 2 * vectorSize;
872 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 0);
873 block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize,
i + 0);
875 storeBlock<bfloat16, Packet8bf, 2>(blockA + ri,
block);
877 ri += 2 * vectorSize;
880 for (;
i + vectorSize <= depth;
i += vectorSize) {
883 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize,
i);
884 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block2, lhs2, 1 * vectorSize,
i);
921 #ifdef EIGEN_VECTORIZE_VSX
980 ri += 2 * vectorSize * vectorSize;
982 for (;
i + 2 <= depth;
i += 2) {
983 for (
Index M = 0;
M < 2 * vectorSize;
M++) {
984 blockA[ri + (
M * 2) + 0] = lhs2(
M,
i + 0);
985 blockA[ri + (
M * 2) + 1] = lhs2(
M,
i + 1);
988 ri += 2 * 2 * vectorSize;
991 for (
Index M = 0;
M < 2 * vectorSize;
M++) {
992 blockA[ri +
M] = lhs2(
M,
i);
994 ri += 2 * vectorSize;
998 if (PanelMode) ri += 2 * vectorSize * (stride - offset - depth);
1000 for (;
j + vectorSize <=
rows;
j += vectorSize) {
1001 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
1004 if (PanelMode) ri += vectorSize * offset;
1007 for (;
i + 2 <= depth;
i += 2) {
1010 block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 0);
1011 block.packet[1] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 1);
1014 t0 = vec_mergeh(
block.packet[0].m_val,
block.packet[1].m_val);
1015 block.packet[1] = vec_mergel(
block.packet[0].m_val,
block.packet[1].m_val);
1016 block.packet[0] = t0;
1018 storeBlock<bfloat16, Packet8bf, 2>(blockA + ri,
block);
1020 ri += 2 * vectorSize;
1023 Packet8bf lhsV = lhs2.template loadPacket<Packet8bf>(0 * vectorSize,
i + 0);
1029 for (;
i + vectorSize <= depth;
i += vectorSize) {
1032 bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize,
i);
1054 #ifdef EIGEN_VECTORIZE_VSX
1086 ri += vectorSize * vectorSize;
1088 for (;
i + 2 <= depth;
i += 2) {
1089 for (
Index M = 0;
M < vectorSize;
M++) {
1090 blockA[ri + (
M * 2) + 0] = lhs2(
M,
i + 0);
1091 blockA[ri + (
M * 2) + 1] = lhs2(
M,
i + 1);
1094 ri += 2 * vectorSize;
1097 for (
Index M = 0;
M < vectorSize;
M++) {
1098 blockA[ri +
M] = lhs2(
M,
i);
1105 if (PanelMode) ri += vectorSize * (stride - offset - depth);
1107 if (
j + 4 <=
rows) {
1108 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
1111 if (PanelMode) ri += 4 * offset;
1113 for (;
i + 2 <= depth;
i += 2) {
1117 block.packet[0] = lhs2.template loadPacketPartial<Packet8bf>(0,
i + 0, 4);
1118 block.packet[1] = lhs2.template loadPacketPartial<Packet8bf>(0,
i + 1, 4);
1120 block.packet[0] = vec_mergeh(
block.packet[0].m_val,
block.packet[1].m_val);
1124 blockA[ri + 0] = lhs2(0,
i + 0);
1125 blockA[ri + 1] = lhs2(0,
i + 1);
1126 blockA[ri + 2] = lhs2(1,
i + 0);
1127 blockA[ri + 3] = lhs2(1,
i + 1);
1128 blockA[ri + 4] = lhs2(2,
i + 0);
1129 blockA[ri + 5] = lhs2(2,
i + 1);
1130 blockA[ri + 6] = lhs2(3,
i + 0);
1131 blockA[ri + 7] = lhs2(3,
i + 1);
1138 Packet8bf lhsV = lhs2.template loadPacketPartial<Packet8bf>(0,
i + 0, 4);
1142 blockA[ri + 0] = lhs2(0,
i);
1143 blockA[ri + 1] = lhs2(1,
i);
1144 blockA[ri + 2] = lhs2(2,
i);
1145 blockA[ri + 3] = lhs2(3,
i);
1151 if (PanelMode) ri += 4 * (stride - offset - depth);
1156 if (PanelMode) ri += offset * (
rows -
j);
1159 for (;
i + 2 <= depth;
i += 2) {
1162 blockA[ri + 0] = lhs(
k,
i + 0);
1163 blockA[ri + 1] = lhs(
k,
i + 1);
1169 blockA[ri] = lhs(
j,
i);
1178 template <
typename DataMapper,
int StorageOrder,
bool PanelMode>
1185 for (;
j + 4 <=
cols;
j += 4) {
1186 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
1189 if (PanelMode) ri += 4 * offset;
1191 for (;
i + vectorSize <= depth;
i += vectorSize) {
1195 bload<DataMapper, Packet8bf, 4, StorageOrder, false, 4>(
block, rhs2,
i, 0);
1199 t0 = vec_mergeh(
reinterpret_cast<Packet4ui>(
block.packet[0].m_val),
1201 t1 = vec_mergel(
reinterpret_cast<Packet4ui>(
block.packet[0].m_val),
1203 t2 = vec_mergeh(
reinterpret_cast<Packet4ui>(
block.packet[2].m_val),
1205 t3 = vec_mergel(
reinterpret_cast<Packet4ui>(
block.packet[2].m_val),
1208 #ifdef EIGEN_VECTORIZE_VSX
1224 storeBlock<bfloat16, Packet8bf, 4>(blockB + ri,
block);
1228 for (
int M = 0;
M < 8;
M++) {
1229 block.packet[
M] = rhs2.template loadPacketPartial<Packet8bf>(
i +
M, 0, 4);
1232 block.packet[0] = vec_mergeh(
block.packet[0].m_val,
block.packet[1].m_val);
1233 block.packet[1] = vec_mergeh(
block.packet[2].m_val,
block.packet[3].m_val);
1234 block.packet[2] = vec_mergeh(
block.packet[4].m_val,
block.packet[5].m_val);
1235 block.packet[3] = vec_mergeh(
block.packet[6].m_val,
block.packet[7].m_val);
1239 for (
int M = 0;
M < 4;
M++) {
1244 ri += 4 * vectorSize;
1246 for (;
i + 2 <= depth;
i += 2) {
1248 blockB[ri + 0] = rhs2(
i + 0, 0);
1249 blockB[ri + 1] = rhs2(
i + 1, 0);
1250 blockB[ri + 2] = rhs2(
i + 0, 1);
1251 blockB[ri + 3] = rhs2(
i + 1, 1);
1252 blockB[ri + 4] = rhs2(
i + 0, 2);
1253 blockB[ri + 5] = rhs2(
i + 1, 2);
1254 blockB[ri + 6] = rhs2(
i + 0, 3);
1255 blockB[ri + 7] = rhs2(
i + 1, 3);
1259 for (
int M = 0;
M < 2;
M++) {
1260 block.packet[
M] = rhs2.template loadPacketPartial<Packet8bf>(
i +
M, 0, 4);
1263 block.packet[0] = vec_mergeh(
block.packet[0].m_val,
block.packet[1].m_val);
1271 blockB[ri + 0] = rhs2(
i, 0);
1272 blockB[ri + 1] = rhs2(
i, 1);
1273 blockB[ri + 2] = rhs2(
i, 2);
1274 blockB[ri + 3] = rhs2(
i, 3);
1279 if (PanelMode) ri += 4 * (stride - offset - depth);
1283 if (PanelMode) ri += offset * (
cols -
j);
1286 for (;
i + 2 <= depth;
i += 2) {
1289 blockB[ri + 0] = rhs(
i + 0,
k);
1290 blockB[ri + 1] = rhs(
i + 1,
k);
1296 blockB[ri] = rhs(
i,
j);
1305 template <
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
bool PanelMode>
1312 for (;
i + vectorSize <= depth;
i += vectorSize) {
1314 cblock.
packet[0] = lhs2.template loadPacket<PacketC>(0,
i + 0);
1315 cblock.
packet[1] = lhs2.template loadPacket<PacketC>(0,
i + 1);
1317 cblock.
packet[2] = lhs2.template loadPacket<PacketC>(1,
i + 0);
1318 cblock.
packet[3] = lhs2.template loadPacket<PacketC>(1,
i + 1);
1326 cblock.
packet[0] = lhs2.template loadPacket<PacketC>(0,
i);
1327 cblock.
packet[1] = lhs2.template loadPacket<PacketC>(1,
i);
1329 cblock.
packet[2] = lhs2.template loadPacket<PacketC>(0,
i + 1);
1330 cblock.
packet[3] = lhs2.template loadPacket<PacketC>(1,
i + 1);
1344 storeBlock<double, Packet, 2>(blockAt + rir, blockr);
1345 storeBlock<double, Packet, 2>(blockAt + rii, blocki);
1347 rir += 2 * vectorSize;
1348 rii += 2 * vectorSize;
1355 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
1356 Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
1357 double* blockAt =
reinterpret_cast<double*
>(blockA);
1360 for (;
j + vectorSize <=
rows;
j += vectorSize) {
1361 const DataMapper lhs2 = lhs.getSubMapper(
j, 0);
1364 rii = rir + vectorDelta;
1366 dhs_ccopy(blockAt, lhs2,
i, rir, rii, depth, vectorSize);
1368 for (;
i < depth;
i++) {
1372 cblock.
packet[0] = lhs2.template loadPacket<PacketC>(0,
i);
1373 cblock.
packet[1] = lhs2.template loadPacket<PacketC>(1,
i);
1389 rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
1393 if (PanelMode) rir += (offset * (
rows -
j - vectorSize));
1394 rii = rir + (((PanelMode) ? stride : depth) * (
rows -
j));
1396 for (
Index i = 0;
i < depth;
i++) {
1399 blockAt[rir] = lhs(
k,
i).real();
1402 blockAt[rii] = -lhs(
k,
i).imag();
1404 blockAt[rii] = lhs(
k,
i).imag();
1415 template <
typename DataMapper,
typename Packet,
typename PacketC,
int StorageOrder,
bool Conjugate,
bool PanelMode>
1419 for (;
i < depth;
i++) {
1423 bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2,
i, 0);
1436 storeBlock<double, Packet, 2>(blockBt + rir, blockr);
1437 storeBlock<double, Packet, 2>(blockBt + rii, blocki);
1439 rir += 2 * vectorSize;
1440 rii += 2 * vectorSize;
1447 const Index vectorDelta = 2 * vectorSize * ((PanelMode) ? stride : depth);
1448 Index rir = ((PanelMode) ? (2 * vectorSize * offset) : 0), rii;
1449 double* blockBt =
reinterpret_cast<double*
>(blockB);
1452 for (;
j + 2 * vectorSize <=
cols;
j += 2 * vectorSize) {
1453 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
1456 rii = rir + vectorDelta;
1458 dhs_ccopy(blockBt, rhs2,
i, rir, rii, depth, vectorSize);
1460 rir += ((PanelMode) ? (2 * vectorSize * (2 * stride - depth)) : vectorDelta);
1463 if (PanelMode) rir -= (offset * (2 * vectorSize - 1));
1466 const DataMapper rhs2 = rhs.getSubMapper(0,
j);
1467 rii = rir + ((PanelMode) ? stride : depth);
1469 for (
Index i = 0;
i < depth;
i++) {
1470 blockBt[rir] = rhs2(
i, 0).real();
1473 blockBt[rii] = -rhs2(
i, 0).imag();
1475 blockBt[rii] = rhs2(
i, 0).imag();
1481 rir += ((PanelMode) ? (2 * stride - depth) : depth);
1491 template <
typename Packet,
bool NegativeAccumulate,
int N>
1493 if (NegativeAccumulate) {
1494 for (
int M = 0;
M <
N;
M++) {
1498 for (
int M = 0;
M <
N;
M++) {
1504 template <
int N,
typename Scalar,
typename Packet,
bool NegativeAccumulate>
1506 Packet lhsV = pload<Packet>(lhs);
1508 pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1513 template <
int N,
typename Packet,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
1516 pger_common<Packet, false, N>(accReal, lhsV, rhsV);
1518 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1522 pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
1523 pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1527 pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
1531 template <
int N,
typename Scalar,
typename Packet,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
1534 Packet lhsV = ploadLhs<Packet>(lhs_ptr);
1537 lhsVi = ploadLhs<Packet>(lhs_ptr_imag);
1541 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1544 template <
typename Packet>
1546 return ploadu<Packet>(lhs);
1550 template <
typename Packet,
int N>
1552 for (
int M = 0;
M <
N;
M++) {
1557 template <
typename Packet,
int N>
1560 for (
int M = 0;
M <
N;
M++) {
1565 template <
typename Packet,
int N>
1567 for (
int M = 0;
M <
N;
M++) {
1573 template <
typename Packet,
int N,
bool mask>
1578 band<Packet, N>(aReal, pMask);
1579 band<Packet, N>(aImag, pMask);
1584 bscalec_common<Packet, N>(cReal, aReal, bReal);
1586 bscalec_common<Packet, N>(cImag, aImag, bReal);
1588 pger_common<Packet, true, N>(&cReal, bImag, aImag.
packet);
1590 pger_common<Packet, false, N>(&cImag, bImag, aReal.
packet);
1596 template <
typename DataMapper,
typename Packet, const Index accCols,
int StorageOrder,
bool Complex,
int N,
bool full>
1600 for (
int M = 0;
M <
N;
M++) {
1601 acc.packet[
M] =
res.template loadPacket<Packet>(
row +
M,
col);
1604 for (
int M = 0;
M <
N;
M++) {
1605 acc.packet[
M +
N] =
res.template loadPacket<Packet>(
row +
M,
col + accCols);
1609 for (
int M = 0;
M <
N;
M++) {
1610 acc.packet[
M] =
res.template loadPacket<Packet>(
row,
col +
M);
1613 for (
int M = 0;
M <
N;
M++) {
1614 acc.packet[
M +
N] =
res.template loadPacket<Packet>(
row + accCols,
col +
M);
1620 template <
typename DataMapper,
typename Packet,
int N>
1622 for (
int M = 0;
M <
N;
M++) {
1627 #ifdef USE_PARTIAL_PACKETS
1628 template <
typename DataMapper,
typename Packet, const Index accCols,
bool Complex, Index N,
bool full>
1632 acc.packet[
M] =
res.template loadPacketPartial<Packet>(
row,
M, elements);
1636 acc.packet[
M +
N] =
res.template loadPacketPartial<Packet>(
row + accCols,
M, elements);
1641 template <
typename DataMapper,
typename Packet, Index N>
1644 res.template storePacketPartial<Packet>(
row,
M, acc.
packet[
M], elements);
1650 #define USE_P10_AND_PVIPR2_0 (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
1652 #define USE_P10_AND_PVIPR2_0 0
1655 #if !USE_P10_AND_PVIPR2_0
1656 const static Packet4i mask4[4] = {{0, 0, 0, 0}, {-1, 0, 0, 0}, {-1, -1, 0, 0}, {-1, -1, -1, 0}};
1659 template <
typename Packet>
1661 #if USE_P10_AND_PVIPR2_0
1663 return Packet(vec_reve(vec_genwm((1 << remaining_rows) - 1)));
1665 return Packet(vec_genwm((1 << remaining_rows) - 1));
1674 #if USE_P10_AND_PVIPR2_0
1687 template <
typename Packet,
int N>
1689 for (
int M = 0;
M <
N;
M++) {
1695 template <
typename Packet,
int N,
bool mask>
1699 band<Packet, N>(accZ, pMask);
1704 bscale<Packet, N>(acc, accZ, pAlpha);
1707 template <
typename Packet,
int N,
bool real>
1711 a0 = pset1<Packet>(ap0[0]);
1713 a1 = pset1<Packet>(ap0[1]);
1714 a2 = pset1<Packet>(ap0[2]);
1715 a3 = pset1<Packet>(ap0[3]);
1720 a1 = pset1<Packet>(ap1[0]);
1726 a2 = pset1<Packet>(ap2[0]);
1751 a0 = vec_splat(a1, 0);
1752 a1 = vec_splat(a1, 1);
1753 a2 = vec_splat(a3, 0);
1754 a3 = vec_splat(a3, 1);
1758 template <
typename Packet,
typename Packetc,
int N,
bool full>
1761 for (
int M = 0;
M <
N;
M++) {
1766 for (
int M = 0;
M <
N;
M++) {
1772 template <
typename Packet,
typename Packetc,
int N,
bool full>
1776 bcouple_common<Packet, Packetc, N, full>(taccReal, taccImag, acc1, acc2);
1778 for (
int M = 0;
M <
N;
M++) {
1783 for (
int M = 0;
M <
N;
M++) {
1793 #define MICRO_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1795 #define MICRO_NORMAL_ROWS accRows == quad_traits<Scalar>::rows || accRows == 1
1797 #define MICRO_NEW_ROWS ((MICRO_NORMAL_ROWS) ? accRows : 1)
1799 #define MICRO_RHS(ptr, N) rhs_##ptr##N
1801 #define MICRO_ZERO_PEEL(peel) \
1802 if ((PEEL_ROW > peel) && (peel != 0)) { \
1803 bsetzero<Packet, accRows>(accZero##peel); \
1805 EIGEN_UNUSED_VARIABLE(accZero##peel); \
1808 #define MICRO_ADD(ptr, N) \
1809 if (MICRO_NORMAL_ROWS) { \
1810 MICRO_RHS(ptr, 0) += (accRows * N); \
1812 MICRO_RHS(ptr, 0) += N; \
1813 MICRO_RHS(ptr, 1) += N; \
1814 if (accRows == 3) { \
1815 MICRO_RHS(ptr, 2) += N; \
1819 #define MICRO_ADD_ROWS(N) MICRO_ADD(ptr, N)
1821 #define MICRO_BROADCAST1(peel, ptr, rhsV, real) \
1822 if (MICRO_NORMAL_ROWS) { \
1823 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + (accRows * peel), MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 0), \
1824 rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1826 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + peel, MICRO_RHS(ptr, 1) + peel, MICRO_RHS(ptr, 2) + peel, \
1827 rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1830 #define MICRO_BROADCAST(peel) MICRO_BROADCAST1(peel, ptr, rhsV, true)
1832 #define MICRO_BROADCAST_EXTRA1(ptr, rhsV, real) \
1833 pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 1), MICRO_RHS(ptr, 2), rhsV[0], rhsV[1], \
1836 #define MICRO_BROADCAST_EXTRA \
1838 MICRO_BROADCAST_EXTRA1(ptr, rhsV, true) \
1841 #define MICRO_SRC2(ptr, N, M) \
1842 if (MICRO_NORMAL_ROWS) { \
1843 EIGEN_UNUSED_VARIABLE(strideB); \
1844 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 1)); \
1845 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1847 MICRO_RHS(ptr, 1) = rhs_base + N + M; \
1848 if (accRows == 3) { \
1849 MICRO_RHS(ptr, 2) = rhs_base + N * 2 + M; \
1851 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1855 #define MICRO_SRC2_PTR MICRO_SRC2(ptr, strideB, 0)
1857 #define MICRO_ZERO_PEEL_ROW MICRO_UNROLL(MICRO_ZERO_PEEL)
1859 #define MICRO_WORK_PEEL(peel) \
1860 if (PEEL_ROW > peel) { \
1861 MICRO_BROADCAST(peel) \
1862 pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
1864 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1867 #define MICRO_WORK_PEEL_ROW \
1868 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
1869 MICRO_UNROLL(MICRO_WORK_PEEL) \
1870 lhs_ptr += (remaining_rows * PEEL_ROW); \
1871 MICRO_ADD_ROWS(PEEL_ROW)
1873 #define MICRO_ADD_PEEL(peel, sum) \
1874 if (PEEL_ROW > peel) { \
1875 for (Index i = 0; i < accRows; i++) { \
1876 accZero##sum.packet[i] += accZero##peel.packet[i]; \
1880 #define MICRO_ADD_PEEL_ROW \
1881 MICRO_ADD_PEEL(4, 0) \
1882 MICRO_ADD_PEEL(5, 1) \
1883 MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
1885 #define MICRO_PREFETCHN1(ptr, N) \
1886 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 0)); \
1887 if (N == 2 || N == 3) { \
1888 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 1)); \
1890 EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 2)); \
1894 #define MICRO_PREFETCHN(N) MICRO_PREFETCHN1(ptr, N)
1896 #define MICRO_COMPLEX_PREFETCHN(N) \
1897 MICRO_PREFETCHN1(ptr_real, N); \
1899 MICRO_PREFETCHN1(ptr_imag, N); \
1902 template <
typename Scalar,
typename Packet, const Index accRows, const Index remaining_rows>
1906 pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1907 lhs_ptr += remaining_rows;
1910 template <
typename Scalar,
typename Packet,
typename DataMapper,
const Index accRows,
const Index accCols,
1911 const Index remaining_rows>
1916 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
1917 const Scalar* lhs_ptr = lhs_base +
row * strideA + remaining_rows * offsetA;
1921 bsetzero<Packet, accRows>(accZero0);
1934 for (;
k < depth;
k++) {
1935 MICRO_EXTRA_ROW<Scalar, Packet, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
1938 #ifdef USE_PARTIAL_PACKETS
1941 bload_partial<DataMapper, Packet, 0, false, accRows>(acc,
res,
row, remaining_rows);
1942 bscale<Packet, accRows>(acc, accZero0, pAlpha);
1943 bstore_partial<DataMapper, Packet, accRows>(acc,
res,
row, remaining_rows);
1945 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc,
res,
row, 0);
1946 if ((accRows == 1) || (
rows >= accCols)) {
1947 bscale<Packet, accRows, true>(acc, accZero0, pAlpha, pMask);
1948 bstore<DataMapper, Packet, accRows>(acc,
res,
row);
1950 bscale<Packet, accRows, false>(acc, accZero0, pAlpha, pMask);
1951 for (
Index j = 0;
j < accRows;
j++) {
1952 for (
Index i = 0;
i < remaining_rows;
i++) {
1960 #define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
1963 MICRO_EXTRA_UNROLL(1) \
1966 if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1967 MICRO_EXTRA_UNROLL(2) \
1971 if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1972 MICRO_EXTRA_UNROLL(3) \
1977 #define MICRO_EXTRA_ROWS(N) \
1978 gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, accRows, accCols, N>( \
1979 res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
1981 template <
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols>
1988 #define MICRO_UNROLL_WORK(func, func2, peel) \
1989 MICRO_UNROLL(func2); \
1990 func(0, peel) func(1, peel) func(2, peel) func(3, peel) func(4, peel) func(5, peel) func(6, peel) func(7, peel)
1992 #define MICRO_WORK_ONE(iter, peel) \
1993 if (unroll_factor > iter) { \
1994 pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
1997 #define MICRO_TYPE_PEEL4(func, func2, peel) \
1998 if (PEEL > peel) { \
1999 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
2000 MICRO_BROADCAST(peel) \
2001 MICRO_UNROLL_WORK(func, func2, peel) \
2003 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2006 #define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2007 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
2008 func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3) func(func1, func2, 4) \
2009 func(func1, func2, 5) func(func1, func2, 6) func(func1, func2, 7)
2011 #define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
2013 func(func1, func2, 0)
2015 #define MICRO_UNROLL_TYPE(MICRO_TYPE, size) \
2016 MICRO_TYPE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE) \
2017 MICRO_ADD_ROWS(size)
2019 #define MICRO_ONE_PEEL4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_PEEL, PEEL)
2021 #define MICRO_ONE4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_ONE, 1)
2023 #define MICRO_DST_PTR_ONE(iter) \
2024 if (unroll_factor > iter) { \
2025 bsetzero<Packet, accRows>(accZero##iter); \
2027 EIGEN_UNUSED_VARIABLE(accZero##iter); \
2030 #define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
2032 #define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
2034 #define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
2036 #ifdef USE_PARTIAL_PACKETS
2037 #define MICRO_STORE_ONE(iter) \
2038 if (unroll_factor > iter) { \
2039 if (MICRO_NORMAL_PARTIAL(iter)) { \
2040 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2041 bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2042 bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2044 bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row + iter * accCols, accCols2); \
2045 bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2046 bstore_partial<DataMapper, Packet, accRows>(acc, res, row + iter * accCols, accCols2); \
2050 #define MICRO_STORE_ONE(iter) \
2051 if (unroll_factor > iter) { \
2052 bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2053 bscale<Packet, accRows, !(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
2054 bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2058 #define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
2060 #ifdef USE_PARTIAL_PACKETS
2061 template <
int unroll_factor,
typename Scalar,
typename Packet,
typename DataMapper,
const Index accRows,
2062 const Index accCols,
bool full>
2064 template <
int unroll_factor,
typename Scalar,
typename Packet,
typename DataMapper,
const Index accRows,
2070 #ifdef USE_PARTIAL_PACKETS
2076 const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
2077 const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
2078 *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
2092 for (;
k < depth;
k++) {
2100 #ifdef USE_PARTIAL_PACKETS
2101 #define MICRO_UNROLL_ITER2(N, M) \
2102 gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, !M>( \
2103 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, M ? remaining_rows : accCols); \
2106 #define MICRO_UNROLL_ITER2(N, M) \
2107 gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, M ? M : accCols>( \
2108 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
2112 template <
typename Scalar,
typename Packet,
typename DataMapper, const Index accRows, const Index accCols>
2116 const DataMapper res3 =
res.getSubMapper(0,
col);
2119 const Scalar* lhs_base = blockA + accCols * offsetA;
2122 #define MAX_UNROLL 7
2126 switch ((
rows -
row) / accCols) {
2167 if (remaining_rows > 0) {
2168 gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA,
2169 strideB,
row,
rows, remaining_rows, pAlpha, pMask);
2173 #define MICRO_EXTRA_COLS(N) \
2174 gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, \
2175 col, rows, remaining_rows, pAlpha, pMask);
2177 template <
typename Scalar,
typename Packet,
typename DataMapper, const Index accCols>
2188 template <
typename Scalar,
typename Packet,
typename RhsPacket,
typename DataMapper,
const Index accRows,
2189 const Index accCols>
2193 const Index remaining_rows =
rows % accCols;
2195 if (strideA == -1) strideA = depth;
2196 if (strideB == -1) strideB = depth;
2199 const Packet pMask = bmask<Packet>(remaining_rows);
2202 for (;
col + accRows <=
cols;
col += accRows) {
2203 gemm_cols<Scalar, Packet, DataMapper, accRows, accCols>(
res, blockA, blockB, depth, strideA, offsetA, strideB,
2204 offsetB,
col,
rows, remaining_rows, pAlpha, pMask);
2208 gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
2213 #define accColsC (accCols / 2)
2214 #define advanceRows ((LhsIsReal) ? 1 : 2)
2215 #define advanceCols ((RhsIsReal) ? 1 : 2)
2218 #define PEEL_COMPLEX 3
2219 #define PEEL_COMPLEX_ROW 3
2221 #define MICRO_COMPLEX_UNROLL(func) func(0) func(1) func(2) func(3)
2223 #define MICRO_COMPLEX_ZERO_PEEL(peel) \
2224 if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
2225 bsetzero<Packet, accRows>(accReal##peel); \
2226 bsetzero<Packet, accRows>(accImag##peel); \
2228 EIGEN_UNUSED_VARIABLE(accReal##peel); \
2229 EIGEN_UNUSED_VARIABLE(accImag##peel); \
2232 #define MICRO_COMPLEX_ADD_ROWS(N, used) \
2233 MICRO_ADD(ptr_real, N) \
2235 MICRO_ADD(ptr_imag, N) \
2236 } else if (used) { \
2237 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2238 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2239 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2242 #define MICRO_COMPLEX_BROADCAST(peel) \
2243 MICRO_BROADCAST1(peel, ptr_real, rhsV, false) \
2245 MICRO_BROADCAST1(peel, ptr_imag, rhsVi, false) \
2247 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2250 #define MICRO_COMPLEX_BROADCAST_EXTRA \
2251 Packet rhsV[4], rhsVi[4]; \
2252 MICRO_BROADCAST_EXTRA1(ptr_real, rhsV, false) \
2254 MICRO_BROADCAST_EXTRA1(ptr_imag, rhsVi, false) \
2256 EIGEN_UNUSED_VARIABLE(rhsVi); \
2258 MICRO_COMPLEX_ADD_ROWS(1, true)
2260 #define MICRO_COMPLEX_SRC2_PTR \
2261 MICRO_SRC2(ptr_real, strideB* advanceCols, 0) \
2263 MICRO_RHS(ptr_imag, 0) = rhs_base + MICRO_NEW_ROWS * strideB; \
2264 MICRO_SRC2(ptr_imag, strideB* advanceCols, strideB) \
2266 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2267 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2268 EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2271 #define MICRO_COMPLEX_ZERO_PEEL_ROW MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_ZERO_PEEL)
2273 #define MICRO_COMPLEX_WORK_PEEL(peel) \
2274 if (PEEL_COMPLEX_ROW > peel) { \
2275 MICRO_COMPLEX_BROADCAST(peel) \
2276 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2277 &accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), \
2278 lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
2280 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2281 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2284 #define MICRO_COMPLEX_ADD_COLS(size) \
2285 lhs_ptr_real += (remaining_rows * size); \
2287 lhs_ptr_imag += (remaining_rows * size); \
2289 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2291 #define MICRO_COMPLEX_WORK_PEEL_ROW \
2292 Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
2293 Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
2294 MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_WORK_PEEL) \
2295 MICRO_COMPLEX_ADD_COLS(PEEL_COMPLEX_ROW) \
2296 MICRO_COMPLEX_ADD_ROWS(PEEL_COMPLEX_ROW, false)
2298 #define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
2299 if (PEEL_COMPLEX_ROW > peel) { \
2300 for (Index i = 0; i < accRows; i++) { \
2301 accReal##sum.packet[i] += accReal##peel.packet[i]; \
2302 accImag##sum.packet[i] += accImag##peel.packet[i]; \
2306 #define MICRO_COMPLEX_ADD_PEEL_ROW \
2307 MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) MICRO_COMPLEX_ADD_PEEL(1, 0)
2309 template <
typename Scalar,
typename Packet,
const Index accRows,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
2310 bool RhsIsReal,
const Index remaining_rows>
2312 const Scalar*& rhs_ptr_real0,
const Scalar*& rhs_ptr_real1,
2313 const Scalar*& rhs_ptr_real2,
const Scalar*& rhs_ptr_imag0,
2314 const Scalar*& rhs_ptr_imag1,
const Scalar*& rhs_ptr_imag2,
2318 pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real,
2319 lhs_ptr_imag, rhsV, rhsVi);
2323 template <
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
const Index accRows,
2324 const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal,
2325 const Index remaining_rows>
2331 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2332 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2333 const Scalar* lhs_ptr_real = lhs_base +
advanceRows *
row * strideA + remaining_rows * offsetA;
2334 const Scalar* lhs_ptr_imag = NULL;
2336 lhs_ptr_imag = lhs_ptr_real + remaining_rows * strideA;
2346 bsetzero<Packet, accRows>(accReal0);
2347 bsetzero<Packet, accRows>(accImag0);
2363 for (;
k < depth;
k++) {
2364 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(
2365 lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1,
2366 rhs_ptr_imag2, accReal0, accImag0);
2369 constexpr
bool full = (remaining_rows >
accColsC);
2370 bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes,
res,
row, 0);
2371 if ((accRows == 1) || (
rows >= accCols)) {
2372 bscalec<Packet, accRows, true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2373 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2374 bstore<DataMapper, Packetc, accRows>(acc0,
res,
row + 0);
2379 bscalec<Packet, accRows, false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2380 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2382 if ((
sizeof(
Scalar) ==
sizeof(
float)) && (remaining_rows == 1)) {
2383 for (
Index j = 0;
j < accRows;
j++) {
2387 bstore<DataMapper, Packetc, accRows>(acc0,
res,
row + 0);
2389 for (
Index j = 0;
j < accRows;
j++) {
2397 #define MICRO_COMPLEX_EXTRA_ROWS(N) \
2398 gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, \
2399 ConjugateRhs, LhsIsReal, RhsIsReal, N>( \
2400 res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
2402 template <
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
const Index accRows,
2403 const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2411 #define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2412 MICRO_COMPLEX_UNROLL(func2); \
2413 func(0, peel) func(1, peel) func(2, peel) func(3, peel)
2415 #define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
2416 if (unroll_factor > iter) { \
2417 pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2418 &accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2421 #define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
2422 if (PEEL_COMPLEX > peel) { \
2423 Packet lhsV0, lhsV1, lhsV2, lhsV3; \
2424 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
2425 MICRO_COMPLEX_BROADCAST(peel) \
2426 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2428 EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2429 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2432 #define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2433 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
2434 Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
2435 func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3)
2437 #define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2438 Packet rhsV0[M], rhsVi0[M]; \
2439 func(func1, func2, 0)
2441 #define MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_TYPE, size) \
2442 MICRO_COMPLEX_TYPE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE) \
2443 MICRO_COMPLEX_ADD_ROWS(size, false)
2445 #define MICRO_COMPLEX_ONE_PEEL4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_PEEL, PEEL_COMPLEX)
2447 #define MICRO_COMPLEX_ONE4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_ONE, 1)
2449 #define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2450 if (unroll_factor > iter) { \
2451 bsetzero<Packet, accRows>(accReal##iter); \
2452 bsetzero<Packet, accRows>(accImag##iter); \
2454 EIGEN_UNUSED_VARIABLE(accReal##iter); \
2455 EIGEN_UNUSED_VARIABLE(accImag##iter); \
2458 #define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2460 #define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2462 #define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2464 #define MICRO_COMPLEX_STORE_ONE(iter) \
2465 if (unroll_factor > iter) { \
2466 constexpr bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
2467 bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter * accCols, 0); \
2468 bscalec<Packet, accRows, !(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, \
2470 bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
2471 bstore<DataMapper, Packetc, accRows>(acc0, res, row + iter * accCols + 0); \
2473 bstore<DataMapper, Packetc, accRows>(acc1, res, row + iter * accCols + accColsC); \
2477 #define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2479 template <
int unroll_factor,
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
2480 const Index accRows,
const Index accCols,
const Index accCols2,
bool ConjugateLhs,
bool ConjugateRhs,
2481 bool LhsIsReal,
bool RhsIsReal>
2487 const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2488 const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2489 const Index imag_delta = accCols * strideA;
2490 const Index imag_delta2 = accCols2 * strideA;
2491 const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
2492 const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
2509 for (;
k < depth;
k++) {
2517 #define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
2518 gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, accRows, accCols, \
2519 M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2520 res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
2523 template <
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
const Index accRows,
2524 const Index accCols,
bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2529 const DataMapper res3 =
res.getSubMapper(0,
col);
2532 const Scalar* lhs_base = blockA + accCols * offsetA;
2535 #define MAX_COMPLEX_UNROLL 4
2539 switch ((
rows -
row) / accCols) {
2540 #if MAX_COMPLEX_UNROLL > 4
2545 #if MAX_COMPLEX_UNROLL > 3
2550 #if MAX_COMPLEX_UNROLL > 2
2555 #if MAX_COMPLEX_UNROLL > 1
2563 #undef MAX_COMPLEX_UNROLL
2565 if (remaining_rows > 0) {
2567 RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB,
row,
rows,
2568 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2572 #define MICRO_COMPLEX_EXTRA_COLS(N) \
2573 gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
2574 RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, \
2575 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2577 template <
typename Scalar,
typename Packet,
typename Packetc,
typename DataMapper,
const Index accCols,
2578 bool ConjugateLhs,
bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2587 template <
typename LhsScalar,
typename RhsScalar,
typename Scalarc,
typename Scalar,
typename Packet,
typename Packetc,
2588 typename RhsPacket,
typename DataMapper,
const Index accRows,
const Index accCols,
bool ConjugateLhs,
2589 bool ConjugateRhs,
bool LhsIsReal,
bool RhsIsReal>
2593 const Index remaining_rows =
rows % accCols;
2595 if (strideA == -1) strideA = depth;
2596 if (strideB == -1) strideB = depth;
2598 const Packet pAlphaReal = pset1<Packet>(
alpha.real());
2599 const Packet pAlphaImag = pset1<Packet>(
alpha.imag());
2600 const Packet pMask = bmask<Packet>(remaining_rows);
2606 for (;
col + accRows <=
cols;
col += accRows) {
2608 RhsIsReal>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows,
2609 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2614 RhsIsReal>(
res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
col,
rows,
cols,
2615 remaining_rows, pAlphaReal, pAlphaImag, pMask);
2624 #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2626 #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && defined(__BUILTIN_CPU_SUPPORTS__)
2627 return __builtin_cpu_supports(
"arch_3_1") && __builtin_cpu_supports(
"mma");
2635 return pmadd(acc, pAlpha, result_block);
2638 template <
bool lhsExtraRows>
2643 pstoreu(result, result_block);
2648 template <
bool rhsExtraCols,
bool lhsExtraRows>
2655 storeF32<lhsExtraRows>(result, result_block,
rows, extra_rows);
2656 }
while (++
x < extra_cols);
2659 float* result2 = result;
2666 storeF32<lhsExtraRows>(result2, result_block[
x],
rows, extra_rows);
2674 return reinterpret_cast<Packet4f>(vec_mergeh(
data, z));
2676 return reinterpret_cast<Packet4f>(vec_mergeh(z,
data));
2683 return reinterpret_cast<Packet4f>(vec_mergel(
data, z));
2685 return reinterpret_cast<Packet4f>(vec_mergel(z,
data));
2689 template <Index N, Index M>
2693 }
else if (
N >= (
M * 8 + 4)) {
2703 storeConvertTwoBF16<N, 0>(to + 0,
block, extra);
2705 storeConvertTwoBF16<N, 1>(to + 8,
block);
2708 storeConvertTwoBF16<N, 2>(to + 16,
block);
2709 storeConvertTwoBF16<N, 3>(to + 24,
block);
2713 template <
bool non_unit_str
ide, Index delta>
2715 if (non_unit_stride) {
2722 static Packet16uc p16uc_MERGE16_32_1 = {0, 1, 16, 17, 2, 3, 18, 19, 0, 1, 16, 17, 2, 3, 18, 19};
2723 static Packet16uc p16uc_MERGE16_32_2 = {4, 5, 20, 21, 6, 7, 22, 23, 4, 5, 20, 21, 6, 7, 22, 23};
2724 static Packet16uc p16uc_MERGE16_32_3 = {8, 9, 24, 25, 10, 11, 26, 27, 8, 9, 24, 25, 10, 11, 26, 27};
2725 static Packet16uc p16uc_MERGE16_32_4 = {12, 13, 28, 29, 14, 15, 30, 31, 12, 13, 28, 29, 14, 15, 30, 31};
2727 static Packet16uc p16uc_MERGE16_32_5 = {0, 1, 16, 17, 16, 17, 16, 17, 0, 1, 16, 17, 16, 17, 16, 17};
2728 static Packet16uc p16uc_MERGE16_32_6 = {2, 3, 18, 19, 18, 19, 18, 19, 2, 3, 18, 19, 18, 19, 18, 19};
2729 static Packet16uc p16uc_MERGE16_32_7 = {4, 5, 20, 21, 20, 21, 20, 21, 4, 5, 20, 21, 20, 21, 20, 21};
2730 static Packet16uc p16uc_MERGE16_32_8 = {6, 7, 22, 23, 22, 23, 22, 23, 6, 7, 22, 23, 22, 23, 22, 23};
2735 return reinterpret_cast<Packet4f>(vec_perm(
data, z, mask));
2737 return reinterpret_cast<Packet4f>(vec_perm(z,
data, mask));
2741 template <
bool lhsExtraRows,
bool odd, Index size>
2764 }
while (++
i < extra_rows);
2776 template <
bool lhsExtraRows>
2781 for (;
col + 4 * 2 <=
cols;
col += 4 * 2, result += 4 * 4 * 4, src += 4 *
rows) {
2782 convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 4>(result,
rows, src, extra_rows);
2784 for (;
col + 2 <=
cols;
col += 2, result += 4 * 4, src +=
rows) {
2785 convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 1>(result,
rows, src, extra_rows);
2788 convertArrayPointerBF16toF32DupOne<lhsExtraRows, true, 1>(result,
rows, src -
delta, extra_rows);
2792 template <const Index size,
bool non_unit_str
ide>
2797 r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
2799 r32.packet[1] = loadBF16fromResult<non_unit_stride, 8>(src, resInc);
2802 r32.packet[2] = loadBF16fromResult<non_unit_stride, 16>(src, resInc);
2803 r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
2805 storeConvertBlockBF16<size>(result +
i, r32,
rows & 3);
2807 src += extra * resInc;
2808 if (
size != 32)
break;
2812 template <
bool non_unit_str
ide>
2818 convertPointerBF16toF32<32, non_unit_stride>(
i, result,
rows, src2, resInc);
2819 convertPointerBF16toF32<16, non_unit_stride>(
i, result,
rows, src2, resInc);
2820 convertPointerBF16toF32<8, non_unit_stride>(
i, result,
rows, src2, resInc);
2821 convertPointerBF16toF32<4, non_unit_stride>(
i, result,
rows, src2, resInc);
2822 convertPointerBF16toF32<1, non_unit_stride>(
i, result,
rows, src2, resInc);
2826 template <Index num_acc, Index size = 4>
2830 for (
Index k = 0;
k < num_acc;
k++) {
2837 template <Index num_acc>
2839 for (
Index i = 0;
i < num_acc;
i++) {
2841 t0 = vec_mergeh(
reinterpret_cast<Packet4ui>(acc[
i][0]),
reinterpret_cast<Packet4ui>(acc[
i][2]));
2842 t1 = vec_mergel(
reinterpret_cast<Packet4ui>(acc[
i][0]),
reinterpret_cast<Packet4ui>(acc[
i][2]));
2843 t2 = vec_mergeh(
reinterpret_cast<Packet4ui>(acc[
i][1]),
reinterpret_cast<Packet4ui>(acc[
i][3]));
2844 t3 = vec_mergel(
reinterpret_cast<Packet4ui>(acc[
i][1]),
reinterpret_cast<Packet4ui>(acc[
i][3]));
2845 acc[
i][0] =
reinterpret_cast<Packet4f>(vec_mergeh(t0, t2));
2846 acc[
i][1] =
reinterpret_cast<Packet4f>(vec_mergel(t0, t2));
2847 acc[
i][2] =
reinterpret_cast<Packet4f>(vec_mergeh(t1, t3));
2848 acc[
i][3] =
reinterpret_cast<Packet4f>(vec_mergel(t1, t3));
2852 template <Index num_acc>
2854 for (
Index i = 0,
j = 0;
j < num_acc;
i++,
j += 2) {
2855 for (
Index x = 0,
y = 0;
x < 2;
x++,
y += 2) {
2856 for (
Index w = 0, z = 0;
w < 2;
w++, z += 2) {
2857 acc[
i][
y +
w] = acc[
j +
x][z + 0] + acc[
j +
x][z + 1];
2863 template <Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows, Index num_rhs>
2866 tranposeResults<num_acc>(acc);
2867 addResults<num_acc>(acc);
2869 constexpr
Index real_rhs = ((num_rhs / 2) - (rhsExtraCols ? 1 : 0));
2871 for (
Index i = 0;
i < real_rhs;
i++, result += 4 *
rows,
k++) {
2872 storeResults<false, lhsExtraRows>(acc[
k],
rows, pAlpha, result, extra_cols, extra_rows);
2875 storeResults<rhsExtraCols, lhsExtraRows>(acc[
k],
rows, pAlpha, result, extra_cols, extra_rows);
2879 template <
bool zero>
2884 dhs1 = vec_mergel(dhs0, dhs2);
2885 dhs0 = vec_mergeh(dhs0, dhs2);
2891 template <Index num_acc,
bool zero,
bool rhsExtraCols, Index num_rhs>
2894 constexpr
Index num_lhs = 4;
2895 Packet4f lhs[num_lhs], rhs[num_rhs];
2897 constexpr
Index real_rhs = (num_rhs - (rhsExtraCols ? 2 : 0));
2898 for (
Index i = 0;
i < real_rhs;
i += 2) {
2899 loadTwoRhsFloat32<zero>(indexB +
k * 4, strideB,
i, rhs[
i + 0], rhs[
i + 1]);
2902 loadTwoRhsFloat32<zero>(indexB +
k * extra_cols - offsetB, strideB, real_rhs, rhs[real_rhs + 0], rhs[real_rhs + 1]);
2905 indexA += 2 *
k * 4;
2906 for (
Index j = 0;
j < num_lhs;
j++) {
2910 for (
Index j = 0;
j < num_rhs;
j++) {
2911 for (
Index i = 0;
i < num_lhs;
i++) {
2917 template <const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
2919 const float* indexB,
Index strideB,
Index offsetB,
float* result,
2920 const Index extra_cols,
const Index extra_rows) {
2921 constexpr
Index num_rhs = num_acc;
2925 zeroAccumulators<num_acc>(acc);
2928 for (
k = 0;
k + 2 <= depth;
k += 2) {
2929 KLoop<num_acc, false, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB,
k, offsetB, extra_cols);
2932 KLoop<num_acc, true, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB,
k, offsetB, extra_cols);
2935 outputResultsVSX<num_acc, rhsExtraCols, lhsExtraRows, num_rhs>(acc,
rows, pAlpha, result, extra_cols, extra_rows);
2939 #define MAX_BFLOAT16_ACC_VSX 4
2941 template <const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
2943 const float* indexB,
Index strideB,
Index offsetB,
float* result) {
2944 constexpr
Index step = (num_acc * 4);
2945 const Index extra_cols = (rhsExtraCols) ? (
cols & 3) : 0;
2946 const Index extra_rows = (lhsExtraRows) ? (
rows & 3) : 0;
2950 colVSXLoopBodyIter<num_acc * 2, rhsExtraCols, lhsExtraRows>(depth,
rows, pAlpha, indexA, indexB, strideB, offsetB,
2951 result, extra_cols, extra_rows);
2953 indexB += strideB * (num_acc * 2);
2954 result +=
rows * step;
2955 }
while (multiIters && (step <=
cols - (
col += step)));
2958 template <const Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows>
2960 const float* indexA,
const float* blockB,
Index strideB,
Index offsetB,
2963 colVSXLoopBody<num_acc + (rhsExtraCols ? 1 : 0), rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA,
2964 blockB, strideB, offsetB, result);
2968 template <
bool rhsExtraCols,
bool lhsExtraRows>
2970 const float* blockB,
Index strideB,
Index offsetB,
float* result) {
2973 colVSXLoopBodyExtraN<3, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB,
2977 colVSXLoopBodyExtraN<2, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB,
2981 colVSXLoopBodyExtraN<1, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB,
2986 colVSXLoopBody<1, true, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
2992 template <Index size,
bool lhsExtraRows = false>
2994 const float* indexA2,
const float* blockB2,
Index strideA,
Index strideB,
2995 Index offsetB,
float* result2) {
2998 convertArrayPointerBF16toF32Dup<lhsExtraRows>(
const_cast<float*
>(indexA2), strideA, delta_rows, indexA,
row,
3001 const float* blockB = blockB2;
3002 float* result = result2 +
row;
3006 colVSXLoopBody<MAX_BFLOAT16_ACC_VSX, false, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA2, blockB,
3007 strideB, 0, result);
3008 blockB += (strideB >> 1) *
col;
3012 colVSXLoopBodyExtra<true, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA2, blockB, strideB, offsetB,
3015 colVSXLoopBodyExtra<false, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA2, blockB, strideB, 0, result);
3020 template <Index size>
3026 indexA +=
size * offsetA;
3027 colVSXLoops<size>(depth,
cols,
rows, pAlpha, indexA, indexA2, indexB, strideA, strideB, offsetB, result +
row);
3029 indexA += bigSuffix *
size / 16;
3033 template <const Index size,
typename DataMapper>
3038 r32.packet[0] = src.template loadPacket<Packet8bf>(
i + 0);
3040 r32.packet[1] = src.template loadPacket<Packet8bf>(
i + 8);
3043 r32.packet[2] = src.template loadPacket<Packet8bf>(
i + 16);
3044 r32.packet[3] = src.template loadPacket<Packet8bf>(
i + 24);
3046 storeConvertBlockBF16<size>(result +
i, r32,
rows & 3);
3048 if (
size != 32)
break;
3052 template <
typename DataMapper>
3054 typedef typename DataMapper::LinearMapper LinearMapper;
3056 const LinearMapper src2 = src.getLinearMapper(0,
j);
3058 convertBF16toF32<32, LinearMapper>(
i, result,
rows, src2);
3059 convertBF16toF32<16, LinearMapper>(
i, result,
rows, src2);
3060 convertBF16toF32<8, LinearMapper>(
i, result,
rows, src2);
3061 convertBF16toF32<4, LinearMapper>(
i, result,
rows, src2);
3062 convertBF16toF32<1, LinearMapper>(
i, result,
rows, src2);
3070 template <
typename DataMapper, const Index size>
3072 const DataMapper res2 =
res.getSubMapper(0,
col);
3074 float* result2 = result +
col *
rows;
3081 res2.template storePacketBlock<Packet8bf, size>(
row, 0,
block);
3087 res2.template storePacketPartial<Packet8bf>(
row,
j, fp16,
rows & 7);
3092 template <
typename DataMapper>
3096 convertArrayF32toBF16ColVSX<DataMapper, 4>(result,
col,
rows,
res);
3101 convertArrayF32toBF16ColVSX<DataMapper, 1>(result,
col,
rows,
res);
3104 convertArrayF32toBF16ColVSX<DataMapper, 2>(result,
col,
rows,
res);
3107 convertArrayF32toBF16ColVSX<DataMapper, 3>(result,
col,
rows,
res);
3112 template <
typename DataMapper>
3118 if (strideA == -1) strideA = depth;
3119 if (strideB == -1) strideB = depth;
3125 convertArrayBF16toF32<DataMapper>(result,
cols,
rows,
res);
3128 Index bigSuffix = 2 * 8 * (strideA - offsetA);
3129 float* indexBF32 = indexB2 + 4 * offsetB;
3136 calcVSXColLoops<16>(indexA, indexA2,
row, depth,
cols,
rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3140 calcVSXColLoops<8>(indexA, indexA2,
row, depth,
cols,
rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3143 calcVSXColLoops<4>(indexA, indexA2,
row, depth,
cols,
rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3148 colVSXLoops<4, true>(depth,
cols,
rows, pAlpha, indexA, indexA2, indexBF32, strideA, strideB, offsetB,
3153 convertArrayF32toBF16VSX<DataMapper>(result,
cols,
rows,
res);
3156 #undef MAX_BFLOAT16_ACC_VSX
3163 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3165 void operator()(
double* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride = 0,
Index offset = 0);
3168 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3172 pack(blockA, lhs, depth,
rows, stride, offset);
3175 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3177 void operator()(
double* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride = 0,
Index offset = 0);
3180 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3184 pack(blockA, lhs, depth,
rows, stride, offset);
3187 #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3188 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3190 void operator()(
double* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride = 0,
Index offset = 0);
3193 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3197 pack(blockB, rhs, depth,
cols, stride, offset);
3200 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3202 void operator()(
double* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride = 0,
Index offset = 0);
3205 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3209 pack(blockB, rhs, depth,
cols, stride, offset);
3212 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3217 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3221 pack(blockB, rhs, depth,
cols, stride, offset);
3224 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3229 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3233 pack(blockB, rhs, depth,
cols, stride, offset);
3237 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3242 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3246 pack(blockA, lhs, depth,
rows, stride, offset);
3249 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3254 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3258 pack(blockA, lhs, depth,
rows, stride, offset);
3261 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3263 void operator()(
float* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride = 0,
Index offset = 0);
3266 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3270 pack(blockA, lhs, depth,
rows, stride, offset);
3273 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3275 void operator()(
float* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride = 0,
Index offset = 0);
3278 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3282 pack(blockA, lhs, depth,
rows, stride, offset);
3285 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3287 void operator()(std::complex<float>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride = 0,
3291 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3293 PanelMode>::operator()(std::complex<float>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
3296 pack(blockA, lhs, depth,
rows, stride, offset);
3299 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3301 void operator()(std::complex<float>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride = 0,
3305 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3307 PanelMode>::operator()(std::complex<float>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
3310 pack(blockA, lhs, depth,
rows, stride, offset);
3313 #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3314 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3316 void operator()(
float* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride = 0,
Index offset = 0);
3319 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3323 pack(blockB, rhs, depth,
cols, stride, offset);
3326 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3328 void operator()(
float* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride = 0,
Index offset = 0);
3331 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3335 pack(blockB, rhs, depth,
cols, stride, offset);
3339 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3341 void operator()(std::complex<float>* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride = 0,
3345 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3349 pack(blockB, rhs, depth,
cols, stride, offset);
3352 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3354 void operator()(std::complex<float>* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride = 0,
3358 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3362 pack(blockB, rhs, depth,
cols, stride, offset);
3365 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3367 void operator()(std::complex<double>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride = 0,
3371 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3373 PanelMode>::operator()(std::complex<double>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
3376 pack(blockA, lhs, depth,
rows, stride, offset);
3379 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3381 void operator()(std::complex<double>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
Index stride = 0,
3385 template <
typename Index,
typename DataMapper,
int Pack1,
int Pack2,
typename Packet,
bool Conjugate,
bool PanelMode>
3387 PanelMode>::operator()(std::complex<double>* blockA,
const DataMapper& lhs,
Index depth,
Index rows,
3390 pack(blockA, lhs, depth,
rows, stride, offset);
3393 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3395 void operator()(std::complex<double>* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride = 0,
3399 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3403 pack(blockB, rhs, depth,
cols, stride, offset);
3406 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3408 void operator()(std::complex<double>* blockB,
const DataMapper& rhs,
Index depth,
Index cols,
Index stride = 0,
3412 template <
typename Index,
typename DataMapper,
int nr,
bool Conjugate,
bool PanelMode>
3416 pack(blockB, rhs, depth,
cols, stride, offset);
3420 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3429 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3437 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3438 (
supportsMMA()) ? &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols> :
3440 &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
3441 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
3444 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3450 void operator()(
const DataMapper&
res,
const std::complex<float>* blockA,
const std::complex<float>* blockB,
3455 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3457 ConjugateRhs>::operator()(
const DataMapper&
res,
const std::complex<float>* blockA,
3463 static void (*gemm_function)(
const DataMapper&,
const std::complex<float>*,
const std::complex<float>*,
Index,
Index,
3465 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3468 accCols, ConjugateLhs, ConjugateRhs,
false,
false>
3473 ConjugateLhs, ConjugateRhs,
false,
false>;
3474 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
3477 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3483 void operator()(
const DataMapper&
res,
const float* blockA,
const std::complex<float>* blockB,
Index rows,
3488 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3494 static void (*gemm_function)(
const DataMapper&,
const float*,
const std::complex<float>*,
Index,
Index,
Index,
3496 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3499 ConjugateLhs, ConjugateRhs,
true,
false>
3504 ConjugateRhs,
true,
false>;
3505 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
3508 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3514 void operator()(
const DataMapper&
res,
const std::complex<float>* blockA,
const float* blockB,
Index rows,
3519 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3525 static void (*gemm_function)(
const DataMapper&,
const std::complex<float>*,
const float*,
Index,
Index,
Index,
3527 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3530 ConjugateLhs, ConjugateRhs,
false,
true>
3535 ConjugateRhs,
false,
true>;
3536 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
3539 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3549 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3557 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3558 (
supportsMMA()) ? &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols> :
3560 &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
3561 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
3564 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3570 void operator()(
const DataMapper&
res,
const std::complex<double>* blockA,
const std::complex<double>* blockB,
3575 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3577 ConjugateRhs>::operator()(
const DataMapper&
res,
const std::complex<double>* blockA,
3583 static void (*gemm_function)(
const DataMapper&,
const std::complex<double>*,
const std::complex<double>*,
Index,
3585 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3589 ConjugateRhs,
false,
false>
3594 ConjugateRhs,
false,
false>;
3595 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
3598 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3604 void operator()(
const DataMapper&
res,
const std::complex<double>* blockA,
const double* blockB,
Index rows,
3609 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3611 const DataMapper&
res,
const std::complex<double>* blockA,
const double* blockB,
Index rows,
Index depth,
3615 static void (*gemm_function)(
const DataMapper&,
const std::complex<double>*,
const double*,
Index,
Index,
Index,
3617 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3620 ConjugateLhs, ConjugateRhs,
false,
true>
3625 ConjugateRhs,
false,
true>;
3626 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
3629 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3635 void operator()(
const DataMapper&
res,
const double* blockA,
const std::complex<double>* blockB,
Index rows,
3640 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3642 const DataMapper&
res,
const double* blockA,
const std::complex<double>* blockB,
Index rows,
Index depth,
3646 static void (*gemm_function)(
const DataMapper&,
const double*,
const std::complex<double>*,
Index,
Index,
Index,
3648 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3651 ConjugateLhs, ConjugateRhs,
true,
false>
3656 ConjugateRhs,
true,
false>;
3657 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
3660 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3670 template <
typename Index,
typename DataMapper,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
3676 #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3677 (
supportsMMA()) ? &Eigen::internal::gemmMMAbfloat16<DataMapper> :
3679 &Eigen::internal::gemmbfloat16<DataMapper>;
3680 gemm_function(
res, blockA, blockB,
rows, depth,
cols,
alpha, strideA, strideB, offsetA, offsetB);
Array< int, Dynamic, 1 > v
Definition: Array_initializer_list_vector_cxx11.cpp:1
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
int data[]
Definition: Map_placement_new.cpp:1
#define MICRO_COMPLEX_UNROLL_ITER(func, N)
Definition: MatrixProductCommon.h:142
#define MICRO_UPDATE
Definition: MatrixProductCommon.h:191
#define MICRO_COMPLEX_UPDATE
Definition: MatrixProductCommon.h:198
#define EIGEN_POWER_PREFETCH(p)
Definition: MatrixProductCommon.h:5
#define MICRO_UNROLL_ITER(func, N)
Definition: MatrixProductCommon.h:139
#define MICRO_COMPLEX_EXTRA_COLS(N)
Definition: MatrixProduct.h:2572
#define MICRO_COMPLEX_DST_PTR
Definition: MatrixProduct.h:2458
#define MICRO_COMPLEX_PREFETCHN(N)
Definition: MatrixProduct.h:1896
#define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col)
Definition: MatrixProduct.h:1960
#define advanceCols
Definition: MatrixProduct.h:2215
#define MICRO_COMPLEX_SRC2_PTR
Definition: MatrixProduct.h:2260
#define MICRO_PREFETCHN(N)
Definition: MatrixProduct.h:1894
#define MICRO_NEW_ROWS
Definition: MatrixProduct.h:1797
#define PEEL_COMPLEX_ROW
Definition: MatrixProduct.h:2219
#define MICRO_ONE_PEEL4
Definition: MatrixProduct.h:2019
#define PEEL_ROW
Definition: MatrixProduct.h:1791
#define MICRO_STORE
Definition: MatrixProduct.h:2058
#define MICRO_WORK_PEEL_ROW
Definition: MatrixProduct.h:1867
#define accColsC
Definition: MatrixProduct.h:2213
#define MICRO_DST_PTR
Definition: MatrixProduct.h:2030
#define advanceRows
Definition: MatrixProduct.h:2214
#define MICRO_EXTRA_ROWS(N)
Definition: MatrixProduct.h:1977
#define MICRO_EXTRA_COLS(N)
Definition: MatrixProduct.h:2173
#define MICRO_COMPLEX_ONE_PEEL4
Definition: MatrixProduct.h:2445
#define MICRO_COMPLEX_PREFETCH
Definition: MatrixProduct.h:2462
#define MICRO_COMPLEX_EXTRA_ROWS(N)
Definition: MatrixProduct.h:2397
#define PEEL_COMPLEX
Definition: MatrixProduct.h:2218
#define MICRO_COMPLEX_BROADCAST_EXTRA
Definition: MatrixProduct.h:2250
#define MICRO_COMPLEX_WORK_PEEL_ROW
Definition: MatrixProduct.h:2291
#define MICRO_ADD_PEEL_ROW
Definition: MatrixProduct.h:1880
#define MICRO_ZERO_PEEL_ROW
Definition: MatrixProduct.h:1857
#define MICRO_COMPLEX_ZERO_PEEL_ROW
Definition: MatrixProduct.h:2271
#define PEEL
Definition: MatrixProduct.h:1790
#define MICRO_SRC2_PTR
Definition: MatrixProduct.h:1855
#define MICRO_COMPLEX_ONE4
Definition: MatrixProduct.h:2447
#define MICRO_COMPLEX_ADD_PEEL_ROW
Definition: MatrixProduct.h:2306
#define MICRO_PREFETCH
Definition: MatrixProduct.h:2034
#define MICRO_ONE4
Definition: MatrixProduct.h:2021
#define MICRO_UNROLL_ITER2(N, M)
Definition: MatrixProduct.h:2106
#define MAX_BFLOAT16_ACC_VSX
Definition: MatrixProduct.h:2939
#define MICRO_COMPLEX_UNROLL_ITER2(N, M)
Definition: MatrixProduct.h:2517
#define MICRO_COMPLEX_STORE
Definition: MatrixProduct.h:2477
#define MICRO_SRC_PTR
Definition: MatrixProduct.h:2032
#define MICRO_COMPLEX_SRC_PTR
Definition: MatrixProduct.h:2460
#define MAX_COMPLEX_UNROLL
#define MICRO_BROADCAST_EXTRA
Definition: MatrixProduct.h:1836
#define MICRO_COMPLEX_ADD_COLS(size)
Definition: MatrixProduct.h:2284
RowVector3d w
Definition: Matrix_resize_int.cpp:3
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:806
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
Map< RowVectorXf > v2(M2.data(), M2.size())
M1<< 1, 2, 3, 4, 5, 6, 7, 8, 9;Map< RowVectorXf > v1(M1.data(), M1.size())
m m block(1, 0, 2, 2)<< 4
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
Matrix< RealScalar, Dynamic, Dynamic > M
Definition: bench_gemm.cpp:50
internal::packet_traits< Scalar >::type Packet
Definition: benchmark-blocking-sizes.cpp:54
Definition: ForwardDeclarations.h:102
The matrix class, also used for vectors and row-vectors.
Definition: Eigen/Eigen/src/Core/Matrix.h:186
EIGEN_STRONG_INLINE PacketScalar packet(Index rowId, Index colId) const
Definition: PlainObjectBase.h:247
Definition: BlasUtil.h:443
std::complex< RealScalar > Complex
Definition: common.h:71
@ N
Definition: constructor.cpp:22
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
Eigen::DenseIndex ret
Definition: level1_cplx_impl.h:43
RealScalar alpha
Definition: level1_cplx_impl.h:151
char char char int int * k
Definition: level2_impl.h:374
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h)
Definition: BFloat16.h:581
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data)
Definition: MatrixProduct.h:2680
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask)
Definition: MatrixProduct.h:2732
EIGEN_STRONG_INLINE void gemm_complex(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProduct.h:2590
EIGEN_ALWAYS_INLINE void colVSXLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const float *indexA2, const float *blockB2, Index strideA, Index strideB, Index offsetB, float *result2)
Definition: MatrixProduct.h:2993
__m128d Packet2d
Definition: LSX/PacketMath.h:36
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet8bf pgather< bfloat16, Packet8bf >(const bfloat16 *from, Index stride)
Definition: AltiVec/PacketMath.h:874
EIGEN_ALWAYS_INLINE void storeResults(Packet4f(&acc)[4], Index rows, const Packet4f pAlpha, float *result, Index extra_cols, Index extra_rows)
Definition: MatrixProduct.h:2649
EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows)
Definition: MatrixProduct.h:1660
__vector int Packet4i
Definition: AltiVec/PacketMath.h:34
EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2404
EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2326
EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock< Packet, N > *accReal, PacketBlock< Packet, N > *accImag, const Packet &lhsV, Packet &lhsVi, const Packet *rhsV, const Packet *rhsVi)
Definition: MatrixProduct.h:1514
EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex< Scalar > *blockB, const std::complex< Scalar > *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:134
EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16 *src, Index resInc)
Definition: MatrixProduct.h:2714
EIGEN_STRONG_INLINE void gemm(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProduct.h:2190
__vector unsigned char Packet16uc
Definition: AltiVec/PacketMath.h:41
void colVSXLoopBody(Index &col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *indexB, Index strideB, Index offsetB, float *result)
Definition: MatrixProduct.h:2942
EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex< float > &from0, const std::complex< float > &from1)
Definition: AltiVec/Complex.h:185
void gemmbfloat16(const DataMapper &res, const bfloat16 *indexA, const bfloat16 *indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProduct.h:3113
EIGEN_ALWAYS_INLINE void pgerc(PacketBlock< Packet, N > *accReal, PacketBlock< Packet, N > *accImag, const Scalar *lhs_ptr, const Scalar *lhs_ptr_imag, const Packet *rhsV, const Packet *rhsVi)
Definition: MatrixProduct.h:1532
EIGEN_ALWAYS_INLINE void storeF32(float *&result, Packet4f result_block, Index rows, Index extra_rows)
Definition: MatrixProduct.h:2639
EIGEN_ALWAYS_INLINE void outputResultsVSX(Packet4f(&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float *result, const Index extra_cols, Index extra_rows)
Definition: MatrixProduct.h:2864
EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock< Packet, N > &acc)
Definition: MatrixProduct.h:1551
EIGEN_STRONG_INLINE void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Definition: AltiVec/Complex.h:339
static const Packet16uc p16uc_GETIMAG32b
Definition: MatrixProduct.h:96
static Packet16uc p16uc_MERGE16_32_8
Definition: MatrixProduct.h:2730
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2482
EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data)
Definition: MatrixProduct.h:2671
const Scalar & y
Definition: RandomImpl.h:36
EIGEN_ALWAYS_INLINE void pstore_partial< bfloat16 >(bfloat16 *to, const Packet8bf &from, const Index n, const Index offset)
Definition: AltiVec/PacketMath.h:737
EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock< Packet, N > &acc, PacketBlock< Packet, N > &accZ, const Packet &pAlpha)
Definition: MatrixProduct.h:1558
static Packet16uc p16uc_TRANSPOSE64_LO
Definition: AltiVec/PacketMath.h:145
EIGEN_ALWAYS_INLINE std::complex< Scalar > getAdjointVal(Index i, Index j, const_blas_data_mapper< std::complex< Scalar >, Index, StorageOrder > &dt)
Definition: MatrixProduct.h:117
EIGEN_ALWAYS_INLINE void tranposeResults(Packet4f(&acc)[num_acc][4])
Definition: MatrixProduct.h:2838
EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar *blockB, const Scalar *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:223
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
Definition: MatrixProduct.h:2813
__vector unsigned short int Packet8us
Definition: AltiVec/PacketMath.h:38
EIGEN_ALWAYS_INLINE void colVSXLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProduct.h:2959
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(const Scalar *&lhs_ptr, const Scalar *&rhs_ptr0, const Scalar *&rhs_ptr1, const Scalar *&rhs_ptr2, PacketBlock< Packet, accRows > &accZero)
Definition: MatrixProduct.h:1903
EIGEN_STRONG_INLINE void pstore< bfloat16 >(bfloat16 *to, const Packet8bf &from)
Definition: AltiVec/PacketMath.h:662
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index &i, float *result, Index rows, const DataMapper &src)
Definition: MatrixProduct.h:3034
EIGEN_ALWAYS_INLINE void bstore(PacketBlock< Packet, N > &acc, const DataMapper &res, Index row)
Definition: MatrixProduct.h:1621
EIGEN_ALWAYS_INLINE void pbroadcastN(const __UNPACK_TYPE__(Packet) *ap0, const __UNPACK_TYPE__(Packet) *ap1, const __UNPACK_TYPE__(Packet) *ap2, Packet &a0, Packet &a1, Packet &a2, Packet &a3)
Definition: MatrixProduct.h:1708
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock< Packet, N > &aReal, PacketBlock< Packet, N > &aImag, const Packet &bReal, const Packet &bImag, PacketBlock< Packet, N > &cReal, PacketBlock< Packet, N > &cImag, const Packet &pMask)
Definition: MatrixProduct.h:1574
EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(const Scalar *&lhs_ptr_real, const Scalar *&lhs_ptr_imag, const Scalar *&rhs_ptr_real0, const Scalar *&rhs_ptr_real1, const Scalar *&rhs_ptr_real2, const Scalar *&rhs_ptr_imag0, const Scalar *&rhs_ptr_imag1, const Scalar *&rhs_ptr_imag2, PacketBlock< Packet, accRows > &accReal, PacketBlock< Packet, accRows > &accImag)
Definition: MatrixProduct.h:2311
EIGEN_ALWAYS_INLINE void gemm_extra_row(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:1982
EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index cols, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2579
void gemm_complexMMA(const DataMapper &res, const LhsScalar *blockAc, const RhsScalar *blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProductMMA.h:859
EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar *blockA, const Scalar *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:255
__vector unsigned int Packet4ui
Definition: AltiVec/PacketMath.h:35
EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf &a)
Definition: AltiVec/Complex.h:303
EIGEN_ALWAYS_INLINE bool supportsMMA()
Definition: MatrixProduct.h:2623
EIGEN_ALWAYS_INLINE void calcVSXColLoops(const bfloat16 *&indexA, const float *indexA2, Index &row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexB, Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
Definition: MatrixProduct.h:3021
EIGEN_STRONG_INLINE void pstore< double >(double *to, const Packet4d &from)
Definition: AVX/PacketMath.h:1611
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float *result, Index rows, const bfloat16 *src, Index extra_rows)
Definition: MatrixProduct.h:2742
static const Packet16uc p16uc_GETIMAG32
Definition: MatrixProduct.h:92
static const Packet16uc p16uc_GETREAL32
Definition: MatrixProduct.h:90
EIGEN_STRONG_INLINE Packet2d pload< Packet2d >(const double *from)
Definition: LSX/PacketMath.h:1407
static Packet16uc p16uc_MERGE16_32_2
Definition: MatrixProduct.h:2723
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Definition: AltiVec/PacketMath.h:42
EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float *result)
Definition: MatrixProduct.h:2633
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet4f, 4, false >(const float *ap0, const float *ap1, const float *ap2, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
Definition: MatrixProduct.h:1741
EIGEN_STRONG_INLINE Packet8bf F32ToBf16Both(Packet4f lo, Packet4f hi)
Definition: AltiVec/PacketMath.h:2237
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet4f, 4, true >(const float *ap0, const float *, const float *, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
Definition: MatrixProduct.h:1735
EIGEN_DEVICE_FUNC void pstoreu_partial(Scalar *to, const Packet &from, const Index n, const Index offset=0)
Definition: GenericPacketMath.h:917
static Packet16uc p16uc_MERGE16_32_3
Definition: MatrixProduct.h:2724
EIGEN_ALWAYS_INLINE void bscale(PacketBlock< Packet, N > &acc, PacketBlock< Packet, N > &accZ, const Packet &pAlpha)
Definition: MatrixProduct.h:1688
EIGEN_ALWAYS_INLINE void storeBlock(Scalar *to, PacketBlock< Packet, N > &block)
Definition: MatrixProduct.h:366
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float *result, Index cols, Index rows, const bfloat16 *src, Index delta, Index extra_rows)
Definition: MatrixProduct.h:2777
EIGEN_STRONG_INLINE Packet4f pset1< Packet4f >(const float &from)
Definition: AltiVec/PacketMath.h:773
EIGEN_ALWAYS_INLINE void pbroadcastN< Packet2d, 4, false >(const double *ap0, const double *, const double *, Packet2d &a0, Packet2d &a1, Packet2d &a2, Packet2d &a3)
Definition: MatrixProduct.h:1747
EIGEN_ALWAYS_INLINE void band(PacketBlock< Packet, N > &acc, const Packet &pMask)
Definition: MatrixProduct.h:1566
EIGEN_ALWAYS_INLINE void addResults(Packet4f(&acc)[num_acc][4])
Definition: MatrixProduct.h:2853
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f(&acc)[num_acc][size])
Definition: MatrixProduct.h:2827
EIGEN_ALWAYS_INLINE void gemm_complex_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlphaReal, const Packet &pAlphaImag, const Packet &pMask)
Definition: MatrixProduct.h:2525
EIGEN_ALWAYS_INLINE Packet2d bmask< Packet2d >(const Index remaining_rows)
Definition: MatrixProduct.h:1673
EIGEN_ALWAYS_INLINE void bload(PacketBlock< Packet, N *(Complex ? 2 :1)> &acc, const DataMapper &res, Index row, Index col)
Definition: MatrixProduct.h:1597
EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex< Scalar > *blockA, const std::complex< Scalar > *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:178
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet) *lhs)
Definition: MatrixProduct.h:1545
void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float *indexA, const float *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProduct.h:2969
static Packet16uc p16uc_MERGE16_32_4
Definition: MatrixProduct.h:2725
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:911
EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet &a)
Definition: AltiVec/PacketMath.h:1876
EIGEN_STRONG_INLINE Packet4f ploadu< Packet4f >(const float *from)
Definition: AltiVec/PacketMath.h:1533
EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock< Packet, N > &taccReal, PacketBlock< Packet, N > &taccImag, PacketBlock< Packetc, N > &acc1, PacketBlock< Packetc, N > &acc2)
Definition: MatrixProduct.h:1759
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float *res)
Definition: MatrixProduct.h:3066
static Packet16uc p16uc_MERGE16_32_6
Definition: MatrixProduct.h:2728
EIGEN_ALWAYS_INLINE void storeConvertTwoBF16(float *to, PacketBlock< Packet8bf,(N+7)/8 > &block, Index extra=0)
Definition: MatrixProduct.h:2690
EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index &row, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:2067
EIGEN_STRONG_INLINE Packet8us pset1< Packet8us >(const unsigned short int &from)
Definition: AltiVec/PacketMath.h:788
EIGEN_ALWAYS_INLINE void loadTwoRhsFloat32(const float *block, Index strideB, Index i, Packet4f &dhs0, Packet4f &dhs1)
Definition: MatrixProduct.h:2880
EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float *to, PacketBlock< Packet8bf,(N+7)/8 > &block, Index extra)
Definition: MatrixProduct.h:2702
static Packet16uc p16uc_MERGE16_32_7
Definition: MatrixProduct.h:2729
static Packet16uc p16uc_MERGE16_32_5
Definition: MatrixProduct.h:2727
EIGEN_STRONG_INLINE void pbroadcast4< Packet4f >(const float *a, Packet4f &a0, Packet4f &a1, Packet4f &a2, Packet4f &a3)
Definition: AltiVec/PacketMath.h:823
EIGEN_ALWAYS_INLINE void pger_common(PacketBlock< Packet, N > *acc, const Packet &lhsV, const Packet *rhsV)
Definition: MatrixProduct.h:1492
EIGEN_ALWAYS_INLINE void gemm_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:2113
EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(const DataMapper &res, const Scalar *lhs_base, const Scalar *rhs_base, Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:1912
EIGEN_ALWAYS_INLINE void bcouple(PacketBlock< Packet, N > &taccReal, PacketBlock< Packet, N > &taccImag, PacketBlock< Packetc, N *2 > &tRes, PacketBlock< Packetc, N > &acc1, PacketBlock< Packetc, N > &acc2)
Definition: MatrixProduct.h:1773
__vector float Packet4f
Definition: AltiVec/PacketMath.h:33
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float *result, Index cols, Index rows, const DataMapper &res)
Definition: MatrixProduct.h:3093
EIGEN_ALWAYS_INLINE void KLoop(const float *indexA, const float *indexB, Packet4f(&acc)[num_acc][4], Index strideB, Index k, Index offsetB, Index extra_cols)
Definition: MatrixProduct.h:2892
static Packet16uc p16uc_MERGE16_32_1
Definition: MatrixProduct.h:2722
EIGEN_ALWAYS_INLINE void pger(PacketBlock< Packet, N > *acc, const Scalar *lhs, const Packet *rhsV)
Definition: MatrixProduct.h:1505
static const Packet16uc p16uc_GETREAL32b
Definition: MatrixProduct.h:94
EIGEN_STRONG_INLINE Packet8bf ploadu< Packet8bf >(const bfloat16 *from)
Definition: AltiVec/PacketMath.h:1549
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index &i, float *result, Index rows, bfloat16 *&src, Index resInc)
Definition: MatrixProduct.h:2793
static const Packet4i mask4[4]
Definition: MatrixProduct.h:1656
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float *result, Index col, Index rows, const DataMapper &res)
Definition: MatrixProduct.h:3071
EIGEN_ALWAYS_INLINE void gemm_extra_cols(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows, Index cols, Index remaining_rows, const Packet &pAlpha, const Packet &pMask)
Definition: MatrixProduct.h:2178
EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper &src)
Definition: MatrixProduct.h:3053
EIGEN_ALWAYS_INLINE void colVSXLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const float *indexA, const float *indexB, Index strideB, Index offsetB, float *result, const Index extra_cols, const Index extra_rows)
Definition: MatrixProduct.h:2918
static Packet16uc p16uc_TRANSPOSE64_HI
Definition: AltiVec/PacketMath.h:143
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
const AutoDiffScalar< DerType > & real(const AutoDiffScalar< DerType > &x)
Definition: AutoDiffScalar.h:486
DerType::Scalar imag(const AutoDiffScalar< DerType > &)
Definition: AutoDiffScalar.h:490
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
int delta
Definition: MultiOpt.py:96
void transpose()
Definition: skew_symmetric_matrix3.cpp:135
Definition: Eigen_Colamd.h:49
list x
Definition: plotDoE.py:28
Definition: BFloat16.h:101
Definition: LSX/Complex.h:260
Definition: AltiVec/Complex.h:38
Definition: GenericPacketMath.h:1407
Packet packet[N]
Definition: GenericPacketMath.h:1408
EIGEN_STRONG_INLINE void operator()(std::complex< double > *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset)
Definition: MatrixProduct.h:1444
EIGEN_ALWAYS_INLINE void dhs_ccopy(double *blockBt, const DataMapper &rhs2, Index &i, Index &rir, Index &rii, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:1417
EIGEN_STRONG_INLINE void operator()(std::complex< double > *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:1352
EIGEN_ALWAYS_INLINE void dhs_ccopy(double *blockAt, const DataMapper &lhs2, Index &i, Index &rir, Index &rii, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:1307
Definition: MatrixProduct.h:381
EIGEN_STRONG_INLINE void operator()(std::complex< Scalar > *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:455
EIGEN_ALWAYS_INLINE void dhs_ccopy(Scalar *blockAt, const DataMapper &lhs2, Index &i, Index &rir, Index &rii, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:420
EIGEN_ALWAYS_INLINE void dhs_cblock(PacketBlock< PacketC, 8 > &cblock, PacketBlock< Packet, 4 > &block, Packet16uc permute)
Definition: MatrixProduct.h:383
EIGEN_STRONG_INLINE void operator()(bfloat16 *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:837
EIGEN_STRONG_INLINE void operator()(bfloat16 *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset)
Definition: MatrixProduct.h:1180
EIGEN_STRONG_INLINE void operator()(double *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset)
Definition: MatrixProduct.h:781
EIGEN_ALWAYS_INLINE void dhs_copy(double *blockB, const DataMapper &rhs2, Index &i, Index &ri, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:739
EIGEN_ALWAYS_INLINE void dhs_copy(double *blockA, const DataMapper &lhs2, Index &i, Index &ri, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:662
EIGEN_STRONG_INLINE void operator()(double *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:691
Definition: MatrixProduct.h:557
EIGEN_STRONG_INLINE void operator()(Scalar *blockA, const DataMapper &lhs, Index depth, Index rows, Index stride, Index offset)
Definition: MatrixProduct.h:586
EIGEN_ALWAYS_INLINE void dhs_copy(Scalar *blockA, const DataMapper &lhs2, Index &i, Index &ri, Index depth, const Index vectorSize)
Definition: MatrixProduct.h:559
Definition: GenericPacketMath.h:225
quad_traits< bfloat16 >::rhstype RhsPacket
Definition: MatrixProduct.h:3663
quad_traits< bfloat16 >::vectortype Packet
Definition: MatrixProduct.h:3662
quad_traits< double >::vectortype Packet
Definition: MatrixProduct.h:3541
quad_traits< double >::rhstype RhsPacket
Definition: MatrixProduct.h:3542
Packet1cd Packetc
Definition: MatrixProduct.h:3632
quad_traits< double >::rhstype RhsPacket
Definition: MatrixProduct.h:3633
quad_traits< double >::vectortype Packet
Definition: MatrixProduct.h:3631
quad_traits< float >::rhstype RhsPacket
Definition: MatrixProduct.h:3423
quad_traits< float >::vectortype Packet
Definition: MatrixProduct.h:3422
Packet2cf Packetc
Definition: MatrixProduct.h:3480
Packet4f RhsPacket
Definition: MatrixProduct.h:3481
Packet4f Packet
Definition: MatrixProduct.h:3479
quad_traits< double >::vectortype Packet
Definition: MatrixProduct.h:3600
Packet1cd Packetc
Definition: MatrixProduct.h:3601
quad_traits< double >::rhstype RhsPacket
Definition: MatrixProduct.h:3602
quad_traits< double >::vectortype Packet
Definition: MatrixProduct.h:3566
quad_traits< double >::rhstype RhsPacket
Definition: MatrixProduct.h:3568
Packet1cd Packetc
Definition: MatrixProduct.h:3567
Packet2cf Packetc
Definition: MatrixProduct.h:3511
Packet4f Packet
Definition: MatrixProduct.h:3510
Packet4f RhsPacket
Definition: MatrixProduct.h:3512
Packet2cf Packetc
Definition: MatrixProduct.h:3447
Packet4f Packet
Definition: MatrixProduct.h:3446
Packet4f RhsPacket
Definition: MatrixProduct.h:3448
Definition: products/GeneralBlockPanelKernel.h:960
EIGEN_DONT_INLINE void operator()(const DataMapper &res, const LhsScalar *blockA, const RhsScalar *blockB, Index rows, Index depth, Index cols, ResScalar alpha, Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0)
Definition: products/GeneralBlockPanelKernel.h:1425
Definition: BlasUtil.h:34
Definition: BlasUtil.h:30
Definition: GenericPacketMath.h:108
vectortype rhstype
Definition: MatrixProduct.h:82
Packet8bf vectortype
Definition: MatrixProduct.h:80
PacketBlock< vectortype, 4 > type
Definition: MatrixProduct.h:81
Packet2d vectortype
Definition: MatrixProduct.h:72
PacketBlock< Packet2d, 2 > rhstype
Definition: MatrixProduct.h:74
PacketBlock< vectortype, 4 > type
Definition: MatrixProduct.h:73
Definition: MatrixProduct.h:63
PacketBlock< vectortype, 4 > type
Definition: MatrixProduct.h:65
packet_traits< Scalar >::type vectortype
Definition: MatrixProduct.h:64
vectortype rhstype
Definition: MatrixProduct.h:66
@ size
Definition: MatrixProduct.h:67
@ rows
Definition: MatrixProduct.h:67
@ vectorsize
Definition: MatrixProduct.h:67
void operator()(double *blockA, const double *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:349
void operator()(float *blockA, const float *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:334
void operator()(std::complex< double > *blockA, const std::complex< double > *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:318
void operator()(std::complex< float > *blockA, const std::complex< float > *_lhs, Index lhsStride, Index cols, Index rows)
Definition: MatrixProduct.h:300
Definition: SelfadjointMatrixMatrix.h:22
void operator()(double *blockB, const double *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:342
void operator()(float *blockB, const float *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:327
void operator()(std::complex< double > *blockB, const std::complex< double > *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:310
void operator()(std::complex< float > *blockB, const std::complex< float > *_rhs, Index rhsStride, Index rows, Index cols, Index k2)
Definition: MatrixProduct.h:292
Definition: SelfadjointMatrixMatrix.h:100
Definition: datatypes.h:12
EIGEN_DONT_INLINE Scalar zero()
Definition: svd_common.h:232
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
Definition: ZVector/PacketMath.h:50