10 #ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
11 #define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
16 #include <x86intrin.h>
18 #include <immintrin.h>
19 #include <type_traits>
22 #include "../../InternalHeaderCheck.h"
24 #if !defined(EIGEN_USE_AVX512_GEMM_KERNELS)
25 #define EIGEN_USE_AVX512_GEMM_KERNELS 1
28 #define SECOND_FETCH (32)
29 #if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
32 #define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
38 template <
typename Scalar,
bool is_unit_inc>
48 #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
53 #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
62 8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
104 #if defined(__PRFCHW__) && __PRFCHW__ == 1
105 _m_prefetchw((
void *)c_addr);
107 _mm_prefetch((
char *)c_addr, _MM_HINT_T0);
111 template <
int nelems>
113 switch (nelems *
sizeof(*a_addr) * 8) {
116 a_reg = ploadu<vec>(a_addr);
119 a_reg = ploadu<vec>(a_addr);
122 a_reg = ploadu<vec>(a_addr);
125 a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(
ploadu<Packet4d>(
reinterpret_cast<const double *
>(a_addr))));
128 a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(
ploadu<Packet4f>(
reinterpret_cast<const float *
>(a_addr))));
131 a_reg = preinterpret<vec>(
pload1<Packet8d>(
reinterpret_cast<const double *
>(a_addr)));
134 a_reg = pload1<vec>(a_addr);
141 template <
int nelems>
144 switch (nelems *
sizeof(*mem) * 8) {
156 pstoreu(mem, preinterpret<vec_ymm>(src));
159 pstoreu(mem, preinterpret<vec_xmm>(src));
162 pstorel(mem, preinterpret<vec_xmm>(src));
165 pstores(mem, preinterpret<vec_xmm>(src));
169 switch (nelems *
sizeof(*mem) * 8) {
196 template <
int nelems>
199 switch (nelems *
sizeof(*mem) * 8) {
202 dst =
padd(src, ploadu<vec>(mem));
205 dst =
padd(src, ploadu<vec>(mem));
208 dst =
padd(src, ploadu<vec>(mem));
211 dst = preinterpret<vec>(
padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
214 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
217 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
220 dst = preinterpret<vec>(
padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
227 switch (nelems *
sizeof(*mem) * 8) {
230 reg = pgather<Scalar, vec>(mem,
inc);
231 dst =
padd(src, reg);
234 reg = pgather<Scalar, vec>(mem,
inc);
235 dst =
padd(src, reg);
238 reg = pgather<Scalar, vec>(mem,
inc);
239 dst =
padd(src, reg);
242 reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem,
inc));
243 dst = preinterpret<vec>(
padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
246 reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem,
inc));
247 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
252 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
254 dst = preinterpret<vec>(
padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
258 dst = preinterpret<vec>(
padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
265 dst =
pmadd(src1, src2, dst);
267 #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
269 __asm__(
"#" : [dst]
"+v"(dst) : [src1]
"%v"(src1), [src2]
"v"(src2));
273 template <
int nelems>
276 switch (nelems *
sizeof(*mem) * 8) {
279 dst =
pmadd(scale, src, ploadu<vec>(mem));
282 dst =
pmadd(scale, src, ploadu<vec>(mem));
285 dst =
pmadd(scale, src, ploadu<vec>(mem));
289 preinterpret<vec>(
pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
293 preinterpret<vec>(
pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
297 preinterpret<vec>(
pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
301 preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
308 switch (nelems *
sizeof(*mem) * 8) {
311 reg = pgather<Scalar, vec>(mem,
inc);
312 dst =
pmadd(scale, src, reg);
315 reg = pgather<Scalar, vec>(mem,
inc);
316 dst =
pmadd(scale, src, reg);
319 reg = pgather<Scalar, vec>(mem,
inc);
320 dst =
pmadd(scale, src, reg);
323 reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem,
inc));
324 dst = preinterpret<vec>(
325 pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
328 reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem,
inc));
329 dst = preinterpret<vec>(
330 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
335 dst = preinterpret<vec>(
336 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
338 dst = preinterpret<vec>(
339 pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
344 preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
350 template <
int j,
int endX,
int i,
int endY,
int nelems>
355 template <
int j,
int endX,
int i,
int endY,
int nelems>
361 a_load<nelems>(a_reg, a_addr);
363 a_loads<j, endX, i + 1, endY, nelems>(ao);
365 a_loads<j + 1, endX, 0, endY, nelems>(ao);
370 template <
int un,
int max_b_unroll,
int i,
int um_vecs,
int a_unroll,
int b_unroll>
391 template <
int un,
int max_b_unroll,
int i,
int um_vecs,
int a_unroll,
int b_unroll>
393 if (un < max_b_unroll) {
394 if (b_unroll >= un + 1) {
395 if (un == 4 &&
i == 0) co2 = co1 + 4 *
ldc;
398 Scalar *co = (un + 1 <= 4) ? co1 : co2;
402 prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
404 prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
408 prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
414 template <
int i,
int um_vecs,
int idx,
int nelems>
420 template <
int i,
int um_vecs,
int idx,
int nelems>
432 vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
434 vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
436 c_reg =
pmul(alpha_reg, c_reg);
438 scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
443 template <
int i,
int um_vecs,
int idx,
int nelems>
448 template <
int i,
int um_vecs,
int idx,
int nelems>
458 c_store<nelems>(c_mem, c_reg);
459 c_reg =
pzero(c_reg);
461 write_c<i + 1, um_vecs, idx, nelems>(cox);
495 template <
int pow,
int a_unroll,
int idx>
502 scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
503 write_c<0, um_vecs, idx, a_unroll>(cox);
506 template <
int pow,
int a_unroll>
508 constexpr
int idx =
pow / 2;
509 Scalar *&cox = idx == 0 ? co1 : co2;
511 constexpr
int max_count = (
pow + 1) / 2;
512 static_assert(max_count <= 4,
"Unsupported max_count.");
514 if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
515 if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
516 if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
517 if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
520 template <
int max_b_unroll,
int a_unroll,
int b_unroll>
528 static_assert(max_b_unroll <= 8,
"Unsupported max_b_unroll");
530 if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
531 if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
532 if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
533 if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
542 template <
int um,
int um_vecs,
int idx,
int uk,
bool fetch_x,
bool ktail>
544 int &fetchB_idx,
vec &b_reg) {
552 template <
int um,
int um_vecs,
int idx,
int uk,
bool fetch_x,
bool ktail>
554 int &fetchB_idx,
vec &b_reg) {
557 auto &a_reg =
zmm[
a_regs[um + (uk % 2) * 3]];
559 vfmadd(c_reg, a_reg, b_reg);
561 if (!fetch_x && um == 0 &&
562 (((idx == 0 || idx == 6) && (uk % 2 == 0 ||
is_f64 || ktail)) ||
563 (idx == 3 && (uk % 2 == 1 ||
is_f64 || ktail)))) {
568 if (um == 0 && idx == 1 && (uk % 2 == 0 ||
is_f64 || ktail)) {
573 compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao,
bo, fetchA_idx, fetchB_idx, b_reg);
578 template <
int um,
int um_vecs,
int uk,
int nelems,
bool ktail>
583 template <
int um,
int um_vecs,
int uk,
int nelems,
bool ktail>
586 auto &a_reg =
zmm[
a_regs[um + (uk % 2) * 3]];
588 a_load<nelems>(a_reg, a_addr);
590 load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
593 template <
int uk,
int pow,
int count,
int um_vecs,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch>
597 int &fetchA_idx,
int &fetchB_idx) {
606 template <
int uk,
int pow,
int count,
int um_vecs,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch>
610 int &fetchA_idx,
int &fetchB_idx) {
611 const int idx = (
pow / 2) + count;
613 if (count < (
pow + 1) / 2) {
616 if (fetch_x && uk == 3 && idx == 0)
prefetch_x(aa);
617 if (fetch_x && uk == 3 && idx == 4) aa += 8;
619 if (b_unroll >=
pow) {
620 compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao,
bo, fetchA_idx, fetchB_idx, b_reg);
627 innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao,
bo, co2, fetchA_idx,
632 if (
pow == 2 && c_fetch) {
633 if (uk % 3 == 0 && uk > 0) {
642 template <
int uk,
int max_b_unroll,
int a_unroll,
int b_unroll,
bool ktail,
bool fetch_x,
bool c_fetch,
643 bool no_a_preload =
false>
645 Scalar *&co2,
int &fetchA_idx,
int &fetchB_idx) {
648 if (max_b_unroll >= 1)
649 innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao,
bo, co2, fetchA_idx, fetchB_idx);
650 if (max_b_unroll >= 2)
651 innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao,
bo, co2, fetchA_idx, fetchB_idx);
652 if (max_b_unroll >= 4)
653 innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao,
bo, co2, fetchA_idx, fetchB_idx);
654 if (max_b_unroll >= 8)
655 innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao,
bo, co2, fetchA_idx, fetchB_idx);
658 if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
700 template <
int a_unroll,
int b_unroll,
int k_factor,
int max_b_unroll,
int max_k_factor,
bool c_fetch,
701 bool no_a_preload =
false>
706 const bool fetch_x = k_factor == max_k_factor;
707 const bool ktail = k_factor == 1;
709 static_assert(k_factor <= 4 && k_factor > 0,
"innerkernel maximum k_factor supported is 4");
710 static_assert(no_a_preload ==
false || (no_a_preload ==
true && k_factor == 1),
711 "skipping a preload only allowed when k unroll is 1");
714 innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
715 aa, ao,
bo, co2, fetchA_idx, fetchB_idx);
717 innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
718 aa, ao,
bo, co2, fetchA_idx, fetchB_idx);
720 innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
721 aa, ao,
bo, co2, fetchA_idx, fetchB_idx);
723 innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
724 aa, ao,
bo, co2, fetchA_idx, fetchB_idx);
727 ao += a_unroll * k_factor;
728 bo += b_unroll * k_factor;
731 template <
int a_unroll,
int b_unroll,
int max_b_unroll>
735 a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
737 a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
743 prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
747 const int max_k_factor = 4;
748 Index kRem =
k % max_k_factor;
750 if (k_ >= max_k_factor) {
752 kRem += max_k_factor;
754 Index loop_count = k_ / max_k_factor;
756 if (loop_count > 0) {
760 while (loop_count > 0) {
761 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao,
bo, co2);
767 loop_count += b_unroll;
768 while (loop_count > 0) {
769 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao,
bo, co2);
774 while (loop_count > 0) {
775 innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao,
bo, co2);
783 while (loop_count > 1) {
784 innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao,
bo, co2);
787 if (loop_count > 0) {
788 innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao,
bo, co2);
792 c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
795 template <
int a_unroll,
int b_unroll,
int max_b_unroll>
798 ao =
a +
a_off * a_unroll;
803 kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao,
bo, co1, co2);
812 template <
int a_unroll,
int max_a_unroll,
int max_b_unroll>
819 if (a_unroll >= max_a_unroll) co2 =
c + 2 *
ldc;
829 for (
Index i =
n / max_b_unroll;
i > 0;
i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao,
bo, co1, co2);
832 if (
n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao,
bo, co1, co2);
834 if (
n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao,
bo, co1, co2);
835 if (
n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao,
bo, co1, co2);
839 int n_rem = 2 * ((
n & 2) != 0) + 1 * ((
n & 1) != 0);
841 nloop<a_unroll, 1, max_b_unroll>(aa, ao,
bo, co1, co2);
852 template <
int max_a_unroll,
int max_b_unroll>
857 const Scalar *ao =
nullptr;
863 for (;
m >= max_a_unroll;
m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao,
bo, co1, co2);
866 if (
m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao,
bo, co1, co2);
867 if (
m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao,
bo, co1, co2);
868 if (
m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao,
bo, co1, co2);
869 if (
m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao,
bo, co1, co2);
870 if (
m & 2 && max_a_unroll > 2 &&
is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao,
bo, co1, co2);
871 if (
m & 1 && max_a_unroll > 1 &&
is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao,
bo, co1, co2);
876 int m_rem = 2 * ((
m & 2) != 0) + 1 * ((
m & 1) != 0);
878 mloop<1, max_a_unroll, max_b_unroll>(ao,
bo, co1, co2);
937 template <
typename Scalar,
int max_a_unroll,
int max_b_unroll,
bool is_alpha1,
bool is_beta0,
bool is_unit_inc>
941 if (a_stride == -1) a_stride =
k;
942 if (b_stride == -1) b_stride =
k;
944 gemm_class<Scalar, is_unit_inc> g(
m,
n,
k,
ldc, inc,
alpha,
a,
b,
c, is_alpha1, is_beta0, a_stride, b_stride, a_off,
946 g.template compute_kern<max_a_unroll, max_b_unroll>();
950 #if EIGEN_USE_AVX512_GEMM_KERNELS
951 template <
bool ConjLhs_,
bool ConjRhs_,
int PacketSize_>
953 :
public gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
957 enum {
nr = Base::Vectorizable ? 8 : 4 };
960 template <
bool ConjLhs_,
bool ConjRhs_,
int PacketSize_>
962 :
public gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
966 enum {
nr = Base::Vectorizable ? 8 : 4 };
969 template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
978 template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
981 constexpr
int nr = 8;
985 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
987 Index packet_cols8 = nr >= 8 ? (
cols / 8) * 8 : 0;
988 Index packet_cols4 = nr >= 4 ? (
cols / 4) * 4 : 0;
990 const Index peeled_k = (depth / PacketSize) * PacketSize;
992 for (
Index j2 = 0; j2 < packet_cols8; j2 += 8) {
994 if (PanelMode) count += 8 * offset;
995 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
996 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
997 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
998 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
999 const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
1000 const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
1001 const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
1002 const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
1004 if ((PacketSize % 8) == 0)
1006 for (;
k < peeled_k;
k += PacketSize) {
1009 kernel.packet[0] = dm0.template loadPacket<Packet>(
k);
1010 kernel.packet[1] = dm1.template loadPacket<Packet>(
k);
1011 kernel.packet[2] = dm2.template loadPacket<Packet>(
k);
1012 kernel.packet[3] = dm3.template loadPacket<Packet>(
k);
1013 kernel.packet[4] = dm4.template loadPacket<Packet>(
k);
1014 kernel.packet[5] = dm5.template loadPacket<Packet>(
k);
1015 kernel.packet[6] = dm6.template loadPacket<Packet>(
k);
1016 kernel.packet[7] = dm7.template loadPacket<Packet>(
k);
1020 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1021 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1022 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1023 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1024 pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
1025 pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
1026 pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
1027 pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
1028 count += 8 * PacketSize;
1031 for (;
k < depth;
k++) {
1032 blockB[count + 0] = cj(dm0(
k));
1033 blockB[count + 1] = cj(dm1(
k));
1034 blockB[count + 2] = cj(dm2(
k));
1035 blockB[count + 3] = cj(dm3(
k));
1036 blockB[count + 4] = cj(dm4(
k));
1037 blockB[count + 5] = cj(dm5(
k));
1038 blockB[count + 6] = cj(dm6(
k));
1039 blockB[count + 7] = cj(dm7(
k));
1043 if (PanelMode) count += 8 * (stride - offset - depth);
1048 for (
Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1050 if (PanelMode) count += 4 * offset;
1051 const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1052 const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1053 const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1054 const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1057 if ((PacketSize % 4) == 0)
1059 for (;
k < peeled_k;
k += PacketSize) {
1061 kernel.packet[0] = dm0.template loadPacket<Packet>(
k);
1062 kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(
k);
1063 kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(
k);
1064 kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(
k);
1066 pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1067 pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1068 pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1069 pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1070 count += 4 * PacketSize;
1073 for (;
k < depth;
k++) {
1074 blockB[count + 0] = cj(dm0(
k));
1075 blockB[count + 1] = cj(dm1(
k));
1076 blockB[count + 2] = cj(dm2(
k));
1077 blockB[count + 3] = cj(dm3(
k));
1081 if (PanelMode) count += 4 * (stride - offset - depth);
1086 for (
Index j2 = packet_cols4; j2 <
cols; ++j2) {
1087 if (PanelMode) count += offset;
1089 for (
Index k = 0;
k < depth;
k++) {
1090 blockB[count] = cj(dm0(
k));
1093 if (PanelMode) count += (stride - offset - depth);
1097 template <
typename Scalar,
typename Index,
typename DataMapper,
bool Conjugate,
bool PanelMode>
1110 constexpr
int nr = 8;
1114 eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
1115 const bool HasHalf = (
int)HalfPacketSize < (
int)PacketSize;
1116 const bool HasQuarter = (
int)QuarterPacketSize < (
int)HalfPacketSize;
1118 Index packet_cols8 = nr >= 8 ? (
cols / 8) * 8 : 0;
1119 Index packet_cols4 = nr >= 4 ? (
cols / 4) * 4 : 0;
1123 for (
Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1125 if (PanelMode) count += 8 * offset;
1126 for (
Index k = 0;
k < depth;
k++) {
1127 if (PacketSize == 8) {
1129 Packet A = rhs.template loadPacket<Packet>(
k, j2);
1130 pstoreu(blockB + count, cj.pconj(
A));
1131 }
else if (HasHalf && HalfPacketSize == 8) {
1132 HalfPacket A = rhs.template loadPacket<HalfPacket>(
k, j2);
1133 pstoreu(blockB + count, cj.pconj(
A));
1134 }
else if (HasQuarter && QuarterPacketSize == 8) {
1136 pstoreu(blockB + count, cj.pconj(
A));
1137 }
else if (PacketSize == 4) {
1140 Packet A = rhs.template loadPacket<Packet>(
k, j2);
1141 Packet B = rhs.template loadPacket<Packet>(
k, j2 + PacketSize);
1142 pstoreu(blockB + count, cj.pconj(
A));
1143 pstoreu(blockB + count + PacketSize, cj.pconj(
B));
1147 blockB[count + 0] = cj(dm0(0));
1148 blockB[count + 1] = cj(dm0(1));
1149 blockB[count + 2] = cj(dm0(2));
1150 blockB[count + 3] = cj(dm0(3));
1151 blockB[count + 4] = cj(dm0(4));
1152 blockB[count + 5] = cj(dm0(5));
1153 blockB[count + 6] = cj(dm0(6));
1154 blockB[count + 7] = cj(dm0(7));
1159 if (PanelMode) count += 8 * (stride - offset - depth);
1164 for (
Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1166 if (PanelMode) count += 4 * offset;
1167 for (
Index k = 0;
k < depth;
k++) {
1168 if (PacketSize == 4) {
1169 Packet A = rhs.template loadPacket<Packet>(
k, j2);
1170 pstoreu(blockB + count, cj.pconj(
A));
1171 count += PacketSize;
1172 }
else if (HasHalf && HalfPacketSize == 4) {
1173 HalfPacket A = rhs.template loadPacket<HalfPacket>(
k, j2);
1174 pstoreu(blockB + count, cj.pconj(
A));
1175 count += HalfPacketSize;
1176 }
else if (HasQuarter && QuarterPacketSize == 4) {
1178 pstoreu(blockB + count, cj.pconj(
A));
1179 count += QuarterPacketSize;
1182 blockB[count + 0] = cj(dm0(0));
1183 blockB[count + 1] = cj(dm0(1));
1184 blockB[count + 2] = cj(dm0(2));
1185 blockB[count + 3] = cj(dm0(3));
1190 if (PanelMode) count += 4 * (stride - offset - depth);
1194 for (
Index j2 = packet_cols4; j2 <
cols; ++j2) {
1195 if (PanelMode) count += offset;
1196 for (
Index k = 0;
k < depth;
k++) {
1197 blockB[count] = cj(rhs(
k, j2));
1200 if (PanelMode) count += stride - offset - depth;
1205 template <
typename Scalar,
typename Index,
typename DataMapper,
int mr,
bool ConjugateLhs,
bool ConjugateRhs>
1212 template <
typename Scalar,
typename Index,
typename DataMapper,
int mr,
bool ConjugateLhs,
bool ConjugateRhs>
1216 if (
res.incr() == 1) {
1218 gemm_kern_avx512<Scalar, mr, 8, true, false, true>(
rows,
cols, depth, &
alpha, blockA, blockB,
1220 strideB, offsetA, offsetB);
1222 gemm_kern_avx512<Scalar, mr, 8, false, false, true>(
rows,
cols, depth, &
alpha, blockA, blockB,
1224 strideB, offsetA, offsetB);
1228 gemm_kern_avx512<Scalar, mr, 8, true, false, false>(
rows,
cols, depth, &
alpha, blockA, blockB,
1230 strideB, offsetA, offsetB);
1232 gemm_kern_avx512<Scalar, mr, 8, false, false, false>(
rows,
cols, depth, &
alpha, blockA, blockB,
1234 strideB, offsetA, offsetB);
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
#define SECOND_FETCH
Definition: GemmKernel.h:28
#define EIGEN_ASM_COMMENT(X)
Definition: Macros.h:972
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_DONT_INLINE
Definition: Macros.h:853
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
internal::packet_traits< Scalar >::type Packet
Definition: benchmark-blocking-sizes.cpp:54
Definition: ForwardDeclarations.h:102
The matrix class, also used for vectors and row-vectors.
Definition: Eigen/Eigen/src/Core/Matrix.h:186
Definition: products/GeneralBlockPanelKernel.h:397
@ nr
Definition: products/GeneralBlockPanelKernel.h:418
Definition: GemmKernel.h:39
EIGEN_ALWAYS_INLINE void compute_kern()
Definition: GemmKernel.h:853
EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2)
Definition: GemmKernel.h:507
const Scalar * alpha
Definition: GemmKernel.h:82
EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
Definition: GemmKernel.h:644
const Index b_off
Definition: GemmKernel.h:91
typename unpacket_traits< vec >::half vec_ymm
Definition: GemmKernel.h:41
static constexpr bool is_f64
Definition: GemmKernel.h:46
EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2)
Definition: GemmKernel.h:264
EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
Definition: GemmKernel.h:796
static constexpr int c_regs[]
Definition: GemmKernel.h:61
static constexpr int b_regs[]
Definition: GemmKernel.h:60
EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2)
Definition: GemmKernel.h:702
static constexpr int alpha_load_reg
Definition: GemmKernel.h:65
EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr)
Definition: GemmKernel.h:101
EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec ®)
Definition: GemmKernel.h:274
static constexpr int a_regs[]
Definition: GemmKernel.h:59
const Index b_stride
Definition: GemmKernel.h:90
static constexpr int b_prefetch_size
Definition: GemmKernel.h:73
EIGEN_ALWAYS_INLINE std::enable_if_t<(j<=endX) &&(i<=endY)> a_loads(const Scalar *ao)
Definition: GemmKernel.h:356
EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr)
Definition: GemmKernel.h:93
typename unpacket_traits< vec_ymm >::half vec_xmm
Definition: GemmKernel.h:42
vec zmm[32]
Definition: GemmKernel.h:75
typename packet_traits< Scalar >::type vec
Definition: GemmKernel.h:40
EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX)||(i > endY)> a_loads(const Scalar *ao)
Definition: GemmKernel.h:351
const Index n
Definition: GemmKernel.h:80
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox)
Definition: GemmKernel.h:444
EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll)||(i > um_vecs)> prefetch_cs(const Scalar *co1, const Scalar *co2)
Definition: GemmKernel.h:371
EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr)
Definition: GemmKernel.h:112
EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec ®)
Definition: GemmKernel.h:197
const Index a_off
Definition: GemmKernel.h:91
static constexpr bool use_less_a_regs
Definition: GemmKernel.h:49
const Index inc
Definition: GemmKernel.h:81
static constexpr bool use_less_b_regs
Definition: GemmKernel.h:54
EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr)
Definition: GemmKernel.h:139
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg)
Definition: GemmKernel.h:415
EIGEN_ALWAYS_INLINE std::enable_if_t<(um<=um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg)
Definition: GemmKernel.h:553
const bool is_beta0
Definition: GemmKernel.h:88
EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
Definition: GemmKernel.h:813
const Index ldc
Definition: GemmKernel.h:80
EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
Definition: GemmKernel.h:732
EIGEN_ALWAYS_INLINE std::enable_if_t<(count<=(pow+1)/2)> innerkernel_1pow(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
Definition: GemmKernel.h:607
EIGEN_ALWAYS_INLINE std::enable_if_t<(i<=um_vecs)> write_c(Scalar *cox)
Definition: GemmKernel.h:449
EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src)
Definition: GemmKernel.h:142
const Scalar * a
Definition: GemmKernel.h:84
Scalar * c
Definition: GemmKernel.h:85
typename unpacket_traits< vec >::mask_t umask_t
Definition: GemmKernel.h:43
static constexpr int c_load_regs[]
Definition: GemmKernel.h:66
EIGEN_ALWAYS_INLINE std::enable_if_t<(i<=um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg)
Definition: GemmKernel.h:421
umask_t mask
Definition: GemmKernel.h:76
EIGEN_ALWAYS_INLINE std::enable_if_t<(um<=um_vecs)> load_a(const Scalar *ao)
Definition: GemmKernel.h:584
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg)
Definition: GemmKernel.h:543
EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr)
Definition: GemmKernel.h:103
static constexpr int b_shift
Definition: GemmKernel.h:69
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(const Scalar *ao)
Definition: GemmKernel.h:579
EIGEN_ALWAYS_INLINE std::enable_if_t<(count >pow+1)/2)> innerkernel_1pow(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
Definition: GemmKernel.h:594
Index m
Definition: GemmKernel.h:79
static constexpr int a_shift
Definition: GemmKernel.h:68
const Index a_stride
Definition: GemmKernel.h:90
const Index k
Definition: GemmKernel.h:80
const bool is_alpha1
Definition: GemmKernel.h:87
EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2)
Definition: GemmKernel.h:521
const Scalar * b
Definition: GemmKernel.h:84
EIGEN_ALWAYS_INLINE std::enable_if_t<(un<=max_b_unroll) &&(i<=um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2)
Definition: GemmKernel.h:392
static constexpr bool is_f32
Definition: GemmKernel.h:45
gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_, const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_, Index a_off_, Index b_off_)
Definition: GemmKernel.h:884
static constexpr int nelems_in_cache_line
Definition: GemmKernel.h:71
EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr)
Definition: GemmKernel.h:97
EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox)
Definition: GemmKernel.h:496
static constexpr int a_prefetch_size
Definition: GemmKernel.h:72
Definition: matrices.h:74
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
RealScalar alpha
Definition: level1_cplx_impl.h:151
const Scalar * a
Definition: level2_cplx_impl.h:32
int * m
Definition: level2_cplx_impl.h:294
char char char int int * k
Definition: level2_impl.h:374
const char const int const RealScalar const RealScalar const int RealScalar const int * ldc
Definition: level2_real_impl.h:71
@ Target
Definition: Constants.h:495
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:625
EIGEN_DEVICE_FUNC Packet padd(const Packet &a, const Packet &b)
Definition: GenericPacketMath.h:318
EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f &)
Definition: AVX/PacketMath.h:774
EIGEN_STRONG_INLINE Packet padds(const Packet &a, const Packet &b)
EIGEN_STRONG_INLINE void ptranspose(PacketBlock< Packet2cf, 2 > &kernel)
Definition: AltiVec/Complex.h:339
EIGEN_STRONG_INLINE void pstorel(Scalar *to, const Packet &from)
EIGEN_STRONG_INLINE Packet8d pload1< Packet8d >(const double *from)
Definition: AVX512/PacketMath.h:326
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218
EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf &a, const Packet4cf &b)
Definition: AVX/Complex.h:88
EIGEN_DEVICE_FUNC Packet pgather(const Packet &src, const Scalar *from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
EIGEN_STRONG_INLINE Packet4d ploadu< Packet4d >(const double *from)
Definition: AVX/PacketMath.h:1511
EIGEN_STRONG_INLINE void pstores(Scalar *to, const Packet &from)
EIGEN_DEVICE_FUNC void pscatter(Scalar *to, const Packet &from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:911
EIGEN_STRONG_INLINE Packet4f ploadu< Packet4f >(const float *from)
Definition: AltiVec/PacketMath.h:1533
EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b, Scalar *c, Index ldc, Index inc=1, Index a_stride=-1, Index b_stride=-1, Index a_off=0, Index b_off=0)
Definition: GemmKernel.h:938
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_CONSTEXPR T div_ceil(T a, T b)
Definition: MathFunctions.h:1251
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
int c
Definition: calibrate.py:100
Definition: Eigen_Colamd.h:49
list bo
Definition: plotDoE.py:19
Definition: GenericPacketMath.h:1407
Definition: ConjHelper.h:42
Definition: products/GeneralBlockPanelKernel.h:960
EIGEN_DONT_INLINE void operator()(const DataMapper &res, const LhsScalar *blockA, const RhsScalar *blockB, Index rows, Index depth, Index cols, ResScalar alpha, Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0)
Definition: products/GeneralBlockPanelKernel.h:1425
packet_traits< Scalar >::type Packet
Definition: GemmKernel.h:971
DataMapper::LinearMapper LinearMapper
Definition: GemmKernel.h:972
packet_traits< Scalar >::type Packet
Definition: GemmKernel.h:1099
EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride=0, Index offset=0)
Definition: GemmKernel.h:1108
unpacket_traits< Packet >::half HalfPacket
Definition: GemmKernel.h:1100
unpacket_traits< typename unpacket_traits< Packet >::half >::half QuarterPacket
Definition: GemmKernel.h:1101
DataMapper::LinearMapper LinearMapper
Definition: GemmKernel.h:1102
Definition: BlasUtil.h:30
Definition: GenericPacketMath.h:108
Definition: GenericPacketMath.h:134
T half
Definition: GenericPacketMath.h:136
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
Definition: ZVector/PacketMath.h:50