Eigen::internal::gemm_class< Scalar, is_unit_inc > Class Template Reference

#include <GemmKernel.h>

Public Member Functions

template<int max_a_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void compute_kern ()
 
 gemm_class (Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_, const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_, Index a_off_, Index b_off_)
 

Private Types

using vec = typename packet_traits< Scalar >::type
 
using vec_ymm = typename unpacket_traits< vec >::half
 
using vec_xmm = typename unpacket_traits< vec_ymm >::half
 
using umask_t = typename unpacket_traits< vec >::mask_t
 

Private Member Functions

EIGEN_ALWAYS_INLINE void prefetch_a (const Scalar *a_addr)
 
EIGEN_ALWAYS_INLINE void prefetch_b (const Scalar *b_addr)
 
EIGEN_ALWAYS_INLINE void prefetch_x (const Scalar *x_addr)
 
EIGEN_ALWAYS_INLINE void prefetch_c (const Scalar *c_addr)
 
template<int nelems>
EIGEN_ALWAYS_INLINE void a_load (vec &a_reg, const Scalar *a_addr)
 
EIGEN_ALWAYS_INLINE void b_load (vec &b_reg, const Scalar *b_addr)
 
template<int nelems>
EIGEN_ALWAYS_INLINE void c_store (Scalar *mem, vec &src)
 
template<int nelems>
EIGEN_ALWAYS_INLINE void vaddm (vec &dst, const Scalar *mem, vec &src, vec &reg)
 
EIGEN_STRONG_INLINE void vfmadd (vec &dst, const vec &src1, const vec &src2)
 
template<int nelems>
EIGEN_ALWAYS_INLINE void vfmaddm (vec &dst, const Scalar *mem, vec &src, vec &scale, vec &reg)
 
template<int j, int endX, int i, int endY, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX)||(i > endY)> a_loads (const Scalar *ao)
 
template<int j, int endX, int i, int endY, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(j<=endX) &&(i<=endY)> a_loads (const Scalar *ao)
 
template<int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll)||(i > um_vecs)> prefetch_cs (const Scalar *co1, const Scalar *co2)
 
template<int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
EIGEN_ALWAYS_INLINE std::enable_if_t<(un<=max_b_unroll) &&(i<=um_vecs)> prefetch_cs (Scalar *&co1, Scalar *&co2)
 
template<int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c (const Scalar *cox, vec &alpha_reg)
 
template<int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i<=um_vecs)> scale_load_c (const Scalar *cox, vec &alpha_reg)
 
template<int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c (Scalar *cox)
 
template<int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i<=um_vecs)> write_c (Scalar *cox)
 
template<int pow, int a_unroll, int idx>
EIGEN_ALWAYS_INLINE void c_update_1count (Scalar *&cox)
 
template<int pow, int a_unroll>
EIGEN_ALWAYS_INLINE void c_update_1pow (Scalar *&co1, Scalar *&co2)
 
template<int max_b_unroll, int a_unroll, int b_unroll>
EIGEN_ALWAYS_INLINE void c_update (Scalar *&co1, Scalar *&co2)
 
template<int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute (const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg)
 
template<int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um<=um_vecs)> compute (const Scalar *ao, const Scalar *bo, int &fetchA_idx, int &fetchB_idx, vec &b_reg)
 
template<int um, int um_vecs, int uk, int nelems, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a (const Scalar *ao)
 
template<int um, int um_vecs, int uk, int nelems, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um<=um_vecs)> load_a (const Scalar *ao)
 
