TrsmKernel.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2022 Intel Corporation
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
11 #define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
12 
13 // IWYU pragma: private
14 #include "../../InternalHeaderCheck.h"
15 
16 #if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
17 #define EIGEN_USE_AVX512_TRSM_KERNELS 1
18 #endif
19 
20 // TRSM kernels currently unconditionally rely on malloc with AVX512.
21 // Disable them if malloc is explicitly disabled at compile-time.
22 #ifdef EIGEN_NO_MALLOC
23 #undef EIGEN_USE_AVX512_TRSM_KERNELS
24 #define EIGEN_USE_AVX512_TRSM_KERNELS 0
25 #endif
26 
27 #if EIGEN_USE_AVX512_TRSM_KERNELS
28 #if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
29 #define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
30 #endif
31 #if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
32 #define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
33 #endif
34 #else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
35 #define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
36 #define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
37 #endif
38 
39 // Need this for some std::min calls.
40 #ifdef min
41 #undef min
42 #endif
43 
44 namespace Eigen {
45 namespace internal {
46 
47 #define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
48 #define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
49 #define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
50 #define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
51 #define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
56 
57 // Compile-time unrolls are implemented here.
58 // Note: this depends on macros and typedefs above.
59 #include "TrsmUnrolls.inc"
60 
61 #if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
78 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
79 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
80 #endif
81 
82 #if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
83 
84 #if EIGEN_USE_AVX512_TRSM_R_KERNELS
85 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
86 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
87 #endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
88 #endif
89 
90 #if EIGEN_USE_AVX512_TRSM_L_KERNELS
91 #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
92 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
93 #endif
94 #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
95 
96 #else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
97 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
98 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
99 #endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
100 
101 template <typename Scalar>
102 int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
103  const int64_t U3 = 3 * packet_traits<Scalar>::size;
104  const int64_t MaxNb = 5 * U3;
105  int64_t Nb = std::min(MaxNb, N);
106  double cutoff_d =
107  (((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
108  int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
109  return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
110 }
111 #else // !(EIGEN_USE_AVX512_TRSM_KERNELS) || !(EIGEN_COMP_CLANG != 0)
112 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
113 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
114 #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
115 #endif
116 
120 template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
122  int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
123  EIGEN_UNUSED_VARIABLE(remN_);
124  EIGEN_UNUSED_VARIABLE(remM_);
125  using urolls = unrolls::trans<Scalar>;
126 
127  constexpr int64_t U3 = urolls::PacketSize * 3;
128  constexpr int64_t U2 = urolls::PacketSize * 2;
129  constexpr int64_t U1 = urolls::PacketSize * 1;
130 
131  static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize");
132  static_assert(unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
133 
134  urolls::template transpose<unrollN, 0>(zmm);
135  EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose<unrollN, 2>(zmm);
136  EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose<unrollN, 1>(zmm);
137 
138  static_assert((remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1");
139  EIGEN_IF_CONSTEXPR(!remN) {
140  urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
141  EIGEN_IF_CONSTEXPR(unrollN > U1) {
142  constexpr int64_t unrollN_ = std::min(unrollN - U1, U1);
143  urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
144  }
145  EIGEN_IF_CONSTEXPR(unrollN > U2) {
146  constexpr int64_t unrollN_ = std::min(unrollN - U2, U1);
147  urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
148  }
149  }
150  else {
152  // Note: without "if constexpr" this section of code will also be
153  // parsed by the compiler so each of the storeC will still be instantiated.
154  // We use enable_if in aux_storeC to set it to an empty function for
155  // these cases.
156  if (remN_ == 15)
157  urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
158  else if (remN_ == 14)
159  urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
160  else if (remN_ == 13)
161  urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
162  else if (remN_ == 12)
163  urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
164  else if (remN_ == 11)
165  urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
166  else if (remN_ == 10)
167  urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
168  else if (remN_ == 9)
169  urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
170  else if (remN_ == 8)
171  urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
172  else if (remN_ == 7)
173  urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
174  else if (remN_ == 6)
175  urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
176  else if (remN_ == 5)
177  urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
178  else if (remN_ == 4)
179  urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
180  else if (remN_ == 3)
181  urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
182  else if (remN_ == 2)
183  urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
184  else if (remN_ == 1)
185  urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
186  }
187  else {
188  if (remN_ == 7)
189  urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
190  else if (remN_ == 6)
191  urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
192  else if (remN_ == 5)
193  urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
194  else if (remN_ == 4)
195  urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
196  else if (remN_ == 3)
197  urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
198  else if (remN_ == 2)
199  urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
200  else if (remN_ == 1)
201  urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
202  }
203  }
204 }
205 
220 template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
221 void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
222  int64_t LDC) {
223  using urolls = unrolls::gemm<Scalar, isAdd>;
224  constexpr int64_t U3 = urolls::PacketSize * 3;
225  constexpr int64_t U2 = urolls::PacketSize * 2;
226  constexpr int64_t U1 = urolls::PacketSize * 1;
228  int64_t N_ = (N / U3) * U3;
231  int64_t j = 0;
232  for (; j < N_; j += U3) {
233  constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 3;
234  int64_t i = 0;
235  for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
236  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
238  urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
239  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
240  urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
241  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
242  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
243  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
244  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
245  }
246  EIGEN_IF_CONSTEXPR(handleKRem) {
247  for (int64_t k = K_; k < K; k++) {
248  urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
249  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
250  B_t += LDB;
251  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
252  else A_t += LDA;
253  }
254  }
255  EIGEN_IF_CONSTEXPR(isCRowMajor) {
256  urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
257  urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
258  }
259  else {
260  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[i + j * LDC], LDC);
261  }
262  }
263  if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
264  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
265  Scalar *B_t = &B_arr[0 * LDB + j];
267  urolls::template setzero<3, 4>(zmm);
268  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
269  urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
270  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
271  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
272  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
273  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
274  }
275  EIGEN_IF_CONSTEXPR(handleKRem) {
276  for (int64_t k = K_; k < K; k++) {
277  urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
278  B_t, A_t, LDB, LDA, zmm);
279  B_t += LDB;
280  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
281  else A_t += LDA;
282  }
283  }
284  EIGEN_IF_CONSTEXPR(isCRowMajor) {
285  urolls::template updateC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
286  urolls::template storeC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
287  }
288  else {
289  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
290  }
291  i += 4;
292  }
293  if (M - i >= 2) {
294  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
295  Scalar *B_t = &B_arr[0 * LDB + j];
297  urolls::template setzero<3, 2>(zmm);
298  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
299  urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
300  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
301  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
302  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
303  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
304  }
305  EIGEN_IF_CONSTEXPR(handleKRem) {
306  for (int64_t k = K_; k < K; k++) {
307  urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
308  B_t, A_t, LDB, LDA, zmm);
309  B_t += LDB;
310  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
311  else A_t += LDA;
312  }
313  }
314  EIGEN_IF_CONSTEXPR(isCRowMajor) {
315  urolls::template updateC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
316  urolls::template storeC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
317  }
318  else {
319  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
320  }
321  i += 2;
322  }
323  if (M - i > 0) {
324  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
325  Scalar *B_t = &B_arr[0 * LDB + j];
327  urolls::template setzero<3, 1>(zmm);
328  {
329  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
330  urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
331  B_t, A_t, LDB, LDA, zmm);
332  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
333  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
334  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
335  }
336  EIGEN_IF_CONSTEXPR(handleKRem) {
337  for (int64_t k = K_; k < K; k++) {
338  urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
339  B_t += LDB;
340  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
341  else A_t += LDA;
342  }
343  }
344  EIGEN_IF_CONSTEXPR(isCRowMajor) {
345  urolls::template updateC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
346  urolls::template storeC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
347  }
348  else {
349  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
350  }
351  }
352  }
353  }
354  if (N - j >= U2) {
355  constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 2;
356  int64_t i = 0;
357  for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
358  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
359  EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0 * LDB + j];
361  urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
362  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
363  urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
364  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
365  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
366  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
367  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
368  }
369  EIGEN_IF_CONSTEXPR(handleKRem) {
370  for (int64_t k = K_; k < K; k++) {
371  urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
372  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
373  B_t += LDB;
374  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
375  else A_t += LDA;
376  }
377  }
378  EIGEN_IF_CONSTEXPR(isCRowMajor) {
379  urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
380  urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
381  }
382  else {
383  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[i + j * LDC], LDC);
384  }
385  }
386  if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
387  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
388  Scalar *B_t = &B_arr[0 * LDB + j];
390  urolls::template setzero<2, 4>(zmm);
391  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
392  urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
393  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
394  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
395  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
396  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
397  }
398  EIGEN_IF_CONSTEXPR(handleKRem) {
399  for (int64_t k = K_; k < K; k++) {
400  urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
401  LDA, zmm);
402  B_t += LDB;
403  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
404  else A_t += LDA;
405  }
406  }
407  EIGEN_IF_CONSTEXPR(isCRowMajor) {
408  urolls::template updateC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
409  urolls::template storeC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
410  }
411  else {
412  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
413  }
414  i += 4;
415  }
416  if (M - i >= 2) {
417  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
418  Scalar *B_t = &B_arr[0 * LDB + j];
420  urolls::template setzero<2, 2>(zmm);
421  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
422  urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
423  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
424  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
425  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
426  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
427  }
428  EIGEN_IF_CONSTEXPR(handleKRem) {
429  for (int64_t k = K_; k < K; k++) {
430  urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
431  LDA, zmm);
432  B_t += LDB;
433  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
434  else A_t += LDA;
435  }
436  }
437  EIGEN_IF_CONSTEXPR(isCRowMajor) {
438  urolls::template updateC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
439  urolls::template storeC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
440  }
441  else {
442  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
443  }
444  i += 2;
445  }
446  if (M - i > 0) {
447  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
448  Scalar *B_t = &B_arr[0 * LDB + j];
450  urolls::template setzero<2, 1>(zmm);
451  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
452  urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
453  LDA, zmm);
454  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
455  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
456  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
457  }
458  EIGEN_IF_CONSTEXPR(handleKRem) {
459  for (int64_t k = K_; k < K; k++) {
460  urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
461  B_t += LDB;
462  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
463  else A_t += LDA;
464  }
465  }
466  EIGEN_IF_CONSTEXPR(isCRowMajor) {
467  urolls::template updateC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
468  urolls::template storeC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
469  }
470  else {
471  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
472  }
473  }
474  j += U2;
475  }
476  if (N - j >= U1) {
477  constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
478  int64_t i = 0;
479  for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
480  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
482  urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
483  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
484  urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
485  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
486  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
487  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
488  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
489  }
490  EIGEN_IF_CONSTEXPR(handleKRem) {
491  for (int64_t k = K_; k < K; k++) {
492  urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
493  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
494  B_t += LDB;
495  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
496  else A_t += LDA;
497  }
498  }
499  EIGEN_IF_CONSTEXPR(isCRowMajor) {
500  urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
501  urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
502  }
503  else {
504  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[i + j * LDC], LDC);
505  }
506  }
507  if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
508  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
509  Scalar *B_t = &B_arr[0 * LDB + j];
511  urolls::template setzero<1, 4>(zmm);
512  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
513  urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
514  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
515  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
516  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
517  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
518  }
519  EIGEN_IF_CONSTEXPR(handleKRem) {
520  for (int64_t k = K_; k < K; k++) {
521  urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
522  LDA, zmm);
523  B_t += LDB;
524  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
525  else A_t += LDA;
526  }
527  }
528  EIGEN_IF_CONSTEXPR(isCRowMajor) {
529  urolls::template updateC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
530  urolls::template storeC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
531  }
532  else {
533  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
534  }
535  i += 4;
536  }
537  if (M - i >= 2) {
538  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
539  Scalar *B_t = &B_arr[0 * LDB + j];
541  urolls::template setzero<1, 2>(zmm);
542  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
543  urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
544  EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
545  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
546  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
547  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
548  }
549  EIGEN_IF_CONSTEXPR(handleKRem) {
550  for (int64_t k = K_; k < K; k++) {
551  urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
552  LDA, zmm);
553  B_t += LDB;
554  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
555  else A_t += LDA;
556  }
557  }
558  EIGEN_IF_CONSTEXPR(isCRowMajor) {
559  urolls::template updateC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
560  urolls::template storeC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
561  }
562  else {
563  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
564  }
565  i += 2;
566  }
567  if (M - i > 0) {
568  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
569  Scalar *B_t = &B_arr[0 * LDB + j];
571  urolls::template setzero<1, 1>(zmm);
572  {
573  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
574  urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
575  LDA, zmm);
576  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
577  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
578  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
579  }
580  EIGEN_IF_CONSTEXPR(handleKRem) {
581  for (int64_t k = K_; k < K; k++) {
582  urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
583  B_t += LDB;
584  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
585  else A_t += LDA;
586  }
587  }
588  EIGEN_IF_CONSTEXPR(isCRowMajor) {
589  urolls::template updateC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
590  urolls::template storeC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
591  }
592  else {
593  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
594  }
595  }
596  }
597  j += U1;
598  }
599  if (N - j > 0) {
600  constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
601  int64_t i = 0;
602  for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
603  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
604  Scalar *B_t = &B_arr[0 * LDB + j];
606  urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
607  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
608  urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
609  EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
610  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
611  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
612  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
613  }
614  EIGEN_IF_CONSTEXPR(handleKRem) {
615  for (int64_t k = K_; k < K; k++) {
616  urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
617  EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
618  B_t += LDB;
619  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
620  else A_t += LDA;
621  }
622  }
623  EIGEN_IF_CONSTEXPR(isCRowMajor) {
624  urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
625  urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
626  }
627  else {
628  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[i + j * LDC], LDC, 0, N - j);
629  }
630  }
631  if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
632  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
633  Scalar *B_t = &B_arr[0 * LDB + j];
635  urolls::template setzero<1, 4>(zmm);
636  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
637  urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
638  EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
639  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
640  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
641  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
642  }
643  EIGEN_IF_CONSTEXPR(handleKRem) {
644  for (int64_t k = K_; k < K; k++) {
645  urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
646  B_t, A_t, LDB, LDA, zmm, N - j);
647  B_t += LDB;
648  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
649  else A_t += LDA;
650  }
651  }
652  EIGEN_IF_CONSTEXPR(isCRowMajor) {
653  urolls::template updateC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
654  urolls::template storeC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
655  }
656  else {
657  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 4, N - j);
658  }
659  i += 4;
660  }
661  if (M - i >= 2) {
662  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
663  Scalar *B_t = &B_arr[0 * LDB + j];
665  urolls::template setzero<1, 2>(zmm);
666  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
667  urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
668  EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
669  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
670  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
671  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
672  }
673  EIGEN_IF_CONSTEXPR(handleKRem) {
674  for (int64_t k = K_; k < K; k++) {
675  urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
676  B_t, A_t, LDB, LDA, zmm, N - j);
677  B_t += LDB;
678  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
679  else A_t += LDA;
680  }
681  }
682  EIGEN_IF_CONSTEXPR(isCRowMajor) {
683  urolls::template updateC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
684  urolls::template storeC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
685  }
686  else {
687  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 2, N - j);
688  }
689  i += 2;
690  }
691  if (M - i > 0) {
692  Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
693  Scalar *B_t = &B_arr[0 * LDB + j];
695  urolls::template setzero<1, 1>(zmm);
696  for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
697  urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
698  B_t, A_t, LDB, LDA, zmm, N - j);
699  B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
700  EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
701  else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
702  }
703  EIGEN_IF_CONSTEXPR(handleKRem) {
704  for (int64_t k = K_; k < K; k++) {
705  urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
706  N - j);
707  B_t += LDB;
708  EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
709  else A_t += LDA;
710  }
711  }
712  EIGEN_IF_CONSTEXPR(isCRowMajor) {
713  urolls::template updateC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
714  urolls::template storeC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
715  }
716  else {
717  transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 1, N - j);
718  }
719  }
720  }
721 }
722 
731 template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
733  static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
734  using urolls = unrolls::trsm<Scalar>;
735  constexpr int64_t U3 = urolls::PacketSize * 3;
736  constexpr int64_t U2 = urolls::PacketSize * 2;
737  constexpr int64_t U1 = urolls::PacketSize * 1;
738 
741 
742  int64_t k = 0;
743  while (K - k >= U3) {
744  urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
745  urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
746  AInPacket);
747  urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
748  k += U3;
749  }
750  if (K - k >= U2) {
751  urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
752  urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
753  AInPacket);
754  urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
755  k += U2;
756  }
757  if (K - k >= U1) {
758  urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
759  urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
760  AInPacket);
761  urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
762  k += U1;
763  }
764  if (K - k > 0) {
765  // Handle remaining number of RHS
766  urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
767  urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
768  AInPacket);
769  urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
770  }
771 }
772 
781 template <typename Scalar, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
783  // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
784  // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
786  if (M == 8)
787  triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
788  else if (M == 7)
789  triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
790  else if (M == 6)
791  triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
792  else if (M == 5)
793  triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
794  else if (M == 4)
795  triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
796  else if (M == 3)
797  triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
798  else if (M == 2)
799  triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
800  else if (M == 1)
801  triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
802  return;
803 }
804 
812 template <typename Scalar, bool toTemp = true, bool remM = false>
814  int64_t remM_ = 0) {
815  EIGEN_UNUSED_VARIABLE(remM_);
816  using urolls = unrolls::transB<Scalar>;
819  constexpr int64_t U3 = urolls::PacketSize * 3;
820  constexpr int64_t U2 = urolls::PacketSize * 2;
821  constexpr int64_t U1 = urolls::PacketSize * 1;
822  int64_t K_ = K / U3 * U3;
823  int64_t k = 0;
824 
825  for (; k < K_; k += U3) {
826  urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
827  B_temp += U3;
828  }
829  if (K - k >= U2) {
830  urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
831  B_temp += U2;
832  k += U2;
833  }
834  if (K - k >= U1) {
835  urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
836  B_temp += U1;
837  k += U1;
838  }
839  EIGEN_IF_CONSTEXPR(U1 > 8) {
840  // Note: without "if constexpr" this section of code will also be
841  // parsed by the compiler so there is an additional check in {load/store}BBlock
842  // to make sure the counter is not non-negative.
843  if (K - k >= 8) {
844  urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
845  B_temp += 8;
846  k += 8;
847  }
848  }
849  EIGEN_IF_CONSTEXPR(U1 > 4) {
850  // Note: without "if constexpr" this section of code will also be
851  // parsed by the compiler so there is an additional check in {load/store}BBlock
852  // to make sure the counter is not non-negative.
853  if (K - k >= 4) {
854  urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
855  B_temp += 4;
856  k += 4;
857  }
858  }
859  if (K - k >= 2) {
860  urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
861  B_temp += 2;
862  k += 2;
863  }
864  if (K - k >= 1) {
865  urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
866  B_temp += 1;
867  k += 1;
868  }
869 }
870 
898 template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
899  bool isUnitDiag = false>
900 void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
901  constexpr int64_t psize = packet_traits<Scalar>::size;
915  constexpr int64_t kB = (3 * psize) * 5; // 5*U3
916  constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
917 
918  int64_t sizeBTemp = 0;
919  Scalar *B_temp = NULL;
920  EIGEN_IF_CONSTEXPR(!isBRowMajor) {
926  sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
927  }
928 
929  EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
930 
931  for (int64_t k = 0; k < numRHS; k += kB) {
932  int64_t bK = numRHS - k > kB ? kB : numRHS - k;
933  int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
934 
935  // bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
936  // instead of bK RHS in triSolveKernelLxK.
938  const int64_t numScalarPerCache = 64 / sizeof(Scalar);
939  // Leading dimension of B_temp, will be a multiple of the cache line size.
940  int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
941  int64_t offsetBTemp = 0;
942  for (int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
943  EIGEN_IF_CONSTEXPR(!isBRowMajor) {
944  int64_t indA_i = isFWDSolve ? i : M - 1 - i;
945  int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW);
946  int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW * LDT - offsetBTemp;
947  int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
948  // Copy values from B to B_temp.
949  copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
950  // Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
951  triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
952  &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
953  // Copy values from B_temp back to B. B_temp will be reused in gemm call below.
954  copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
955 
956  offsetBTemp += EIGEN_AVX_MAX_NUM_ROW * LDT;
957  }
958  else {
959  int64_t ind = isFWDSolve ? i : M - 1 - i;
960  triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
961  &A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind * LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB);
962  }
963  if (i + EIGEN_AVX_MAX_NUM_ROW < M_) {
976  EIGEN_IF_CONSTEXPR(isBRowMajor) {
977  int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
978  int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
979  int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
980  int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
981  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
982  &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
984  }
985  else {
986  if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
995  int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
996  int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
997  int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
998  int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
999  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1000  &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1001  M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1002  offsetBTemp = 0;
1003  gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
1004  } else {
1008  int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1009  int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
1010  int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1011  int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1012  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1013  &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1014  EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1015  }
1016  }
1017  }
1018  }
1019  // Handle M remainder..
1020  int64_t bM = M - M_;
1021  if (bM > 0) {
1022  if (M_ > 0) {
1023  EIGEN_IF_CONSTEXPR(isBRowMajor) {
1024  int64_t indA_i = isFWDSolve ? M_ : 0;
1025  int64_t indA_j = isFWDSolve ? 0 : bM;
1026  int64_t indB_i = isFWDSolve ? 0 : bM;
1027  int64_t indB_i2 = isFWDSolve ? M_ : 0;
1028  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1029  &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
1030  bK, M_, LDA, LDB, LDB);
1031  }
1032  else {
1033  int64_t indA_i = isFWDSolve ? M_ : 0;
1034  int64_t indA_j = isFWDSolve ? gemmOff : bM;
1035  int64_t indB_i = isFWDSolve ? M_ : 0;
1036  int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1037  gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
1038  B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
1039  M_ - gemmOff, LDA, LDT, LDB);
1040  }
1041  }
1042  EIGEN_IF_CONSTEXPR(!isBRowMajor) {
1043  int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_;
1044  int64_t indB_i = isFWDSolve ? M_ : 0;
1045  int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
1046  copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1047  triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
1048  B_temp + offB_1, bM, bkL, LDA, bkL);
1049  copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1050  }
1051  else {
1052  int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
1053  triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
1054  B_arr + k + ind * LDB, bM, bK, LDA, LDB);
1055  }
1056  }
1057  }
1058 
1059  EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
1060 }
1061 
1062 // Template specializations of trsmKernelL/R for float/double and inner strides of 1.
1063 #if (EIGEN_USE_AVX512_TRSM_KERNELS)
1064 #if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1065 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
1066  bool Specialized>
1067 struct trsmKernelR;
1068 
1069 template <typename Index, int Mode, int TriStorageOrder>
1070 struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true> {
1071  static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1072  Index otherStride);
1073 };
1074 
1075 template <typename Index, int Mode, int TriStorageOrder>
1076 struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true> {
1077  static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1078  Index otherStride);
1079 };
1080 
1081 template <typename Index, int Mode, int TriStorageOrder>
1083  Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1084  Index otherStride) {
1085  EIGEN_UNUSED_VARIABLE(otherIncr);
1086 #ifdef EIGEN_RUNTIME_NO_MALLOC
1087  if (!is_malloc_allowed()) {
1088  trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1089  size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1090  return;
1091  }
1092 #endif
1093  triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1094  const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1095 }
1096 
1097 template <typename Index, int Mode, int TriStorageOrder>
1099  Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1100  Index otherStride) {
1101  EIGEN_UNUSED_VARIABLE(otherIncr);
1102 #ifdef EIGEN_RUNTIME_NO_MALLOC
1103  if (!is_malloc_allowed()) {
1104  trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1105  size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1106  return;
1107  }
1108 #endif
1109  triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1110  const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1111 }
1112 #endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1113 
1114 // These trsm kernels require temporary memory allocation
1115 #if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
1116 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
1117  bool Specialized = true>
1118 struct trsmKernelL;
1119 
1120 template <typename Index, int Mode, int TriStorageOrder>
1121 struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true> {
1122  static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1123  Index otherStride);
1124 };
1125 
1126 template <typename Index, int Mode, int TriStorageOrder>
1127 struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true> {
1128  static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1129  Index otherStride);
1130 };
1131 
1132 template <typename Index, int Mode, int TriStorageOrder>
1134  Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1135  Index otherStride) {
1136  EIGEN_UNUSED_VARIABLE(otherIncr);
1137 #ifdef EIGEN_RUNTIME_NO_MALLOC
1138  if (!is_malloc_allowed()) {
1139  trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1140  size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1141  return;
1142  }
1143 #endif
1144  triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1145  const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1146 }
1147 
1148 template <typename Index, int Mode, int TriStorageOrder>
1150  Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1151  Index otherStride) {
1152  EIGEN_UNUSED_VARIABLE(otherIncr);
1153 #ifdef EIGEN_RUNTIME_NO_MALLOC
1154  if (!is_malloc_allowed()) {
1155  trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1156  size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1157  return;
1158  }
1159 #endif
1160  triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1161  const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1162 }
1163 #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
1164 #endif // EIGEN_USE_AVX512_TRSM_KERNELS
1165 } // namespace internal
1166 } // namespace Eigen
1167 #endif // EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_DONT_INLINE
Definition: Macros.h:853
#define EIGEN_IF_CONSTEXPR(X)
Definition: Macros.h:1306
std::vector< int > ind
Definition: Slicing_stdvector_cxx11.cpp:1
#define EIGEN_AVX_MAX_K_UNROL
Definition: TrsmKernel.h:49
#define EIGEN_AVX_MAX_NUM_ROW
Definition: TrsmKernel.h:48
#define EIGEN_AVX_B_LOAD_SETS
Definition: TrsmKernel.h:50
#define EIGEN_AVX_MAX_A_BCAST
Definition: TrsmKernel.h:51
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
Definition: ForwardDeclarations.h:102
The matrix class, also used for vectors and row-vectors.
Definition: Eigen/Eigen/src/Core/Matrix.h:186
@ N
Definition: constructor.cpp:22
#define min(a, b)
Definition: datatypes.h:22
@ Specialized
Definition: Constants.h:311
char char char int int * k
Definition: level2_impl.h:374
Packet8f vecHalfFloat
Definition: TrsmKernel.h:54
Packet8d vecFullDouble
Definition: TrsmKernel.h:53
EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:732
__m512d Packet8d
Definition: AVX512/PacketMath.h:36
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB, int64_t LDC)
Definition: TrsmKernel.h:221
EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock< vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS > &zmm, Scalar *C_arr, int64_t LDC, int64_t remM_=0, int64_t remN_=0)
Definition: TrsmKernel.h:121
EIGEN_DEVICE_FUNC void handmade_aligned_free(void *ptr)
Definition: Memory.h:158
EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_, int64_t remM_=0)
Definition: TrsmKernel.h:813
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:900
Packet4d vecHalfDouble
Definition: TrsmKernel.h:55
EIGEN_DEVICE_FUNC void * handmade_aligned_malloc(std::size_t size, std::size_t alignment=EIGEN_DEFAULT_ALIGN_BYTES)
Definition: Memory.h:142
Packet16f vecFullFloat
Definition: TrsmKernel.h:52
__m256 Packet8f
Definition: AVX/PacketMath.h:34
void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB)
Definition: TrsmKernel.h:782
__m256d Packet4d
Definition: AVX/PacketMath.h:36
__m512 Packet16f
Definition: AVX512/PacketMath.h:34
std::int64_t int64_t
Definition: Meta.h:43
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
double K
Wave number.
Definition: sphere_scattering.cc:115
type
Definition: compute_granudrum_aor.py:141
Definition: Eigen_Colamd.h:49
@ LDB
Definition: octree.h:49
Definition: GenericPacketMath.h:1407
Definition: GenericPacketMath.h:108
Definition: TriangularSolverMatrix.h:23
static void kernel(Index size, Index otherSize, const Scalar *_tri, Index triStride, Scalar *_other, Index otherIncr, Index otherStride)
Definition: TriangularSolverMatrix.h:42
Definition: TriangularSolverMatrix.h:32
static void kernel(Index size, Index otherSize, const Scalar *_tri, Index triStride, Scalar *_other, Index otherIncr, Index otherStride)
Definition: TriangularSolverMatrix.h:84
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2