TensorInflation.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Ke Yang <yangke@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
25 namespace internal {
26 template <typename Strides, typename XprType>
27 struct traits<TensorInflationOp<Strides, XprType> > : public traits<XprType> {
28  typedef typename XprType::Scalar Scalar;
30  typedef typename XprTraits::StorageKind StorageKind;
31  typedef typename XprTraits::Index Index;
32  typedef typename XprType::Nested Nested;
33  typedef std::remove_reference_t<Nested> Nested_;
34  static constexpr int NumDimensions = XprTraits::NumDimensions;
35  static constexpr int Layout = XprTraits::Layout;
36  typedef typename XprTraits::PointerType PointerType;
37 };
38 
39 template <typename Strides, typename XprType>
40 struct eval<TensorInflationOp<Strides, XprType>, Eigen::Dense> {
42 };
43 
44 template <typename Strides, typename XprType>
45 struct nested<TensorInflationOp<Strides, XprType>, 1, typename eval<TensorInflationOp<Strides, XprType> >::type> {
47 };
48 
49 } // end namespace internal
50 
51 template <typename Strides, typename XprType>
52 class TensorInflationOp : public TensorBase<TensorInflationOp<Strides, XprType>, ReadOnlyAccessors> {
53  public:
56  typedef typename XprType::CoeffReturnType CoeffReturnType;
60 
62  : m_xpr(expr), m_strides(strides) {}
63 
64  EIGEN_DEVICE_FUNC const Strides& strides() const { return m_strides; }
65 
67 
68  protected:
69  typename XprType::Nested m_xpr;
70  const Strides m_strides;
71 };
72 
73 // Eval as rvalue
74 template <typename Strides, typename ArgType, typename Device>
75 struct TensorEvaluator<const TensorInflationOp<Strides, ArgType>, Device> {
77  typedef typename XprType::Index Index;
80  typedef typename XprType::Scalar Scalar;
86 
88  enum {
89  IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
91  BlockAccess = false,
92  PreferBlockAccess = false,
93  CoordAccess = false, // to be implemented
94  RawAccess = false
95  };
96 
97  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
99  //===--------------------------------------------------------------------===//
100 
101  EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
102  : m_impl(op.expression(), device), m_strides(op.strides()) {
103  m_dimensions = m_impl.dimensions();
104  // Expand each dimension to the inflated dimension.
105  for (int i = 0; i < NumDims; ++i) {
106  m_dimensions[i] = (m_dimensions[i] - 1) * op.strides()[i] + 1;
107  }
108 
109  // Remember the strides for fast division.
110  for (int i = 0; i < NumDims; ++i) {
111  m_fastStrides[i] = internal::TensorIntDivisor<Index>(m_strides[i]);
112  }
113 
114  const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
115  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
116  m_outputStrides[0] = 1;
117  m_inputStrides[0] = 1;
118  for (int i = 1; i < NumDims; ++i) {
119  m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
120  m_inputStrides[i] = m_inputStrides[i - 1] * input_dims[i - 1];
121  }
122  } else { // RowMajor
123  m_outputStrides[NumDims - 1] = 1;
124  m_inputStrides[NumDims - 1] = 1;
125  for (int i = NumDims - 2; i >= 0; --i) {
126  m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
127  m_inputStrides[i] = m_inputStrides[i + 1] * input_dims[i + 1];
128  }
129  }
130  }
131 
132  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
133 
135  m_impl.evalSubExprsIfNeeded(NULL);
136  return true;
137  }
138  EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
139 
140  // Computes the input index given the output index. Returns true if the output
141  // index doesn't fall into a hole.
143  eigen_assert(index < dimensions().TotalSize());
144  *inputIndex = 0;
145  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
147  for (int i = NumDims - 1; i > 0; --i) {
148  const Index idx = index / m_outputStrides[i];
149  if (idx != idx / m_fastStrides[i] * m_strides[i]) {
150  return false;
151  }
152  *inputIndex += idx / m_strides[i] * m_inputStrides[i];
153  index -= idx * m_outputStrides[i];
154  }
155  if (index != index / m_fastStrides[0] * m_strides[0]) {
156  return false;
157  }
158  *inputIndex += index / m_strides[0];
159  return true;
160  } else {
162  for (int i = 0; i < NumDims - 1; ++i) {
163  const Index idx = index / m_outputStrides[i];
164  if (idx != idx / m_fastStrides[i] * m_strides[i]) {
165  return false;
166  }
167  *inputIndex += idx / m_strides[i] * m_inputStrides[i];
168  index -= idx * m_outputStrides[i];
169  }
170  if (index != index / m_fastStrides[NumDims - 1] * m_strides[NumDims - 1]) {
171  return false;
172  }
173  *inputIndex += index / m_strides[NumDims - 1];
174  }
175  return true;
176  }
177 
179  Index inputIndex = 0;
180  if (getInputIndex(index, &inputIndex)) {
181  return m_impl.coeff(inputIndex);
182  } else {
183  return Scalar(0);
184  }
185  }
186 
187  // TODO(yangke): optimize this function so that we can detect and produce
188  // all-zero packets
189  template <int LoadMode>
191  EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
192  eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
193 
194  EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
196  for (int i = 0; i < PacketSize; ++i) {
197  values[i] = coeff(index + i);
198  }
199  PacketReturnType rslt = internal::pload<PacketReturnType>(values);
200  return rslt;
201  }
202 
204  const double compute_cost = NumDims * (3 * TensorOpCost::DivCost<Index>() + 3 * TensorOpCost::MulCost<Index>() +
205  2 * TensorOpCost::AddCost<Index>());
206  const double input_size = m_impl.dimensions().TotalSize();
207  const double output_size = m_dimensions.TotalSize();
208  if (output_size == 0) return TensorOpCost();
209  return m_impl.costPerCoeff(vectorized) +
210  TensorOpCost(sizeof(CoeffReturnType) * input_size / output_size, 0, compute_cost, vectorized, PacketSize);
211  }
212 
213  EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
214 
215  protected:
220  const Strides m_strides;
222 };
223 
224 } // end namespace Eigen
225 
226 #endif // EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALIGN_MAX
Definition: ConfigureVectorization.h:146
#define EIGEN_UNROLL_LOOP
Definition: Macros.h:1298
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
SCALAR Scalar
Definition: bench_gemm.cpp:45
Generic expression where a coefficient-wise binary operator is applied to two expressions.
Definition: CwiseBinaryOp.h:79
The tensor base class.
Definition: TensorBase.h:1026
Definition: TensorInflation.h:52
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorInflationOp(const XprType &expr, const Strides &strides)
Definition: TensorInflation.h:61
Eigen::internal::traits< TensorInflationOp >::Scalar Scalar
Definition: TensorInflation.h:54
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorInflation.h:55
EIGEN_DEVICE_FUNC const internal::remove_all_t< typename XprType::Nested > & expression() const
Definition: TensorInflation.h:66
XprType::CoeffReturnType CoeffReturnType
Definition: TensorInflation.h:56
XprType::Nested m_xpr
Definition: TensorInflation.h:69
EIGEN_DEVICE_FUNC const Strides & strides() const
Definition: TensorInflation.h:64
Eigen::internal::nested< TensorInflationOp >::type Nested
Definition: TensorInflation.h:57
const Strides m_strides
Definition: TensorInflation.h:70
Eigen::internal::traits< TensorInflationOp >::StorageKind StorageKind
Definition: TensorInflation.h:58
Eigen::internal::traits< TensorInflationOp >::Index Index
Definition: TensorInflation.h:59
Definition: TensorCostModel.h:28
Definition: TensorBlock.h:566
@ ColMajor
Definition: Constants.h:318
char char * op
Definition: level2_impl.h:374
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
EIGEN_ALWAYS_INLINE DSizes< IndexType, NumDims > strides(const DSizes< IndexType, NumDims > &dimensions)
Definition: TensorBlock.h:29
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Definition: Eigen_Colamd.h:49
Definition: Constants.h:519
T Real
Definition: NumTraits.h:183
Definition: TensorMeta.h:47
Definition: TensorForwardDeclarations.h:42
EIGEN_STRONG_INLINE void cleanup()
Definition: TensorInflation.h:138
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorInflation.h:132
DSizes< Index, NumDims > Dimensions
Definition: TensorInflation.h:79
StorageMemory< CoeffReturnType, Device > Storage
Definition: TensorInflation.h:84
array< Index, NumDims > m_inputStrides
Definition: TensorInflation.h:218
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorInflation.h:82
Storage::Type EvaluatorPointerType
Definition: TensorInflation.h:85
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const
Definition: TensorInflation.h:213
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool getInputIndex(Index index, Index *inputIndex) const
Definition: TensorInflation.h:142
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorInflation.h:178
TensorInflationOp< Strides, ArgType > XprType
Definition: TensorInflation.h:76
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType)
Definition: TensorInflation.h:134
internal::TensorBlockNotImplemented TensorBlock
Definition: TensorInflation.h:98
array< Index, NumDims > m_outputStrides
Definition: TensorInflation.h:217
XprType::CoeffReturnType CoeffReturnType
Definition: TensorInflation.h:81
array< internal::TensorIntDivisor< Index >, NumDims > m_fastStrides
Definition: TensorInflation.h:221
TensorEvaluator< ArgType, Device > m_impl
Definition: TensorInflation.h:219
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
Definition: TensorInflation.h:190
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const
Definition: TensorInflation.h:203
EIGEN_STRONG_INLINE TensorEvaluator(const XprType &op, const Device &device)
Definition: TensorInflation.h:101
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:31
static constexpr int Layout
Definition: TensorEvaluator.h:46
Derived::Scalar Scalar
Definition: TensorEvaluator.h:33
@ PacketAccess
Definition: TensorEvaluator.h:50
@ IsAligned
Definition: TensorEvaluator.h:49
static constexpr int PacketSize
Definition: TensorEvaluator.h:38
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorEvaluator.h:89
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorEvaluator.h:69
Definition: Meta.h:305
const TensorInflationOp< Strides, XprType > & type
Definition: TensorInflation.h:41
Definition: XprHelper.h:427
Definition: TensorTraits.h:152
ref_selector< T >::type type
Definition: TensorTraits.h:153
traits< XprType > XprTraits
Definition: TensorInflation.h:29
XprTraits::PointerType PointerType
Definition: TensorInflation.h:36
XprType::Scalar Scalar
Definition: TensorInflation.h:28
XprTraits::StorageKind StorageKind
Definition: TensorInflation.h:30
std::remove_reference_t< Nested > Nested_
Definition: TensorInflation.h:33
XprType::Nested Nested
Definition: TensorInflation.h:32
XprTraits::Index Index
Definition: TensorInflation.h:31
Definition: ForwardDeclarations.h:21