1 #ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
2 #define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
5 #define BFLOAT16_UNROLL _Pragma("unroll 8")
7 #define BFLOAT16_UNROLL _Pragma("GCC unroll(8)")
27 return loadBfloat16<zero>(blockB + strideB *
i);
30 template <
Index num_acc,
Index num_packets,
bool zero,
bool rhsExtraCols,
bool lhsExtraRows,
Index num_rhs,
37 for (
Index i = 0;
i < (num_rhs - (rhsExtraCols ? 1 : 0));
i++) {
38 rhs[
i] = loadRhsBfloat16<zero>(indexB +
k * 4, strideB,
i);
41 rhs[num_rhs - 1] = loadRhsBfloat16<zero>(indexB +
k * extra_cols - offsetB, strideB, num_rhs - 1);
44 indexA +=
k * (lhsExtraRows ? extra_rows : num_packets);
46 lhs[0] = loadBfloat16<zero>(indexA);
49 for (
Index j = 0;
j < num_lhs;
j += 2) {
63 for (
Index i = 0,
x = 0;
i < num_rhs;
i++) {
65 for (
Index j = 0;
j < num_lhs;
j++,
x++) {
66 __builtin_mma_xvbf16ger2pp(&(quad_acc[
x]),
reinterpret_cast<Packet16uc>(rhs[
i].m_val),
72 template <Index num_acc>
75 for (
Index k = 0;
k < num_acc;
k++) __builtin_mma_xxsetaccz(&(quad_acc[
k]));
78 template <Index num_acc>
81 for (
Index k = 0;
k < num_acc;
k++) __builtin_mma_disassemble_acc((
void*)acc[
k], &(quad_acc[
k]));
84 template <Index num_acc,
bool rhsExtraCols,
bool lhsExtraRows, Index num_rhs, Index num_lhs>
88 for (
Index i = 0,
k = 0;
i < num_rhs - (rhsExtraCols ? 1 : 0);
i++, result += 4 *
rows) {
90 for (
Index j = 0;
j < num_lhs;
j++,
k++) {
91 storeResults<false, lhsExtraRows>(acc[
k],
rows, pAlpha, result +
j * 4, extra_cols, extra_rows);
95 storeResults<rhsExtraCols, lhsExtraRows>(acc[num_acc - 1],
rows, pAlpha, result, extra_cols, extra_rows);
99 template <const Index num_acc, const Index num_packets,
bool rhsExtraCols,
bool lhsExtraRows,
bool multiIter = false>
102 const Index extra_cols,
const Index extra_rows) {
103 constexpr
Index num_lhs = multiIter ? (num_packets / 4) : 1;
104 constexpr
Index num_rhs = (num_acc + num_lhs - 1) / num_lhs;
106 for (
Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8),
107 indexB += (multiIter ? (num_rhs * strideB) : 0), result += (multiIter ? (4 *
rows * num_rhs) : 4)) {
109 __vector_quad quad_acc[num_acc];
111 zeroAccumulators<num_acc>(quad_acc);
114 for (
k = 0;
k + 2 <= depth;
k += 2) {
115 KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(
116 indexA, indexB, quad_acc, strideB,
k, offsetB, extra_cols, extra_rows);
119 KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(
120 indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB,
k, offsetB, extra_cols, extra_rows);
123 disassembleAccumulators<num_acc>(quad_acc, acc);
125 outputResults<num_acc, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(acc,
rows, pAlpha, result, extra_cols,
130 #define MAX_BFLOAT16_ACC 8
132 template <const Index num_acc, const Index num_packets,
bool rhsExtraCols,
bool lhsExtraRows>
135 constexpr
Index step = (num_acc * 4);
136 const Index extra_cols = (rhsExtraCols) ? (
cols & 3) : 0;
137 const Index extra_rows = (lhsExtraRows) ? (
rows & 3) : 0;
139 constexpr
bool normIters = multiIters && ((num_acc % (num_packets / 4)) == 0);
142 colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, normIters>(
143 depth,
rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
145 indexB += strideB * num_acc;
146 result +=
rows * step;
147 }
while (multiIters && (step <=
cols - (
col += step)));
150 template <const Index num_acc, const Index num_packets,
bool rhsExtraCols,
bool lhsExtraRows>
155 colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(
156 col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB, result);
160 template <const Index num_packets,
bool rhsExtraCols,
bool lhsExtraRows>
165 colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB,
166 strideB, offsetB, result);
169 colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB,
170 strideB, offsetB, result);
173 colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB,
174 strideB, offsetB, result);
177 colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB,
178 strideB, offsetB, result);
181 colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB,
182 strideB, offsetB, result);
185 colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB,
186 strideB, offsetB, result);
189 colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB,
190 strideB, offsetB, result);
194 colLoopBody<1, num_packets, true, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB,
201 template <const Index num_packets,
bool lhsExtraRows = false>
206 colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB,
208 blockB += (strideB >> 2) *
col;
212 colLoopBodyExtra<num_packets, true, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, offsetB,
215 colLoopBodyExtra<num_packets, false, lhsExtraRows>(
col, depth,
cols,
rows, pAlpha, indexA, blockB, strideB, 0,
222 __vector_pair fp16_vp = *
reinterpret_cast<__vector_pair*
>(
const_cast<float*
>(
res));
223 __builtin_vsx_disassemble_pair(
reinterpret_cast<void*
>(fp16), &fp16_vp);
224 fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
225 fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
226 return vec_pack(
reinterpret_cast<Packet4ui>(fp16[0]),
reinterpret_cast<Packet4ui>(fp16[1]));
229 template <
typename DataMapper, const Index size>
231 const DataMapper res2 =
res.getSubMapper(0,
col);
233 float* result2 = result +
col *
rows;
241 res2.template storePacketBlock<Packet8bf, size>(
row, 0,
block);
248 res2.template storePacketPartial<Packet8bf>(
row,
j, fp16,
rows & 7);
253 template <const Index size,
bool non_unit_str
ide = false>
267 storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc,
rows & 7);
269 storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
272 storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
273 storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
276 dst += extra * resInc;
277 if (
size != 32)
break;
281 template <
bool non_unit_str
ide = false>
284 convertPointerF32toBF16<32, non_unit_stride>(
i, result,
rows, dst, resInc);
285 convertPointerF32toBF16<16, non_unit_stride>(
i, result,
rows, dst, resInc);
286 convertPointerF32toBF16<8, non_unit_stride>(
i, result,
rows, dst, resInc);
287 convertPointerF32toBF16<1, non_unit_stride>(
i, result,
rows, dst, resInc);
290 template <
typename DataMapper>
294 convertArrayF32toBF16Col<DataMapper, 4>(result,
col,
rows,
res);
299 convertArrayF32toBF16Col<DataMapper, 1>(result,
col,
rows,
res);
302 convertArrayF32toBF16Col<DataMapper, 2>(result,
col,
rows,
res);
305 convertArrayF32toBF16Col<DataMapper, 3>(result,
col,
rows,
res);
310 template <Index size>
313 Index offsetB,
Index bigSuffix,
float* result) {
315 indexA +=
size * offsetA;
316 colLoops<size>(depth,
cols,
rows, pAlpha, indexA, indexB, strideB, offsetB, result +
row);
318 indexA += bigSuffix *
size / 16;
322 template <
typename DataMapper>
329 convertArrayBF16toF32<DataMapper>(result,
cols,
rows,
res);
331 if (strideA == -1) strideA = depth;
332 if (strideB == -1) strideB = depth;
341 Index bigSuffix = (2 * 8) * (strideA - offsetA);
342 indexB += 4 * offsetB;
348 calcColLoops<16>(indexA,
row, depth,
cols,
rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
351 calcColLoops<8>(indexA,
row, depth,
cols,
rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
353 calcColLoops<4>(indexA,
row, depth,
cols,
rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
357 colLoops<4, true>(depth,
cols,
rows, pAlpha, indexA, indexB, strideB, offsetB, result +
row);
361 convertArrayF32toBF16<DataMapper>(result,
cols,
rows,
res);
364 #undef MAX_BFLOAT16_ACC
366 #if !EIGEN_ALTIVEC_DISABLE_MMA
367 template <Index num_acc,
typename LhsMapper,
bool zero>
369 a0[
k + 0] = lhs.template loadPacket<Packet8bf>(
k * 4, 0);
371 b1 = lhs.template loadPacket<Packet8bf>(
k * 4, 1);
373 if (num_acc > (
k + 1)) {
374 a0[
k + 1] = vec_mergel(a0[
k + 0].m_val, b1.
m_val);
376 a0[
k + 0] = vec_mergeh(a0[
k + 0].m_val, b1.
m_val);
379 template <Index num_acc>
382 for (
Index k = 0;
k < num_acc;
k++) {
383 __builtin_mma_xvbf16ger2pp(&(quad_acc[
k]),
reinterpret_cast<Packet16uc>(b0.
m_val),
388 template <Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool zero,
bool linear>
392 Packet8bf b0 = loadColData<RhsMapper, linear>(rhs,
j);
398 using LhsSubMapper =
typename LhsMapper::SubMapper;
400 LhsSubMapper lhs2 = lhs.getSubMapper(0,
j);
402 for (
Index k = 0;
k < num_acc;
k += 2) {
403 loadVecLoop<num_acc, LhsSubMapper, zero>(
k, lhs2, a0, b1);
406 multVec<num_acc>(quad_acc, a0, b0);
409 #define MAX_BFLOAT16_VEC_ACC 8
411 template <const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
414 constexpr
Index step = (num_acc * 4);
415 const Index extra_rows = (extraRows) ? (
rows & 3) : 0;
420 __vector_quad quad_acc[num_acc];
422 zeroAccumulators<num_acc>(quad_acc);
424 using LhsSubMapper =
typename LhsMapper::SubMapper;
426 LhsSubMapper lhs2 = lhs.getSubMapper(
row, 0);
427 for (
Index j = 0;
j + 2 <= cend;
j += 2) {
428 vecColLoop<num_acc, LhsSubMapper, RhsMapper, false, linear>(
j, lhs2, rhs, quad_acc);
431 vecColLoop<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
434 disassembleAccumulators<num_acc>(quad_acc, acc);
436 outputVecColResults<num_acc, extraRows>(acc, result, pAlpha, extra_rows);
439 }
while (multiIters && (step <=
rows - (
row += step)));
442 template <const Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
444 const Packet4f pAlpha,
float* result) {
446 colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs,
451 template <
typename LhsMapper,
typename RhsMapper,
bool extraRows,
bool linear>
453 const Packet4f pAlpha,
float* result) {
456 colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
459 colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
462 colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
465 colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
468 colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
471 colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
474 colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
478 colVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
484 template <
typename LhsMapper,
typename RhsMapper,
bool linear>
489 colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false, linear>(
row, cend,
rows, lhs, rhs, pAlpha,
494 colVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
496 colVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(
row, cend,
rows, lhs, rhs, pAlpha, result);
500 template <
typename RhsMapper,
typename LhsMapper,
typename =
void>
504 using RhsSubMapper =
typename RhsMapper::SubMapper;
506 RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
507 calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2,
rows, lhs, rhs2, pAlpha, result);
511 template <
typename RhsMapper,
typename LhsMapper>
513 std::enable_if_t<std::is_member_function_pointer<decltype(&RhsMapper::stride)>::value>>
517 using RhsSubMapper =
typename RhsMapper::SubMapper;
519 RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
520 if (rhs.stride() == 1) {
521 calcVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2,
rows, lhs, rhs2, pAlpha, result);
523 calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2,
rows, lhs, rhs2, pAlpha, result);
528 template <
typename LhsMapper,
typename RhsMapper>
539 const Index lhsStride = lhs.stride();
550 for (
Index j2 = 0; j2 <
cols; j2 += block_cols) {
553 using LhsSubMapper =
typename LhsMapper::SubMapper;
555 LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
563 0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f};
565 template <Index num_acc>
567 if (num_acc > (
k + 1)) {
568 acc[
k][0] = vec_mergeh(acc[
k][0], acc[
k + 1][0]);
569 acc[
k][1] = vec_mergeo(acc[
k][1], acc[
k + 1][1]);
570 acc[
k][2] = vec_mergel(acc[
k][2], acc[
k + 1][2]);
573 acc[
k][0] = (acc[
k][0] + acc[
k][2]) + (acc[
k][1] + acc[
k][3]);
575 acc[
k][0] = vec_mergeh(acc[
k][0], acc[
k][1]);
576 acc[
k][0] += vec_mergel(acc[
k][2], acc[
k][3]);
578 acc[
k][0] += vec_sld(acc[
k][0], acc[
k][0], 12);
580 acc[
k][0] += vec_sld(acc[
k][0], acc[
k][0], 4);
585 template <Index num_acc>
588 for (
Index k = 0;
k < num_acc;
k += 4) {
589 preduxVecResults2<num_acc>(acc,
k + 0);
590 if (num_acc > (
k + 2)) {
591 preduxVecResults2<num_acc>(acc,
k + 2);
592 acc[
k + 0][0] =
reinterpret_cast<Packet4f>(
593 vec_mergeh(
reinterpret_cast<Packet2ul>(acc[
k + 0][0]),
reinterpret_cast<Packet2ul>(acc[
k + 2][0])));
598 template <Index num_acc,
typename LhsMapper,
typename RhsMapper,
bool extra>
604 b0 = rhs.template loadPacketPartial<Packet8bf>(
j, extra_cols);
606 b0 = rhs.template loadPacket<Packet8bf>(
j);
609 const LhsMapper lhs2 = lhs.getSubMapper(0,
j);
611 for (
Index k = 0;
k < num_acc;
k++) {
613 a0[
k] = lhs2.template loadPacketPartial<Packet8bf>(
k, 0, extra_cols);
615 a0[
k] = lhs2.template loadPacket<Packet8bf>(
k, 0);
619 multVec<num_acc>(quad_acc, a0, b0);
622 template <Index num_acc,
typename LhsMapper,
typename RhsMapper>
626 for (;
j + 8 <=
cols;
j += 8) {
627 multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs,
j, extra_cols);
631 multVecLoop<num_acc, LhsMapper, RhsMapper, true>(quad_acc, lhs, rhs,
j, extra_cols);
635 template <const Index num_acc,
typename LhsMapper,
typename RhsMapper>
643 __vector_quad quad_acc[num_acc];
645 zeroAccumulators<num_acc>(quad_acc);
647 const LhsMapper lhs2 = lhs.getSubMapper(
row, 0);
648 vecLoop<num_acc, LhsMapper, RhsMapper>(
cols, lhs2, rhs, quad_acc, extra_cols);
650 disassembleAccumulators<num_acc>(quad_acc, acc);
652 preduxVecResults<num_acc>(acc);
654 outputVecResults<num_acc>(acc, result, pAlpha);
657 }
while (multiIters && (num_acc <=
rows - (
row += num_acc)));
660 template <const Index num_acc,
typename LhsMapper,
typename RhsMapper>
662 const Packet4f pAlpha,
float* result) {
664 colVecLoopBody<num_acc, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
668 template <
typename LhsMapper,
typename RhsMapper>
670 const Packet4f pAlpha,
float* result) {
673 colVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
676 colVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
679 colVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
682 colVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
685 colVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
688 colVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
691 colVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
696 template <
typename LhsMapper,
typename RhsMapper>
701 colVecLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
704 colVecLoopBodyExtra<LhsMapper, RhsMapper>(
row,
cols,
rows, lhs, rhs, pAlpha, result);
707 template <
typename LhsMapper,
typename RhsMapper>
710 typedef typename RhsMapper::LinearMapper LinearMapper;
715 LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
726 convertArrayPointerBF16toF32<true>(result, 1,
rows,
res, resIncr);
728 calcVecLoops<LhsMapper, LinearMapper>(
cols,
rows, lhs, rhs2, pAlpha, result);
732 convertArrayPointerF32toBF16<true>(result,
rows,
res, resIncr);
737 #undef MAX_BFLOAT16_VEC_ACC
738 #undef BFLOAT16_UNROLL
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define eigen_internal_assert(x)
Definition: Macros.h:916
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define MAX_BFLOAT16_ACC
Definition: MatrixProductMMAbfloat16.h:130
#define MAX_BFLOAT16_VEC_ACC
Definition: MatrixProductMMAbfloat16.h:409
#define BFLOAT16_UNROLL
Definition: MatrixProductMMAbfloat16.h:7
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:806
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
m m block(1, 0, 2, 2)<< 4
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
RealScalar alpha
Definition: level1_cplx_impl.h:151
char char char int int * k
Definition: level2_impl.h:374
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h)
Definition: BFloat16.h:581
EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
Definition: MatrixProductMMAbfloat16.h:708
static Packet16uc p16uc_ELEMENT_VEC3
Definition: MatrixProductMMAbfloat16.h:562
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:669
__vector unsigned char Packet16uc
Definition: AltiVec/PacketMath.h:41
EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f(&acc)[num_acc][4])
Definition: MatrixProductMMAbfloat16.h:586
void gemmMMAbfloat16(const DataMapper &res, const bfloat16 *indexA, const bfloat16 *indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProductMMAbfloat16.h:323
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16 *indexA)
Definition: MatrixProductMMAbfloat16.h:15
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
Definition: MatrixProduct.h:2813
void colVecColLoopBody(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:412
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper &res)
Definition: MatrixProductMMAbfloat16.h:291
void colVecLoopBody(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:636
EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad(&quad_acc)[num_acc], const LhsMapper &lhs, RhsMapper &rhs, Index j, Index extra_cols)
Definition: MatrixProductMMAbfloat16.h:599
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Index rows, const DataMapper &res)
Definition: MatrixProductMMAbfloat16.h:230
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16 *blockB, Index strideB, Index i)
Definition: MatrixProductMMAbfloat16.h:26
__vector unsigned int Packet4ui
Definition: AltiVec/PacketMath.h:35
EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f(&acc)[num_acc][4], Index k)
Definition: MatrixProductMMAbfloat16.h:566
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProductMMAbfloat16.h:151
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Definition: AltiVec/PacketMath.h:42
EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:697
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProductMMAbfloat16.h:202
EIGEN_ALWAYS_INLINE void outputResults(Packet4f(&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float *result, const Index extra_cols, Index extra_rows)
Definition: MatrixProductMMAbfloat16.h:85
EIGEN_STRONG_INLINE Packet4f pset1< Packet4f >(const float &from)
Definition: AltiVec/PacketMath.h:773
void colLoopBody(Index &col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *indexB, Index strideB, Index offsetB, float *result)
Definition: MatrixProductMMAbfloat16.h:133
EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper &lhs, RhsMapper &rhs, __vector_quad(&quad_acc)[num_acc])
Definition: MatrixProductMMAbfloat16.h:389
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:443
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f(&acc)[num_acc][size])
Definition: MatrixProduct.h:2827
EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *indexB, Index strideB, Index offsetB, float *result, const Index extra_cols, const Index extra_rows)
Definition: MatrixProductMMAbfloat16.h:100
EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper &lhs, RhsMapper &rhs, __vector_quad(&quad_acc)[num_acc], Index extra_cols)
Definition: MatrixProductMMAbfloat16.h:623
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16 *dst, Index resInc=1)
Definition: MatrixProductMMAbfloat16.h:282
EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16 *&indexA, Index &row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
Definition: MatrixProductMMAbfloat16.h:311
EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad(&quad_acc)[num_acc], Packet4f(&acc)[num_acc][4])
Definition: MatrixProductMMAbfloat16.h:79
EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper &lhs, Packet8bf(&a0)[num_acc], Packet8bf b1)
Definition: MatrixProductMMAbfloat16.h:368
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtraN(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:661
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index &i, float *result, Index rows, bfloat16 *&dst, Index resInc=1)
Definition: MatrixProductMMAbfloat16.h:254
EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:485
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
Definition: MatrixProductMMAbfloat16.h:220
__vector float Packet4f
Definition: AltiVec/PacketMath.h:33
EIGEN_ALWAYS_INLINE void KLoop(const float *indexA, const float *indexB, Packet4f(&acc)[num_acc][4], Index strideB, Index k, Index offsetB, Index extra_cols)
Definition: MatrixProduct.h:2892
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:452
EIGEN_STRONG_INLINE Packet8bf ploadu< Packet8bf >(const bfloat16 *from)
Definition: AltiVec/PacketMath.h:1549
EIGEN_ALWAYS_INLINE void multVec(__vector_quad(&quad_acc)[num_acc], Packet8bf(&a0)[num_acc], Packet8bf b0)
Definition: MatrixProductMMAbfloat16.h:380
EIGEN_STRONG_INLINE Packet8bf pset1< Packet8bf >(const bfloat16 &from)
Definition: AltiVec/PacketMath.h:808
void gemvMMA_bfloat16_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
Definition: MatrixProductMMAbfloat16.h:529
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProductMMAbfloat16.h:161
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Definition: Eigen_Colamd.h:49
list x
Definition: plotDoE.py:28
Definition: BFloat16.h:101
Definition: GenericPacketMath.h:1407
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper &lhs, RhsMapper &rhs, Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:515
Definition: MatrixProductMMAbfloat16.h:501
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper &lhs, RhsMapper &rhs, Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:502
Definition: GenericPacketMath.h:225
T m_val
Definition: GenericPacketMath.h:235
EIGEN_DONT_INLINE Scalar zero()
Definition: svd_common.h:232
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2