TensorConcatenation.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
25 namespace internal {
26 template <typename Axis, typename LhsXprType, typename RhsXprType>
27 struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > {
28  // Type promotion to handle the case where the types of the lhs and the rhs are different.
32  typedef
34  typedef typename LhsXprType::Nested LhsNested;
35  typedef typename RhsXprType::Nested RhsNested;
36  typedef std::remove_reference_t<LhsNested> LhsNested_;
37  typedef std::remove_reference_t<RhsNested> RhsNested_;
38  static constexpr int NumDimensions = traits<LhsXprType>::NumDimensions;
39  static constexpr int Layout = traits<LhsXprType>::Layout;
40  enum { Flags = 0 };
44 };
45 
46 template <typename Axis, typename LhsXprType, typename RhsXprType>
47 struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense> {
49 };
50 
51 template <typename Axis, typename LhsXprType, typename RhsXprType>
52 struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1,
53  typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type> {
55 };
56 
57 } // end namespace internal
58 
59 template <typename Axis, typename LhsXprType, typename RhsXprType>
60 class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> {
61  public:
67  typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
68  typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
70 
71  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType& lhs, const RhsXprType& rhs, Axis axis)
72  : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {}
73 
75  return m_lhs_xpr;
76  }
77 
79  return m_rhs_xpr;
80  }
81 
82  EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
83 
85  protected:
86  typename LhsXprType::Nested m_lhs_xpr;
87  typename RhsXprType::Nested m_rhs_xpr;
88  const Axis m_axis;
89 };
90 
91 // Eval as rvalue
92 template <typename Axis, typename LeftArgType, typename RightArgType, typename Device>
93 struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> {
95  typedef typename XprType::Index Index;
97  static constexpr int RightNumDims =
100  typedef typename XprType::Scalar Scalar;
106  enum {
107  IsAligned = false,
108  PacketAccess =
110  BlockAccess = false,
113  RawAccess = false
114  };
115 
116  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
118  //===--------------------------------------------------------------------===//
119 
120  EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
121  : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) {
124  NumDims == 1),
125  YOU_MADE_A_PROGRAMMING_MISTAKE);
126  EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE);
127  EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
128 
129  eigen_assert(0 <= m_axis && m_axis < NumDims);
130  const Dimensions& lhs_dims = m_leftImpl.dimensions();
131  const Dimensions& rhs_dims = m_rightImpl.dimensions();
132  {
133  int i = 0;
134  for (; i < m_axis; ++i) {
135  eigen_assert(lhs_dims[i] > 0);
136  eigen_assert(lhs_dims[i] == rhs_dims[i]);
137  m_dimensions[i] = lhs_dims[i];
138  }
139  eigen_assert(lhs_dims[i] > 0); // Now i == m_axis.
140  eigen_assert(rhs_dims[i] > 0);
141  m_dimensions[i] = lhs_dims[i] + rhs_dims[i];
142  for (++i; i < NumDims; ++i) {
143  eigen_assert(lhs_dims[i] > 0);
144  eigen_assert(lhs_dims[i] == rhs_dims[i]);
145  m_dimensions[i] = lhs_dims[i];
146  }
147  }
148 
149  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
150  m_leftStrides[0] = 1;
151  m_rightStrides[0] = 1;
152  m_outputStrides[0] = 1;
153 
154  for (int j = 1; j < NumDims; ++j) {
155  m_leftStrides[j] = m_leftStrides[j - 1] * lhs_dims[j - 1];
156  m_rightStrides[j] = m_rightStrides[j - 1] * rhs_dims[j - 1];
157  m_outputStrides[j] = m_outputStrides[j - 1] * m_dimensions[j - 1];
158  }
159  } else {
160  m_leftStrides[NumDims - 1] = 1;
161  m_rightStrides[NumDims - 1] = 1;
162  m_outputStrides[NumDims - 1] = 1;
163 
164  for (int j = NumDims - 2; j >= 0; --j) {
165  m_leftStrides[j] = m_leftStrides[j + 1] * lhs_dims[j + 1];
166  m_rightStrides[j] = m_rightStrides[j + 1] * rhs_dims[j + 1];
167  m_outputStrides[j] = m_outputStrides[j + 1] * m_dimensions[j + 1];
168  }
169  }
170  }
171 
172  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
173 
174  // TODO(phli): Add short-circuit memcpy evaluation if underlying data are linear?
176  m_leftImpl.evalSubExprsIfNeeded(NULL);
177  m_rightImpl.evalSubExprsIfNeeded(NULL);
178  return true;
179  }
180 
182  m_leftImpl.cleanup();
183  m_rightImpl.cleanup();
184  }
185 
186  // TODO(phli): attempt to speed this up. The integer divisions and modulo are slow.
187  // See CL/76180724 comments for more ideas.
189  // Collect dimension-wise indices (subs).
191  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
192  for (int i = NumDims - 1; i > 0; --i) {
193  subs[i] = index / m_outputStrides[i];
194  index -= subs[i] * m_outputStrides[i];
195  }
196  subs[0] = index;
197  } else {
198  for (int i = 0; i < NumDims - 1; ++i) {
199  subs[i] = index / m_outputStrides[i];
200  index -= subs[i] * m_outputStrides[i];
201  }
202  subs[NumDims - 1] = index;
203  }
204 
205  const Dimensions& left_dims = m_leftImpl.dimensions();
206  if (subs[m_axis] < left_dims[m_axis]) {
207  Index left_index;
208  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
209  left_index = subs[0];
211  for (int i = 1; i < NumDims; ++i) {
212  left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
213  }
214  } else {
215  left_index = subs[NumDims - 1];
217  for (int i = NumDims - 2; i >= 0; --i) {
218  left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
219  }
220  }
221  return m_leftImpl.coeff(left_index);
222  } else {
223  subs[m_axis] -= left_dims[m_axis];
224  const Dimensions& right_dims = m_rightImpl.dimensions();
225  Index right_index;
226  if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
227  right_index = subs[0];
229  for (int i = 1; i < NumDims; ++i) {
230  right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
231  }
232  } else {
233  right_index = subs[NumDims - 1];
235  for (int i = NumDims - 2; i >= 0; --i) {
236  right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
237  }
238  }
239  return m_rightImpl.coeff(right_index);
240  }
241  }
242 
243  // TODO(phli): Add a real vectorization.
244  template <int LoadMode>
246  const int packetSize = PacketType<CoeffReturnType, Device>::size;
247  EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
248  eigen_assert(index + packetSize - 1 < dimensions().TotalSize());
249 
250  EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
252  for (int i = 0; i < packetSize; ++i) {
253  values[i] = coeff(index + i);
254  }
255  PacketReturnType rslt = internal::pload<PacketReturnType>(values);
256  return rslt;
257  }
258 
260  const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + 2 * TensorOpCost::MulCost<Index>() +
261  TensorOpCost::DivCost<Index>() + TensorOpCost::ModCost<Index>());
262  const double lhs_size = m_leftImpl.dimensions().TotalSize();
263  const double rhs_size = m_rightImpl.dimensions().TotalSize();
264  return (lhs_size / (lhs_size + rhs_size)) * m_leftImpl.costPerCoeff(vectorized) +
265  (rhs_size / (lhs_size + rhs_size)) * m_rightImpl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
266  }
267 
268  EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
269 
270  protected:
277  const Axis m_axis;
278 };
279 
280 // Eval as lvalue
281 template <typename Axis, typename LeftArgType, typename RightArgType, typename Device>
282 struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
283  : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> {
286  typedef typename Base::Dimensions Dimensions;
288  enum {
289  IsAligned = false,
290  PacketAccess =
292  BlockAccess = false,
295  RawAccess = false
296  };
297 
298  //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
300  //===--------------------------------------------------------------------===//
301 
302  EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device) : Base(op, device) {
303  EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE);
304  }
305 
306  typedef typename XprType::Index Index;
307  typedef typename XprType::Scalar Scalar;
310 
312  // Collect dimension-wise indices (subs).
314  for (int i = Base::NumDims - 1; i > 0; --i) {
315  subs[i] = index / this->m_outputStrides[i];
316  index -= subs[i] * this->m_outputStrides[i];
317  }
318  subs[0] = index;
319 
320  const Dimensions& left_dims = this->m_leftImpl.dimensions();
321  if (subs[this->m_axis] < left_dims[this->m_axis]) {
322  Index left_index = subs[0];
323  for (int i = 1; i < Base::NumDims; ++i) {
324  left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
325  }
326  return this->m_leftImpl.coeffRef(left_index);
327  } else {
328  subs[this->m_axis] -= left_dims[this->m_axis];
329  const Dimensions& right_dims = this->m_rightImpl.dimensions();
330  Index right_index = subs[0];
331  for (int i = 1; i < Base::NumDims; ++i) {
332  right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
333  }
334  return this->m_rightImpl.coeffRef(right_index);
335  }
336  }
337 
338  template <int StoreMode>
340  const int packetSize = PacketType<CoeffReturnType, Device>::size;
341  EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
342  eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
343 
344  EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
345  internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
346  for (int i = 0; i < packetSize; ++i) {
347  coeffRef(index + i) = values[i];
348  }
349  }
350 };
351 
352 } // end namespace Eigen
353 
354 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_ALIGN_MAX
Definition: ConfigureVectorization.h:146
#define EIGEN_UNROLL_LOOP
Definition: Macros.h:1298
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
#define EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(Derived)
Macro to manually inherit assignment operators. This is necessary, because the implicitly defined ass...
Definition: TensorMacros.h:81
The tensor base class.
Definition: TensorBase.h:1026
Tensor concatenation class.
Definition: TensorConcatenation.h:60
EIGEN_DEVICE_FUNC const Axis & axis() const
Definition: TensorConcatenation.h:82
internal::traits< TensorConcatenationOp >::StorageKind StorageKind
Definition: TensorConcatenation.h:64
TensorBase< TensorConcatenationOp< Axis, LhsXprType, RhsXprType >, WriteAccessors > Base
Definition: TensorConcatenation.h:62
internal::nested< TensorConcatenationOp >::type Nested
Definition: TensorConcatenation.h:66
EIGEN_DEVICE_FUNC const internal::remove_all_t< typename RhsXprType::Nested > & rhsExpression() const
Definition: TensorConcatenation.h:78
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType &lhs, const RhsXprType &rhs, Axis axis)
Definition: TensorConcatenation.h:71
internal::traits< TensorConcatenationOp >::Index Index
Definition: TensorConcatenation.h:65
RhsXprType::Nested m_rhs_xpr
Definition: TensorConcatenation.h:87
NumTraits< Scalar >::Real RealScalar
Definition: TensorConcatenation.h:69
const Axis m_axis
Definition: TensorConcatenation.h:88
internal::traits< TensorConcatenationOp >::Scalar Scalar
Definition: TensorConcatenation.h:63
internal::promote_storage_type< typename LhsXprType::CoeffReturnType, typename RhsXprType::CoeffReturnType >::ret CoeffReturnType
Definition: TensorConcatenation.h:68
LhsXprType::Nested m_lhs_xpr
Definition: TensorConcatenation.h:86
EIGEN_DEVICE_FUNC const internal::remove_all_t< typename LhsXprType::Nested > & lhsExpression() const
Definition: TensorConcatenation.h:74
Definition: TensorCostModel.h:28
Definition: TensorBlock.h:566
@ WriteAccessors
Definition: Constants.h:374
@ ColMajor
Definition: Constants.h:318
Eigen::DenseIndex ret
Definition: level1_cplx_impl.h:43
char char * op
Definition: level2_impl.h:374
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
Extend namespace for flags.
Definition: fsi_chan_precond_driver.cc:56
val
Definition: calibrate.py:119
type
Definition: compute_granudrum_aor.py:141
Definition: Eigen_Colamd.h:49
list x
Definition: plotDoE.py:28
Definition: Constants.h:519
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
Definition: NumTraits.h:217
Definition: TensorMeta.h:47
Definition: TensorForwardDeclarations.h:42
TensorConcatenationOp< Axis, LeftArgType, RightArgType > XprType
Definition: TensorConcatenation.h:285
internal::TensorBlockNotImplemented TensorBlock
Definition: TensorConcatenation.h:299
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writePacket(Index index, const PacketReturnType &x) const
Definition: TensorConcatenation.h:339
XprType::CoeffReturnType CoeffReturnType
Definition: TensorConcatenation.h:308
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorConcatenation.h:309
TensorEvaluator< const TensorConcatenationOp< Axis, LeftArgType, RightArgType >, Device > Base
Definition: TensorConcatenation.h:284
EIGEN_STRONG_INLINE TensorEvaluator(XprType &op, const Device &device)
Definition: TensorConcatenation.h:302
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType & coeffRef(Index index) const
Definition: TensorConcatenation.h:311
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorConcatenation.h:188
TensorEvaluator< RightArgType, Device > m_rightImpl
Definition: TensorConcatenation.h:276
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorConcatenation.h:102
TensorEvaluator< LeftArgType, Device > m_leftImpl
Definition: TensorConcatenation.h:275
array< Index, NumDims > m_rightStrides
Definition: TensorConcatenation.h:274
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const
Definition: TensorConcatenation.h:268
EIGEN_STRONG_INLINE TensorEvaluator(const XprType &op, const Device &device)
Definition: TensorConcatenation.h:120
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const
Definition: TensorConcatenation.h:259
StorageMemory< CoeffReturnType, Device > Storage
Definition: TensorConcatenation.h:103
array< Index, NumDims > m_leftStrides
Definition: TensorConcatenation.h:273
array< Index, NumDims > m_outputStrides
Definition: TensorConcatenation.h:272
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
Definition: TensorConcatenation.h:245
TensorConcatenationOp< Axis, LeftArgType, RightArgType > XprType
Definition: TensorConcatenation.h:94
internal::TensorBlockNotImplemented TensorBlock
Definition: TensorConcatenation.h:117
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorConcatenation.h:172
XprType::CoeffReturnType CoeffReturnType
Definition: TensorConcatenation.h:101
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType)
Definition: TensorConcatenation.h:175
EIGEN_STRONG_INLINE void cleanup()
Definition: TensorConcatenation.h:181
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:31
static constexpr int Layout
Definition: TensorEvaluator.h:46
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType & coeffRef(Index index) const
Definition: TensorEvaluator.h:94
@ PacketAccess
Definition: TensorEvaluator.h:50
@ IsAligned
Definition: TensorEvaluator.h:49
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorEvaluator.h:89
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorEvaluator.h:69
Definition: Meta.h:305
const TensorConcatenationOp< Axis, LhsXprType, RhsXprType > & type
Definition: TensorConcatenation.h:48
Definition: XprHelper.h:427
Definition: TensorTraits.h:152
ref_selector< T >::type type
Definition: TensorTraits.h:153
Definition: XprHelper.h:145
Definition: XprHelper.h:591
RhsXprType::Nested RhsNested
Definition: TensorConcatenation.h:35
promote_storage_type< typename LhsXprType::Scalar, typename RhsXprType::Scalar >::ret Scalar
Definition: TensorConcatenation.h:29
LhsXprType::Nested LhsNested
Definition: TensorConcatenation.h:34
promote_storage_type< typename traits< LhsXprType >::StorageKind, typename traits< RhsXprType >::StorageKind >::ret StorageKind
Definition: TensorConcatenation.h:31
std::remove_reference_t< LhsNested > LhsNested_
Definition: TensorConcatenation.h:36
promote_index_type< typename traits< LhsXprType >::Index, typename traits< RhsXprType >::Index >::type Index
Definition: TensorConcatenation.h:33
std::conditional_t< Pointer_type_promotion< typename LhsXprType::Scalar, Scalar >::val, typename traits< LhsXprType >::PointerType, typename traits< RhsXprType >::PointerType > PointerType
Definition: TensorConcatenation.h:43
std::remove_reference_t< RhsNested > RhsNested_
Definition: TensorConcatenation.h:37
Definition: ForwardDeclarations.h:21
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2