10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
27 template <
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
31 std::remove_const_t<typename RhsXprType::Scalar>>::ResScalar
Scalar;
43 static constexpr
int NumDimensions =
53 template <
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
58 template <
typename Dimensions,
typename LhsXprType,
typename RhsXprType,
typename OutputKernelType>
64 template <
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename OutputKernelType_,
75 static constexpr
int NumDimensions =
80 template <
typename LhsScalar,
typename RhsScalar>
84 template <
typename Device>
86 LhsScalar** lhs_block, RhsScalar** rhs_block) {
91 *lhs_block =
static_cast<LhsScalar*
>(
static_cast<void*
>(block_mem));
92 *rhs_block =
static_cast<RhsScalar*
>(
static_cast<void*
>(block_mem + sz.
lhs_size));
96 template <
typename Device>
99 const Index num_slices, std::vector<LhsScalar*>* lhs_blocks,
100 std::vector<RhsScalar*>* rhs_blocks) {
106 void* block_mem = d.allocate((num_lhs * sz.
lhs_size + num_rhs * sz.
rhs_size) * num_slices);
108 char* mem =
static_cast<char*
>(block_mem);
110 for (
Index x = 0;
x < num_slices;
x++) {
111 if (num_lhs > 0) lhs_blocks[
x].resize(num_lhs);
112 for (
Index m = 0;
m < num_lhs;
m++) {
113 lhs_blocks[
x][
m] =
static_cast<LhsScalar*
>(
static_cast<void*
>(mem));
116 if (num_rhs > 0) rhs_blocks[
x].resize(num_rhs);
117 for (
Index n = 0;
n < num_rhs;
n++) {
118 rhs_blocks[
x][
n] =
static_cast<RhsScalar*
>(
static_cast<void*
>(mem));
126 template <
typename Device>
128 d.deallocate(handle);
139 sz.
lhs_size = numext::div_ceil<Index>(bm * bk *
sizeof(LhsScalar), align) * align;
140 sz.
rhs_size = numext::div_ceil<Index>(bn * bk *
sizeof(RhsScalar), align) * align;
173 template <
typename ResScalar,
typename LhsScalar,
typename RhsScalar,
typename StorageIndex,
typename OutputMapper,
174 typename LhsMapper,
typename RhsMapper>
181 StorageIndex bk_, StorageIndex bn_)
182 :
m(m_),
k(k_),
n(n_),
bm(bm_),
bk(bk_),
bn(bn_) {}
205 template <
typename Device>
210 template <
typename Device>
212 const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks,
213 std::vector<RhsBlock>* rhs_blocks) {
217 template <
typename Device>
223 const StorageIndex depth,
const StorageIndex
rows) {
229 const StorageIndex depth,
const StorageIndex
cols) {
235 const StorageIndex depth,
const StorageIndex
cols,
236 const ResScalar
alpha,
const ResScalar
beta) {
239 static const int kComputeStrideFromBlockDimensions = -1;
241 kComputeStrideFromBlockDimensions,
242 kComputeStrideFromBlockDimensions,
250 const StorageIndex
m;
251 const StorageIndex
k;
252 const StorageIndex
n;
253 const StorageIndex
bm;
254 const StorageIndex
bk;
255 const StorageIndex
bn;
291 template <
typename Index,
typename Scalar>
294 Index num_cols)
const {
304 template <
typename Indices,
typename LhsXprType,
typename RhsXprType,
305 typename OutputKernelType =
const NoOpOutputKernel>
307 :
public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors> {
318 const OutputKernelType& output_kernel = OutputKernelType())
341 template <
typename Derived>
350 typedef std::remove_const_t<typename XprType::Scalar>
Scalar;
408 YOU_MADE_A_PROGRAMMING_MISTAKE);
423 eval_op_indices[
i].first =
op.indices()[
i].first;
424 eval_op_indices[
i].second =
op.indices()[
i].second;
446 eigen_assert(eval_op_indices[
j].first != eval_op_indices[
i].first &&
447 eval_op_indices[
j].second != eval_op_indices[
i].second &&
"contraction axes should be unique");
448 if (eval_op_indices[
j].first < eval_op_indices[
i].first) {
456 for (
int i = 0;
i <
LDims - 1; ++
i) {
457 lhs_strides[
i + 1] = lhs_strides[
i] * eval_left_dims[
i];
462 for (
int i = 0;
i <
RDims - 1; ++
i) {
463 rhs_strides[
i + 1] = rhs_strides[
i] * eval_right_dims[
i];
480 Index nocontract_idx = 0;
484 bool contracting =
false;
486 if (eval_op_indices[
j].first ==
i) {
510 bool contracting =
false;
513 if (eval_op_indices[
j].second ==
i) {
539 Index left = eval_op_indices[
i].first;
540 Index right = eval_op_indices[
i].second;
543 eigen_assert(
size == eval_right_dims[right] &&
"Contraction axes must be same size");
553 if (
i > 0 && right < eval_op_indices[
i - 1].second) {
590 #ifdef EIGEN_USE_THREADS
591 template <
typename EvalSubExprsCallback>
593 m_leftImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
594 m_rightImpl.evalSubExprsIfNeededAsync(
nullptr, [
this, done, dest](
bool) {
596 evalToAsync(dest, [done]() { done(
false); });
599 evalToAsync(
m_result, [done]() { done(
true); });
606 #ifndef TENSOR_CONTRACTION_DISPATCH
607 #define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
608 if (this->m_lhs_inner_dim_contiguous) { \
609 if (this->m_rhs_inner_dim_contiguous) { \
610 if (this->m_rhs_inner_dim_reordered) { \
611 METHOD<true, true, true, ALIGNMENT> ARGS; \
613 METHOD<true, true, false, ALIGNMENT> ARGS; \
616 if (this->m_rhs_inner_dim_reordered) { \
617 METHOD<true, false, true, ALIGNMENT> ARGS; \
619 METHOD<true, false, false, ALIGNMENT> ARGS; \
623 if (this->m_rhs_inner_dim_contiguous) { \
624 if (this->m_rhs_inner_dim_reordered) { \
625 METHOD<false, true, true, ALIGNMENT> ARGS; \
627 METHOD<false, true, false, ALIGNMENT> ARGS; \
630 if (this->m_rhs_inner_dim_reordered) { \
631 METHOD<false, false, true, ALIGNMENT> ARGS; \
633 METHOD<false, false, false, ALIGNMENT> ARGS; \
639 #ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH
640 #define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \
641 if (this->m_lhs_inner_dim_contiguous) { \
642 if (this->m_rhs_inner_dim_contiguous) { \
643 if (this->m_rhs_inner_dim_reordered) { \
644 (new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \
646 (new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \
649 if (this->m_rhs_inner_dim_reordered) { \
650 (new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \
652 (new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \
656 if (this->m_rhs_inner_dim_contiguous) { \
657 if (this->m_rhs_inner_dim_reordered) { \
658 (new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \
660 (new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \
663 if (this->m_rhs_inner_dim_reordered) { \
664 (new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \
666 (new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \
673 static_cast<const Derived*
>(
this)->
template evalProduct<Unaligned>(buffer);
676 #ifdef EIGEN_USE_THREADS
677 template <
typename EvalToCallback>
678 void evalToAsync(
Scalar* buffer, EvalToCallback done)
const {
679 static_cast<const Derived*
>(
this)->
template evalProductAsync<EvalToCallback, Unaligned>(buffer, std::move(done));
683 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
686 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(
689 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(
694 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
695 #if !defined(EIGEN_HIPCC)
703 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
704 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
712 contract_t, lhs_packet_size, lhs_inner_dim_contiguous,
false,
717 contract_t, rhs_packet_size, rhs_inner_dim_contiguous,
718 rhs_inner_dim_reordered, rhs_alignment>
725 const Index resIncr(1);
738 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
739 #if !defined(EIGEN_HIPCC)
746 this->
template evalGemmPartial<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
747 Alignment,
true>(buffer, 0,
k, 1);
750 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
752 int num_threads)
const {
753 evalGemmPartial<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment,
754 false>(buffer, k_start, k_end, num_threads);
757 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment,
758 bool use_output_kernel>
762 const Index k_slice = k_end - k_start;
771 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
772 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
781 contract_t, lhs_packet_size, lhs_inner_dim_contiguous,
false,
786 contract_t, rhs_packet_size, rhs_inner_dim_contiguous,
793 TensorContractionKernel;
802 OutputMapper
output(buffer,
m);
806 k_slice,
m,
n, num_threads);
807 const Index kc = blocking.
kc();
811 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
812 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
817 TensorContractionKernel kernel(
m, k_slice,
n, mc, kc, nc);
819 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
820 const BlockMemHandle packed_mem = kernel.allocate(this->
m_device, &blockA, &blockB);
824 if (!TensorContractionKernel::HasBeta) {
828 for (
Index i2 = 0; i2 <
m; i2 += mc) {
830 for (
Index k2 = k_start; k2 < k_end; k2 += kc) {
833 kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
841 for (
Index j2 = 0; j2 <
n; j2 += nc) {
844 kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc);
848 const OutputMapper output_mapper =
output.getSubMapper(i2, j2);
849 kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc,
alpha,
beta);
852 if (use_output_kernel && k2 + kc >= k_end) {
859 kernel.deallocate(this->
m_device, packed_mem);
878 template <
int LoadMode>
880 return internal::ploadt<PacketReturnType, LoadMode>(
m_result + index);
915 template <
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType,
typename Device>
918 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device>> {
923 typedef std::remove_const_t<typename XprType::Scalar>
Scalar;
937 static constexpr
int LDims =
939 static constexpr
int RDims =
947 static constexpr
int NumDims = LDims + RDims - 2 * ContractDims;
954 template <
int Alignment>
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
#define EIGEN_ALWAYS_INLINE
Definition: Macros.h:845
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define EIGEN_DONT_INLINE
Definition: Macros.h:853
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS)
Definition: TensorContraction.h:607
#define EIGEN_DEVICE_REF
Definition: TensorMacros.h:34
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
The tensor base class.
Definition: TensorBase.h:1026
Definition: TensorContraction.h:307
EIGEN_DEVICE_FUNC const internal::remove_all_t< typename LhsXprType::Nested > & lhsExpression() const
Definition: TensorContraction.h:324
EIGEN_DEVICE_FUNC const OutputKernelType & outputKernel() const
Definition: TensorContraction.h:332
const OutputKernelType m_output_kernel
Definition: TensorContraction.h:338
Eigen::internal::traits< TensorContractionOp >::Index Index
Definition: TensorContraction.h:314
Eigen::internal::nested< TensorContractionOp >::type Nested
Definition: TensorContraction.h:312
Eigen::internal::traits< TensorContractionOp >::StorageKind StorageKind
Definition: TensorContraction.h:313
EIGEN_DEVICE_FUNC const internal::remove_all_t< typename RhsXprType::Nested > & rhsExpression() const
Definition: TensorContraction.h:328
internal::gebp_traits< typename LhsXprType::CoeffReturnType, typename RhsXprType::CoeffReturnType >::ResScalar CoeffReturnType
Definition: TensorContraction.h:311
Eigen::internal::traits< TensorContractionOp >::Scalar Scalar
Definition: TensorContraction.h:309
const Indices m_indices
Definition: TensorContraction.h:337
LhsXprType::Nested m_lhs_xpr
Definition: TensorContraction.h:335
RhsXprType::Nested m_rhs_xpr
Definition: TensorContraction.h:336
EIGEN_DEVICE_FUNC const Indices & indices() const
Definition: TensorContraction.h:321
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType &lhs, const RhsXprType &rhs, const Indices &dims, const OutputKernelType &output_kernel=OutputKernelType())
Definition: TensorContraction.h:316
Definition: TensorCostModel.h:28
Definition: TensorBlock.h:566
Definition: TensorContractionBlocking.h:24
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex nc() const
Definition: TensorContractionBlocking.h:58
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex mc() const
Definition: TensorContractionBlocking.h:57
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc() const
Definition: TensorContractionBlocking.h:56
Definition: BlasUtil.h:304
Definition: products/GeneralBlockPanelKernel.h:397
LhsPacket LhsPacket4Packing
Definition: products/GeneralBlockPanelKernel.h:440
@ nr
Definition: products/GeneralBlockPanelKernel.h:418
@ LhsProgress
Definition: products/GeneralBlockPanelKernel.h:433
@ mr
Definition: products/GeneralBlockPanelKernel.h:430
@ Unaligned
Definition: Constants.h:235
@ Aligned
Definition: Constants.h:242
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
Eigen::DenseIndex ret
Definition: level1_cplx_impl.h:43
RealScalar alpha
Definition: level1_cplx_impl.h:151
int * m
Definition: level2_cplx_impl.h:294
Scalar beta
Definition: level2_cplx_impl.h:36
char char char int int * k
Definition: level2_impl.h:374
char char * op
Definition: level2_impl.h:374
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
Definition: MathFunctions.h:926
EIGEN_STRONG_INLINE void swap(T &a, T &b)
Definition: Meta.h:536
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
auto run(Kernel kernel, Args &&... args) -> decltype(kernel(args...))
Definition: gpu_test_helper.h:414
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const T1 & choose(Cond< true >, const T1 &first, const T2 &)
Definition: TensorMeta.h:22
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Extend namespace for flags.
Definition: fsi_chan_precond_driver.cc:56
dictionary params
Definition: Particles2023AnalysisHung.py:35
val
Definition: calibrate.py:119
type
Definition: compute_granudrum_aor.py:141
Definition: Eigen_Colamd.h:49
list x
Definition: plotDoE.py:28
void output(std::ostream &outfile, const unsigned &nplot)
Overload output function.
Definition: overloaded_element_body.h:490
Definition: TensorMeta.h:19
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const
Definition: TensorDimensions.h:167
Definition: Constants.h:519
Definition: TensorContraction.h:275
EIGEN_ALWAYS_INLINE void operator()(const internal::blas_data_mapper< Scalar, Index, ColMajor > &output_mapper, const TensorContractionParams ¶ms, Index i, Index j, Index num_rows, Index num_cols) const
Definition: TensorContraction.h:292
Definition: TensorMeta.h:47
Definition: TensorForwardDeclarations.h:42
Definition: TensorContraction.h:342
XprType::CoeffReturnType CoeffReturnType
Definition: TensorContraction.h:352
static constexpr int Layout
Definition: TensorContraction.h:357
TensorEvaluator< EvalRightArgType, Device > RightEvaluatorType
Definition: TensorContraction.h:381
EIGEN_STRONG_INLINE void cleanup()
Definition: TensorContraction.h:862
Index m_i_size
Definition: TensorContraction.h:901
DSizes< Index, NumDims > Dimensions
Definition: TensorContraction.h:394
static constexpr int NumDims
Definition: TensorContraction.h:388
EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType &op, const Device &device)
Definition: TensorContraction.h:396
static constexpr int LDims
Definition: TensorContraction.h:383
StorageMemory< Scalar, Device > Storage
Definition: TensorContraction.h:354
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
Definition: TensorContraction.h:879
right_nocontract_t m_j_strides
Definition: TensorContraction.h:897
internal::traits< Derived >::Device Device
Definition: TensorContraction.h:347
right_nocontract_t m_right_nocontract_strides
Definition: TensorContraction.h:899
internal::traits< Derived >::LeftArgType LeftArgType
Definition: TensorContraction.h:344
contract_t m_right_contracting_strides
Definition: TensorContraction.h:890
const Device EIGEN_DEVICE_REF m_device
Definition: TensorContraction.h:909
static constexpr int RDims
Definition: TensorContraction.h:385
array< Index, RDims - ContractDims > right_nocontract_t
Definition: TensorContraction.h:392
TensorEvaluator< EvalLeftArgType, Device > LeftEvaluatorType
Definition: TensorContraction.h:380
left_nocontract_t m_left_nocontract_strides
Definition: TensorContraction.h:898
std::conditional_t< static_cast< int >Layout)==static_cast< int >ColMajor), LeftArgType, RightArgType > EvalLeftArgType
Definition: TensorContraction.h:376
EvaluatorPointerType m_result
Definition: TensorContraction.h:911
EIGEN_DEVICE_FUNC void evalGemv(Scalar *buffer) const
Definition: TensorContraction.h:699
XprType::Index Index
Definition: TensorContraction.h:351
Storage::Type EvaluatorPointerType
Definition: TensorContraction.h:355
bool m_rhs_inner_dim_reordered
Definition: TensorContraction.h:894
internal::traits< Derived >::RightArgType RightArgType
Definition: TensorContraction.h:345
bool m_rhs_inner_dim_contiguous
Definition: TensorContraction.h:893
EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
Definition: TensorContraction.h:759
contract_t m_left_contracting_strides
Definition: TensorContraction.h:889
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data)
Definition: TensorContraction.h:577
static constexpr int ContractDims
Definition: TensorContraction.h:387
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
Definition: TensorContraction.h:349
std::remove_const_t< typename XprType::Scalar > Scalar
Definition: TensorContraction.h:350
EIGEN_DEVICE_FUNC void evalTo(Scalar *buffer) const
Definition: TensorContraction.h:672
Index m_j_size
Definition: TensorContraction.h:902
TensorEvaluator< EvalRightArgType, Device > m_rightImpl
Definition: TensorContraction.h:908
std::conditional_t< static_cast< int >Layout)==static_cast< int >ColMajor), RightArgType, LeftArgType > EvalRightArgType
Definition: TensorContraction.h:378
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorContraction.h:575
Index m_k_size
Definition: TensorContraction.h:903
internal::TensorBlockNotImplemented TensorBlock
Definition: TensorContraction.h:368
void evalProductSequential(Scalar *buffer) const
Definition: TensorContraction.h:684
OutputKernelType m_output_kernel
Definition: TensorContraction.h:910
TensorContractionParams m_tensor_contraction_params
Definition: TensorContraction.h:905
bool m_lhs_inner_dim_contiguous
Definition: TensorContraction.h:892
Dimensions m_dimensions
Definition: TensorContraction.h:886
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const
Definition: TensorContraction.h:874
array< Index, ContractDims > contract_t
Definition: TensorContraction.h:390
internal::traits< Derived >::OutputKernelType OutputKernelType
Definition: TensorContraction.h:346
TensorEvaluator< EvalLeftArgType, Device > m_leftImpl
Definition: TensorContraction.h:907
@ PreferBlockAccess
Definition: TensorContraction.h:362
@ PacketAccess
Definition: TensorContraction.h:360
@ RawAccess
Definition: TensorContraction.h:364
@ CoordAccess
Definition: TensorContraction.h:363
@ IsAligned
Definition: TensorContraction.h:359
@ BlockAccess
Definition: TensorContraction.h:361
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorContraction.h:353
array< Index, LDims - ContractDims > left_nocontract_t
Definition: TensorContraction.h:391
EIGEN_DEVICE_FUNC void evalGemm(Scalar *buffer) const
Definition: TensorContraction.h:743
contract_t m_k_strides
Definition: TensorContraction.h:888
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const
Definition: TensorContraction.h:883
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorContraction.h:872
EIGEN_DEVICE_FUNC void evalGemmPartialWithoutOutputKernel(Scalar *buffer, Index k_start, Index k_end, int num_threads) const
Definition: TensorContraction.h:751
internal::traits< Derived >::Indices Indices
Definition: TensorContraction.h:343
left_nocontract_t m_i_strides
Definition: TensorContraction.h:896
Definition: TensorContraction.h:262
bool swapped_arguments
Definition: TensorContraction.h:265
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorContraction.h:926
TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType > XprType
Definition: TensorContraction.h:922
XprType::Index Index
Definition: TensorContraction.h:924
std::conditional_t< Layout==static_cast< int >ColMajor), RightArgType, LeftArgType > EvalRightArgType
Definition: TensorContraction.h:935
TensorEvaluator< const TensorContractionOp< Indices, LeftArgType, RightArgType, OutputKernelType >, Device > Self
Definition: TensorContraction.h:919
DSizes< Index, NumDims > Dimensions
Definition: TensorContraction.h:950
array< Index, LDims - ContractDims > left_nocontract_t
Definition: TensorContraction.h:944
TensorEvaluator(const XprType &op, const Device &device)
Definition: TensorContraction.h:952
XprType::CoeffReturnType CoeffReturnType
Definition: TensorContraction.h:925
std::remove_const_t< typename XprType::Scalar > Scalar
Definition: TensorContraction.h:923
TensorContractionEvaluatorBase< Self > Base
Definition: TensorContraction.h:920
std::conditional_t< Layout==static_cast< int >ColMajor), LeftArgType, RightArgType > EvalLeftArgType
Definition: TensorContraction.h:934
void evalProduct(Scalar *buffer) const
Definition: TensorContraction.h:955
array< Index, ContractDims > contract_t
Definition: TensorContraction.h:943
array< Index, RDims - ContractDims > right_nocontract_t
Definition: TensorContraction.h:945
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:31
static constexpr int Layout
Definition: TensorEvaluator.h:46
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType dest)
Definition: TensorEvaluator.h:71
EIGEN_STRONG_INLINE void cleanup()
Definition: TensorEvaluator.h:87
Derived::Index Index
Definition: TensorEvaluator.h:32
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorEvaluator.h:69
Definition: TensorContraction.h:132
Index lhs_size
Definition: TensorContraction.h:133
Index rhs_size
Definition: TensorContraction.h:134
Definition: TensorContraction.h:81
static EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device &d, const Index bm, const Index bk, const Index bn, LhsScalar **lhs_block, RhsScalar **rhs_block)
Definition: TensorContraction.h:85
static EIGEN_DEVICE_FUNC void deallocate(Device &d, BlockMemHandle handle)
Definition: TensorContraction.h:127
void * BlockMemHandle
Definition: TensorContraction.h:82
static EIGEN_DEVICE_FUNC BlockSizes ComputeLhsRhsBlockSizes(const Index bm, const Index bk, const Index bn)
Definition: TensorContraction.h:136
static EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(Device &d, const Index bm, const Index bk, const Index bn, const Index num_lhs, const Index num_rhs, const Index num_slices, std::vector< LhsScalar * > *lhs_blocks, std::vector< RhsScalar * > *rhs_blocks)
Definition: TensorContraction.h:97
Definition: TensorContraction.h:175
const StorageIndex m
Definition: TensorContraction.h:250
LhsScalar * LhsBlock
Definition: TensorContraction.h:185
internal::gemm_pack_lhs< LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor > LhsPacker
Definition: TensorContraction.h:196
RhsScalar * RhsBlock
Definition: TensorContraction.h:186
EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(Device &d, const StorageIndex num_lhs, const StorageIndex num_rhs, const StorageIndex num_slices, std::vector< LhsBlock > *lhs_blocks, std::vector< RhsBlock > *rhs_blocks)
Definition: TensorContraction.h:211
const StorageIndex n
Definition: TensorContraction.h:252
EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
Definition: TensorContraction.h:180
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(LhsBlock *lhsBlock, const typename LhsMapper::SubMapper &data_mapper, const StorageIndex depth, const StorageIndex rows)
Definition: TensorContraction.h:222
internal::gebp_kernel< LhsScalar, RhsScalar, StorageIndex, OutputMapper, Traits::mr, Traits::nr, false, false > GebpKernel
Definition: TensorContraction.h:203
const StorageIndex bm
Definition: TensorContraction.h:253
const StorageIndex bk
Definition: TensorContraction.h:254
const StorageIndex k
Definition: TensorContraction.h:251
internal::gemm_pack_rhs< RhsScalar, StorageIndex, typename RhsMapper::SubMapper, Traits::nr, ColMajor > RhsPacker
Definition: TensorContraction.h:199
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(const OutputMapper &output_mapper, const LhsBlock &lhsBlock, const RhsBlock &rhsBlock, const StorageIndex rows, const StorageIndex depth, const StorageIndex cols, const ResScalar alpha, const ResScalar beta)
Definition: TensorContraction.h:233
static EIGEN_DEVICE_FUNC void deallocate(Device &d, BlockMemHandle handle)
Definition: TensorContraction.h:218
@ HasBeta
Definition: TensorContraction.h:178
BlockMemAllocator::BlockMemHandle BlockMemHandle
Definition: TensorContraction.h:190
internal::gebp_traits< LhsScalar, RhsScalar > Traits
Definition: TensorContraction.h:192
EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device &d, LhsBlock *lhs_block, RhsBlock *rhs_block)
Definition: TensorContraction.h:206
TensorContractionBlockMemAllocator< LhsScalar, RhsScalar > BlockMemAllocator
Definition: TensorContraction.h:189
const StorageIndex bn
Definition: TensorContraction.h:255
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(RhsBlock *rhsBlock, const typename RhsMapper::SubMapper &data_mapper, const StorageIndex depth, const StorageIndex cols)
Definition: TensorContraction.h:228
const TensorContractionOp< Dimensions, LhsXprType, RhsXprType, OutputKernelType > & type
Definition: TensorContraction.h:55
Definition: XprHelper.h:427
Definition: products/GeneralBlockPanelKernel.h:960
Definition: BlasUtil.h:34
Definition: BlasUtil.h:30
Definition: BlasUtil.h:42
TensorContractionOp< Dimensions, LhsXprType, RhsXprType, OutputKernelType > type
Definition: TensorContraction.h:61
Definition: TensorTraits.h:152
ref_selector< T >::type type
Definition: TensorTraits.h:153
std::remove_reference_t< RhsNested > RhsNested_
Definition: TensorContraction.h:40
RhsXprType::Nested RhsNested
Definition: TensorContraction.h:38
gebp_traits< std::remove_const_t< typename LhsXprType::Scalar >, std::remove_const_t< typename RhsXprType::Scalar > >::ResScalar Scalar
Definition: TensorContraction.h:31
promote_storage_type< typename traits< LhsXprType >::StorageKind, typename traits< RhsXprType >::StorageKind >::ret StorageKind
Definition: TensorContraction.h:34
LhsXprType::Nested LhsNested
Definition: TensorContraction.h:37
std::conditional_t< Pointer_type_promotion< typename LhsXprType::Scalar, Scalar >::val, typename traits< LhsXprType >::PointerType, typename traits< RhsXprType >::PointerType > PointerType
Definition: TensorContraction.h:48
promote_index_type< typename traits< LhsXprType >::Index, typename traits< RhsXprType >::Index >::type Index
Definition: TensorContraction.h:36
std::remove_reference_t< LhsNested > LhsNested_
Definition: TensorContraction.h:39
Device_ Device
Definition: TensorContraction.h:72
RightArgType_ RightArgType
Definition: TensorContraction.h:70
LeftArgType_ LeftArgType
Definition: TensorContraction.h:69
Indices_ Indices
Definition: TensorContraction.h:68
OutputKernelType_ OutputKernelType
Definition: TensorContraction.h:71
Definition: ForwardDeclarations.h:21
Definition: GenericPacketMath.h:134
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2