12 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
13 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
15 #if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
18 #include "./InternalHeaderCheck.h"
22 template <
typename Scalar,
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper,
23 bool needs_edge_check>
24 __device__
EIGEN_STRONG_INLINE void EigenContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
31 const Index base_m = 64 * m_block_idx;
32 const Index base_n = 64 * n_block_idx;
72 const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
73 const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
74 const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
75 const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
76 const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
77 const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
78 const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
79 const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
81 const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
82 const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
83 const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
84 const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
85 const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
86 const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
87 const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
88 const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
100 const Index lhs_vert = base_m + load_idx_vert;
102 #define prefetchIntoRegisters(base_k) \
122 if (!needs_edge_check || lhs_vert < m_size) { \
123 const Index lhs_horiz_0 = base_k + threadIdx.z + 0 * 8; \
124 const Index lhs_horiz_1 = base_k + threadIdx.z + 1 * 8; \
125 const Index lhs_horiz_2 = base_k + threadIdx.z + 2 * 8; \
126 const Index lhs_horiz_3 = base_k + threadIdx.z + 3 * 8; \
127 const Index lhs_horiz_4 = base_k + threadIdx.z + 4 * 8; \
128 const Index lhs_horiz_5 = base_k + threadIdx.z + 5 * 8; \
129 const Index lhs_horiz_6 = base_k + threadIdx.z + 6 * 8; \
130 const Index lhs_horiz_7 = base_k + threadIdx.z + 7 * 8; \
132 if (!needs_edge_check || lhs_horiz_7 < k_size) { \
133 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
134 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
135 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
136 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
137 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
138 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
139 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
140 lhs_pf7 = lhs(lhs_vert, lhs_horiz_7); \
141 } else if (lhs_horiz_6 < k_size) { \
142 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
143 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
144 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
145 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
146 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
147 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
148 lhs_pf6 = lhs(lhs_vert, lhs_horiz_6); \
149 } else if (lhs_horiz_5 < k_size) { \
150 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
151 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
152 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
153 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
154 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
155 lhs_pf5 = lhs(lhs_vert, lhs_horiz_5); \
156 } else if (lhs_horiz_4 < k_size) { \
157 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
158 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
159 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
160 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
161 lhs_pf4 = lhs(lhs_vert, lhs_horiz_4); \
162 } else if (lhs_horiz_3 < k_size) { \
163 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
164 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
165 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
166 lhs_pf3 = lhs(lhs_vert, lhs_horiz_3); \
167 } else if (lhs_horiz_2 < k_size) { \
168 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
169 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
170 lhs_pf2 = lhs(lhs_vert, lhs_horiz_2); \
171 } else if (lhs_horiz_1 < k_size) { \
172 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
173 lhs_pf1 = lhs(lhs_vert, lhs_horiz_1); \
174 } else if (lhs_horiz_0 < k_size) { \
175 lhs_pf0 = lhs(lhs_vert, lhs_horiz_0); \
179 const Index rhs_vert = base_k + load_idx_vert; \
180 if (!needs_edge_check || rhs_vert < k_size) { \
181 const Index rhs_horiz_0 = base_n + threadIdx.z + 0 * 8; \
182 const Index rhs_horiz_1 = base_n + threadIdx.z + 1 * 8; \
183 const Index rhs_horiz_2 = base_n + threadIdx.z + 2 * 8; \
184 const Index rhs_horiz_3 = base_n + threadIdx.z + 3 * 8; \
185 const Index rhs_horiz_4 = base_n + threadIdx.z + 4 * 8; \
186 const Index rhs_horiz_5 = base_n + threadIdx.z + 5 * 8; \
187 const Index rhs_horiz_6 = base_n + threadIdx.z + 6 * 8; \
188 const Index rhs_horiz_7 = base_n + threadIdx.z + 7 * 8; \
190 if (rhs_horiz_7 < n_size) { \
191 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
192 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
193 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
194 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
195 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
196 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
197 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
198 rhs_pf7 = rhs(rhs_vert, rhs_horiz_7); \
199 } else if (rhs_horiz_6 < n_size) { \
200 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
201 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
202 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
203 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
204 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
205 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
206 rhs_pf6 = rhs(rhs_vert, rhs_horiz_6); \
207 } else if (rhs_horiz_5 < n_size) { \
208 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
209 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
210 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
211 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
212 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
213 rhs_pf5 = rhs(rhs_vert, rhs_horiz_5); \
214 } else if (rhs_horiz_4 < n_size) { \
215 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
216 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
217 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
218 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
219 rhs_pf4 = rhs(rhs_vert, rhs_horiz_4); \
220 } else if (rhs_horiz_3 < n_size) { \
221 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
222 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
223 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
224 rhs_pf3 = rhs(rhs_vert, rhs_horiz_3); \
225 } else if (rhs_horiz_2 < n_size) { \
226 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
227 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
228 rhs_pf2 = rhs(rhs_vert, rhs_horiz_2); \
229 } else if (rhs_horiz_1 < n_size) { \
230 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
231 rhs_pf1 = rhs(rhs_vert, rhs_horiz_1); \
232 } else if (rhs_horiz_0 < n_size) { \
233 rhs_pf0 = rhs(rhs_vert, rhs_horiz_0); \
238 #define writeRegToShmem() \
239 lhs_shmem[lhs_store_idx_0] = lhs_pf0; \
240 rhs_shmem[rhs_store_idx_0] = rhs_pf0; \
242 lhs_shmem[lhs_store_idx_1] = lhs_pf1; \
243 rhs_shmem[rhs_store_idx_1] = rhs_pf1; \
245 lhs_shmem[lhs_store_idx_2] = lhs_pf2; \
246 rhs_shmem[rhs_store_idx_2] = rhs_pf2; \
248 lhs_shmem[lhs_store_idx_3] = lhs_pf3; \
249 rhs_shmem[rhs_store_idx_3] = rhs_pf3; \
251 lhs_shmem[lhs_store_idx_4] = lhs_pf4; \
252 rhs_shmem[rhs_store_idx_4] = rhs_pf4; \
254 lhs_shmem[lhs_store_idx_5] = lhs_pf5; \
255 rhs_shmem[rhs_store_idx_5] = rhs_pf5; \
257 lhs_shmem[lhs_store_idx_6] = lhs_pf6; \
258 rhs_shmem[rhs_store_idx_6] = rhs_pf6; \
260 lhs_shmem[lhs_store_idx_7] = lhs_pf7; \
261 rhs_shmem[rhs_store_idx_7] = rhs_pf7;
264 #define res(i, j) _res_##i##j
265 #define initResultRow(i) \
266 Scalar res(i, 0) = conv(0); \
267 Scalar res(i, 1) = conv(0); \
268 Scalar res(i, 2) = conv(0); \
269 Scalar res(i, 3) = conv(0); \
270 Scalar res(i, 4) = conv(0); \
271 Scalar res(i, 5) = conv(0); \
272 Scalar res(i, 6) = conv(0); \
273 Scalar res(i, 7) = conv(0);
275 internal::scalar_cast_op<int, Scalar> conv;
286 for (
Index base_k = 0; base_k < k_size; base_k += 64) {
291 prefetchIntoRegisters(base_k);
294 #undef prefetchIntoRegisters
295 #undef writeRegToShmem
303 #define lcol(i) _lcol##i
313 #define rrow(j) _rrow##j
327 #define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
328 #define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
330 #define loadData(i, j) \
331 lcol(0) = lhs_element(0, j); \
332 rrow(0) = rhs_element(i, 0); \
333 lcol(1) = lhs_element(1, j); \
334 rrow(1) = rhs_element(i, 1); \
335 lcol(2) = lhs_element(2, j); \
336 rrow(2) = rhs_element(i, 2); \
337 lcol(3) = lhs_element(3, j); \
338 rrow(3) = rhs_element(i, 3); \
339 lcol(4) = lhs_element(4, j); \
340 rrow(4) = rhs_element(i, 4); \
341 lcol(5) = lhs_element(5, j); \
342 rrow(5) = rhs_element(i, 5); \
343 lcol(6) = lhs_element(6, j); \
344 rrow(6) = rhs_element(i, 6); \
345 lcol(7) = lhs_element(7, j); \
346 rrow(7) = rhs_element(i, 7);
348 #define computeCol(j) \
349 res(0, j) += lcol(0) * rrow(j); \
350 res(1, j) += lcol(1) * rrow(j); \
351 res(2, j) += lcol(2) * rrow(j); \
352 res(3, j) += lcol(3) * rrow(j); \
353 res(4, j) += lcol(4) * rrow(j); \
354 res(5, j) += lcol(5) * rrow(j); \
355 res(6, j) += lcol(6) * rrow(j); \
356 res(7, j) += lcol(7) * rrow(j);
358 #define computePass(i) \
393 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
394 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
396 #define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
399 #define reduceRow(i, mask) \
400 shuffleInc(i, 0, mask); \
401 shuffleInc(i, 1, mask); \
402 shuffleInc(i, 2, mask); \
403 shuffleInc(i, 3, mask); \
404 shuffleInc(i, 4, mask); \
405 shuffleInc(i, 5, mask); \
406 shuffleInc(i, 6, mask); \
407 shuffleInc(i, 7, mask);
409 #define reduceMatrix(mask) \
410 reduceRow(0, mask); \
411 reduceRow(1, mask); \
412 reduceRow(2, mask); \
413 reduceRow(3, mask); \
414 reduceRow(4, mask); \
415 reduceRow(5, mask); \
416 reduceRow(6, mask); \
444 #define writeResultShmem(i, j) lhs_shmem[i + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j] = res(i, j);
446 #define writeRow(i) \
447 writeResultShmem(i, 0); \
448 writeResultShmem(i, 1); \
449 writeResultShmem(i, 2); \
450 writeResultShmem(i, 3); \
451 writeResultShmem(i, 4); \
452 writeResultShmem(i, 5); \
453 writeResultShmem(i, 6); \
454 writeResultShmem(i, 7);
466 #undef writeResultShmem
473 if (max_j_write == 8) {
494 for (
int j = 0;
j < max_j_write;
j++) {
503 template <
typename Scalar,
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
505 #if defined(EIGEN_HIPCC)
506 __launch_bounds__(512, 1)
508 __launch_bounds__(512)
510 EigenContractionKernel(
const LhsMapper lhs,
const RhsMapper rhs,
const OutputMapper
output,
const Index m_size,
512 __shared__
Scalar lhs_shmem[72 * 64];
513 __shared__
Scalar rhs_shmem[72 * 64];
518 const Index base_m = 64 * m_block_idx;
519 const Index base_n = 64 * n_block_idx;
521 if (base_m + 63 < m_size && base_n + 63 < n_size) {
522 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(
523 lhs, rhs,
output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
525 EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(
526 lhs, rhs,
output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
530 template <
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
531 bool CHECK_RHS_BOUNDARY>
532 __device__ __forceinline__
void EigenFloatContractionKernelInternal16x16(
const LhsMapper lhs,
const RhsMapper rhs,
533 const OutputMapper
output,
534 float2 lhs_shmem2[][16],
535 float2 rhs_shmem2[][8],
const Index m_size,
539 float4 lhs_pf0, rhs_pf0;
542 for (
int i = 0;
i < 4;
i++) {
546 #define prefetch_lhs(reg, row, col) \
547 if (!CHECK_LHS_BOUNDARY) { \
548 if (col < k_size) { \
549 reg = lhs.template loadPacket<float4, Unaligned>(row, col); \
552 if (col < k_size) { \
553 if (row + 3 < m_size) { \
554 reg = lhs.template loadPacket<float4, Unaligned>(row, col); \
555 } else if (row + 2 < m_size) { \
556 reg.x = lhs(row + 0, col); \
557 reg.y = lhs(row + 1, col); \
558 reg.z = lhs(row + 2, col); \
559 } else if (row + 1 < m_size) { \
560 reg.x = lhs(row + 0, col); \
561 reg.y = lhs(row + 1, col); \
562 } else if (row < m_size) { \
563 reg.x = lhs(row + 0, col); \
570 for (
Index k = 0;
k < k_size;
k += 16) {
571 lhs_pf0 = internal::pset1<float4>(0);
572 rhs_pf0 = internal::pset1<float4>(0);
575 prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
580 if (!CHECK_RHS_BOUNDARY) {
581 if ((rhs_vert + 3) < k_size) {
583 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
584 }
else if (rhs_vert + 2 < k_size) {
586 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
587 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
588 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
589 }
else if (rhs_vert + 1 < k_size) {
590 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
591 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
592 }
else if (rhs_vert < k_size) {
593 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
596 if (rhs_horiz0 < n_size) {
597 if ((rhs_vert + 3) < k_size) {
598 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
599 }
else if ((rhs_vert + 2) < k_size) {
600 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
601 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
602 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
603 }
else if ((rhs_vert + 1) < k_size) {
604 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
605 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
606 }
else if (rhs_vert < k_size) {
607 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
620 #if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
621 x1 = __shfl_xor(
x1, 4);
622 x2 = __shfl_xor(
x2, 4);
624 x1 = __shfl_xor_sync(0xFFFFFFFF,
x1, 4);
625 x2 = __shfl_xor_sync(0xFFFFFFFF,
x2, 4);
655 #define add_vals(fl1, fl2, fr1, fr2) \
656 results[0].x += fl1.x * fr1.x; \
657 results[0].y += fl1.y * fr1.x; \
658 results[0].z += fl2.x * fr1.x; \
659 results[0].w += fl2.y * fr1.x; \
661 results[1].x += fl1.x * fr1.y; \
662 results[1].y += fl1.y * fr1.y; \
663 results[1].z += fl2.x * fr1.y; \
664 results[1].w += fl2.y * fr1.y; \
666 results[2].x += fl1.x * fr2.x; \
667 results[2].y += fl1.y * fr2.x; \
668 results[2].z += fl2.x * fr2.x; \
669 results[2].w += fl2.y * fr2.x; \
671 results[3].x += fl1.x * fr2.y; \
672 results[3].y += fl1.y * fr2.y; \
673 results[3].z += fl2.x * fr2.y; \
674 results[3].w += fl2.y * fr2.y;
680 for (
int koff = 0; koff < 16; koff++) {
682 float2 fl1 = lhs_shmem2[koff][
threadIdx.x];
683 float2 fl2 = lhs_shmem2[koff + 16][
threadIdx.x];
686 float2 fr1 = rhs_shmem2[(start_feature >> 1) + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
687 float2 fr2 = rhs_shmem2[(start_feature >> 1) + 1 + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
689 add_vals(fl1, fl2, fr1, fr2)
698 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
699 for (
int i = 0;
i < 4;
i++) {
705 }
else if (!CHECK_RHS_BOUNDARY) {
707 if (lhs_vert + 3 < m_size) {
708 for (
int i = 0;
i < 4;
i++) {
714 }
else if (lhs_vert + 2 < m_size) {
715 for (
int i = 0;
i < 4;
i++) {
720 }
else if (lhs_vert + 1 < m_size) {
721 for (
int i = 0;
i < 4;
i++) {
725 }
else if (lhs_vert < m_size) {
726 for (
int i = 0;
i < 4;
i++) {
730 }
else if (!CHECK_LHS_BOUNDARY) {
740 for (
int i = 0;
i < 4;
i++) {
741 if (horiz_base +
i < n_size) {
750 for (
int i = 0;
i < 4;
i++) {
751 if (horiz_base +
i < n_size) {
752 if (lhs_vert < m_size)
output(lhs_vert, horiz_base +
i) =
results[
i].x;
753 if (lhs_vert + 1 < m_size)
output(lhs_vert + 1, horiz_base +
i) =
results[
i].y;
754 if (lhs_vert + 2 < m_size)
output(lhs_vert + 2, horiz_base +
i) =
results[
i].z;
755 if (lhs_vert + 3 < m_size)
output(lhs_vert + 3, horiz_base +
i) =
results[
i].w;
761 template <
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper,
bool CHECK_LHS_BOUNDARY,
762 bool CHECK_RHS_BOUNDARY>
763 __device__ __forceinline__
void EigenFloatContractionKernelInternal(
const LhsMapper lhs,
const RhsMapper rhs,
764 const OutputMapper
output, float2 lhs_shmem2[][32],
765 float2 rhs_shmem2[][8],
const Index m_size,
769 float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
770 float4 rhs_pf0, rhs_pf1;
773 for (
int i = 0;
i < 8;
i++) {
778 for (
Index k = 0;
k < k_size;
k += 32) {
779 lhs_pf0 = internal::pset1<float4>(0);
780 lhs_pf1 = internal::pset1<float4>(0);
781 lhs_pf2 = internal::pset1<float4>(0);
782 lhs_pf3 = internal::pset1<float4>(0);
784 rhs_pf0 = internal::pset1<float4>(0);
785 rhs_pf1 = internal::pset1<float4>(0);
787 if (!CHECK_LHS_BOUNDARY) {
789 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k));
790 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 8));
791 lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 16));
792 lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 24));
793 }
else if ((
threadIdx.y / 4 +
k + 16) < k_size) {
794 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k));
795 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 8));
796 lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 16));
797 }
else if ((
threadIdx.y / 4 +
k + 8) < k_size) {
798 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k));
799 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 8));
801 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k));
805 if (lhs_vert + 3 < m_size) {
807 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k));
808 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 8));
809 lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 16));
810 lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 24));
811 }
else if ((
threadIdx.y / 4 +
k + 16) < k_size) {
812 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k));
813 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 8));
814 lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 16));
815 }
else if ((
threadIdx.y / 4 +
k + 8) < k_size) {
816 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k));
817 lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k + 8));
819 lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (
threadIdx.y / 4 +
k));
821 }
else if (lhs_vert + 2 < m_size) {
823 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
824 lhs_pf0.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k));
825 lhs_pf0.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k));
826 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
827 lhs_pf1.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 8));
828 lhs_pf1.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k + 8));
829 lhs_pf2.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 16));
830 lhs_pf2.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 16));
831 lhs_pf2.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k + 16));
832 lhs_pf3.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 24));
833 lhs_pf3.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 24));
834 lhs_pf3.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k + 24));
835 }
else if ((
threadIdx.y / 4 +
k + 16) < k_size) {
836 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
837 lhs_pf0.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k));
838 lhs_pf0.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k));
839 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
840 lhs_pf1.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 8));
841 lhs_pf1.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k + 8));
842 lhs_pf2.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 16));
843 lhs_pf2.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 16));
844 lhs_pf2.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k + 16));
845 }
else if ((
threadIdx.y / 4 +
k + 8) < k_size) {
846 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
847 lhs_pf0.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k));
848 lhs_pf0.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k));
849 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
850 lhs_pf1.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 8));
851 lhs_pf1.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k + 8));
853 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
854 lhs_pf0.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k));
855 lhs_pf0.z = lhs(lhs_vert + 2, (
threadIdx.y / 4 +
k));
857 }
else if (lhs_vert + 1 < m_size) {
859 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
860 lhs_pf0.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k));
861 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
862 lhs_pf1.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 8));
863 lhs_pf2.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 16));
864 lhs_pf2.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 16));
865 lhs_pf3.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 24));
866 lhs_pf3.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 24));
867 }
else if ((
threadIdx.y / 4 +
k + 16) < k_size) {
868 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
869 lhs_pf0.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k));
870 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
871 lhs_pf1.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 8));
872 lhs_pf2.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 16));
873 lhs_pf2.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 16));
874 }
else if ((
threadIdx.y / 4 +
k + 8) < k_size) {
875 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
876 lhs_pf0.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k));
877 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
878 lhs_pf1.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k + 8));
880 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
881 lhs_pf0.y = lhs(lhs_vert + 1, (
threadIdx.y / 4 +
k));
883 }
else if (lhs_vert < m_size) {
885 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
886 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
887 lhs_pf2.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 16));
888 lhs_pf3.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 24));
889 }
else if ((
threadIdx.y / 4 +
k + 16) < k_size) {
890 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
891 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
892 lhs_pf2.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 16));
893 }
else if ((
threadIdx.y / 4 +
k + 8) < k_size) {
894 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
895 lhs_pf1.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k + 8));
897 lhs_pf0.x = lhs(lhs_vert + 0, (
threadIdx.y / 4 +
k));
905 if (!CHECK_RHS_BOUNDARY) {
906 if ((rhs_vert + 3) < k_size) {
908 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
909 rhs_pf1 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz1);
910 }
else if (rhs_vert + 2 < k_size) {
912 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
913 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
914 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
915 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
916 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
917 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
918 }
else if (rhs_vert + 1 < k_size) {
919 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
920 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
921 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
922 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
923 }
else if (rhs_vert < k_size) {
924 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
925 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
928 if (rhs_horiz1 < n_size) {
929 if ((rhs_vert + 3) < k_size) {
931 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
932 rhs_pf1 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz1);
933 }
else if (rhs_vert + 2 < k_size) {
935 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
936 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
937 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
938 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
939 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
940 rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
941 }
else if (
k +
threadIdx.x * 4 + 1 < k_size) {
942 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
943 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
944 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
945 rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
947 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
948 rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
950 }
else if (rhs_horiz0 < n_size) {
951 if ((rhs_vert + 3) < k_size) {
953 rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
954 }
else if ((rhs_vert + 2) < k_size) {
956 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
957 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
958 rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
959 }
else if ((rhs_vert + 1) < k_size) {
960 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
961 rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
962 }
else if (rhs_vert < k_size) {
963 rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
992 #define add_vals(a_feat1, a_feat2, f1, f2, f3, f4) \
993 results[0].x += a_feat1.x * f1.x; \
994 results[1].x += a_feat1.x * f1.y; \
995 results[2].x += a_feat1.x * f2.x; \
996 results[3].x += a_feat1.x * f2.y; \
997 results[4].x += a_feat1.x * f3.x; \
998 results[5].x += a_feat1.x * f3.y; \
999 results[6].x += a_feat1.x * f4.x; \
1000 results[7].x += a_feat1.x * f4.y; \
1002 results[0].y += a_feat1.y * f1.x; \
1003 results[1].y += a_feat1.y * f1.y; \
1004 results[2].y += a_feat1.y * f2.x; \
1005 results[3].y += a_feat1.y * f2.y; \
1006 results[4].y += a_feat1.y * f3.x; \
1007 results[5].y += a_feat1.y * f3.y; \
1008 results[6].y += a_feat1.y * f4.x; \
1009 results[7].y += a_feat1.y * f4.y; \
1011 results[0].z += a_feat2.x * f1.x; \
1012 results[1].z += a_feat2.x * f1.y; \
1013 results[2].z += a_feat2.x * f2.x; \
1014 results[3].z += a_feat2.x * f2.y; \
1015 results[4].z += a_feat2.x * f3.x; \
1016 results[5].z += a_feat2.x * f3.y; \
1017 results[6].z += a_feat2.x * f4.x; \
1018 results[7].z += a_feat2.x * f4.y; \
1020 results[0].w += a_feat2.y * f1.x; \
1021 results[1].w += a_feat2.y * f1.y; \
1022 results[2].w += a_feat2.y * f2.x; \
1023 results[3].w += a_feat2.y * f2.y; \
1024 results[4].w += a_feat2.y * f3.x; \
1025 results[5].w += a_feat2.y * f3.y; \
1026 results[6].w += a_feat2.y * f4.x; \
1027 results[7].w += a_feat2.y * f4.y;
1043 for (
int koff = 0; koff < 32; koff++) {
1048 int start_feature = (
threadIdx.y / 4) * 8;
1050 float2 br1 = rhs_shmem2[start_feature / 2 + (koff % 4) * 32][koff / 4];
1051 float2 br2 = rhs_shmem2[start_feature / 2 + 1 + (koff % 4) * 32][koff / 4];
1052 float2 br3 = rhs_shmem2[start_feature / 2 + 2 + (koff % 4) * 32][koff / 4];
1053 float2 br4 = rhs_shmem2[start_feature / 2 + 3 + (koff % 4) * 32][koff / 4];
1055 add_vals(a3, a4, br1, br2, br3, br4)
1064 if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
1065 for (
int i = 0;
i < 8;
i++) {
1071 }
else if (!CHECK_RHS_BOUNDARY) {
1072 if (lhs_vert + 3 < m_size) {
1073 for (
int i = 0;
i < 8;
i++) {
1079 }
else if (lhs_vert + 2 < m_size) {
1080 for (
int i = 0;
i < 8;
i++) {
1085 }
else if (lhs_vert + 1 < m_size) {
1086 for (
int i = 0;
i < 8;
i++) {
1090 }
else if (lhs_vert < m_size) {
1091 for (
int i = 0;
i < 8;
i++) {
1095 }
else if (!CHECK_LHS_BOUNDARY) {
1097 for (
int i = 0;
i < 8;
i++) {
1098 if (horiz_base +
i < n_size) {
1107 for (
int i = 0;
i < 8;
i++) {
1108 if (horiz_base +
i < n_size) {
1109 if (lhs_vert < m_size)
output(lhs_vert, horiz_base +
i) =
results[
i].x;
1110 if (lhs_vert + 1 < m_size)
output(lhs_vert + 1, horiz_base +
i) =
results[
i].y;
1111 if (lhs_vert + 2 < m_size)
output(lhs_vert + 2, horiz_base +
i) =
results[
i].z;
1112 if (lhs_vert + 3 < m_size)
output(lhs_vert + 3, horiz_base +
i) =
results[
i].w;
1118 template <
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
1120 #if defined(EIGEN_HIPCC)
1121 __launch_bounds__(256, 1)
1123 __launch_bounds__(256)
1125 EigenFloatContractionKernel(
const LhsMapper lhs,
const RhsMapper rhs,
const OutputMapper
output,
const Index m_size,
1127 __shared__ float2 lhs_shmem[64 * 32];
1128 __shared__ float2 rhs_shmem[128 * 8];
1130 typedef float2 LHS_MEM[64][32];
1131 typedef float2 RHS_MEM[128][8];
1136 const Index base_m = 128 * m_block_idx;
1137 const Index base_n = 64 * n_block_idx;
1139 bool check_rhs = (base_n + 63) >= n_size;
1140 bool check_lhs128 = (base_m + 127) >= m_size;
1143 if (!check_lhs128) {
1145 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1146 lhs, rhs,
output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1148 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1149 lhs, rhs,
output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1152 if (!check_lhs128) {
1154 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1155 lhs, rhs,
output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1157 EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1158 lhs, rhs,
output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
1163 template <
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
1165 #if defined(EIGEN_HIPCC)
1166 __launch_bounds__(256, 1)
1168 __launch_bounds__(256)
1170 EigenFloatContractionKernel16x16(
const LhsMapper lhs,
const RhsMapper rhs,
const OutputMapper
output,
1172 __shared__ float2 lhs_shmem[32][16];
1173 __shared__ float2 rhs_shmem[64][8];
1178 const Index base_m = 64 * m_block_idx;
1179 const Index base_n = 64 * n_block_idx;
1181 if (base_m + 63 < m_size) {
1182 if (base_n + 63 < n_size) {
1183 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
1184 lhs, rhs,
output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1186 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
1187 lhs, rhs,
output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1190 if (base_n + 63 < n_size) {
1191 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
1192 lhs, rhs,
output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1194 EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
1195 lhs, rhs,
output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
1200 template <
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
1201 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice>
1202 :
public TensorContractionEvaluatorBase<TensorEvaluator<
1203 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
1204 typedef GpuDevice Device;
1206 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
1207 typedef TensorContractionEvaluatorBase<Self> Base;
1209 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>
XprType;
1210 typedef std::remove_const_t<typename XprType::Scalar>
Scalar;
1221 typedef std::conditional_t<Layout == static_cast<int>(
ColMajor), LeftArgType, RightArgType> EvalLeftArgType;
1222 typedef std::conditional_t<Layout == static_cast<int>(
ColMajor), RightArgType, LeftArgType> EvalRightArgType;
1224 static constexpr
int LDims =
1225 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>
::value;
1226 static constexpr
int RDims =
1227 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>
::value;
1230 typedef array<Index, LDims> left_dim_mapper_t;
1231 typedef array<Index, RDims> right_dim_mapper_t;
1233 typedef array<Index, ContractDims> contract_t;
1234 typedef array<
Index, LDims - ContractDims> left_nocontract_t;
1235 typedef array<
Index, RDims - ContractDims> right_nocontract_t;
1237 static constexpr
int NumDims = LDims + RDims - 2 * ContractDims;
1242 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
1243 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
1245 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
1246 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
1248 typedef typename LeftEvaluator::Dimensions LeftDimensions;
1249 typedef typename RightEvaluator::Dimensions RightDimensions;
1253 GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
1258 this->m_leftImpl.evalSubExprsIfNeeded(NULL);
1259 this->m_rightImpl.evalSubExprsIfNeeded(NULL);
1265 evalTo(this->m_result);
1270 void evalTo(
Scalar* buffer)
const {
1271 if (this->m_lhs_inner_dim_contiguous) {
1272 if (this->m_rhs_inner_dim_contiguous) {
1273 if (this->m_rhs_inner_dim_reordered) {
1274 evalTyped<true, true, true, Unaligned>(buffer);
1276 evalTyped<true, true, false, Unaligned>(buffer);
1279 if (this->m_rhs_inner_dim_reordered) {
1280 evalTyped<true, false, true, Unaligned>(buffer);
1282 evalTyped<true, false, false, Unaligned>(buffer);
1286 if (this->m_rhs_inner_dim_contiguous) {
1287 if (this->m_rhs_inner_dim_reordered) {
1288 evalTyped<false, true, true, Unaligned>(buffer);
1290 evalTyped<false, true, false, Unaligned>(buffer);
1293 if (this->m_rhs_inner_dim_reordered) {
1294 evalTyped<false, false, true, Unaligned>(buffer);
1296 evalTyped<false, false, false, Unaligned>(buffer);
1302 template <
typename LhsScalar,
typename RhsScalar,
typename Index,
typename LhsMapper,
typename RhsMapper,
1303 typename OutputMapper>
1304 struct LaunchKernels {
1306 const GpuDevice& device) {
1307 const Index m_blocks = (
m + 63) / 64;
1308 const Index n_blocks = (
n + 63) / 64;
1309 const dim3 num_blocks(m_blocks, n_blocks, 1);
1310 const dim3 block_size(8, 8, 8);
1311 LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1312 block_size, 0, device, lhs, rhs,
output,
m,
n,
k);
1316 template <
typename Index,
typename LhsMapper,
typename RhsMapper,
typename OutputMapper>
1317 struct LaunchKernels<float, float,
Index, LhsMapper, RhsMapper, OutputMapper> {
1319 const GpuDevice& device) {
1320 if (
m < 768 ||
n < 768) {
1321 const Index m_blocks = (
m + 63) / 64;
1322 const Index n_blocks = (
n + 63) / 64;
1323 const dim3 num_blocks(m_blocks, n_blocks, 1);
1324 const dim3 block_size(16, 16, 1);
1325 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1326 block_size, 0, device, lhs, rhs,
output,
m,
n,
k);
1328 const Index m_blocks = (
m + 127) / 128;
1329 const Index n_blocks = (
n + 63) / 64;
1330 const dim3 num_blocks(m_blocks, n_blocks, 1);
1331 const dim3 block_size(8, 32, 1);
1332 LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
1333 block_size, 0, device, lhs, rhs,
output,
m,
n,
k);
1338 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1339 void evalTyped(
Scalar* buffer)
const {
1341 const Index k = this->m_k_size;
1345 const Index m = this->m_i_size;
1348 const Index n = this->m_j_size;
1353 typedef internal::TensorContractionInputMapper<LhsScalar,
Index,
internal::Lhs, LeftEvaluator, left_nocontract_t,
1354 contract_t, 4, lhs_inner_dim_contiguous,
false,
Unaligned>
1357 typedef internal::TensorContractionInputMapper<RhsScalar,
Index,
internal::Rhs, RightEvaluator, right_nocontract_t,
1358 contract_t, 4, rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
1362 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1365 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
1366 this->m_left_contracting_strides, this->m_k_strides);
1368 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
1369 this->m_right_contracting_strides, this->m_k_strides);
1371 OutputMapper
output(buffer,
m);
1373 #if defined(EIGEN_USE_HIP)
1374 setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
1376 setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
1379 LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs,
output,
m,
n,
k,
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
SCALAR Scalar
Definition: bench_gemm.cpp:45
std::map< std::string, Array< float, 1, 8, DontAlign|RowMajor > > results
Definition: dense_solvers.cpp:10
dim3 threadIdx
Definition: gpu_common.h:16
dim3 blockIdx
Definition: gpu_common.h:16
@ Unaligned
Definition: Constants.h:235
@ ColMajor
Definition: Constants.h:318
int * m
Definition: level2_cplx_impl.h:294
if(UPLO(*uplo)==INVALID) info
Definition: level3_impl.h:428
char char char int int * k
Definition: level2_impl.h:374
char char * op
Definition: level2_impl.h:374
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Vector< double > x1(const Vector< double > &coord)
Cartesian coordinates centered at the point (0.5,1)
Definition: poisson/poisson_with_singularity/two_d_poisson.cc:86
Vector< double > x2(const Vector< double > &coord)
Cartesian coordinates centered at the point (1.5,1)
Definition: poisson/poisson_with_singularity/two_d_poisson.cc:102
val
Definition: calibrate.py:119
CwiseBinaryOp< internal::scalar_sum_op< double, double >, const CpyMatrixXd, const CpyMatrixXd > XprType
Definition: nestbyvalue.cpp:15
void output(std::ostream &outfile, const unsigned &nplot)
Overload output function.
Definition: overloaded_element_body.h:490
internal::packet_traits< Scalar >::type type
Definition: TensorMeta.h:48
static constexpr int Layout
Definition: TensorEvaluator.h:46
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType dest)
Definition: TensorEvaluator.h:71
Derived::Scalar Scalar
Definition: TensorEvaluator.h:33
const Device EIGEN_DEVICE_REF m_device
Definition: TensorEvaluator.h:170
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived &m, const Device &device)
Definition: TensorEvaluator.h:66
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const
Definition: TensorEvaluator.h:165
Derived::Scalar CoeffReturnType
Definition: TensorEvaluator.h:34
Derived XprType
Definition: TensorEvaluator.h:37
Derived::Index Index
Definition: TensorEvaluator.h:32
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorEvaluator.h:35
Derived::Dimensions Dimensions
Definition: TensorEvaluator.h:36
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorEvaluator.h:69
static constexpr Index value
Definition: Meta.h:306
@ value
Definition: Meta.h:206
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2