template<int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
EIGEN_ALWAYS_INLINE std::enable_if_t<(count >pow+1)/2)> innerkernel_1pow (const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
 
template<int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
EIGEN_ALWAYS_INLINE std::enable_if_t<(count<=(pow+1)/2)> innerkernel_1pow (const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
 
template<int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch, bool no_a_preload = false>
EIGEN_ALWAYS_INLINE void innerkernel_1uk (const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo, Scalar *&co2, int &fetchA_idx, int &fetchB_idx)
 
template<int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch, bool no_a_preload = false>
EIGEN_ALWAYS_INLINE void innerkernel (const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2)
 
template<int a_unroll, int b_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void kloop (const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
 
template<int a_unroll, int b_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void nloop (const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
 
template<int a_unroll, int max_a_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void mloop (const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2)
 

Private Attributes

vec zmm [32]
 
umask_t mask
 
Index m
 
const Index n
 
const Index k
 
const Index ldc
 
const Index inc
 
const Scalaralpha
 
const Scalara
 
const Scalarb
 
Scalarc
 
const bool is_alpha1
 
const bool is_beta0
 
const Index a_stride
 
const Index b_stride
 
const Index a_off
 
const Index b_off
 

Static Private Attributes

static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float)
 
static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double)
 
static constexpr bool use_less_a_regs = !is_unit_inc
 
static constexpr bool use_less_b_regs = !is_unit_inc
 
static constexpr int a_regs [] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5}
 
static constexpr int b_regs [] = {6, use_less_b_regs ? 6 : 7}
 
static constexpr int c_regs []
 
static constexpr int alpha_load_reg = 0
 
static constexpr int c_load_regs [] = {1, 2, 6}
 
static constexpr int a_shift = 128
 
static constexpr int b_shift = 128
 
static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8
 
static constexpr int a_prefetch_size = nelems_in_cache_line * 2
 
static constexpr int b_prefetch_size = nelems_in_cache_line * 8
 

Member Typedef Documentation

◆ umask_t

template<typename Scalar , bool is_unit_inc>
using Eigen::internal::gemm_class< Scalar, is_unit_inc >::umask_t = typename unpacket_traits<vec>::mask_t
private

◆ vec

template<typename Scalar , bool is_unit_inc>
using Eigen::internal::gemm_class< Scalar, is_unit_inc >::vec = typename packet_traits<Scalar>::type
private

◆ vec_xmm

template<typename Scalar , bool is_unit_inc>
using Eigen::internal::gemm_class< Scalar, is_unit_inc >::vec_xmm = typename unpacket_traits<vec_ymm>::half
private

◆ vec_ymm

template<typename Scalar , bool is_unit_inc>
using Eigen::internal::gemm_class< Scalar, is_unit_inc >::vec_ymm = typename unpacket_traits<vec>::half
private

Constructor & Destructor Documentation

◆ gemm_class()

template<typename Scalar , bool is_unit_inc>
Eigen::internal::gemm_class< Scalar, is_unit_inc >::gemm_class ( Index  m_,
Index  n_,
Index  k_,
Index  ldc_,
Index  inc_,
const Scalar alpha_,
const Scalar a_,
const Scalar b_,
Scalar c_,
bool  is_alpha1_,
bool  is_beta0_,
Index  a_stride_,
Index  b_stride_,
Index  a_off_,
Index  b_off_ 
)
inline
887  : m(m_),
888  n(n_),
889  k(k_),
890  ldc(ldc_),
891  inc(inc_),
892  alpha(alpha_),
893  a(a_),
894  b(b_),
895  c(c_),
896  is_alpha1(is_alpha1_),
897  is_beta0(is_beta0_),
898  a_stride(a_stride_),
899  b_stride(b_stride_),
900  a_off(a_off_),
901  b_off(b_off_) {
902  // Zero out all accumulation registers.
903  zmm[8] = pzero(zmm[8]);
904  zmm[9] = pzero(zmm[9]);
905  zmm[10] = pzero(zmm[10]);
906  zmm[11] = pzero(zmm[11]);
907  zmm[12] = pzero(zmm[12]);
908  zmm[13] = pzero(zmm[13]);
909  zmm[14] = pzero(zmm[14]);
910  zmm[15] = pzero(zmm[15]);
911  zmm[16] = pzero(zmm[16]);
912  zmm[17] = pzero(zmm[17]);
913  zmm[18] = pzero(zmm[18]);
914  zmm[19] = pzero(zmm[19]);
915  zmm[20] = pzero(zmm[20]);
916  zmm[21] = pzero(zmm[21]);
917  zmm[22] = pzero(zmm[22]);
918  zmm[23] = pzero(zmm[23]);
919  zmm[24] = pzero(zmm[24]);
920  zmm[25] = pzero(zmm[25]);
921  zmm[26] = pzero(zmm[26]);
922  zmm[27] = pzero(zmm[27]);
923  zmm[28] = pzero(zmm[28]);
924  zmm[29] = pzero(zmm[29]);
925  zmm[30] = pzero(zmm[30]);
926  zmm[31] = pzero(zmm[31]);
927  }
const Scalar * alpha
Definition: GemmKernel.h:82
const Index b_off
Definition: GemmKernel.h:91
const Index b_stride
Definition: GemmKernel.h:90
vec zmm[32]
Definition: GemmKernel.h:75
const Index n
Definition: GemmKernel.h:80
const Index a_off
Definition: GemmKernel.h:91
const Index inc
Definition: GemmKernel.h:81
const bool is_beta0
Definition: GemmKernel.h:88
const Index ldc
Definition: GemmKernel.h:80
const Scalar * a
Definition: GemmKernel.h:84
Scalar * c
Definition: GemmKernel.h:85
Index m
Definition: GemmKernel.h:79
const Index a_stride
Definition: GemmKernel.h:90
const Index k
Definition: GemmKernel.h:80
const bool is_alpha1
Definition: GemmKernel.h:87
const Scalar * b
Definition: GemmKernel.h:84
EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f &)
Definition: AVX/PacketMath.h:774

References Eigen::internal::pzero(), and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

Member Function Documentation

◆ a_load()

template<typename Scalar , bool is_unit_inc>
template<int nelems>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_load ( vec a_reg,
const Scalar a_addr 
)
inlineprivate
112  {
113  switch (nelems * sizeof(*a_addr) * 8) {
114  default:
115  case 512 * 3:
116  a_reg = ploadu<vec>(a_addr);
117  break;
118  case 512 * 2:
119  a_reg = ploadu<vec>(a_addr);
120  break;
121  case 512 * 1:
122  a_reg = ploadu<vec>(a_addr);
123  break;
124  case 256 * 1:
125  a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(reinterpret_cast<const double *>(a_addr))));
126  break;
127  case 128 * 1:
128  a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(reinterpret_cast<const float *>(a_addr))));
129  break;
130  case 64 * 1:
131  a_reg = preinterpret<vec>(pload1<Packet8d>(reinterpret_cast<const double *>(a_addr)));
132  break;
133  case 32 * 1:
134  a_reg = pload1<vec>(a_addr);
135  break;
136  }
137  }
EIGEN_STRONG_INLINE Packet8d pload1< Packet8d >(const double *from)
Definition: AVX512/PacketMath.h:326
EIGEN_STRONG_INLINE Packet4d ploadu< Packet4d >(const double *from)
Definition: AVX/PacketMath.h:1511
EIGEN_STRONG_INLINE Packet4f ploadu< Packet4f >(const float *from)
Definition: AltiVec/PacketMath.h:1533

