11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
12 #define EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
28 template <
typename Dims,
typename XprType>
34 typedef typename XprType::Nested
Nested;
35 typedef std::remove_reference_t<Nested>
Nested_;
37 static constexpr
int Layout = XprTraits::Layout;
40 template <
typename Dims,
typename XprType>
45 template <
typename Dims,
typename XprType>
52 template <
typename Dims,
typename XprType>
77 template <
typename Dims,
typename ArgType,
typename Device>
80 static constexpr
int NumInputDims =
83 static constexpr
int NumOutputDims = NumInputDims - NumReducedDims;
108 : m_impl(
op.expression(), device), m_traceDim(1),
m_device(device) {
110 EIGEN_STATIC_ASSERT((NumReducedDims >= 2) || ((NumReducedDims == 0) && (NumInputDims == 0)),
111 YOU_MADE_A_PROGRAMMING_MISTAKE);
113 for (
int i = 0;
i < NumInputDims; ++
i) {
114 m_reduced[
i] =
false;
117 const Dims& op_dims =
op.dims();
118 for (
int i = 0;
i < NumReducedDims; ++
i) {
121 m_reduced[op_dims[
i]] =
true;
125 int num_distinct_reduce_dims = 0;
126 for (
int i = 0;
i < NumInputDims; ++
i) {
128 ++num_distinct_reduce_dims;
133 eigen_assert(num_distinct_reduce_dims == NumReducedDims);
138 int output_index = 0;
139 int reduced_index = 0;
140 for (
int i = 0;
i < NumInputDims; ++
i) {
142 m_reducedDims[reduced_index] = input_dims[
i];
143 if (reduced_index > 0) {
145 eigen_assert(m_reducedDims[0] == m_reducedDims[reduced_index]);
149 m_dimensions[output_index] = input_dims[
i];
154 if (NumReducedDims != 0) {
155 m_traceDim = m_reducedDims[0];
159 if (NumOutputDims > 0) {
161 m_outputStrides[0] = 1;
162 for (
int i = 1;
i < NumOutputDims; ++
i) {
163 m_outputStrides[
i] = m_outputStrides[
i - 1] * m_dimensions[
i - 1];
166 m_outputStrides.back() = 1;
167 for (
int i = NumOutputDims - 2;
i >= 0; --
i) {
168 m_outputStrides[
i] = m_outputStrides[
i + 1] * m_dimensions[
i + 1];
174 if (NumInputDims > 0) {
177 input_strides[0] = 1;
178 for (
int i = 1;
i < NumInputDims; ++
i) {
179 input_strides[
i] = input_strides[
i - 1] * input_dims[
i - 1];
182 input_strides.back() = 1;
183 for (
int i = NumInputDims - 2;
i >= 0; --
i) {
184 input_strides[
i] = input_strides[
i + 1] * input_dims[
i + 1];
190 for (
int i = 0;
i < NumInputDims; ++
i) {
192 m_reducedStrides[reduced_index] = input_strides[
i];
195 m_preservedStrides[output_index] = input_strides[
i];
205 m_impl.evalSubExprsIfNeeded(NULL);
214 Index index_stride = 0;
215 for (
int i = 0;
i < NumReducedDims; ++
i) {
216 index_stride += m_reducedStrides[
i];
221 if (NumOutputDims != 0) cur_index = firstInput(index);
222 for (
Index i = 0;
i < m_traceDim; ++
i) {
223 result += m_impl.coeff(cur_index);
224 cur_index += index_stride;
230 template <
int LoadMode>
238 PacketReturnType result = internal::ploadt<PacketReturnType, LoadMode>(values);
245 Index startInput = 0;
247 for (
int i = NumOutputDims - 1;
i > 0; --
i) {
248 const Index idx = index / m_outputStrides[
i];
249 startInput += idx * m_preservedStrides[
i];
250 index -= idx * m_outputStrides[
i];
252 startInput += index * m_preservedStrides[0];
254 for (
int i = 0;
i < NumOutputDims - 1; ++
i) {
255 const Index idx = index / m_outputStrides[
i];
256 startInput += idx * m_preservedStrides[
i];
257 index -= idx * m_outputStrides[
i];
259 startInput += index * m_preservedStrides[NumOutputDims - 1];
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define EIGEN_ONLY_USED_FOR_DEBUG(x)
Definition: Macros.h:922
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
#define EIGEN_DEVICE_REF
Definition: TensorMacros.h:34
SCALAR Scalar
Definition: bench_gemm.cpp:45
Generic expression where a coefficient-wise binary operator is applied to two expressions.
Definition: CwiseBinaryOp.h:79
The tensor base class.
Definition: TensorBase.h:1026
Definition: TensorTrace.h:53
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const internal::remove_all_t< typename XprType::Nested > & expression() const
Definition: TensorTrace.h:67
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dims & dims() const
Definition: TensorTrace.h:65
Eigen::internal::nested< TensorTraceOp >::type Nested
Definition: TensorTrace.h:58
const Dims m_dims
Definition: TensorTrace.h:73
XprType::Nested m_xpr
Definition: TensorTrace.h:72
XprType::CoeffReturnType CoeffReturnType
Definition: TensorTrace.h:57
Eigen::internal::traits< TensorTraceOp >::StorageKind StorageKind
Definition: TensorTrace.h:59
Eigen::internal::traits< TensorTraceOp >::Scalar Scalar
Definition: TensorTrace.h:55
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorTrace.h:56
Eigen::internal::traits< TensorTraceOp >::Index Index
Definition: TensorTrace.h:60
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTraceOp(const XprType &expr, const Dims &dims)
Definition: TensorTrace.h:62
Definition: TensorBlock.h:566
@ ColMajor
Definition: Constants.h:318
char char * op
Definition: level2_impl.h:374
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Definition: Eigen_Colamd.h:49
Definition: Constants.h:519
T Real
Definition: NumTraits.h:183
Definition: TensorForwardDeclarations.h:42
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorTrace.h:211
XprType::Index Index
Definition: TensorTrace.h:84
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorTrace.h:202
internal::TensorBlockNotImplemented TensorBlock
Definition: TensorTrace.h:104
array< Index, NumOutputDims > m_preservedStrides
Definition: TensorTrace.h:273
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index firstInput(Index index) const
Definition: TensorTrace.h:244
array< Index, NumReducedDims > m_reducedDims
Definition: TensorTrace.h:270
array< bool, NumInputDims > m_reduced
Definition: TensorTrace.h:269
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorTrace.h:88
EIGEN_STRONG_INLINE void cleanup()
Definition: TensorTrace.h:209
Storage::Type EvaluatorPointerType
Definition: TensorTrace.h:91
TensorTraceOp< Dims, ArgType > XprType
Definition: TensorTrace.h:79
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
Definition: TensorTrace.h:231
array< Index, NumReducedDims > m_reducedStrides
Definition: TensorTrace.h:272
const Device EIGEN_DEVICE_REF m_device
Definition: TensorTrace.h:268
XprType::CoeffReturnType CoeffReturnType
Definition: TensorTrace.h:87
array< Index, NumOutputDims > m_outputStrides
Definition: TensorTrace.h:271
Index m_traceDim
Definition: TensorTrace.h:267
EIGEN_STRONG_INLINE TensorEvaluator(const XprType &op, const Device &device)
Definition: TensorTrace.h:107
TensorEvaluator< ArgType, Device > m_impl
Definition: TensorTrace.h:265
Dimensions m_dimensions
Definition: TensorTrace.h:264
StorageMemory< CoeffReturnType, Device > Storage
Definition: TensorTrace.h:90
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType)
Definition: TensorTrace.h:204
DSizes< Index, NumOutputDims > Dimensions
Definition: TensorTrace.h:85
XprType::Scalar Scalar
Definition: TensorTrace.h:86
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:31
static constexpr int Layout
Definition: TensorEvaluator.h:46
const Device EIGEN_DEVICE_REF m_device
Definition: TensorEvaluator.h:170
@ PacketAccess
Definition: TensorEvaluator.h:50
@ IsAligned
Definition: TensorEvaluator.h:49
static constexpr int PacketSize
Definition: TensorEvaluator.h:38
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorEvaluator.h:89
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorEvaluator.h:69
const TensorTraceOp< Dims, XprType > & type
Definition: TensorTrace.h:42
Definition: XprHelper.h:427
TensorTraceOp< Dims, XprType > type
Definition: TensorTrace.h:47
Definition: TensorTraits.h:152
ref_selector< T >::type type
Definition: TensorTraits.h:153
std::remove_reference_t< Nested > Nested_
Definition: TensorTrace.h:35
XprType::Scalar Scalar
Definition: TensorTrace.h:30
XprType::Nested Nested
Definition: TensorTrace.h:34
XprTraits::StorageKind StorageKind
Definition: TensorTrace.h:32
traits< XprType > XprTraits
Definition: TensorTrace.h:31
XprTraits::Index Index
Definition: TensorTrace.h:33
Definition: ForwardDeclarations.h:21
Definition: GenericPacketMath.h:134