MatrixProductMMAbfloat16.h
Go to the documentation of this file.
1 #ifndef EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
2 #define EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
3 
4 #if EIGEN_COMP_LLVM
5 #define BFLOAT16_UNROLL _Pragma("unroll 8")
6 #else
7 #define BFLOAT16_UNROLL _Pragma("GCC unroll(8)")
8 #endif
9 
10 namespace Eigen {
11 
12 namespace internal {
13 
14 template <bool zero>
16  Packet8bf lhs1 = ploadu<Packet8bf>(indexA);
17  if (zero) {
19  return vec_mergeh(lhs1.m_val, lhs2.m_val);
20  } else {
21  return lhs1;
22  }
23 }
24 
25 template <bool zero>
27  return loadBfloat16<zero>(blockB + strideB * i);
28 }
29 
30 template <Index num_acc, Index num_packets, bool zero, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs,
31  Index num_lhs>
32 EIGEN_ALWAYS_INLINE void KLoop(const bfloat16* indexA, const bfloat16* indexB, __vector_quad (&quad_acc)[num_acc],
33  Index strideB, Index k, Index offsetB, Index extra_cols, Index extra_rows) {
34  Packet8bf lhs[num_lhs], rhs[num_rhs];
35 
37  for (Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++) {
38  rhs[i] = loadRhsBfloat16<zero>(indexB + k * 4, strideB, i);
39  }
40  if (rhsExtraCols) {
41  rhs[num_rhs - 1] = loadRhsBfloat16<zero>(indexB + k * extra_cols - offsetB, strideB, num_rhs - 1);
42  }
43 
44  indexA += k * (lhsExtraRows ? extra_rows : num_packets);
45  if (num_lhs == 1) {
46  lhs[0] = loadBfloat16<zero>(indexA);
47  } else {
49  for (Index j = 0; j < num_lhs; j += 2) {
50  Packet8bf lhs1 = ploadu<Packet8bf>(indexA + (j + 0) * (zero ? 4 : 8));
51  if (zero) {
53  lhs[j + 0] = vec_mergeh(lhs1.m_val, lhs2.m_val);
54  lhs[j + 1] = vec_mergel(lhs1.m_val, lhs2.m_val);
55  } else {
56  lhs[j + 0] = lhs1;
57  lhs[j + 1] = ploadu<Packet8bf>(indexA + (j + 1) * 8);
58  }
59  }
60  }
61 
63  for (Index i = 0, x = 0; i < num_rhs; i++) {
65  for (Index j = 0; j < num_lhs; j++, x++) {
66  __builtin_mma_xvbf16ger2pp(&(quad_acc[x]), reinterpret_cast<Packet16uc>(rhs[i].m_val),
67  reinterpret_cast<Packet16uc>(lhs[j].m_val));
68  }
69  }
70 }
71 
72 template <Index num_acc>
73 EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc]) {
75  for (Index k = 0; k < num_acc; k++) __builtin_mma_xxsetaccz(&(quad_acc[k]));
76 }
77 
78 template <Index num_acc>
79 EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_acc], Packet4f (&acc)[num_acc][4]) {
81  for (Index k = 0; k < num_acc; k++) __builtin_mma_disassemble_acc((void*)acc[k], &(quad_acc[k]));
82 }
83 
84 template <Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
85 EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result,
86  const Index extra_cols, Index extra_rows) {
88  for (Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4 * rows) {
90  for (Index j = 0; j < num_lhs; j++, k++) {
91  storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + j * 4, extra_cols, extra_rows);
92  }
93  }
94  if (rhsExtraCols) {
95  storeResults<rhsExtraCols, lhsExtraRows>(acc[num_acc - 1], rows, pAlpha, result, extra_cols, extra_rows);
96  }
97 }
98 
99 template <const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows, bool multiIter = false>
100 EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
101  const bfloat16* indexB, Index strideB, Index offsetB, float* result,
102  const Index extra_cols, const Index extra_rows) {
103  constexpr Index num_lhs = multiIter ? (num_packets / 4) : 1;
104  constexpr Index num_rhs = (num_acc + num_lhs - 1) / num_lhs;
105 
106  for (Index offset_row = 0; offset_row < num_packets; offset_row += 4, indexA += (multiIter ? 0 : 8),
107  indexB += (multiIter ? (num_rhs * strideB) : 0), result += (multiIter ? (4 * rows * num_rhs) : 4)) {
108  Packet4f acc[num_acc][4];
109  __vector_quad quad_acc[num_acc];
110 
111  zeroAccumulators<num_acc>(quad_acc);
112 
113  Index k;
114  for (k = 0; k + 2 <= depth; k += 2) {
115  KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(
116  indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
117  }
118  if (depth & 1) {
119  KLoop<num_acc, num_packets, true, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(
120  indexA - (multiIter ? 0 : offset_row), indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
121  }
122 
123  disassembleAccumulators<num_acc>(quad_acc, acc);
124 
125  outputResults<num_acc, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(acc, rows, pAlpha, result, extra_cols,
126  extra_rows);
127  }
128 }
129 
130 #define MAX_BFLOAT16_ACC 8
131 
132 template <const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
133 void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
134  const bfloat16* indexB, Index strideB, Index offsetB, float* result) {
135  constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
136  const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
137  const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
138  constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC);
139  constexpr bool normIters = multiIters && ((num_acc % (num_packets / 4)) == 0);
140 
141  do {
142  colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, normIters>(
143  depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
144 
145  indexB += strideB * num_acc;
146  result += rows * step;
147  } while (multiIters && (step <= cols - (col += step)));
148 }
149 
150 template <const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
152  const bfloat16* indexA, const bfloat16* blockB, Index strideB, Index offsetB,
153  float* result) {
154  if (MAX_BFLOAT16_ACC > num_acc) {
155  colLoopBody<num_acc + (rhsExtraCols ? 1 : 0), num_packets, rhsExtraCols, lhsExtraRows>(
156  col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
157  }
158 }
159 
160 template <const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
161 void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
162  const bfloat16* blockB, Index strideB, Index offsetB, float* result) {
163  switch ((cols - col) >> 2) {
164  case 7:
165  colLoopBodyExtraN<7, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
166  strideB, offsetB, result);
167  break;
168  case 6:
169  colLoopBodyExtraN<6, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
170  strideB, offsetB, result);
171  break;
172  case 5:
173  colLoopBodyExtraN<5, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
174  strideB, offsetB, result);
175  break;
176  case 4:
177  colLoopBodyExtraN<4, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
178  strideB, offsetB, result);
179  break;
180  case 3:
181  colLoopBodyExtraN<3, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
182  strideB, offsetB, result);
183  break;
184  case 2:
185  colLoopBodyExtraN<2, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
186  strideB, offsetB, result);
187  break;
188  case 1:
189  colLoopBodyExtraN<1, num_packets, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
190  strideB, offsetB, result);
191  break;
192  default:
193  if (rhsExtraCols) {
194  colLoopBody<1, num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
195  offsetB, result);
196  }
197  break;
198  }
199 }
200 
201 template <const Index num_packets, bool lhsExtraRows = false>
202 EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
203  const bfloat16* blockB, Index strideB, Index offsetB, float* result) {
204  Index col = 0;
205  if (cols >= (MAX_BFLOAT16_ACC * 4)) {
206  colLoopBody<MAX_BFLOAT16_ACC, num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB,
207  strideB, 0, result);
208  blockB += (strideB >> 2) * col;
209  result += rows * col;
210  }
211  if (cols & 3) {
212  colLoopBodyExtra<num_packets, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB,
213  result);
214  } else {
215  colLoopBodyExtra<num_packets, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, 0,
216  result);
217  }
218 }
219 
221  Packet16uc fp16[2];
222  __vector_pair fp16_vp = *reinterpret_cast<__vector_pair*>(const_cast<float*>(res));
223  __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp);
224  fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
225  fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
226  return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1]));
227 }
228 
229 template <typename DataMapper, const Index size>
230 EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float* result, Index col, Index rows, const DataMapper& res) {
231  const DataMapper res2 = res.getSubMapper(0, col);
232  Index row;
233  float* result2 = result + col * rows;
234  for (row = 0; row + 8 <= rows; row += 8, result2 += 8) {
235  // get and save block
238  for (Index j = 0; j < size; j++) {
239  block.packet[j] = convertF32toBF16(result2 + j * rows);
240  }
241  res2.template storePacketBlock<Packet8bf, size>(row, 0, block);
242  }
243  // extra rows
244  if (row < rows) {
246  for (Index j = 0; j < size; j++) {
247  Packet8bf fp16 = convertF32toBF16(result2 + j * rows);
248  res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
249  }
250  }
251 }
252 
253 template <const Index size, bool non_unit_stride = false>
255  Index resInc = 1) {
256  constexpr Index extra = ((size < 8) ? 8 : size);
257  while (i + size <= rows) {
258  PacketBlock<Packet8bf, (size + 7) / 8> r32;
259  r32.packet[0] = convertF32toBF16(result + i + 0);
260  if (size >= 16) {
261  r32.packet[1] = convertF32toBF16(result + i + 8);
262  }
263  if (size >= 32) {
264  r32.packet[2] = convertF32toBF16(result + i + 16);
265  r32.packet[3] = convertF32toBF16(result + i + 24);
266  }
267  storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc, rows & 7);
268  if (size >= 16) {
269  storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
270  }
271  if (size >= 32) {
272  storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
273  storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
274  }
275  i += extra;
276  dst += extra * resInc;
277  if (size != 32) break;
278  }
279 }
280 
281 template <bool non_unit_stride = false>
283  Index i = 0;
284  convertPointerF32toBF16<32, non_unit_stride>(i, result, rows, dst, resInc);
285  convertPointerF32toBF16<16, non_unit_stride>(i, result, rows, dst, resInc);
286  convertPointerF32toBF16<8, non_unit_stride>(i, result, rows, dst, resInc);
287  convertPointerF32toBF16<1, non_unit_stride>(i, result, rows, dst, resInc);
288 }
289 
290 template <typename DataMapper>
291 EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float* result, Index cols, Index rows, const DataMapper& res) {
292  Index col;
293  for (col = 0; col + 4 <= cols; col += 4) {
294  convertArrayF32toBF16Col<DataMapper, 4>(result, col, rows, res);
295  }
296  // extra cols
297  switch (cols - col) {
298  case 1:
299  convertArrayF32toBF16Col<DataMapper, 1>(result, col, rows, res);
300  break;
301  case 2:
302  convertArrayF32toBF16Col<DataMapper, 2>(result, col, rows, res);
303  break;
304  case 3:
305  convertArrayF32toBF16Col<DataMapper, 3>(result, col, rows, res);
306  break;
307  }
308 }
309 
310 template <Index size>
312  const Packet4f pAlpha, const bfloat16* indexB, Index strideB, Index offsetA,
313  Index offsetB, Index bigSuffix, float* result) {
314  if ((size == 16) || (rows & size)) {
315  indexA += size * offsetA;
316  colLoops<size>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
317  row += size;
318  indexA += bigSuffix * size / 16;
319  }
320 }
321 
322 template <typename DataMapper>
323 void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth,
324  Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
326  const Packet4f pAlpha = pset1<Packet4f>(falpha);
328 
329  convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
330 
331  if (strideA == -1) strideA = depth;
332  if (strideB == -1) strideB = depth;
333  // Packing is done in blocks.
334  // There's 4 possible sizes of blocks
335  // Blocks of 8 columns with 16 elements (8x16)
336  // Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
337  // Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
338  // Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
339 
340  // Loop for LHS standard block (8x16)
341  Index bigSuffix = (2 * 8) * (strideA - offsetA);
342  indexB += 4 * offsetB;
343  strideB *= 4;
344  offsetB *= 3;
345 
346  Index row = 0;
347  while (row + 16 <= rows) {
348  calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
349  }
350  // LHS (8x8) block
351  calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
352  // LHS (8x4) block
353  calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
354  // extra rows
355  if (rows & 3) {
356  // This index is the beginning of remaining block.
357  colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
358  }
359 
360  // Convert back to bfloat16
361  convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
362 }
363 
364 #undef MAX_BFLOAT16_ACC
365 
366 #if !EIGEN_ALTIVEC_DISABLE_MMA
367 template <Index num_acc, typename LhsMapper, bool zero>
368 EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper& lhs, Packet8bf (&a0)[num_acc], Packet8bf b1) {
369  a0[k + 0] = lhs.template loadPacket<Packet8bf>(k * 4, 0);
370  if (!zero) {
371  b1 = lhs.template loadPacket<Packet8bf>(k * 4, 1);
372  }
373  if (num_acc > (k + 1)) {
374  a0[k + 1] = vec_mergel(a0[k + 0].m_val, b1.m_val);
375  }
376  a0[k + 0] = vec_mergeh(a0[k + 0].m_val, b1.m_val);
377 }
378 
379 template <Index num_acc>
380 EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf (&a0)[num_acc], Packet8bf b0) {
382  for (Index k = 0; k < num_acc; k++) {
383  __builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast<Packet16uc>(b0.m_val),
384  reinterpret_cast<Packet16uc>(a0[k].m_val));
385  }
386 }
387 
388 template <Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>
389 EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc]) {
390  Packet8bf a0[num_acc];
392  Packet8bf b0 = loadColData<RhsMapper, linear>(rhs, j);
393 
394  if (zero) {
395  b0 = vec_mergeh(b0.m_val, b1.m_val);
396  }
397 
398  using LhsSubMapper = typename LhsMapper::SubMapper;
399 
400  LhsSubMapper lhs2 = lhs.getSubMapper(0, j);
402  for (Index k = 0; k < num_acc; k += 2) {
403  loadVecLoop<num_acc, LhsSubMapper, zero>(k, lhs2, a0, b1);
404  }
405 
406  multVec<num_acc>(quad_acc, a0, b0);
407 }
408 
409 #define MAX_BFLOAT16_VEC_ACC 8
410 
411 template <const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
412 void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
413  float* result) {
414  constexpr Index step = (num_acc * 4);
415  const Index extra_rows = (extraRows) ? (rows & 3) : 0;
416  constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC);
417 
418  do {
419  Packet4f acc[num_acc][4];
420  __vector_quad quad_acc[num_acc];
421 
422  zeroAccumulators<num_acc>(quad_acc);
423 
424  using LhsSubMapper = typename LhsMapper::SubMapper;
425 
426  LhsSubMapper lhs2 = lhs.getSubMapper(row, 0);
427  for (Index j = 0; j + 2 <= cend; j += 2) {
428  vecColLoop<num_acc, LhsSubMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
429  }
430  if (cend & 1) {
431  vecColLoop<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
432  }
433 
434  disassembleAccumulators<num_acc>(quad_acc, acc);
435 
436  outputVecColResults<num_acc, extraRows>(acc, result, pAlpha, extra_rows);
437 
438  result += step;
439  } while (multiIters && (step <= rows - (row += step)));
440 }
441 
442 template <const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
443 EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
444  const Packet4f pAlpha, float* result) {
445  if (MAX_BFLOAT16_VEC_ACC > num_acc) {
446  colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs,
447  pAlpha, result);
448  }
449 }
450 
451 template <typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
452 EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
453  const Packet4f pAlpha, float* result) {
454  switch ((rows - row) >> 2) {
455  case 7:
456  colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
457  break;
458  case 6:
459  colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
460  break;
461  case 5:
462  colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
463  break;
464  case 4:
465  colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
466  break;
467  case 3:
468  colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
469  break;
470  case 2:
471  colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
472  break;
473  case 1:
474  colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
475  break;
476  default:
477  if (extraRows) {
478  colVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
479  }
480  break;
481  }
482 }
483 
484 template <typename LhsMapper, typename RhsMapper, bool linear>
485 EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
486  float* result) {
487  Index row = 0;
488  if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) {
489  colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha,
490  result);
491  result += row;
492  }
493  if (rows & 3) {
494  colVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
495  } else {
496  colVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
497  }
498 }
499 
500 template <typename RhsMapper, typename LhsMapper, typename = void>
501 struct UseMMAStride : std::false_type {
502  static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha,
503  float* result) {
504  using RhsSubMapper = typename RhsMapper::SubMapper;
505 
506  RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
507  calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
508  }
509 };
510 
511 template <typename RhsMapper, typename LhsMapper>
512 struct UseMMAStride<RhsMapper, LhsMapper,
513  std::enable_if_t<std::is_member_function_pointer<decltype(&RhsMapper::stride)>::value>>
514  : std::true_type {
515  static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha,
516  float* result) {
517  using RhsSubMapper = typename RhsMapper::SubMapper;
518 
519  RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
520  if (rhs.stride() == 1) {
521  calcVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2, rows, lhs, rhs2, pAlpha, result);
522  } else {
523  calcVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
524  }
525  }
526 };
527 
528 template <typename LhsMapper, typename RhsMapper>
529 void gemvMMA_bfloat16_col(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, bfloat16* res,
530  Index resIncr, bfloat16 alpha) {
531  EIGEN_UNUSED_VARIABLE(resIncr);
532  eigen_internal_assert(resIncr == 1);
533 
534  // The following copy tells the compiler that lhs's attributes are not modified outside this function
535  // This helps GCC to generate proper code.
536  LhsMapper lhs(alhs);
537  RhsMapper rhs2(rhs);
538 
539  const Index lhsStride = lhs.stride();
540 
541  // TODO: improve the following heuristic:
542  const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8);
544  Packet4f pAlpha = pset1<Packet4f>(falpha);
545 
547 
549 
550  for (Index j2 = 0; j2 < cols; j2 += block_cols) {
551  Index jend = numext::mini(j2 + block_cols, cols);
552 
553  using LhsSubMapper = typename LhsMapper::SubMapper;
554 
555  LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
556  UseMMAStride<RhsMapper, LhsSubMapper>::run(j2, jend, rows, lhs2, rhs2, pAlpha, result);
557  }
558 
560 }
561 
562 static Packet16uc p16uc_ELEMENT_VEC3 = {0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f,
563  0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f};
564 
565 template <Index num_acc>
567  if (num_acc > (k + 1)) {
568  acc[k][0] = vec_mergeh(acc[k][0], acc[k + 1][0]);
569  acc[k][1] = vec_mergeo(acc[k][1], acc[k + 1][1]);
570  acc[k][2] = vec_mergel(acc[k][2], acc[k + 1][2]);
571  acc[k][3] = vec_perm(acc[k][3], acc[k + 1][3], p16uc_ELEMENT_VEC3);
572 
573  acc[k][0] = (acc[k][0] + acc[k][2]) + (acc[k][1] + acc[k][3]);
574  } else {
575  acc[k][0] = vec_mergeh(acc[k][0], acc[k][1]);
576  acc[k][0] += vec_mergel(acc[k][2], acc[k][3]);
577 #ifdef _BIG_ENDIAN
578  acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
579 #else
580  acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
581 #endif
582  }
583 }
584 
585 template <Index num_acc>
588  for (Index k = 0; k < num_acc; k += 4) {
589  preduxVecResults2<num_acc>(acc, k + 0);
590  if (num_acc > (k + 2)) {
591  preduxVecResults2<num_acc>(acc, k + 2);
592  acc[k + 0][0] = reinterpret_cast<Packet4f>(
593  vec_mergeh(reinterpret_cast<Packet2ul>(acc[k + 0][0]), reinterpret_cast<Packet2ul>(acc[k + 2][0])));
594  }
595  }
596 }
597 
598 template <Index num_acc, typename LhsMapper, typename RhsMapper, bool extra>
599 EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const LhsMapper& lhs, RhsMapper& rhs, Index j,
600  Index extra_cols) {
601  Packet8bf a0[num_acc], b0;
602 
603  if (extra) {
604  b0 = rhs.template loadPacketPartial<Packet8bf>(j, extra_cols);
605  } else {
606  b0 = rhs.template loadPacket<Packet8bf>(j);
607  }
608 
609  const LhsMapper lhs2 = lhs.getSubMapper(0, j);
611  for (Index k = 0; k < num_acc; k++) {
612  if (extra) {
613  a0[k] = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
614  } else {
615  a0[k] = lhs2.template loadPacket<Packet8bf>(k, 0);
616  }
617  }
618 
619  multVec<num_acc>(quad_acc, a0, b0);
620 }
621 
622 template <Index num_acc, typename LhsMapper, typename RhsMapper>
623 EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc],
624  Index extra_cols) {
625  Index j = 0;
626  for (; j + 8 <= cols; j += 8) {
627  multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs, j, extra_cols);
628  }
629 
630  if (extra_cols) {
631  multVecLoop<num_acc, LhsMapper, RhsMapper, true>(quad_acc, lhs, rhs, j, extra_cols);
632  }
633 }
634 
635 template <const Index num_acc, typename LhsMapper, typename RhsMapper>
636 void colVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
637  float* result) {
638  constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC);
639  const Index extra_cols = (cols & 7);
640 
641  do {
642  Packet4f acc[num_acc][4];
643  __vector_quad quad_acc[num_acc];
644 
645  zeroAccumulators<num_acc>(quad_acc);
646 
647  const LhsMapper lhs2 = lhs.getSubMapper(row, 0);
648  vecLoop<num_acc, LhsMapper, RhsMapper>(cols, lhs2, rhs, quad_acc, extra_cols);
649 
650  disassembleAccumulators<num_acc>(quad_acc, acc);
651 
652  preduxVecResults<num_acc>(acc);
653 
654  outputVecResults<num_acc>(acc, result, pAlpha);
655 
656  result += num_acc;
657  } while (multiIters && (num_acc <= rows - (row += num_acc)));
658 }
659 
660 template <const Index num_acc, typename LhsMapper, typename RhsMapper>
661 EIGEN_ALWAYS_INLINE void colVecLoopBodyExtraN(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
662  const Packet4f pAlpha, float* result) {
663  if (MAX_BFLOAT16_VEC_ACC > num_acc) {
664  colVecLoopBody<num_acc, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
665  }
666 }
667 
668 template <typename LhsMapper, typename RhsMapper>
669 EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
670  const Packet4f pAlpha, float* result) {
671  switch (rows - row) {
672  case 7:
673  colVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
674  break;
675  case 6:
676  colVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
677  break;
678  case 5:
679  colVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
680  break;
681  case 4:
682  colVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
683  break;
684  case 3:
685  colVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
686  break;
687  case 2:
688  colVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
689  break;
690  case 1:
691  colVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
692  break;
693  }
694 }
695 
696 template <typename LhsMapper, typename RhsMapper>
697 EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
698  float* result) {
699  Index row = 0;
700  if (rows >= MAX_BFLOAT16_VEC_ACC) {
701  colVecLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
702  result += row;
703  }
704  colVecLoopBodyExtra<LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
705 }
706 
707 template <typename LhsMapper, typename RhsMapper>
708 EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
709  bfloat16* res, Index resIncr, bfloat16 alpha) {
710  typedef typename RhsMapper::LinearMapper LinearMapper;
711 
712  // The following copy tells the compiler that lhs's attributes are not modified outside this function
713  // This helps GCC to generate proper code.
714  LhsMapper lhs(alhs);
715  LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
716 
717  eigen_internal_assert(rhs.stride() == 1);
718 
720  const Packet4f pAlpha = pset1<Packet4f>(falpha);
721 
723  if (resIncr == 1) {
725  } else {
726  convertArrayPointerBF16toF32<true>(result, 1, rows, res, resIncr);
727  }
728  calcVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
729  if (resIncr == 1) {
731  } else {
732  convertArrayPointerF32toBF16<true>(result, rows, res, resIncr);
733  }
734 }
735 #endif
736 
737 #undef MAX_BFLOAT16_VEC_ACC
738 #undef BFLOAT16_UNROLL
739 
740 } // namespace internal
741 } // namespace Eigen
742 #endif // EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define eigen_internal_assert(x)
Definition: Macros.h:916
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
m col(1)
m row(1)
#define MAX_BFLOAT16_ACC
Definition: MatrixProductMMAbfloat16.h:130
#define MAX_BFLOAT16_VEC_ACC
Definition: MatrixProductMMAbfloat16.h:409
#define BFLOAT16_UNROLL
Definition: MatrixProductMMAbfloat16.h:7
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:806
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
m m block(1, 0, 2, 2)<< 4
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
RealScalar alpha
Definition: level1_cplx_impl.h:151
char char char int int * k
Definition: level2_impl.h:374
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h)
Definition: BFloat16.h:581
EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
Definition: MatrixProductMMAbfloat16.h:708
static Packet16uc p16uc_ELEMENT_VEC3
Definition: MatrixProductMMAbfloat16.h:562
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:669
__vector unsigned char Packet16uc
Definition: AltiVec/PacketMath.h:41
EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f(&acc)[num_acc][4])
Definition: MatrixProductMMAbfloat16.h:586
void gemmMMAbfloat16(const DataMapper &res, const bfloat16 *indexA, const bfloat16 *indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
Definition: MatrixProductMMAbfloat16.h:323
EIGEN_ALWAYS_INLINE Packet8bf loadBfloat16(const bfloat16 *indexA)
Definition: MatrixProductMMAbfloat16.h:15
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16 *src, Index resInc)
Definition: MatrixProduct.h:2813
void colVecColLoopBody(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:412
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper &res)
Definition: MatrixProductMMAbfloat16.h:291
void colVecLoopBody(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:636
EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad(&quad_acc)[num_acc], const LhsMapper &lhs, RhsMapper &rhs, Index j, Index extra_cols)
Definition: MatrixProductMMAbfloat16.h:599
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Index rows, const DataMapper &res)
Definition: MatrixProductMMAbfloat16.h:230
EIGEN_ALWAYS_INLINE Packet8bf loadRhsBfloat16(const bfloat16 *blockB, Index strideB, Index i)
Definition: MatrixProductMMAbfloat16.h:26
__vector unsigned int Packet4ui
Definition: AltiVec/PacketMath.h:35
EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f(&acc)[num_acc][4], Index k)
Definition: MatrixProductMMAbfloat16.h:566
EIGEN_ALWAYS_INLINE void colLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProductMMAbfloat16.h:151
eigen_packet_wrapper< __vector unsigned short int, 0 > Packet8bf
Definition: AltiVec/PacketMath.h:42
EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:697
EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProductMMAbfloat16.h:202
EIGEN_ALWAYS_INLINE void outputResults(Packet4f(&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float *result, const Index extra_cols, Index extra_rows)
Definition: MatrixProductMMAbfloat16.h:85
EIGEN_STRONG_INLINE Packet4f pset1< Packet4f >(const float &from)
Definition: AltiVec/PacketMath.h:773
void colLoopBody(Index &col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *indexB, Index strideB, Index offsetB, float *result)
Definition: MatrixProductMMAbfloat16.h:133
EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper &lhs, RhsMapper &rhs, __vector_quad(&quad_acc)[num_acc])
Definition: MatrixProductMMAbfloat16.h:389
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:443
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f(&acc)[num_acc][size])
Definition: MatrixProduct.h:2827
EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *indexB, Index strideB, Index offsetB, float *result, const Index extra_cols, const Index extra_rows)
Definition: MatrixProductMMAbfloat16.h:100
EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper &lhs, RhsMapper &rhs, __vector_quad(&quad_acc)[num_acc], Index extra_cols)
Definition: MatrixProductMMAbfloat16.h:623
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16 *dst, Index resInc=1)
Definition: MatrixProductMMAbfloat16.h:282
EIGEN_ALWAYS_INLINE void calcColLoops(const bfloat16 *&indexA, Index &row, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexB, Index strideB, Index offsetA, Index offsetB, Index bigSuffix, float *result)
Definition: MatrixProductMMAbfloat16.h:311
EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad(&quad_acc)[num_acc], Packet4f(&acc)[num_acc][4])
Definition: MatrixProductMMAbfloat16.h:79
EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper &lhs, Packet8bf(&a0)[num_acc], Packet8bf b1)
Definition: MatrixProductMMAbfloat16.h:368
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtraN(Index &row, Index cols, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:661
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index &i, float *result, Index rows, bfloat16 *&dst, Index resInc=1)
Definition: MatrixProductMMAbfloat16.h:254
EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:485
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
Definition: MatrixProductMMAbfloat16.h:220
__vector float Packet4f
Definition: AltiVec/PacketMath.h:33
EIGEN_ALWAYS_INLINE void KLoop(const float *indexA, const float *indexB, Packet4f(&acc)[num_acc][4], Index strideB, Index k, Index offsetB, Index extra_cols)
Definition: MatrixProduct.h:2892
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index &row, Index cend, Index rows, LhsMapper &lhs, RhsMapper &rhs, const Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:452
EIGEN_STRONG_INLINE Packet8bf ploadu< Packet8bf >(const bfloat16 *from)
Definition: AltiVec/PacketMath.h:1549
EIGEN_ALWAYS_INLINE void multVec(__vector_quad(&quad_acc)[num_acc], Packet8bf(&a0)[num_acc], Packet8bf b0)
Definition: MatrixProductMMAbfloat16.h:380
EIGEN_STRONG_INLINE Packet8bf pset1< Packet8bf >(const bfloat16 &from)
Definition: AltiVec/PacketMath.h:808
void gemvMMA_bfloat16_col(Index rows, Index cols, const LhsMapper &alhs, const RhsMapper &rhs, bfloat16 *res, Index resIncr, bfloat16 alpha)
Definition: MatrixProductMMAbfloat16.h:529
void colLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16 *indexA, const bfloat16 *blockB, Index strideB, Index offsetB, float *result)
Definition: MatrixProductMMAbfloat16.h:161
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Definition: Eigen_Colamd.h:49
list x
Definition: plotDoE.py:28
Definition: BFloat16.h:101
Definition: GenericPacketMath.h:1407
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper &lhs, RhsMapper &rhs, Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:515
Definition: MatrixProductMMAbfloat16.h:501
static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper &lhs, RhsMapper &rhs, Packet4f pAlpha, float *result)
Definition: MatrixProductMMAbfloat16.h:502
Definition: GenericPacketMath.h:225
T m_val
Definition: GenericPacketMath.h:235
EIGEN_DONT_INLINE Scalar zero()
Definition: svd_common.h:232
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2