References Eigen::internal::pload1< Packet8d >(), Eigen::internal::ploadu< Packet4d >(), and Eigen::internal::ploadu< Packet4f >().

◆ a_loads() [1/2]

template<typename Scalar , bool is_unit_inc>
template<int j, int endX, int i, int endY, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_loads ( const Scalar ao)
inlineprivate
351  {
353  }
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966

References EIGEN_UNUSED_VARIABLE.

◆ a_loads() [2/2]

template<typename Scalar , bool is_unit_inc>
template<int j, int endX, int i, int endY, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_loads ( const Scalar ao)
inlineprivate
356  {
357  if (j < endX) {
358  if (i < endY) {
359  auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
360  const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
361  a_load<nelems>(a_reg, a_addr);
362 
363  a_loads<j, endX, i + 1, endY, nelems>(ao);
364  } else {
365  a_loads<j + 1, endX, 0, endY, nelems>(ao);
366  }
367  }
368  }
int i
Definition: BiCGSTAB_step_by_step.cpp:9
SCALAR Scalar
Definition: bench_gemm.cpp:45
static constexpr int a_regs[]
Definition: GemmKernel.h:59
static constexpr int a_shift
Definition: GemmKernel.h:68
static constexpr int nelems_in_cache_line
Definition: GemmKernel.h:71
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_regs, Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_shift, i, j, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

◆ b_load()

template<typename Scalar , bool is_unit_inc>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_load ( vec b_reg,
const Scalar b_addr 
)
inlineprivate

◆ c_store()

template<typename Scalar , bool is_unit_inc>
template<int nelems>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_store ( Scalar mem,
vec src 
)
inlineprivate
142  {
143  if (is_unit_inc) {
144  switch (nelems * sizeof(*mem) * 8) {
145  default:
146  case 512 * 3:
147  pstoreu(mem, src);
148  break;
149  case 512 * 2:
150  pstoreu(mem, src);
151  break;
152  case 512 * 1:
153  pstoreu(mem, src);
154  break;
155  case 256 * 1:
156  pstoreu(mem, preinterpret<vec_ymm>(src));
157  break;
158  case 128 * 1:
159  pstoreu(mem, preinterpret<vec_xmm>(src));
160  break;
161  case 64 * 1:
162  pstorel(mem, preinterpret<vec_xmm>(src));
163  break;
164  case 32 * 1:
165  pstores(mem, preinterpret<vec_xmm>(src));
166  break;
167  }
168  } else {
169  switch (nelems * sizeof(*mem) * 8) {
170  default:
171  case 512 * 3:
172  pscatter(mem, src, inc);
173  break;
174  case 512 * 2:
175  pscatter(mem, src, inc);
176  break;
177  case 512 * 1:
178  pscatter(mem, src, inc);
179  break;
180  case 256 * 1:
181  pscatter(mem, src, inc, mask);
182  break;
183  case 128 * 1:
184  pscatter(mem, src, inc, mask);
185  break;
186  case 64 * 1:
187  pscatter(mem, src, inc, mask);
188  break;
189  case 32 * 1:
190  pscatter(mem, src, inc, mask);
191  break;
192  }
193  }
194  }
umask_t mask
Definition: GemmKernel.h:76
EIGEN_STRONG_INLINE void pstorel(Scalar *to, const Packet &from)
EIGEN_STRONG_INLINE void pstores(Scalar *to, const Packet &from)
EIGEN_DEVICE_FUNC void pscatter(Scalar *to, const Packet &from, Index stride, typename unpacket_traits< Packet >::mask_t umask)
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:911

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::inc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::mask, Eigen::internal::pscatter(), Eigen::internal::pstorel(), Eigen::internal::pstores(), and Eigen::internal::pstoreu().

◆ c_update()

template<typename Scalar , bool is_unit_inc>
template<int max_b_unroll, int a_unroll, int b_unroll>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_update ( Scalar *&  co1,
Scalar *&  co2 
)
inlineprivate
521  {
522  auto &alpha_reg = zmm[alpha_load_reg];
523 
524  co2 = co1 + ldc;
525  if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
526  if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask = static_cast<umask_t>((1ull << a_unroll) - 1);
527 
528  static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll");
529 
530  if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
531  if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
532  if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
533  if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
534 
535  if (b_unroll == 1)
536  co1 += ldc;
537  else
538  co1 = co2 + ldc;
539  }
static constexpr int alpha_load_reg
Definition: GemmKernel.h:65
typename unpacket_traits< vec >::mask_t umask_t
Definition: GemmKernel.h:43

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::alpha, Eigen::internal::gemm_class< Scalar, is_unit_inc >::alpha_load_reg, Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_alpha1, Eigen::internal::gemm_class< Scalar, is_unit_inc >::ldc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::mask, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

◆ c_update_1count()

