10 #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
11 #define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
14 #include "../../InternalHeaderCheck.h"
16 #if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
17 #define EIGEN_USE_AVX512_TRSM_KERNELS 1
22 #ifdef EIGEN_NO_MALLOC
23 #undef EIGEN_USE_AVX512_TRSM_KERNELS
24 #define EIGEN_USE_AVX512_TRSM_KERNELS 0
27 #if EIGEN_USE_AVX512_TRSM_KERNELS
28 #if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
29 #define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
31 #if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
32 #define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
35 #define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
36 #define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
47 #define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
48 #define EIGEN_AVX_MAX_NUM_ROW (int64_t(8))
49 #define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
50 #define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
51 #define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
59 #include "TrsmUnrolls.inc"
61 #if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
78 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
79 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
82 #if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
84 #if EIGEN_USE_AVX512_TRSM_R_KERNELS
85 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
86 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
90 #if EIGEN_USE_AVX512_TRSM_L_KERNELS
91 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
92 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
97 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
98 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
101 template <
typename Scalar>
112 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
113 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
114 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
120 template <
typename Scalar,
typename vec,
int64_t unrollM,
int64_t unrollN,
bool remM,
bool remN>
125 using urolls = unrolls::trans<Scalar>;
127 constexpr
int64_t U3 = urolls::PacketSize * 3;
128 constexpr
int64_t U2 = urolls::PacketSize * 2;
129 constexpr
int64_t U1 = urolls::PacketSize * 1;
131 static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3,
"unrollN should be a multiple of PacketSize");
134 urolls::template transpose<unrollN, 0>(zmm);
138 static_assert((remN && unrollN == U1) || !remN,
"When handling N remainder set unrollN=U1");
140 urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
143 urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
147 urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
157 urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
158 else if (remN_ == 14)
159 urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
160 else if (remN_ == 13)
161 urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
162 else if (remN_ == 12)
163 urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
164 else if (remN_ == 11)
165 urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
166 else if (remN_ == 10)
167 urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
169 urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
171 urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
173 urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
175 urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
177 urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
179 urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
181 urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
183 urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
185 urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
189 urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
191 urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
193 urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
195 urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
197 urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
199 urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
201 urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
220 template <
typename Scalar,
bool isARowMajor,
bool isCRowMajor,
bool isAdd,
bool handleKRem>
223 using urolls = unrolls::gemm<Scalar, isAdd>;
224 constexpr
int64_t U3 = urolls::PacketSize * 3;
225 constexpr
int64_t U2 = urolls::PacketSize * 2;
226 constexpr
int64_t U1 = urolls::PacketSize * 1;
232 for (;
j < N_;
j += U3) {
236 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)], *B_t = &B_arr[0 *
LDB +
j];
238 urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
256 urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
257 urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
260 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[
i +
j * LDC], LDC);
264 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
267 urolls::template setzero<3, 4>(zmm);
277 urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
278 B_t, A_t,
LDB, LDA, zmm);
285 urolls::template updateC<3, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
286 urolls::template storeC<3, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
289 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 4);
294 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
297 urolls::template setzero<3, 2>(zmm);
307 urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
308 B_t, A_t,
LDB, LDA, zmm);
315 urolls::template updateC<3, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
316 urolls::template storeC<3, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
319 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 2);
324 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
327 urolls::template setzero<3, 1>(zmm);
330 urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
331 B_t, A_t,
LDB, LDA, zmm);
338 urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t,
LDB, LDA, zmm);
345 urolls::template updateC<3, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
346 urolls::template storeC<3, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
349 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 1);
358 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)], *B_t = &B_arr[0 *
LDB +
j];
361 urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
379 urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
380 urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
383 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[
i +
j * LDC], LDC);
387 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
390 urolls::template setzero<2, 4>(zmm);
400 urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t,
LDB,
408 urolls::template updateC<2, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
409 urolls::template storeC<2, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
412 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 4);
417 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
420 urolls::template setzero<2, 2>(zmm);
430 urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t,
LDB,
438 urolls::template updateC<2, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
439 urolls::template storeC<2, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
442 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 2);
447 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
450 urolls::template setzero<2, 1>(zmm);
452 urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t,
LDB,
460 urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t,
LDB, LDA, zmm);
467 urolls::template updateC<2, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
468 urolls::template storeC<2, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
471 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 1);
480 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)], *B_t = &B_arr[0 *
LDB +
j];
482 urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
500 urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
501 urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[
i * LDC +
j], LDC, zmm);
504 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[
i +
j * LDC], LDC);
508 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
511 urolls::template setzero<1, 4>(zmm);
521 urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t,
LDB,
529 urolls::template updateC<1, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
530 urolls::template storeC<1, 4>(&C_arr[
i * LDC +
j], LDC, zmm);
533 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 4);
538 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
541 urolls::template setzero<1, 2>(zmm);
551 urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t,
LDB,
559 urolls::template updateC<1, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
560 urolls::template storeC<1, 2>(&C_arr[
i * LDC +
j], LDC, zmm);
563 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 2);
568 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
571 urolls::template setzero<1, 1>(zmm);
574 urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t,
LDB,
582 urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t,
LDB, LDA, zmm);
589 urolls::template updateC<1, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
590 urolls::template storeC<1, 1>(&C_arr[
i * LDC +
j], LDC, zmm);
593 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[
i +
j * LDC], LDC, 1);
603 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
606 urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
624 urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[
i * LDC +
j], LDC, zmm,
N -
j);
625 urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[
i * LDC +
j], LDC, zmm,
N -
j);
628 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[
i +
j * LDC], LDC, 0,
N -
j);
632 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
635 urolls::template setzero<1, 4>(zmm);
645 urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
646 B_t, A_t,
LDB, LDA, zmm,
N -
j);
653 urolls::template updateC<1, 4, true>(&C_arr[
i * LDC +
j], LDC, zmm,
N -
j);
654 urolls::template storeC<1, 4, true>(&C_arr[
i * LDC +
j], LDC, zmm,
N -
j);
657 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[
i +
j * LDC], LDC, 4,
N -
j);
662 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
665 urolls::template setzero<1, 2>(zmm);
675 urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
676 B_t, A_t,
LDB, LDA, zmm,
N -
j);
683 urolls::template updateC<1, 2, true>(&C_arr[
i * LDC +
j], LDC, zmm,
N -
j);
684 urolls::template storeC<1, 2, true>(&C_arr[
i * LDC +
j], LDC, zmm,
N -
j);
687 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[
i +
j * LDC], LDC, 2,
N -
j);
692 Scalar *A_t = &A_arr[idA<isARowMajor>(
i, 0, LDA)];
695 urolls::template setzero<1, 1>(zmm);
697 urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
698 B_t, A_t,
LDB, LDA, zmm,
N -
j);
705 urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t,
LDB, LDA, zmm,
713 urolls::template updateC<1, 1, true>(&C_arr[
i * LDC +
j], LDC, zmm,
N -
j);
714 urolls::template storeC<1, 1, true>(&C_arr[
i * LDC +
j], LDC, zmm,
N -
j);
717 transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[
i +
j * LDC], LDC, 1,
N -
j);
731 template <
typename Scalar,
typename vec,
int64_t unrollM,
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag>
734 using urolls = unrolls::trsm<Scalar>;
735 constexpr
int64_t U3 = urolls::PacketSize * 3;
736 constexpr
int64_t U2 = urolls::PacketSize * 2;
737 constexpr
int64_t U1 = urolls::PacketSize * 1;
743 while (
K -
k >= U3) {
744 urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr +
k,
LDB, RHSInPacket);
745 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
747 urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr +
k,
LDB, RHSInPacket);
751 urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr +
k,
LDB, RHSInPacket);
752 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
754 urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr +
k,
LDB, RHSInPacket);
758 urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr +
k,
LDB, RHSInPacket);
759 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
761 urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr +
k,
LDB, RHSInPacket);
766 urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr +
k,
LDB, RHSInPacket,
K -
k);
767 urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
769 urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr +
k,
LDB, RHSInPacket,
K -
k);
781 template <
typename Scalar,
bool isARowMajor,
bool isFWDSolve,
bool isUnitDiag>
787 triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr,
K, LDA,
LDB);
789 triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr,
K, LDA,
LDB);
791 triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr,
K, LDA,
LDB);
793 triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr,
K, LDA,
LDB);
795 triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr,
K, LDA,
LDB);
797 triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr,
K, LDA,
LDB);
799 triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr,
K, LDA,
LDB);
801 triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr,
K, LDA,
LDB);
812 template <
typename Scalar,
bool toTemp = true,
bool remM = false>
816 using urolls = unrolls::transB<Scalar>;
819 constexpr
int64_t U3 = urolls::PacketSize * 3;
820 constexpr
int64_t U2 = urolls::PacketSize * 2;
821 constexpr
int64_t U1 = urolls::PacketSize * 1;
825 for (;
k < K_;
k += U3) {
826 urolls::template transB_kernel<U3, toTemp, remM>(B_arr +
k *
LDB,
LDB, B_temp, LDB_, ymm, remM_);
830 urolls::template transB_kernel<U2, toTemp, remM>(B_arr +
k *
LDB,
LDB, B_temp, LDB_, ymm, remM_);
835 urolls::template transB_kernel<U1, toTemp, remM>(B_arr +
k *
LDB,
LDB, B_temp, LDB_, ymm, remM_);
844 urolls::template transB_kernel<8, toTemp, remM>(B_arr +
k *
LDB,
LDB, B_temp, LDB_, ymm, remM_);
854 urolls::template transB_kernel<4, toTemp, remM>(B_arr +
k *
LDB,
LDB, B_temp, LDB_, ymm, remM_);
860 urolls::template transB_kernel<2, toTemp, remM>(B_arr +
k *
LDB,
LDB, B_temp, LDB_, ymm, remM_);
865 urolls::template transB_kernel<1, toTemp, remM>(B_arr +
k *
LDB,
LDB, B_temp, LDB_, ymm, remM_);
898 template <
typename Scalar,
bool isARowMajor =
true,
bool isBRowMajor =
true,
bool isFWDSolve =
true,
899 bool isUnitDiag =
false>
915 constexpr
int64_t kB = (3 * psize) * 5;
926 sizeBTemp = (((
std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
932 int64_t bK = numRHS -
k > kB ? kB : numRHS -
k;
940 int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
947 int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
949 copyBToRowMajor<Scalar, true, false>(B_arr + indB_i +
k *
LDB,
LDB, bK, B_temp + offB_1, LDT);
951 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
952 &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2,
EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
954 copyBToRowMajor<Scalar, false, false>(B_arr + indB_i +
k *
LDB,
LDB, bK, B_temp + offB_1, LDT);
960 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
981 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
982 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr +
k + indB_i *
LDB, B_arr +
k + indB_i2 *
LDB,
998 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
999 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1000 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (
k)*
LDB,
1011 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1012 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1013 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (
k)*
LDB,
1024 int64_t indA_i = isFWDSolve ? M_ : 0;
1025 int64_t indA_j = isFWDSolve ? 0 : bM;
1026 int64_t indB_i = isFWDSolve ? 0 : bM;
1027 int64_t indB_i2 = isFWDSolve ? M_ : 0;
1028 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1029 &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr +
k + indB_i *
LDB, B_arr +
k + indB_i2 *
LDB, bM,
1033 int64_t indA_i = isFWDSolve ? M_ : 0;
1034 int64_t indA_j = isFWDSolve ? gemmOff : bM;
1035 int64_t indB_i = isFWDSolve ? M_ : 0;
1036 int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1037 gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
1038 B_temp + offB_1, B_arr + indB_i + (
k)*
LDB, bM, bK,
1039 M_ - gemmOff, LDA, LDT,
LDB);
1043 int64_t indA_i = isFWDSolve ? M_ :
M - 1 - M_;
1044 int64_t indB_i = isFWDSolve ? M_ : 0;
1045 int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
1046 copyBToRowMajor<Scalar, true, true>(B_arr + indB_i +
k *
LDB,
LDB, bK, B_temp, bkL, bM);
1047 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
1048 B_temp + offB_1, bM, bkL, LDA, bkL);
1049 copyBToRowMajor<Scalar, false, true>(B_arr + indB_i +
k *
LDB,
LDB, bK, B_temp, bkL, bM);
1053 triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(
ind,
ind, LDA)],
1063 #if (EIGEN_USE_AVX512_TRSM_KERNELS)
1064 #if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1065 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStride,
1069 template <
typename Index,
int Mode,
int TriStorageOrder>
1075 template <
typename Index,
int Mode,
int TriStorageOrder>
1081 template <
typename Index,
int Mode,
int TriStorageOrder>
1084 Index otherStride) {
1086 #ifdef EIGEN_RUNTIME_NO_MALLOC
1087 if (!is_malloc_allowed()) {
1089 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1093 triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1094 const_cast<float *
>(_tri), _other,
size, otherSize, triStride, otherStride);
1097 template <
typename Index,
int Mode,
int TriStorageOrder>
1100 Index otherStride) {
1102 #ifdef EIGEN_RUNTIME_NO_MALLOC
1103 if (!is_malloc_allowed()) {
1105 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1109 triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1110 const_cast<double *
>(_tri), _other,
size, otherSize, triStride, otherStride);
1115 #if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
1116 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStride,
1120 template <
typename Index,
int Mode,
int TriStorageOrder>
1126 template <
typename Index,
int Mode,
int TriStorageOrder>
1132 template <
typename Index,
int Mode,
int TriStorageOrder>
1135 Index otherStride) {
1137 #ifdef EIGEN_RUNTIME_NO_MALLOC
1138 if (!is_malloc_allowed()) {
1140 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1144 triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1145 const_cast<float *
>(_tri), _other,
size, otherSize, triStride, otherStride);
1148 template <
typename Index,
int Mode,
int TriStorageOrder>
1151 Index otherStride) {
1153 #ifdef EIGEN_RUNTIME_NO_MALLOC
1154 if (!is_malloc_allowed()) {
1156 size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1160 triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1161 const_cast<double *
>(_tri), _other,
size, otherSize, triStride, otherStride);
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_DONT_INLINE
Definition: Macros.h:853
#define EIGEN_IF_CONSTEXPR(X)
Definition: Macros.h:1306
std::vector< int > ind
Definition: Slicing_stdvector_cxx11.cpp:1
#define EIGEN_AVX_MAX_K_UNROL
Definition: TrsmKernel.h:49
#define EIGEN_AVX_MAX_NUM_ROW
Definition: TrsmKernel.h:48
#define EIGEN_AVX_B_LOAD_SETS
Definition: TrsmKernel.h:50
#define EIGEN_AVX_MAX_A_BCAST
Definition: TrsmKernel.h:51
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
Definition: ForwardDeclarations.h:102
The matrix class, also used for vectors and row-vectors.
Definition: Eigen/Eigen/src/Core/Matrix.h:186
@ N
Definition: constructor.cpp:22
#define min(a, b)
Definition: datatypes.h:22
@ Specialized
Definition: Constants.h:311
char char char int int * k
Definition: level2_impl.h:374
Packet8f vecHalfFloat
Definition: TrsmKernel.h:54
Packet8d vecFullDouble
Definition: TrsmKernel.h:53
EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:732
__m512d Packet8d
Definition: AVX512/PacketMath.h:36
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC)
Definition: TrsmKernel.h:221
EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock< vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS > &zmm, Scalar *C_arr, int64_t LDC, int64_t remM_=0, int64_t remN_=0)
Definition: TrsmKernel.h:121
EIGEN_DEVICE_FUNC void handmade_aligned_free(void *ptr)
Definition: Memory.h:158
EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_, int64_t remM_=0)
Definition: TrsmKernel.h:813
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:900
Packet4d vecHalfDouble
Definition: TrsmKernel.h:55
EIGEN_DEVICE_FUNC void * handmade_aligned_malloc(std::size_t size, std::size_t alignment=EIGEN_DEFAULT_ALIGN_BYTES)
Definition: Memory.h:142
Packet16f vecFullFloat
Definition: TrsmKernel.h:52
__m256 Packet8f
Definition: AVX/PacketMath.h:34
void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:782
__m256d Packet4d
Definition: AVX/PacketMath.h:36
__m512 Packet16f
Definition: AVX512/PacketMath.h:34
std::int64_t int64_t
Definition: Meta.h:43
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
double K
Wave number.
Definition: sphere_scattering.cc:115
type
Definition: compute_granudrum_aor.py:141
Definition: Eigen_Colamd.h:49
@ LDB
Definition: octree.h:49
Definition: GenericPacketMath.h:1407
Definition: GenericPacketMath.h:108
Definition: TriangularSolverMatrix.h:23
static void kernel(Index size, Index otherSize, const Scalar *_tri, Index triStride, Scalar *_other, Index otherIncr, Index otherStride)
Definition: TriangularSolverMatrix.h:42
Definition: TriangularSolverMatrix.h:32
static void kernel(Index size, Index otherSize, const Scalar *_tri, Index triStride, Scalar *_other, Index otherIncr, Index otherStride)
Definition: TriangularSolverMatrix.h:84
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2