10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H
26 template <
typename Str
ides,
typename XprType>
32 typedef typename XprType::Nested
Nested;
33 typedef std::remove_reference_t<Nested>
Nested_;
34 static constexpr
int NumDimensions = XprTraits::NumDimensions;
35 static constexpr
int Layout = XprTraits::Layout;
39 template <
typename Str
ides,
typename XprType>
44 template <
typename Str
ides,
typename XprType>
51 template <
typename Str
ides,
typename XprType>
74 template <
typename Str
ides,
typename ArgType,
typename Device>
92 PreferBlockAccess =
false,
102 : m_impl(
op.expression(), device), m_strides(
op.
strides()) {
103 m_dimensions = m_impl.dimensions();
105 for (
int i = 0;
i < NumDims; ++
i) {
106 m_dimensions[
i] = (m_dimensions[
i] - 1) *
op.strides()[
i] + 1;
110 for (
int i = 0;
i < NumDims; ++
i) {
116 m_outputStrides[0] = 1;
117 m_inputStrides[0] = 1;
118 for (
int i = 1;
i < NumDims; ++
i) {
119 m_outputStrides[
i] = m_outputStrides[
i - 1] * m_dimensions[
i - 1];
120 m_inputStrides[
i] = m_inputStrides[
i - 1] * input_dims[
i - 1];
123 m_outputStrides[NumDims - 1] = 1;
124 m_inputStrides[NumDims - 1] = 1;
125 for (
int i = NumDims - 2;
i >= 0; --
i) {
126 m_outputStrides[
i] = m_outputStrides[
i + 1] * m_dimensions[
i + 1];
127 m_inputStrides[
i] = m_inputStrides[
i + 1] * input_dims[
i + 1];
135 m_impl.evalSubExprsIfNeeded(NULL);
147 for (
int i = NumDims - 1;
i > 0; --
i) {
148 const Index idx = index / m_outputStrides[
i];
149 if (idx != idx / m_fastStrides[
i] * m_strides[
i]) {
152 *inputIndex += idx / m_strides[
i] * m_inputStrides[
i];
153 index -= idx * m_outputStrides[
i];
155 if (index != index / m_fastStrides[0] * m_strides[0]) {
158 *inputIndex += index / m_strides[0];
162 for (
int i = 0;
i < NumDims - 1; ++
i) {
163 const Index idx = index / m_outputStrides[
i];
164 if (idx != idx / m_fastStrides[
i] * m_strides[
i]) {
167 *inputIndex += idx / m_strides[
i] * m_inputStrides[
i];
168 index -= idx * m_outputStrides[
i];
170 if (index != index / m_fastStrides[NumDims - 1] * m_strides[NumDims - 1]) {
173 *inputIndex += index / m_strides[NumDims - 1];
179 Index inputIndex = 0;
180 if (getInputIndex(index, &inputIndex)) {
181 return m_impl.coeff(inputIndex);
189 template <
int LoadMode>
204 const double compute_cost = NumDims * (3 * TensorOpCost::DivCost<Index>() + 3 * TensorOpCost::MulCost<Index>() +
205 2 * TensorOpCost::AddCost<Index>());
206 const double input_size = m_impl.dimensions().TotalSize();
207 const double output_size = m_dimensions.TotalSize();
209 return m_impl.costPerCoeff(vectorized) +
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_UNROLL_LOOP
Definition: Macros.h:1298
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
SCALAR Scalar
Definition: bench_gemm.cpp:45
Generic expression where a coefficient-wise binary operator is applied to two expressions.
Definition: CwiseBinaryOp.h:79
The tensor base class.
Definition: TensorBase.h:1026
Definition: TensorInflation.h:52
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorInflationOp(const XprType &expr, const Strides &strides)
Definition: TensorInflation.h:61
Eigen::internal::traits< TensorInflationOp >::Scalar Scalar
Definition: TensorInflation.h:54
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorInflation.h:55
EIGEN_DEVICE_FUNC const internal::remove_all_t< typename XprType::Nested > & expression() const
Definition: TensorInflation.h:66
XprType::CoeffReturnType CoeffReturnType
Definition: TensorInflation.h:56
XprType::Nested m_xpr
Definition: TensorInflation.h:69
EIGEN_DEVICE_FUNC const Strides & strides() const
Definition: TensorInflation.h:64
Eigen::internal::nested< TensorInflationOp >::type Nested
Definition: TensorInflation.h:57
const Strides m_strides
Definition: TensorInflation.h:70
Eigen::internal::traits< TensorInflationOp >::StorageKind StorageKind
Definition: TensorInflation.h:58
Eigen::internal::traits< TensorInflationOp >::Index Index
Definition: TensorInflation.h:59
Definition: TensorCostModel.h:28
Definition: TensorBlock.h:566
@ ColMajor
Definition: Constants.h:318
char char * op
Definition: level2_impl.h:374
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
EIGEN_ALWAYS_INLINE DSizes< IndexType, NumDims > strides(const DSizes< IndexType, NumDims > &dimensions)
Definition: TensorBlock.h:29
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Definition: Eigen_Colamd.h:49
Definition: Constants.h:519
T Real
Definition: NumTraits.h:183
Definition: TensorMeta.h:47
Definition: TensorForwardDeclarations.h:42
EIGEN_STRONG_INLINE void cleanup()
Definition: TensorInflation.h:138
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorInflation.h:132
DSizes< Index, NumDims > Dimensions
Definition: TensorInflation.h:79
StorageMemory< CoeffReturnType, Device > Storage
Definition: TensorInflation.h:84
array< Index, NumDims > m_inputStrides
Definition: TensorInflation.h:218
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorInflation.h:82
const Strides m_strides
Definition: TensorInflation.h:220
Storage::Type EvaluatorPointerType
Definition: TensorInflation.h:85
XprType::Index Index
Definition: TensorInflation.h:77
Dimensions m_dimensions
Definition: TensorInflation.h:216
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const
Definition: TensorInflation.h:213
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool getInputIndex(Index index, Index *inputIndex) const
Definition: TensorInflation.h:142
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorInflation.h:178
TensorInflationOp< Strides, ArgType > XprType
Definition: TensorInflation.h:76
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType)
Definition: TensorInflation.h:134
internal::TensorBlockNotImplemented TensorBlock
Definition: TensorInflation.h:98
XprType::Scalar Scalar
Definition: TensorInflation.h:80
array< Index, NumDims > m_outputStrides
Definition: TensorInflation.h:217
XprType::CoeffReturnType CoeffReturnType
Definition: TensorInflation.h:81
array< internal::TensorIntDivisor< Index >, NumDims > m_fastStrides
Definition: TensorInflation.h:221
TensorEvaluator< ArgType, Device > m_impl
Definition: TensorInflation.h:219
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
Definition: TensorInflation.h:190
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const
Definition: TensorInflation.h:203
EIGEN_STRONG_INLINE TensorEvaluator(const XprType &op, const Device &device)
Definition: TensorInflation.h:101
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:31
static constexpr int Layout
Definition: TensorEvaluator.h:46
Derived::Scalar Scalar
Definition: TensorEvaluator.h:33
@ PacketAccess
Definition: TensorEvaluator.h:50
@ IsAligned
Definition: TensorEvaluator.h:49
static constexpr int PacketSize
Definition: TensorEvaluator.h:38
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorEvaluator.h:89
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorEvaluator.h:69
const TensorInflationOp< Strides, XprType > & type
Definition: TensorInflation.h:41
Definition: XprHelper.h:427
TensorInflationOp< Strides, XprType > type
Definition: TensorInflation.h:46
Definition: TensorTraits.h:152
ref_selector< T >::type type
Definition: TensorTraits.h:153
traits< XprType > XprTraits
Definition: TensorInflation.h:29
XprTraits::PointerType PointerType
Definition: TensorInflation.h:36
XprType::Scalar Scalar
Definition: TensorInflation.h:28
XprTraits::StorageKind StorageKind
Definition: TensorInflation.h:30
std::remove_reference_t< Nested > Nested_
Definition: TensorInflation.h:33
XprType::Nested Nested
Definition: TensorInflation.h:32
XprTraits::Index Index
Definition: TensorInflation.h:31
Definition: ForwardDeclarations.h:21