template<typename Scalar , bool is_unit_inc>
template<int pow, int a_unroll, int idx>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_update_1count ( Scalar *&  cox)
inlineprivate
496  {
497  if (pow >= 4) cox += ldc;
498 
499  const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
500  auto &alpha_reg = zmm[alpha_load_reg];
501 
502  scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
503  write_c<0, um_vecs, idx, a_unroll>(cox);
504  }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16 &a, const bfloat16 &b)
Definition: BFloat16.h:625
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_CONSTEXPR T div_ceil(T a, T b)
Definition: MathFunctions.h:1251

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::alpha_load_reg, Eigen::numext::div_ceil(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::ldc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, Eigen::bfloat16_impl::pow(), and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

◆ c_update_1pow()

template<typename Scalar , bool is_unit_inc>
template<int pow, int a_unroll>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_update_1pow ( Scalar *&  co1,
Scalar *&  co2 
)
inlineprivate
507  {
508  constexpr int idx = pow / 2;
509  Scalar *&cox = idx == 0 ? co1 : co2;
510 
511  constexpr int max_count = (pow + 1) / 2;
512  static_assert(max_count <= 4, "Unsupported max_count.");
513 
514  if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
515  if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
516  if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
517  if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
518  }

References Eigen::bfloat16_impl::pow().

◆ compute() [1/2]

template<typename Scalar , bool is_unit_inc>
template<int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::compute ( const Scalar ao,
const Scalar bo,
int fetchA_idx,
int fetchB_idx,
vec b_reg 
)
inlineprivate
544  {
547  EIGEN_UNUSED_VARIABLE(fetchA_idx);
548  EIGEN_UNUSED_VARIABLE(fetchB_idx);
549  EIGEN_UNUSED_VARIABLE(b_reg);
550  }
list bo
Definition: plotDoE.py:19

References plotDoE::bo, and EIGEN_UNUSED_VARIABLE.

◆ compute() [2/2]

template<typename Scalar , bool is_unit_inc>
template<int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::compute ( const Scalar ao,
const Scalar bo,
int fetchA_idx,
int fetchB_idx,
vec b_reg 
)
inlineprivate
554  {
555  if (um < um_vecs) {
556  auto &c_reg = zmm[c_regs[um + idx * 3]];
557  auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
558 
559  vfmadd(c_reg, a_reg, b_reg);
560 
561  if (!fetch_x && um == 0 &&
562  (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) ||
563  (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
564  prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
565  fetchA_idx++;
566  }
567 
568  if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
569  prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
570  fetchB_idx++;
571  }
572 
573  compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
574  }
575  }
static constexpr bool is_f64
Definition: GemmKernel.h:46
EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2)
Definition: GemmKernel.h:264
static constexpr int c_regs[]
Definition: GemmKernel.h:61
EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr)
Definition: GemmKernel.h:93
EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr)
Definition: GemmKernel.h:97

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_regs, plotDoE::bo, Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_regs, Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_f64, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_a(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_b(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::vfmadd(), and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

◆ compute_kern()

template<typename Scalar , bool is_unit_inc>
template<int max_a_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::compute_kern ( )
inline
853  {
854  a -= -a_shift;
855  b -= -b_shift;
856 
857  const Scalar *ao = nullptr;
858  const Scalar *bo = nullptr;
859  Scalar *co1 = nullptr;
860  Scalar *co2 = nullptr;
861 
862  // Main m-loop.
863  for (; m >= max_a_unroll; m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
864 
865  // m-remainders.
866  if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
867  if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
868  if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
869  if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
870  if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
871  if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
872 
873  // Copy kernels don't support tails of m = 2 for single precision.
874  // Loop over ones.
875  if (is_f32) {
876  int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
877  while (m_rem > 0) {
878  mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
879  m_rem--;
880  }
881  }
882  }
static constexpr int b_shift
Definition: GemmKernel.h:69
static constexpr bool is_f32
Definition: GemmKernel.h:45

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::a, Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_shift, Eigen::internal::gemm_class< Scalar, is_unit_inc >::b, Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_shift, plotDoE::bo, Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_f32, Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_f64, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::m.

◆ innerkernel()

template<typename Scalar , bool is_unit_inc>
template<int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch, bool no_a_preload = false>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::innerkernel ( const Scalar *&  aa,
const Scalar *&  ao,
const Scalar *&  bo,
Scalar *&  co2 
)
inlineprivate
702  {
703  int fetchA_idx = 0;
704  int fetchB_idx = 0;
705 
706  const bool fetch_x = k_factor == max_k_factor;
707  const bool ktail = k_factor == 1;
708 
709  static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
710  static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1),
711  "skipping a preload only allowed when k unroll is 1");
712 
713  if (k_factor > 0)
714  innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
715  aa, ao, bo, co2, fetchA_idx, fetchB_idx);
716  if (k_factor > 1)
717  innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
718  aa, ao, bo, co2, fetchA_idx, fetchB_idx);
719  if (k_factor > 2)
720  innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
721  aa, ao, bo, co2, fetchA_idx, fetchB_idx);
722  if (k_factor > 3)
723  innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
724  aa, ao, bo, co2, fetchA_idx, fetchB_idx);
725 
726  // Advance A/B pointers after uk-loop.
727  ao += a_unroll * k_factor;
728  bo += b_unroll * k_factor;
729  }

References plotDoE::bo.

◆ innerkernel_1pow() [1/2]

template<typename Scalar , bool is_unit_inc>
template<int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
EIGEN_ALWAYS_INLINE std::enable_if_t<(count >pow + 1) / 2)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::innerkernel_1pow ( const Scalar *&  aa,
const Scalar *const &  ao,
const Scalar *const &  bo,
Scalar *&  co2,
int fetchA_idx,
int fetchB_idx 
)
inlineprivate
597  {
602  EIGEN_UNUSED_VARIABLE(fetchA_idx);
603  EIGEN_UNUSED_VARIABLE(fetchB_idx);
604  }

References plotDoE::bo, and EIGEN_UNUSED_VARIABLE.

◆ innerkernel_1pow() [2/2]

