TensorContractionGpu.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014-2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 // Copyright (C) 2015 Navdeep Jaitly <ndjaitly@google.com>
6 // Copyright (C) 2014 Eric Martin <eric@ericmart.in>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
14 
15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
16 
17 // IWYU pragma: private
18 #include "./InternalHeaderCheck.h"
19 
20 namespace Eigen {
21 
22 template <typename Scalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper,
23  bool needs_edge_check>
24 __device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
25  const OutputMapper output, Scalar* lhs_shmem,
26  Scalar* rhs_shmem, const Index m_size,
27  const Index n_size, const Index k_size) {
28  const Index m_block_idx = blockIdx.x;
29  const Index n_block_idx = blockIdx.y;
30 
31  const Index base_m = 64 * m_block_idx;
32  const Index base_n = 64 * n_block_idx;
33 
34  // declare and initialize 64 registers for output 8x8 block
35 
36  // prefetch registers
37  Scalar lhs_pf0;
38  Scalar lhs_pf1;
39  Scalar lhs_pf2;
40  Scalar lhs_pf3;
41  Scalar lhs_pf4;
42  Scalar lhs_pf5;
43  Scalar lhs_pf6;
44  Scalar lhs_pf7;
45 
46  Scalar rhs_pf0;
47  Scalar rhs_pf1;
48  Scalar rhs_pf2;
49  Scalar rhs_pf3;
50  Scalar rhs_pf4;
51  Scalar rhs_pf5;
52  Scalar rhs_pf6;
53  Scalar rhs_pf7;
54 
55  // shared memory is formatted
56  // (contract idx in block, nocontract idx in block, block idx)
57  // where block idx is column major. This transposition limits the number of
58  // bank conflicts when reading the LHS. The core idea is that since the contracting
59  // index is shared by both sides, then the contracting index should be in threadIdx.x.
60 
61  // On the LHS, we pad each row inside of each block with an extra element. This makes
62  // each block 8 rows of 9 elements, which is 72 elements. This gives no bank conflicts
63  // on writes and very few 2-way conflicts on reads. There is an 8x8 grid of these blocks.
64 
65  // On the RHS we just add 8 padding elements to the end of each block. This gives no bank
66  // conflicts on writes and also none on reads.
67 
68  // storage indices
69  const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
70  const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
71 
72  const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
73  const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
74  const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
75  const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
76  const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
77  const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
78  const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
79  const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
80 
81  const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
82  const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
83  const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
84  const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
85  const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
86  const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
87  const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
88  const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
89 
90  // in the loading code, the following variables are important:
91  // threadIdx.x: the vertical position in an 8x8 block
92  // threadIdx.y: the vertical index of the 8x8 block in the grid
93  // threadIdx.z: the horizontal position in an 8x8 block
94  // k: the horizontal index of the 8x8 block in the grid
95  //
96  // The k parameter is implicit (it was the loop counter for a loop that went
97  // from 0 to <8, but now that loop is unrolled in the below code.
98 
99  const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
100  const Index lhs_vert = base_m + load_idx_vert;
101 
102 #define prefetchIntoRegisters(base_k) \
103  { \
104  lhs_pf0 = conv(0); \
105  lhs_pf1 = conv(0); \
106  lhs_pf2 = conv(0); \
107  lhs_pf3 = conv(0); \
108  lhs_pf4 = conv(0); \
109  lhs_pf5 = conv(0); \
110  lhs_pf6 = conv(0); \
111  lhs_pf7 = conv(0); \
112  \
113  rhs_pf0 = conv(0); \
114  rhs_pf1 = conv(0); \
115  rhs_pf2 = conv(0); \
116  rhs_pf3 = conv(0); \
117  rhs_pf4 = conv(0); \
118  rhs_pf5 = conv(0); \
119  rhs_pf6 = conv(0); \
120  rhs_pf7 = conv(0); \
121  \
122  if (!needs_edge_check || lhs_vert < m_size) { \
123  const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
124  const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
125  const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
126  const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
127  const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
128  const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
129  const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
130  const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
131  \
132  if (!needs_edge_check || lhs_horiz_7 < k_size) { \
133  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
134  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
135  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
136  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
137  lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
138  lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
139  lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
140  lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
141  } else if (lhs_horiz_6 < k_size) { \
142  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
143  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
144  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
145  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
146  lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
147  lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
148  lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
149  } else if (lhs_horiz_5 < k_size) { \
150  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
151  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
152  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
153  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
154  lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
155  lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
156  } else if (lhs_horiz_4 < k_size) { \
157  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
158  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
159  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
160  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
161  lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
162  } else if (lhs_horiz_3 < k_size) { \
163  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
164  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
165  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
166  lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
167  } else if (lhs_horiz_2 < k_size) { \
168  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
169  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
170  lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
171  } else if (lhs_horiz_1 < k_size) { \
172  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
173  lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
174  } else if (lhs_horiz_0 < k_size) { \
175  lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
176  } \
177  } \
178  \
179  const Index rhs_vert = base_k + load_idx_vert; \
180  if (!needs_edge_check || rhs_vert < k_size) { \
181  const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
182  const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
183  const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
184  const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
185  const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
186  const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
187  const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
188  const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
189  \
190  if (rhs_horiz_7 < n_size) { \
191  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
192  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
193  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
194  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
195  rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
196  rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
197  rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
198  rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
199  } else if (rhs_horiz_6 < n_size) { \
200  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
201  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
202  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
203  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
204  rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
205  rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
206  rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
207  } else if (rhs_horiz_5 < n_size) { \
208  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
209  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
210  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
211  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
212  rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
213  rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
214  } else if (rhs_horiz_4 < n_size) { \
215  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
216  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
217  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
218  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
219  rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
220  } else if (rhs_horiz_3 < n_size) { \
221  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
222  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
223  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
224  rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
225  } else if (rhs_horiz_2 < n_size) { \
226  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
227  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
228  rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
229  } else if (rhs_horiz_1 < n_size) { \
230  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
231  rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
232  } else if (rhs_horiz_0 < n_size) { \
233  rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
234  } \
235  } \
236  }
237 
238 #define writeRegToShmem() \
239  lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
240  rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
241  \
242  lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
243  rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
244  \
245  lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
246  rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
247  \
248  lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
249  rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
250  \
251  lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
252  rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
253  \
254  lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
255  rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
256  \
257  lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
258  rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
259  \
260  lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
261  rhs_shmem[rhs_store_idx_7] = rhs_pf7;
262 
263  // declare and initialize result array
264 #define res(i, j) _res_##i##j
265 #define initResultRow(i) \
266  Scalar res(i, 0) = conv(0); \
267  Scalar res(i, 1) = conv(0); \
268  Scalar res(i, 2) = conv(0); \
269  Scalar res(i, 3) = conv(0); \
270  Scalar res(i, 4) = conv(0); \
271  Scalar res(i, 5) = conv(0); \
272  Scalar res(i, 6) = conv(0); \
273  Scalar res(i, 7) = conv(0);
274 
275  internal::scalar_cast_op<int, Scalar> conv;
276  initResultRow(0);
277  initResultRow(1);
278  initResultRow(2);
279  initResultRow(3);
280  initResultRow(4);
281  initResultRow(5);
282  initResultRow(6);
283  initResultRow(7);
284 #undef initResultRow
285 
286  for (Index base_k = 0; base_k < k_size; base_k += 64) {
287  // wait for previous iteration to finish with shmem. Despite common sense,
288  // the code is a bit faster with this here then at bottom of loop
289  __syncthreads();
290 
291  prefetchIntoRegisters(base_k);
292  writeRegToShmem();
293 
294 #undef prefetchIntoRegisters
295 #undef writeRegToShmem
296 
297  // wait for shared mem packing to be done before starting computation
298  __syncthreads();
299 
300  // compute 8x8 matrix product by outer product. This involves packing one column
301  // of LHS and one row of RHS into registers (takes 16 registers).
302 
303 #define lcol(i) _lcol##i
304  Scalar lcol(0);
305  Scalar lcol(1);
306  Scalar lcol(2);
307  Scalar lcol(3);
308  Scalar lcol(4);
309  Scalar lcol(5);
310  Scalar lcol(6);
311  Scalar lcol(7);
312 
313 #define rrow(j) _rrow##j
314  Scalar rrow(0);
315  Scalar rrow(1);
316  Scalar rrow(2);
317  Scalar rrow(3);
318  Scalar rrow(4);
319  Scalar rrow(5);
320  Scalar rrow(6);
321  Scalar rrow(7);
322 
323  // Now x corresponds to k, y to m, and z to n
324  const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
325  const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
326 
327 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
328 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
329 
330 #define loadData(i, j) \
331  lcol(0) = lhs_element(0, j); \
332  rrow(0) = rhs_element(i, 0); \
333  lcol(1) = lhs_element(1, j); \
334  rrow(1) = rhs_element(i, 1); \
335  lcol(2) = lhs_element(2, j); \
336  rrow(2) = rhs_element(i, 2); \
337  lcol(3) = lhs_element(3, j); \
338  rrow(3) = rhs_element(i, 3); \
339  lcol(4) = lhs_element(4, j); \
340  rrow(4) = rhs_element(i, 4); \
341  lcol(5) = lhs_element(5, j); \
342  rrow(5) = rhs_element(i, 5); \
343  lcol(6) = lhs_element(6, j); \
344  rrow(6) = rhs_element(i, 6); \
345  lcol(7) = lhs_element(7, j); \
346  rrow(7) = rhs_element(i, 7);
347 
348 #define computeCol(j) \
349  res(0, j) += lcol(0) * rrow(j); \
350  res(1, j) += lcol(1) * rrow(j); \
351  res(2, j) += lcol(2) * rrow(j); \
352  res(3, j) += lcol(3) * rrow(j); \
353  res(4, j) += lcol(4) * rrow(j); \
354  res(5, j) += lcol(5) * rrow(j); \
355  res(6, j) += lcol(6) * rrow(j); \
356  res(7, j) += lcol(7) * rrow(j);
357 
358 #define computePass(i) \
359  loadData(i, i); \
360  \
361  computeCol(0); \
362  computeCol(1); \
363  computeCol(2); \
364  computeCol(3); \
365  computeCol(4); \
366  computeCol(5); \
367  computeCol(6); \
368  computeCol(7);
369 
370  computePass(0);
371  computePass(1);
372  computePass(2);
373  computePass(3);
374  computePass(4);
375  computePass(5);
376  computePass(6);
377  computePass(7);
378 
379 #undef lcol
380 #undef rrow
381 #undef lhs_element
382 #undef rhs_element
383 #undef loadData
384 #undef computeCol
385 #undef computePass
386  } // end loop over k
387 
388  // we've now iterated over all of the large (ie width 64) k blocks and
389  // accumulated results in registers. At this point thread (x, y, z) contains
390  // the sum across all big k blocks of the product of little k block of index (x, y)
391  // with block of index (y, z). To compute the final output, we need to reduce
392  // the 8 threads over y by summation.
393 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
395 #else
396 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
397 #endif
398 
399 #define reduceRow(i, mask) \
400  shuffleInc(i, 0, mask); \
401  shuffleInc(i, 1, mask); \
402  shuffleInc(i, 2, mask); \
403  shuffleInc(i, 3, mask); \
404  shuffleInc(i, 4, mask); \
405  shuffleInc(i, 5, mask); \
406  shuffleInc(i, 6, mask); \
407  shuffleInc(i, 7, mask);
408 
409 #define reduceMatrix(mask) \
410  reduceRow(0, mask); \
411  reduceRow(1, mask); \
412  reduceRow(2, mask); \
413  reduceRow(3, mask); \
414  reduceRow(4, mask); \
415  reduceRow(5, mask); \
416  reduceRow(6, mask); \
417  reduceRow(7, mask);
418 
419  // actually perform the reduction, now each thread of index (_, y, z)
420  // contains the correct values in its registers that belong in the output
421  // block
422  reduceMatrix(1);
423  reduceMatrix(2);
424  reduceMatrix(4);
425 
426 #undef shuffleInc
427 #undef reduceRow
428 #undef reduceMatrix
429 
430  // now we need to copy the 64 values into main memory. We can't split work
431  // among threads because all variables are in registers. There's 2 ways
432  // to do this:
433  // (1) have 1 thread do 64 writes from registers into global memory
434  // (2) have 1 thread do 64 writes into shared memory, and then 8 threads
435  // each do 8 writes into global memory. We can just overwrite the shared
436  // memory from the problem we just solved.
437  // (2) is slightly faster than (1) due to less branching and more ILP
438 
439  // TODO: won't yield much gain, but could just use currently unused shared mem
440  // and then we won't have to sync
441  // wait for shared mem to be out of use
442  __syncthreads();
443 
444 #define writeResultShmem(i, j) lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j);
445 
446 #define writeRow(i) \
447  writeResultShmem(i, 0); \
448  writeResultShmem(i, 1); \
449  writeResultShmem(i, 2); \
450  writeResultShmem(i, 3); \
451  writeResultShmem(i, 4); \
452  writeResultShmem(i, 5); \
453  writeResultShmem(i, 6); \
454  writeResultShmem(i, 7);
455 
456  if (threadIdx.x == 0) {
457  writeRow(0);
458  writeRow(1);
459  writeRow(2);
460  writeRow(3);
461  writeRow(4);
462  writeRow(5);
463  writeRow(6);
464  writeRow(7);
465  }
466 #undef writeResultShmem
467 #undef writeRow
468 
469  const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
470  const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
471 
472  if (threadIdx.x < max_i_write) {
473  if (max_j_write == 8) {
474  // TODO: can i trade bank conflicts for coalesced writes?
475  Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
476  Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
477  Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
478  Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
479  Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
480  Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
481  Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
482  Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
483 
484  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
485  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
486  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
487  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
488  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
489  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
490  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
491  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
492  } else {
493 #pragma unroll 7
494  for (int j = 0; j < max_j_write; j++) {
495  Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
496  output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
497  }
498  }
499  }
500 #undef res
501 }
502 
503 template <typename Scalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
504 __global__ void
505 #if defined(EIGEN_HIPCC)
506 __launch_bounds__(512, 1)
507 #else
508 __launch_bounds__(512)
509 #endif
510  EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size,
511  const Index n_size, const Index k_size) {
512  __shared__ Scalar lhs_shmem[72 * 64];
513  __shared__ Scalar rhs_shmem[72 * 64];
514 
515  const Index m_block_idx = blockIdx.x;
516  const Index n_block_idx = blockIdx.y;
517 
518  const Index base_m = 64 * m_block_idx;
519  const Index base_n = 64 * n_block_idx;
520 
521  if (base_m + 63 < m_size && base_n + 63 < n_size) {
522  EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(
523  lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
524  } else {
525  EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(
526  lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
527  }
528 }
529 
530 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
531  bool CHECK_RHS_BOUNDARY>
532 __device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
533  const OutputMapper output,
534  float2 lhs_shmem2[][16],
535  float2 rhs_shmem2[][8], const Index m_size,
536  const Index n_size, const Index k_size,
537  const Index base_m, const Index base_n) {
538  // prefetch registers
539  float4 lhs_pf0, rhs_pf0;
540 
541  float4 results[4];
542  for (int i = 0; i < 4; i++) {
543  results[i].x = results[i].y = results[i].z = results[i].w = 0;
544  }
545 
546 #define prefetch_lhs(reg, row, col) \
547  if (!CHECK_LHS_BOUNDARY) { \
548  if (col < k_size) { \
549  reg = lhs.template loadPacket<float4, Unaligned>(row, col); \
550  } \
551  } else { \
552  if (col < k_size) { \
553  if (row + 3 < m_size) { \
554  reg = lhs.template loadPacket<float4, Unaligned>(row, col); \
555  } else if (row + 2 < m_size) { \
556  reg.x = lhs(row + 0, col); \
557  reg.y = lhs(row + 1, col); \
558  reg.z = lhs(row + 2, col); \
559  } else if (row + 1 < m_size) { \
560  reg.x = lhs(row + 0, col); \
561  reg.y = lhs(row + 1, col); \
562  } else if (row < m_size) { \
563  reg.x = lhs(row + 0, col); \
564  } \
565  } \
566  }
567 
568  Index lhs_vert = base_m + threadIdx.x * 4;
569 
570  for (Index k = 0; k < k_size; k += 16) {
571  lhs_pf0 = internal::pset1<float4>(0);
572  rhs_pf0 = internal::pset1<float4>(0);
573 
574  Index lhs_horiz = threadIdx.y + k;
575  prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
576 
577  Index rhs_vert = k + (threadIdx.x % 4) * 4;
578  Index rhs_horiz0 = (threadIdx.x >> 2) + threadIdx.y * 4 + base_n;
579 
580  if (!CHECK_RHS_BOUNDARY) {
581  if ((rhs_vert + 3) < k_size) {
582  // just CHECK_RHS_BOUNDARY
583  rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
584  } else if (rhs_vert + 2 < k_size) {
585  // just CHECK_RHS_BOUNDARY
586  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
587  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
588  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
589  } else if (rhs_vert + 1 < k_size) {
590  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
591  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
592  } else if (rhs_vert < k_size) {
593  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
594  }
595  } else {
596  if (rhs_horiz0 < n_size) {
597  if ((rhs_vert + 3) < k_size) {
598  rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
599  } else if ((rhs_vert + 2) < k_size) {
600  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
601  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
602  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
603  } else if ((rhs_vert + 1) < k_size) {
604  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
605  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
606  } else if (rhs_vert < k_size) {
607  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
608  }
609  }
610  }
611  float x1, x2;
612  // the following can be a bitwise operation..... some day.
613  if ((threadIdx.x % 8) < 4) {
614  x1 = rhs_pf0.y;
615  x2 = rhs_pf0.w;
616  } else {
617  x1 = rhs_pf0.x;
618  x2 = rhs_pf0.z;
619  }
620 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
621  x1 = __shfl_xor(x1, 4);
622  x2 = __shfl_xor(x2, 4);
623 #else
624  x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
625  x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
626 #endif
627  if ((threadIdx.x % 8) < 4) {
628  rhs_pf0.y = x1;
629  rhs_pf0.w = x2;
630  } else {
631  rhs_pf0.x = x1;
632  rhs_pf0.z = x2;
633  }
634 
635  // We have 64 features.
636  // Row 0 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 0, 1.
637  // Row 1 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 2, 3.
638  // ...
639  // Row 31 -> times (0, 4, 8, 12, 1, 5, 9, 13) for features 62, 63
640  // Row 32 -> times (2, 6, 10, 14, 3, 7, 11, 15) for features 0, 1
641  // ...
642  rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2][threadIdx.x % 8] = make_float2(rhs_pf0.x, rhs_pf0.y);
643  rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2 + 32][threadIdx.x % 8] = make_float2(rhs_pf0.z, rhs_pf0.w);
644 
645  // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
646  // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
647  // ...
648  // Row 15 (time 15) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61)
649  // Row 16 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63)
650  // ...
651 
652  lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
653  lhs_shmem2[threadIdx.y + 16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
654 
655 #define add_vals(fl1, fl2, fr1, fr2) \
656  results[0].x += fl1.x * fr1.x; \
657  results[0].y += fl1.y * fr1.x; \
658  results[0].z += fl2.x * fr1.x; \
659  results[0].w += fl2.y * fr1.x; \
660  \
661  results[1].x += fl1.x * fr1.y; \
662  results[1].y += fl1.y * fr1.y; \
663  results[1].z += fl2.x * fr1.y; \
664  results[1].w += fl2.y * fr1.y; \
665  \
666  results[2].x += fl1.x * fr2.x; \
667  results[2].y += fl1.y * fr2.x; \
668  results[2].z += fl2.x * fr2.x; \
669  results[2].w += fl2.y * fr2.x; \
670  \
671  results[3].x += fl1.x * fr2.y; \
672  results[3].y += fl1.y * fr2.y; \
673  results[3].z += fl2.x * fr2.y; \
674  results[3].w += fl2.y * fr2.y;
675 
676  __syncthreads();
677 
678 // Do the multiplies.
679 #pragma unroll
680  for (int koff = 0; koff < 16; koff++) {
681  // 32 x threads.
682  float2 fl1 = lhs_shmem2[koff][threadIdx.x];
683  float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
684 
685  int start_feature = threadIdx.y * 4;
686  float2 fr1 = rhs_shmem2[(start_feature >> 1) + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
687  float2 fr2 = rhs_shmem2[(start_feature >> 1) + 1 + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
688 
689  add_vals(fl1, fl2, fr1, fr2)
690  }
691  __syncthreads();
692  }
693 
694 #undef prefetch_lhs
695 #undef add_vals
696 
697  Index horiz_base = threadIdx.y * 4 + base_n;
698  if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
699  for (int i = 0; i < 4; i++) {
700  output(lhs_vert, horiz_base + i) = results[i].x;
701  output(lhs_vert + 1, horiz_base + i) = results[i].y;
702  output(lhs_vert + 2, horiz_base + i) = results[i].z;
703  output(lhs_vert + 3, horiz_base + i) = results[i].w;
704  }
705  } else if (!CHECK_RHS_BOUNDARY) {
706  // CHECK LHS
707  if (lhs_vert + 3 < m_size) {
708  for (int i = 0; i < 4; i++) {
709  output(lhs_vert, horiz_base + i) = results[i].x;
710  output(lhs_vert + 1, horiz_base + i) = results[i].y;
711  output(lhs_vert + 2, horiz_base + i) = results[i].z;
712  output(lhs_vert + 3, horiz_base + i) = results[i].w;
713  }
714  } else if (lhs_vert + 2 < m_size) {
715  for (int i = 0; i < 4; i++) {
716  output(lhs_vert, horiz_base + i) = results[i].x;
717  output(lhs_vert + 1, horiz_base + i) = results[i].y;
718  output(lhs_vert + 2, horiz_base + i) = results[i].z;
719  }
720  } else if (lhs_vert + 1 < m_size) {
721  for (int i = 0; i < 4; i++) {
722  output(lhs_vert, horiz_base + i) = results[i].x;
723  output(lhs_vert + 1, horiz_base + i) = results[i].y;
724  }
725  } else if (lhs_vert < m_size) {
726  for (int i = 0; i < 4; i++) {
727  output(lhs_vert, horiz_base + i) = results[i].x;
728  }
729  }
730  } else if (!CHECK_LHS_BOUNDARY) {
731  // CHECK RHS
732  /*
733  int ncols_rem = fminf(n_size- horiz_base, 4);
734  for (int i = 0; i < ncols_rem; i++) {
735  output(lhs_vert, horiz_base + i) = results[i].x;
736  output(lhs_vert + 1, horiz_base + i) = results[i].y;
737  output(lhs_vert + 2, horiz_base + i) = results[i].z;
738  output(lhs_vert + 3, horiz_base + i) = results[i].w;
739  }*/
740  for (int i = 0; i < 4; i++) {
741  if (horiz_base + i < n_size) {
742  output(lhs_vert, horiz_base + i) = results[i].x;
743  output(lhs_vert + 1, horiz_base + i) = results[i].y;
744  output(lhs_vert + 2, horiz_base + i) = results[i].z;
745  output(lhs_vert + 3, horiz_base + i) = results[i].w;
746  }
747  }
748  } else {
749  // CHECK both boundaries.
750  for (int i = 0; i < 4; i++) {
751  if (horiz_base + i < n_size) {
752  if (lhs_vert < m_size) output(lhs_vert, horiz_base + i) = results[i].x;
753  if (lhs_vert + 1 < m_size) output(lhs_vert + 1, horiz_base + i) = results[i].y;
754  if (lhs_vert + 2 < m_size) output(lhs_vert + 2, horiz_base + i) = results[i].z;
755  if (lhs_vert + 3 < m_size) output(lhs_vert + 3, horiz_base + i) = results[i].w;
756  }
757  }
758  }
759 }
760 
761 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
762  bool CHECK_RHS_BOUNDARY>
763 __device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
764  const OutputMapper output, float2 lhs_shmem2[][32],
765  float2 rhs_shmem2[][8], const Index m_size,
766  const Index n_size, const Index k_size,
767  const Index base_m, const Index base_n) {
768  // prefetch registers
769  float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
770  float4 rhs_pf0, rhs_pf1;
771 
772  float4 results[8];
773  for (int i = 0; i < 8; i++) {
774  results[i].x = results[i].y = results[i].z = results[i].w = 0;
775  }
776 
777  Index lhs_vert = base_m + threadIdx.x * 4 + (threadIdx.y % 4) * 32;
778  for (Index k = 0; k < k_size; k += 32) {
779  lhs_pf0 = internal::pset1<float4>(0);
780  lhs_pf1 = internal::pset1<float4>(0);
781  lhs_pf2 = internal::pset1<float4>(0);
782  lhs_pf3 = internal::pset1<float4>(0);
783 
784  rhs_pf0 = internal::pset1<float4>(0);
785  rhs_pf1 = internal::pset1<float4>(0);
786 
787  if (!CHECK_LHS_BOUNDARY) {
788  if ((threadIdx.y / 4 + k + 24) < k_size) {
789  lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
790  lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
791  lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
792  lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
793  } else if ((threadIdx.y / 4 + k + 16) < k_size) {
794  lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
795  lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
796  lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
797  } else if ((threadIdx.y / 4 + k + 8) < k_size) {
798  lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
799  lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
800  } else if ((threadIdx.y / 4 + k) < k_size) {
801  lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
802  }
803  } else {
804  // just CHECK_LHS_BOUNDARY
805  if (lhs_vert + 3 < m_size) {
806  if ((threadIdx.y / 4 + k + 24) < k_size) {
807  lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
808  lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
809  lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
810  lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
811  } else if ((threadIdx.y / 4 + k + 16) < k_size) {
812  lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
813  lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
814  lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
815  } else if ((threadIdx.y / 4 + k + 8) < k_size) {
816  lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
817  lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
818  } else if ((threadIdx.y / 4 + k) < k_size) {
819  lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
820  }
821  } else if (lhs_vert + 2 < m_size) {
822  if ((threadIdx.y / 4 + k + 24) < k_size) {
823  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
824  lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
825  lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
826  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
827  lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
828  lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
829  lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
830  lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
831  lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
832  lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
833  lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
834  lhs_pf3.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 24));
835  } else if ((threadIdx.y / 4 + k + 16) < k_size) {
836  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
837  lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
838  lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
839  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
840  lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
841  lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
842  lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
843  lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
844  lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
845  } else if ((threadIdx.y / 4 + k + 8) < k_size) {
846  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
847  lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
848  lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
849  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
850  lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
851  lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
852  } else if ((threadIdx.y / 4 + k) < k_size) {
853  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
854  lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
855  lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
856  }
857  } else if (lhs_vert + 1 < m_size) {
858  if ((threadIdx.y / 4 + k + 24) < k_size) {
859  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
860  lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
861  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
862  lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
863  lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
864  lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
865  lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
866  lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
867  } else if ((threadIdx.y / 4 + k + 16) < k_size) {
868  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
869  lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
870  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
871  lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
872  lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
873  lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
874  } else if ((threadIdx.y / 4 + k + 8) < k_size) {
875  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
876  lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
877  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
878  lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
879  } else if ((threadIdx.y / 4 + k) < k_size) {
880  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
881  lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
882  }
883  } else if (lhs_vert < m_size) {
884  if ((threadIdx.y / 4 + k + 24) < k_size) {
885  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
886  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
887  lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
888  lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
889  } else if ((threadIdx.y / 4 + k + 16) < k_size) {
890  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
891  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
892  lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
893  } else if ((threadIdx.y / 4 + k + 8) < k_size) {
894  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
895  lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
896  } else if ((threadIdx.y / 4 + k) < k_size) {
897  lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
898  }
899  }
900  }
901  __syncthreads();
902  Index rhs_vert = k + threadIdx.x * 4;
903  Index rhs_horiz0 = threadIdx.y * 2 + base_n;
904  Index rhs_horiz1 = threadIdx.y * 2 + 1 + base_n;
905  if (!CHECK_RHS_BOUNDARY) {
906  if ((rhs_vert + 3) < k_size) {
907  // just CHECK_RHS_BOUNDARY
908  rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
909  rhs_pf1 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz1);
910  } else if (rhs_vert + 2 < k_size) {
911  // just CHECK_RHS_BOUNDARY
912  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
913  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
914  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
915  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
916  rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
917  rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
918  } else if (rhs_vert + 1 < k_size) {
919  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
920  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
921  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
922  rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
923  } else if (rhs_vert < k_size) {
924  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
925  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
926  }
927  } else {
928  if (rhs_horiz1 < n_size) {
929  if ((rhs_vert + 3) < k_size) {
930  // just CHECK_RHS_BOUNDARY
931  rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
932  rhs_pf1 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz1);
933  } else if (rhs_vert + 2 < k_size) {
934  // just CHECK_RHS_BOUNDARY
935  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
936  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
937  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
938  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
939  rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
940  rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
941  } else if (k + threadIdx.x * 4 + 1 < k_size) {
942  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
943  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
944  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
945  rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
946  } else if (k + threadIdx.x * 4 < k_size) {
947  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
948  rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
949  }
950  } else if (rhs_horiz0 < n_size) {
951  if ((rhs_vert + 3) < k_size) {
952  // just CHECK_RHS_BOUNDARY
953  rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
954  } else if ((rhs_vert + 2) < k_size) {
955  // just CHECK_RHS_BOUNDARY
956  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
957  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
958  rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
959  } else if ((rhs_vert + 1) < k_size) {
960  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
961  rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
962  } else if (rhs_vert < k_size) {
963  rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
964  }
965  }
966  }
967  __syncthreads();
968  // Loaded. Do computation
969  // Row 0 -> times (0, 4, 8, .. 28) for features 0, 1.
970  // Row 1 -> times (0, 4, 8, .. 28) for features 2, 3.
971  // ..
972  // Row 31 -> times (0, 4, 8, .. 28) for features 62, 63
973  rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
974  // Row 32 -> times (1, 5, 9, .. 29) for features 0, 1.
975  // Row 33 -> times (1, 5, 9, .. 29) for features 2, 3.
976  // ..
977  rhs_shmem2[threadIdx.y + 32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
978  // Row 64 -> times (2, 6, 10, .. 30) for features 0, 1.
979  // Row 65 -> times (2, 6, 10, .. 30) for features 2, 3.
980  rhs_shmem2[threadIdx.y + 64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
981  // Row 96 -> times (3, 7, 11, .. 31) for features 0, 1.
982  // Row 97 -> times (3, 7, 11, .. 31) for features 2, 3.
983  rhs_shmem2[threadIdx.y + 96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
984 
985  // LHS.
986  // Row 0 (time 0) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
987  // Row 1 (time 1) -> features (0, 1), (4, 5), .. (28, 29), (32, 33), .. (60, 61) .. (124, 125)
988  // ...
989  // Row 8 (time 0) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127)
990  // Row 15 (time 7) -> features (2, 3), (6, 7), .. (30, 31), (34, 35), .. (62, 63) .. (126, 127)
991 
992 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4) \
993  results[0].x += a_feat1.x * f1.x; \
994  results[1].x += a_feat1.x * f1.y; \
995  results[2].x += a_feat1.x * f2.x; \
996  results[3].x += a_feat1.x * f2.y; \
997  results[4].x += a_feat1.x * f3.x; \
998  results[5].x += a_feat1.x * f3.y; \
999  results[6].x += a_feat1.x * f4.x; \
1000  results[7].x += a_feat1.x * f4.y; \
1001  \
1002  results[0].y += a_feat1.y * f1.x; \
1003  results[1].y += a_feat1.y * f1.y; \
1004  results[2].y += a_feat1.y * f2.x; \
1005  results[3].y += a_feat1.y * f2.y; \
1006  results[4].y += a_feat1.y * f3.x; \
1007  results[5].y += a_feat1.y * f3.y; \
1008  results[6].y += a_feat1.y * f4.x; \
1009  results[7].y += a_feat1.y * f4.y; \
1010  \
1011  results[0].z += a_feat2.x * f1.x; \
1012  results[1].z += a_feat2.x * f1.y; \
1013  results[2].z += a_feat2.x * f2.x; \
1014  results[3].z += a_feat2.x * f2.y; \
1015  results[4].z += a_feat2.x * f3.x; \
1016  results[5].z += a_feat2.x * f3.y; \
1017  results[6].z += a_feat2.x * f4.x; \
1018  results[7].z += a_feat2.x * f4.y; \
1019  \
1020  results[0].w += a_feat2.y * f1.x; \
1021  results[1].w += a_feat2.y * f1.y; \
1022  results[2].w += a_feat2.y * f2.x; \
1023  results[3].w += a_feat2.y * f2.y; \
1024  results[4].w += a_feat2.y * f3.x; \
1025  results[5].w += a_feat2.y * f3.y; \
1026  results[6].w += a_feat2.y * f4.x; \
1027  results[7].w += a_feat2.y * f4.y;
1028 
1029  lhs_shmem2[threadIdx.y / 4][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.x, lhs_pf0.y);
1030  lhs_shmem2[threadIdx.y / 4 + 8][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.x, lhs_pf1.y);
1031  lhs_shmem2[threadIdx.y / 4 + 16][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.x, lhs_pf2.y);
1032  lhs_shmem2[threadIdx.y / 4 + 24][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.x, lhs_pf3.y);
1033 
1034  lhs_shmem2[threadIdx.y / 4 + 32][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.z, lhs_pf0.w);
1035  lhs_shmem2[threadIdx.y / 4 + 40][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.z, lhs_pf1.w);
1036  lhs_shmem2[threadIdx.y / 4 + 48][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.z, lhs_pf2.w);
1037  lhs_shmem2[threadIdx.y / 4 + 56][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.z, lhs_pf3.w);
1038 
1039  __syncthreads();
1040 
1041 // Do the multiplies.
1042 #pragma unroll
1043  for (int koff = 0; koff < 32; koff++) {
1044  float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
1045  float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
1046 
1047  // first feature is at (threadIdx.y/4) * 8 last is at start + 8.
1048  int start_feature = (threadIdx.y / 4) * 8;
1049 
1050  float2 br1 = rhs_shmem2[start_feature / 2 + (koff % 4) * 32][koff / 4];
1051  float2 br2 = rhs_shmem2[start_feature / 2 + 1 + (koff % 4) * 32][koff / 4];
1052  float2 br3 = rhs_shmem2[start_feature / 2 + 2 + (koff % 4) * 32][koff / 4];
1053  float2 br4 = rhs_shmem2[start_feature / 2 + 3 + (koff % 4) * 32][koff / 4];
1054 
1055  add_vals(a3, a4, br1, br2, br3, br4)
1056  }
1057  __syncthreads();
1058  } // end loop over k
1059 
1060 #undef add_vals
1061 
1062  __syncthreads();
1063  Index horiz_base = (threadIdx.y / 4) * 8 + base_n;
1064  if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1065  for (int i = 0; i < 8; i++) {
1066  output(lhs_vert, horiz_base + i) = results[i].x;
1067  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1068  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1069  output(lhs_vert + 3, horiz_base + i) = results[i].w;
1070  }
1071  } else if (!CHECK_RHS_BOUNDARY) {
1072  if (lhs_vert + 3 < m_size) {
1073  for (int i = 0; i < 8; i++) {
1074  output(lhs_vert, horiz_base + i) = results[i].x;
1075  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1076  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1077  output(lhs_vert + 3, horiz_base + i) = results[i].w;
1078  }
1079  } else if (lhs_vert + 2 < m_size) {
1080  for (int i = 0; i < 8; i++) {
1081  output(lhs_vert, horiz_base + i) = results[i].x;
1082  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1083  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1084  }
1085  } else if (lhs_vert + 1 < m_size) {
1086  for (int i = 0; i < 8; i++) {
1087  output(lhs_vert, horiz_base + i) = results[i].x;
1088  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1089  }
1090  } else if (lhs_vert < m_size) {
1091  for (int i = 0; i < 8; i++) {
1092  output(lhs_vert, horiz_base + i) = results[i].x;
1093  }
1094  }
1095  } else if (!CHECK_LHS_BOUNDARY) {
1096  // CHECK BOUNDARY_B
1097  for (int i = 0; i < 8; i++) {
1098  if (horiz_base + i < n_size) {
1099  output(lhs_vert, horiz_base + i) = results[i].x;
1100  output(lhs_vert + 1, horiz_base + i) = results[i].y;
1101  output(lhs_vert + 2, horiz_base + i) = results[i].z;
1102  output(lhs_vert + 3, horiz_base + i) = results[i].w;
1103  }
1104  }
1105  } else {
1106  // CHECK both boundaries.
1107  for (int i = 0; i < 8; i++) {
1108  if (horiz_base + i < n_size) {
1109  if (lhs_vert < m_size) output(lhs_vert, horiz_base + i) = results[i].x;
1110  if (lhs_vert + 1 < m_size) output(lhs_vert + 1, horiz_base + i) = results[i].y;
1111  if (lhs_vert + 2 < m_size) output(lhs_vert + 2, horiz_base + i) = results[i].z;
1112  if (lhs_vert + 3 < m_size) output(lhs_vert + 3, horiz_base + i) = results[i].w;
1113  }
1114  }
1115  }
1116 }
1117 
1118 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
1119 __global__ void
1120 #if defined(EIGEN_HIPCC)
1121 __launch_bounds__(256, 1)
1122 #else
1123 __launch_bounds__(256)
1124 #endif
1125  EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size,
1126  const Index n_size, const Index k_size) {
1127  __shared__ float2 lhs_shmem[64 * 32];
1128  __shared__ float2 rhs_shmem[128 * 8];
1129 
1130  typedef float2 LHS_MEM[64][32];
1131  typedef float2 RHS_MEM[128][8];
1132 
1133  const Index m_block_idx = blockIdx.x;
1134  const Index n_block_idx = blockIdx.y;
1135 
1136  const Index base_m = 128 * m_block_idx;
1137  const Index base_n = 64 * n_block_idx;
1138 
1139  bool check_rhs = (base_n + 63) >= n_size;
1140  bool check_lhs128 = (base_m + 127) >= m_size;
1141 
1142  if (!check_rhs) {
1143  if (!check_lhs128) {
1144  // >= 128 rows left
1145  EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1146  lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1147  } else {
1148  EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1149  lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1150  }
1151  } else {
1152  if (!check_lhs128) {
1153  // >= 128 rows left
1154  EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1155  lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1156  } else {
1157  EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1158  lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1159  }
1160  }
1161 }
1162 
1163 template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
1164 __global__ void
1165 #if defined(EIGEN_HIPCC)
1166 __launch_bounds__(256, 1)
1167 #else
1168 __launch_bounds__(256)
1169 #endif
1170  EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output,
1171  const Index m_size, const Index n_size, const Index k_size) {
1172  __shared__ float2 lhs_shmem[32][16];
1173  __shared__ float2 rhs_shmem[64][8];
1174 
1175  const Index m_block_idx = blockIdx.x;
1176  const Index n_block_idx = blockIdx.y;
1177 
1178  const Index base_m = 64 * m_block_idx;
1179  const Index base_n = 64 * n_block_idx;
1180 
1181  if (base_m + 63 < m_size) {
1182  if (base_n + 63 < n_size) {
1183  EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1184  lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1185  } else {
1186  EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1187  lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1188  }
1189  } else {
1190  if (base_n + 63 < n_size) {
1191  EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1192  lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1193  } else {
1194  EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1195  lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1196  }
1197  }
1198 }
1199 
1200 template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
1201 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice>
1202  : public TensorContractionEvaluatorBase<TensorEvaluator<
1203  const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1204  typedef GpuDevice Device;
1205 
1206  typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1207  typedef TensorContractionEvaluatorBase<Self> Base;
1208 
1209  typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
1210  typedef std::remove_const_t<typename XprType::Scalar> Scalar;
1211  typedef typename XprType::Index Index;
1212  typedef typename XprType::CoeffReturnType CoeffReturnType;
1214 
1216 
1217  // Most of the code is assuming that both input tensors are ColMajor. If the
1218  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
1219  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
1220  // will pretend B is LHS and A is RHS.
1221  typedef std::conditional_t<Layout == static_cast<int>(ColMajor), LeftArgType, RightArgType> EvalLeftArgType;
1222  typedef std::conditional_t<Layout == static_cast<int>(ColMajor), RightArgType, LeftArgType> EvalRightArgType;
1223 
1224  static constexpr int LDims =
1225  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
1226  static constexpr int RDims =
1227  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
1228  static constexpr int ContractDims = internal::array_size<Indices>::value;
1229 
1230  typedef array<Index, LDims> left_dim_mapper_t;
1231  typedef array<Index, RDims> right_dim_mapper_t;
1232 
1233  typedef array<Index, ContractDims> contract_t;
1234  typedef array<Index, LDims - ContractDims> left_nocontract_t;
1235  typedef array<Index, RDims - ContractDims> right_nocontract_t;
1236 
1237  static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
1238 
1239  typedef DSizes<Index, NumDims> Dimensions;
1240 
1241  // typedefs needed in evalTo
1242  typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
1243  typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
1244 
1245  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1246  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1247 
1248  typedef typename LeftEvaluator::Dimensions LeftDimensions;
1249  typedef typename RightEvaluator::Dimensions RightDimensions;
1250 
1251  TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {
1253  GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1254  }
1255 
1256  // We need to redefine this method to make nvcc happy
1258  this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1259  this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1260  if (data) {
1261  evalTo(data);
1262  return false;
1263  } else {
1264  this->m_result = static_cast<Scalar*>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
1265  evalTo(this->m_result);
1266  return true;
1267  }
1268  }
1269 
1270  void evalTo(Scalar* buffer) const {
1271  if (this->m_lhs_inner_dim_contiguous) {
1272  if (this->m_rhs_inner_dim_contiguous) {
1273  if (this->m_rhs_inner_dim_reordered) {
1274  evalTyped<true, true, true, Unaligned>(buffer);
1275  } else {
1276  evalTyped<true, true, false, Unaligned>(buffer);
1277  }
1278  } else {
1279  if (this->m_rhs_inner_dim_reordered) {
1280  evalTyped<true, false, true, Unaligned>(buffer);
1281  } else {
1282  evalTyped<true, false, false, Unaligned>(buffer);
1283  }
1284  }
1285  } else {
1286  if (this->m_rhs_inner_dim_contiguous) {
1287  if (this->m_rhs_inner_dim_reordered) {
1288  evalTyped<false, true, true, Unaligned>(buffer);
1289  } else {
1290  evalTyped<false, true, false, Unaligned>(buffer);
1291  }
1292  } else {
1293  if (this->m_rhs_inner_dim_reordered) {
1294  evalTyped<false, false, true, Unaligned>(buffer);
1295  } else {
1296  evalTyped<false, false, false, Unaligned>(buffer);
1297  }
1298  }
1299  }
1300  }
1301 
1302  template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper,
1303  typename OutputMapper>
1304  struct LaunchKernels {
1305  static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k,
1306  const GpuDevice& device) {
1307  const Index m_blocks = (m + 63) / 64;
1308  const Index n_blocks = (n + 63) / 64;
1309  const dim3 num_blocks(m_blocks, n_blocks, 1);
1310  const dim3 block_size(8, 8, 8);
1311  LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1312  block_size, 0, device, lhs, rhs, output, m, n, k);
1313  }
1314  };
1315 
1316  template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
1317  struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
1318  static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k,
1319  const GpuDevice& device) {
1320  if (m < 768 || n < 768) {
1321  const Index m_blocks = (m + 63) / 64;
1322  const Index n_blocks = (n + 63) / 64;
1323  const dim3 num_blocks(m_blocks, n_blocks, 1);
1324  const dim3 block_size(16, 16, 1);
1325  LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1326  block_size, 0, device, lhs, rhs, output, m, n, k);
1327  } else {
1328  const Index m_blocks = (m + 127) / 128;
1329  const Index n_blocks = (n + 63) / 64;
1330  const dim3 num_blocks(m_blocks, n_blocks, 1);
1331  const dim3 block_size(8, 32, 1);
1332  LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1333  block_size, 0, device, lhs, rhs, output, m, n, k);
1334  }
1335  }
1336  };
1337 
1338  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
1339  void evalTyped(Scalar* buffer) const {
1340  // columns in left side, rows in right side
1341  const Index k = this->m_k_size;
1343 
1344  // rows in left side
1345  const Index m = this->m_i_size;
1346 
1347  // columns in right side
1348  const Index n = this->m_j_size;
1349 
1350  // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar))
1351  this->m_device.fill(buffer, buffer + m * n, Scalar(0));
1352 
1353  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
1354  contract_t, 4, lhs_inner_dim_contiguous, false, Unaligned>
1355  LhsMapper;
1356 
1357  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
1358  contract_t, 4, rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1359  Unaligned>
1360  RhsMapper;
1361 
1362  typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1363 
1364  // initialize data mappers
1365  LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1366  this->m_left_contracting_strides, this->m_k_strides);
1367 
1368  RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1369  this->m_right_contracting_strides, this->m_k_strides);
1370 
1371  OutputMapper output(buffer, m);
1372 
1373 #if defined(EIGEN_USE_HIP)
1374  setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1375 #else
1376  setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1377 #endif
1378 
1379  LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k,
1380  this->m_device);
1381  }
1382 };
1383 
1384 } // end namespace Eigen
1385 
1386 #endif // EIGEN_USE_GPU and EIGEN_GPUCC
1387 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
SCALAR Scalar
Definition: bench_gemm.cpp:45
std::map< std::string, Array< float, 1, 8, DontAlign|RowMajor > > results
Definition: dense_solvers.cpp:10
dim3 threadIdx
Definition: gpu_common.h:16
dim3 blockIdx
Definition: gpu_common.h:16
@ Unaligned
Definition: Constants.h:235
@ ColMajor
Definition: Constants.h:318
int * m
Definition: level2_cplx_impl.h:294
if(UPLO(*uplo)==INVALID) info
Definition: level3_impl.h:428
char char char int int * k
Definition: level2_impl.h:374
char char * op
Definition: level2_impl.h:374
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Vector< double > x1(const Vector< double > &coord)
Cartesian coordinates centered at the point (0.5,1)
Definition: poisson/poisson_with_singularity/two_d_poisson.cc:86
Vector< double > x2(const Vector< double > &coord)
Cartesian coordinates centered at the point (1.5,1)
Definition: poisson/poisson_with_singularity/two_d_poisson.cc:102
val
Definition: calibrate.py:119
CwiseBinaryOp< internal::scalar_sum_op< double, double >, const CpyMatrixXd, const CpyMatrixXd > XprType
Definition: nestbyvalue.cpp:15
void output(std::ostream &outfile, const unsigned &nplot)
Overload output function.
Definition: overloaded_element_body.h:490
internal::packet_traits< Scalar >::type type
Definition: TensorMeta.h:48
static constexpr int Layout
Definition: TensorEvaluator.h:46
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType dest)
Definition: TensorEvaluator.h:71
Derived::Scalar Scalar
Definition: TensorEvaluator.h:33
const Device EIGEN_DEVICE_REF m_device
Definition: TensorEvaluator.h:170
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived &m, const Device &device)
Definition: TensorEvaluator.h:66
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const
Definition: TensorEvaluator.h:165
Derived::Scalar CoeffReturnType
Definition: TensorEvaluator.h:34
Derived XprType
Definition: TensorEvaluator.h:37
Derived::Index Index
Definition: TensorEvaluator.h:32
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorEvaluator.h:35
Derived::Dimensions Dimensions
Definition: TensorEvaluator.h:36
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorEvaluator.h:69
static constexpr Index value
Definition: Meta.h:306
@ value
Definition: Meta.h:206
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2