template<typename Scalar , bool is_unit_inc>
template<int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::innerkernel_1pow ( const Scalar *&  aa,
const Scalar *const &  ao,
const Scalar *const &  bo,
Scalar *&  co2,
int fetchA_idx,
int fetchB_idx 
)
inlineprivate
610  {
611  const int idx = (pow / 2) + count;
612 
613  if (count < (pow + 1) / 2) {
614  auto &b_reg = zmm[b_regs[idx % 2]];
615 
616  if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
617  if (fetch_x && uk == 3 && idx == 4) aa += 8;
618 
619  if (b_unroll >= pow) {
620  compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
621 
622  const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift;
623  b_load(b_reg, b_addr);
624  }
625 
626  // Go to the next count.
627  innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
628  fetchB_idx);
629 
630  } else {
631  // Maybe prefetch C data after count-loop.
632  if (pow == 2 && c_fetch) {
633  if (uk % 3 == 0 && uk > 0) {
634  co2 += ldc;
635  } else {
636  prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
637  }
638  }
639  }
640  }
static constexpr int b_regs[]
Definition: GemmKernel.h:60
EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr)
Definition: GemmKernel.h:101
static constexpr bool use_less_b_regs
Definition: GemmKernel.h:54
EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr)
Definition: GemmKernel.h:139
EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr)
Definition: GemmKernel.h:103

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_load(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_regs, Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_shift, plotDoE::bo, Eigen::internal::gemm_class< Scalar, is_unit_inc >::ldc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, Eigen::bfloat16_impl::pow(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_c(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_x(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::use_less_b_regs, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

◆ innerkernel_1uk()

template<typename Scalar , bool is_unit_inc>
template<int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch, bool no_a_preload = false>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::innerkernel_1uk ( const Scalar *&  aa,
const Scalar *const &  ao,
const Scalar *const &  bo,
Scalar *&  co2,
int fetchA_idx,
int fetchB_idx 
)
inlineprivate
645  {
646  const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
647 
648  if (max_b_unroll >= 1)
649  innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
650  if (max_b_unroll >= 2)
651  innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
652  if (max_b_unroll >= 4)
653  innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
654  if (max_b_unroll >= 8)
655  innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
656 
657  // Load A after pow-loop. Skip this at the end to prevent running over the buffer
658  if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
659  }

References plotDoE::bo, Eigen::numext::div_ceil(), and Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line.

◆ kloop()

template<typename Scalar , bool is_unit_inc>
template<int a_unroll, int b_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::kloop ( const Scalar *&  aa,
const Scalar *&  ao,
const Scalar *&  bo,
Scalar *&  co1,
Scalar *&  co2 
)
inlineprivate
732  {
733  const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
734  if (!use_less_a_regs && k > 1)
735  a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
736  else
737  a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
738 
739  b_load(zmm[b_regs[0]], bo - b_shift + 0);
740  if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1);
741 
742 #ifndef SECOND_FETCH
743  prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
744 #endif // SECOND_FETCH
745 
746  // Unrolling k-loop by a factor of 4.
747  const int max_k_factor = 4;
748  Index kRem = k % max_k_factor;
749  Index k_ = k - kRem;
750  if (k_ >= max_k_factor) {
751  k_ -= max_k_factor;
752  kRem += max_k_factor;
753  }
754  Index loop_count = k_ / max_k_factor;
755 
756  if (loop_count > 0) {
757 #ifdef SECOND_FETCH
758  loop_count -= SECOND_FETCH;
759 #endif
760  while (loop_count > 0) {
761  innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
762  loop_count--;
763  }
764 #ifdef SECOND_FETCH
765  co2 = co1 + nelems_in_cache_line - 1;
766 
767  loop_count += b_unroll;
768  while (loop_count > 0) {
769  innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
770  loop_count--;
771  }
772 
773  loop_count += SECOND_FETCH - b_unroll;
774  while (loop_count > 0) {
775  innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
776  loop_count--;
777  }
778 #endif
779  }
780 
781  // k-loop remainder handling.
782  loop_count = kRem;
783  while (loop_count > 1) {
784  innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
785  loop_count--;
786  }
787  if (loop_count > 0) {
788  innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
789  }
790 
791  // Update C matrix.
792  c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
793  }
#define SECOND_FETCH
Definition: GemmKernel.h:28
static constexpr bool use_less_a_regs
Definition: GemmKernel.h:49
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_load(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_regs, Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_shift, plotDoE::bo, Eigen::numext::div_ceil(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::k, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, SECOND_FETCH, Eigen::internal::gemm_class< Scalar, is_unit_inc >::use_less_a_regs, Eigen::internal::gemm_class< Scalar, is_unit_inc >::use_less_b_regs, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

◆ load_a() [1/2]

template<typename Scalar , bool is_unit_inc>
template<int um, int um_vecs, int uk, int nelems, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::load_a ( const Scalar ao)
inlineprivate
579  {
581  }

References EIGEN_UNUSED_VARIABLE.

◆ load_a() [2/2]

template<typename Scalar , bool is_unit_inc>
template<int um, int um_vecs, int uk, int nelems, bool ktail>
EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::load_a ( const Scalar ao)
inlineprivate
584  {
585  if (um < um_vecs) {
586  auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
587  const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift;
588  a_load<nelems>(a_reg, a_addr);
589 
590  load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
591  }
592  }

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_regs, Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_shift, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, Eigen::internal::gemm_class< Scalar, is_unit_inc >::use_less_a_regs, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

◆ mloop()

template<typename Scalar , bool is_unit_inc>
template<int a_unroll, int max_a_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::mloop ( const Scalar *&  ao,
const Scalar *&  bo,
Scalar *&  co1,
Scalar *&  co2 
)
inlineprivate
813  {
814  // Set prefetch A pointers.
815  const Scalar *aa = a + a_unroll * a_stride;
816 
817  // Set C matrix pointers.
818  co1 = c;
819  if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
820  if (is_unit_inc)
821  c += a_unroll;
822  else
823  c += a_unroll * inc;
824 
825  // Set B matrix pointer.
826  bo = b;
827 
828  // Main n-loop.
829  for (Index i = n / max_b_unroll; i > 0; i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
830 
831  // n-remainders.
832  if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
833 #if 0
834  if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
835  if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
836 #else
837  // Copy kernels don't support tails of n = 2 for single/double precision.
838  // Loop over ones.
839  int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
840  while (n_rem > 0) {
841  nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
842  n_rem--;
843  }
844 #endif
845 
846  // Advance A matrix pointer.
847  a = ao + a_unroll * (a_stride - k - a_off);
848  }

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::a, Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_off, Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_stride, Eigen::internal::gemm_class< Scalar, is_unit_inc >::b, plotDoE::bo, Eigen::internal::gemm_class< Scalar, is_unit_inc >::c, i, Eigen::internal::gemm_class< Scalar, is_unit_inc >::inc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::k, Eigen::internal::gemm_class< Scalar, is_unit_inc >::ldc, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::n.

◆ nloop()

template<typename Scalar , bool is_unit_inc>
template<int a_unroll, int b_unroll, int max_b_unroll>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::nloop ( const Scalar *&  aa,
const Scalar *&  ao,
const Scalar *&  bo,
Scalar *&  co1,
Scalar *&  co2 
)
inlineprivate
796  {
797  // Set A matrix pointer.
798  ao = a + a_off * a_unroll;
799 
800  // Set B matrix pointer if needed.
801  bo += b_unroll * b_off;
802 
803  kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
804 
805  // Advance B matrix pointer if needed.
806  bo += b_unroll * (b_stride - k - b_off);
807 
808  // Advance prefetch A pointer.
809  aa += 16;
810  }

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::a, Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_off, Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_off, Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_stride, plotDoE::bo, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::k.

◆ prefetch_a()

template<typename Scalar , bool is_unit_inc>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_a ( const Scalar a_addr)
inlineprivate
93  {
94  _mm_prefetch((char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
95  }
static constexpr int a_prefetch_size
Definition: GemmKernel.h:72

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_prefetch_size, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_shift.

Referenced by Eigen::internal::gemm_class< Scalar, is_unit_inc >::compute().

◆ prefetch_b()

template<typename Scalar , bool is_unit_inc>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_b ( const Scalar b_addr)
inlineprivate
97  {
98  _mm_prefetch((char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
99  }
static constexpr int b_prefetch_size
Definition: GemmKernel.h:73

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_prefetch_size, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_shift.

Referenced by Eigen::internal::gemm_class< Scalar, is_unit_inc >::compute().

◆ prefetch_c()

template<typename Scalar , bool is_unit_inc>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_c ( const Scalar c_addr)
inlineprivate
103  {
104 #if defined(__PRFCHW__) && __PRFCHW__ == 1
105  _m_prefetchw((void *)c_addr);
106 #else
107  _mm_prefetch((char *)c_addr, _MM_HINT_T0);
108 #endif
109  }

Referenced by Eigen::internal::gemm_class< Scalar, is_unit_inc >::innerkernel_1pow(), and Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_cs().

◆ prefetch_cs() [1/2]

template<typename Scalar , bool is_unit_inc>
template<int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_cs ( const Scalar co1,
const Scalar co2 
)
inlineprivate
372  {
375  }

References EIGEN_UNUSED_VARIABLE.

◆ prefetch_cs() [2/2]

template<typename Scalar , bool is_unit_inc>
template<int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_cs ( Scalar *&  co1,
Scalar *&  co2 
)
inlineprivate
392  {
393  if (un < max_b_unroll) {
394  if (b_unroll >= un + 1) {
395  if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
396 
397  if (i < um_vecs) {
398  Scalar *co = (un + 1 <= 4) ? co1 : co2;
399  auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
400  prefetch_c(co + co_off);
401 
402  prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
403  } else {
404  prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
405  }
406 
407  } else {
408  prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
409  }
410  }
411  }

References i, Eigen::internal::gemm_class< Scalar, is_unit_inc >::ldc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, and Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_c().

◆ prefetch_x()

template<typename Scalar , bool is_unit_inc>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::prefetch_x ( const Scalar x_addr)
inlineprivate

◆ scale_load_c() [1/2]

template<typename Scalar , bool is_unit_inc>
template<int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::scale_load_c ( const Scalar cox,
vec alpha_reg 
)
inlineprivate
415  {
417  EIGEN_UNUSED_VARIABLE(alpha_reg);
418  }

References EIGEN_UNUSED_VARIABLE.

◆ scale_load_c() [2/2]

template<typename Scalar , bool is_unit_inc>
template<int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::scale_load_c ( const Scalar cox,
vec alpha_reg 
)
inlineprivate
421  {
422  if (i < um_vecs) {
423  auto &c_reg = zmm[c_regs[i + idx * 3]];
424  auto &c_load_reg = zmm[c_load_regs[i % 3]];
425  auto c_mem = cox;
426  if (is_unit_inc)
427  c_mem += i * nelems_in_cache_line;
428  else
429  c_mem += i * nelems_in_cache_line * inc;
430 
431  if (!is_beta0 && is_alpha1)
432  vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
433  else if (!is_beta0 && !is_alpha1)
434  vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
435  else if (is_beta0 && !is_alpha1)
436  c_reg = pmul(alpha_reg, c_reg);
437 
438  scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
439  }
440  }
static constexpr int c_load_regs[]
Definition: GemmKernel.h:66
EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf &a, const Packet4cf &b)
Definition: AVX/Complex.h:88

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_load_regs, Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_regs, i, Eigen::internal::gemm_class< Scalar, is_unit_inc >::inc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_alpha1, Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_beta0, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, Eigen::internal::pmul(), and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

◆ vaddm()

template<typename Scalar , bool is_unit_inc>
template<int nelems>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::vaddm ( vec dst,
const Scalar mem,
vec src,
vec reg 
)
inlineprivate
197  {
198  if (is_unit_inc) {
199  switch (nelems * sizeof(*mem) * 8) {
200  default:
201  case 512 * 3:
202  dst = padd(src, ploadu<vec>(mem));
203  break;
204  case 512 * 2:
205  dst = padd(src, ploadu<vec>(mem));
206  break;
207  case 512 * 1:
208  dst = padd(src, ploadu<vec>(mem));
209  break;
210  case 256 * 1:
211  dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
212  break;
213  case 128 * 1:
214  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
215  break;
216  case 64 * 1:
217  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
218  break;
219  case 32 * 1:
220  dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
221  break;
222  }
223  } else {
224  // Zero out scratch register
225  reg = pzero(reg);
226 
227  switch (nelems * sizeof(*mem) * 8) {
228  default:
229  case 512 * 3:
230  reg = pgather<Scalar, vec>(mem, inc);
231  dst = padd(src, reg);
232  break;
233  case 512 * 2:
234  reg = pgather<Scalar, vec>(mem, inc);
235  dst = padd(src, reg);
236  break;
237  case 512 * 1:
238  reg = pgather<Scalar, vec>(mem, inc);
239  dst = padd(src, reg);
240  break;
241  case 256 * 1:
242  reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
243  dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
244  break;
245  case 128 * 1:
246  reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
247  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
248  break;
249  case 64 * 1:
250  if (is_f32) {
251  reg = pgather(reg, mem, inc, mask);
252  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
253  } else {
254  dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
255  }
256  break;
257  case 32 * 1:
258  dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
259  break;
260  }
261  }
262  }
EIGEN_DEVICE_FUNC Packet padd(const Packet &a, const Packet &b)
Definition: GenericPacketMath.h:318
EIGEN_STRONG_INLINE Packet padds(const Packet &a, const Packet &b)
EIGEN_DEVICE_FUNC Packet pgather(const Packet &src, const Scalar *from, Index stride, typename unpacket_traits< Packet >::mask_t umask)

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::inc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_f32, Eigen::internal::gemm_class< Scalar, is_unit_inc >::mask, Eigen::internal::padd(), Eigen::internal::padds(), Eigen::internal::pgather(), and Eigen::internal::pzero().

◆ vfmadd()

template<typename Scalar , bool is_unit_inc>
EIGEN_STRONG_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::vfmadd ( vec dst,
const vec src1,
const vec src2 
)
inlineprivate
264  {
265  dst = pmadd(src1, src2, dst);
266 
267 #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
268  // Workaround register spills for gcc and clang
269  __asm__("#" : [dst] "+v"(dst) : [src1] "%v"(src1), [src2] "v"(src2));
270 #endif
271  }
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218

References Eigen::internal::pmadd().

Referenced by Eigen::internal::gemm_class< Scalar, is_unit_inc >::compute().

◆ vfmaddm()

template<typename Scalar , bool is_unit_inc>
template<int nelems>
EIGEN_ALWAYS_INLINE void Eigen::internal::gemm_class< Scalar, is_unit_inc >::vfmaddm ( vec dst,
const Scalar mem,
vec src,
vec scale,
vec reg 
)
inlineprivate
274  {
275  if (is_unit_inc) {
276  switch (nelems * sizeof(*mem) * 8) {
277  default:
278  case 512 * 3:
279  dst = pmadd(scale, src, ploadu<vec>(mem));
280  break;
281  case 512 * 2:
282  dst = pmadd(scale, src, ploadu<vec>(mem));
283  break;
284  case 512 * 1:
285  dst = pmadd(scale, src, ploadu<vec>(mem));
286  break;
287  case 256 * 1:
288  dst =
289  preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
290  break;
291  case 128 * 1:
292  dst =
293  preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
294  break;
295  case 64 * 1:
296  dst =
297  preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
298  break;
299  case 32 * 1:
300  dst =
301  preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
302  break;
303  }
304  } else {
305  // Zero out scratch register
306  reg = pzero(reg);
307 
308  switch (nelems * sizeof(*mem) * 8) {
309  default:
310  case 512 * 3:
311  reg = pgather<Scalar, vec>(mem, inc);
312  dst = pmadd(scale, src, reg);
313  break;
314  case 512 * 2:
315  reg = pgather<Scalar, vec>(mem, inc);
316  dst = pmadd(scale, src, reg);
317  break;
318  case 512 * 1:
319  reg = pgather<Scalar, vec>(mem, inc);
320  dst = pmadd(scale, src, reg);
321  break;
322  case 256 * 1:
323  reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
324  dst = preinterpret<vec>(
325  pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
326  break;
327  case 128 * 1:
328  reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
329  dst = preinterpret<vec>(
330  pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
331  break;
332  case 64 * 1:
333  if (is_f32) {
334  reg = pgather(reg, mem, inc, mask);
335  dst = preinterpret<vec>(
336  pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
337  } else {
338  dst = preinterpret<vec>(
339  pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
340  }
341  break;
342  case 32 * 1:
343  dst =
344  preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
345  break;
346  }
347  }
348  }

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::inc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_f32, Eigen::internal::gemm_class< Scalar, is_unit_inc >::mask, Eigen::internal::pgather(), Eigen::internal::pmadd(), and Eigen::internal::pzero().

◆ write_c() [1/2]

template<typename Scalar , bool is_unit_inc>
template<int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::write_c ( Scalar cox)
inlineprivate
444  {
446  }

References EIGEN_UNUSED_VARIABLE.

◆ write_c() [2/2]

template<typename Scalar , bool is_unit_inc>
template<int i, int um_vecs, int idx, int nelems>
EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> Eigen::internal::gemm_class< Scalar, is_unit_inc >::write_c ( Scalar cox)
inlineprivate
449  {
450  if (i < um_vecs) {
451  auto &c_reg = zmm[c_regs[i + idx * 3]];
452  auto c_mem = cox;
453  if (is_unit_inc)
454  c_mem += i * nelems_in_cache_line;
455  else
456  c_mem += i * nelems_in_cache_line * inc;
457 
458  c_store<nelems>(c_mem, c_reg);
459  c_reg = pzero(c_reg);
460 
461  write_c<i + 1, um_vecs, idx, nelems>(cox);
462  }
463  }

References Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_regs, i, Eigen::internal::gemm_class< Scalar, is_unit_inc >::inc, Eigen::internal::gemm_class< Scalar, is_unit_inc >::nelems_in_cache_line, Eigen::internal::pzero(), and Eigen::internal::gemm_class< Scalar, is_unit_inc >::zmm.

Member Data Documentation

◆ a

◆ a_off

template<typename Scalar , bool is_unit_inc>
const Index Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_off
private

◆ a_prefetch_size

template<typename Scalar , bool is_unit_inc>
constexpr int Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_prefetch_size = nelems_in_cache_line * 2
staticconstexprprivate

◆ a_regs

◆ a_shift

◆ a_stride

template<typename Scalar , bool is_unit_inc>
const Index Eigen::internal::gemm_class< Scalar, is_unit_inc >::a_stride
private

◆ alpha

template<typename Scalar , bool is_unit_inc>
const Scalar* Eigen::internal::gemm_class< Scalar, is_unit_inc >::alpha
private

◆ alpha_load_reg

template<typename Scalar , bool is_unit_inc>
constexpr int Eigen::internal::gemm_class< Scalar, is_unit_inc >::alpha_load_reg = 0
staticconstexprprivate

◆ b

◆ b_off

template<typename Scalar , bool is_unit_inc>
const Index Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_off
private

◆ b_prefetch_size

template<typename Scalar , bool is_unit_inc>
constexpr int Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_prefetch_size = nelems_in_cache_line * 8
staticconstexprprivate

◆ b_regs

template<typename Scalar , bool is_unit_inc>
constexpr int Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_regs[] = {6, use_less_b_regs ? 6 : 7}
staticconstexprprivate

◆ b_shift

◆ b_stride

template<typename Scalar , bool is_unit_inc>
const Index Eigen::internal::gemm_class< Scalar, is_unit_inc >::b_stride
private

◆ c

template<typename Scalar , bool is_unit_inc>
Scalar* Eigen::internal::gemm_class< Scalar, is_unit_inc >::c
private

◆ c_load_regs

template<typename Scalar , bool is_unit_inc>
constexpr int Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_load_regs[] = {1, 2, 6}
staticconstexprprivate

◆ c_regs

template<typename Scalar , bool is_unit_inc>
constexpr int Eigen::internal::gemm_class< Scalar, is_unit_inc >::c_regs[]
staticconstexprprivate
Initial value:
= {
8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
}

Referenced by Eigen::internal::gemm_class< Scalar, is_unit_inc >::compute(), Eigen::internal::gemm_class< Scalar, is_unit_inc >::scale_load_c(), and Eigen::internal::gemm_class< Scalar, is_unit_inc >::write_c().

◆ inc

◆ is_alpha1

template<typename Scalar , bool is_unit_inc>
const bool Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_alpha1
private

◆ is_beta0

template<typename Scalar , bool is_unit_inc>
const bool Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_beta0
private

◆ is_f32

template<typename Scalar , bool is_unit_inc>
constexpr bool Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_f32 = sizeof(Scalar) == sizeof(float)
staticconstexprprivate

◆ is_f64

template<typename Scalar , bool is_unit_inc>
constexpr bool Eigen::internal::gemm_class< Scalar, is_unit_inc >::is_f64 = sizeof(Scalar) == sizeof(double)
staticconstexprprivate

◆ k

◆ ldc

◆ m

template<typename Scalar , bool is_unit_inc>
Index Eigen::internal::gemm_class< Scalar, is_unit_inc >::m
private

◆ mask

◆ n

template<typename Scalar , bool is_unit_inc>
const Index Eigen::internal::gemm_class< Scalar, is_unit_inc >::n
private

◆ nelems_in_cache_line

◆ use_less_a_regs

template<typename Scalar , bool is_unit_inc>
constexpr bool Eigen::internal::gemm_class< Scalar, is_unit_inc >::use_less_a_regs = !is_unit_inc
staticconstexprprivate

◆ use_less_b_regs

template<typename Scalar , bool is_unit_inc>
constexpr bool Eigen::internal::gemm_class< Scalar, is_unit_inc >::use_less_b_regs = !is_unit_inc
staticconstexprprivate

◆ zmm


The documentation for this class was generated from the following file: