10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
26 template <
typename PatchDim,
typename XprType>
32 typedef typename XprType::Nested
Nested;
33 typedef std::remove_reference_t<Nested>
Nested_;
34 static constexpr
int NumDimensions = XprTraits::NumDimensions + 1;
35 static constexpr
int Layout = XprTraits::Layout;
39 template <
typename PatchDim,
typename XprType>
44 template <
typename PatchDim,
typename XprType>
51 template <
typename PatchDim,
typename XprType>
74 template <
typename PatchDim,
typename ArgType,
typename Device>
102 Index num_patches = 1;
104 const PatchDim& patch_dims =
op.patch_dims();
106 for (
int i = 0;
i < NumDims - 1; ++
i) {
107 m_dimensions[
i] = patch_dims[
i];
108 num_patches *= (input_dims[
i] - patch_dims[
i] + 1);
110 m_dimensions[NumDims - 1] = num_patches;
112 m_inputStrides[0] = 1;
113 m_patchStrides[0] = 1;
114 for (
int i = 1;
i < NumDims - 1; ++
i) {
115 m_inputStrides[
i] = m_inputStrides[
i - 1] * input_dims[
i - 1];
116 m_patchStrides[
i] = m_patchStrides[
i - 1] * (input_dims[
i - 1] - patch_dims[
i - 1] + 1);
118 m_outputStrides[0] = 1;
119 for (
int i = 1;
i < NumDims; ++
i) {
120 m_outputStrides[
i] = m_outputStrides[
i - 1] * m_dimensions[
i - 1];
123 for (
int i = 0;
i < NumDims - 1; ++
i) {
124 m_dimensions[
i + 1] = patch_dims[
i];
125 num_patches *= (input_dims[
i] - patch_dims[
i] + 1);
127 m_dimensions[0] = num_patches;
129 m_inputStrides[NumDims - 2] = 1;
130 m_patchStrides[NumDims - 2] = 1;
131 for (
int i = NumDims - 3;
i >= 0; --
i) {
132 m_inputStrides[
i] = m_inputStrides[
i + 1] * input_dims[
i + 1];
133 m_patchStrides[
i] = m_patchStrides[
i + 1] * (input_dims[
i + 1] - patch_dims[
i + 1] + 1);
135 m_outputStrides[NumDims - 1] = 1;
136 for (
int i = NumDims - 2;
i >= 0; --
i) {
137 m_outputStrides[
i] = m_outputStrides[
i + 1] * m_dimensions[
i + 1];
145 m_impl.evalSubExprsIfNeeded(NULL);
152 Index output_stride_index = (
static_cast<int>(
Layout) ==
static_cast<int>(
ColMajor)) ? NumDims - 1 : 0;
154 Index patchIndex = index / m_outputStrides[output_stride_index];
156 Index patchOffset = index - patchIndex * m_outputStrides[output_stride_index];
157 Index inputIndex = 0;
160 for (
int i = NumDims - 2;
i > 0; --
i) {
161 const Index patchIdx = patchIndex / m_patchStrides[
i];
162 patchIndex -= patchIdx * m_patchStrides[
i];
163 const Index offsetIdx = patchOffset / m_outputStrides[
i];
164 patchOffset -= offsetIdx * m_outputStrides[
i];
165 inputIndex += (patchIdx + offsetIdx) * m_inputStrides[
i];
169 for (
int i = 0;
i < NumDims - 2; ++
i) {
170 const Index patchIdx = patchIndex / m_patchStrides[
i];
171 patchIndex -= patchIdx * m_patchStrides[
i];
172 const Index offsetIdx = patchOffset / m_outputStrides[
i + 1];
173 patchOffset -= offsetIdx * m_outputStrides[
i + 1];
174 inputIndex += (patchIdx + offsetIdx) * m_inputStrides[
i];
177 inputIndex += (patchIndex + patchOffset);
178 return m_impl.coeff(inputIndex);
181 template <
int LoadMode>
185 Index output_stride_index = (
static_cast<int>(
Layout) ==
static_cast<int>(
ColMajor)) ? NumDims - 1 : 0;
187 Index patchIndices[2] = {indices[0] / m_outputStrides[output_stride_index],
188 indices[1] / m_outputStrides[output_stride_index]};
189 Index patchOffsets[2] = {indices[0] - patchIndices[0] * m_outputStrides[output_stride_index],
190 indices[1] - patchIndices[1] * m_outputStrides[output_stride_index]};
192 Index inputIndices[2] = {0, 0};
195 for (
int i = NumDims - 2;
i > 0; --
i) {
196 const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[
i], patchIndices[1] / m_patchStrides[
i]};
197 patchIndices[0] -= patchIdx[0] * m_patchStrides[
i];
198 patchIndices[1] -= patchIdx[1] * m_patchStrides[
i];
200 const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[
i], patchOffsets[1] / m_outputStrides[
i]};
201 patchOffsets[0] -= offsetIdx[0] * m_outputStrides[
i];
202 patchOffsets[1] -= offsetIdx[1] * m_outputStrides[
i];
204 inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[
i];
205 inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[
i];
209 for (
int i = 0;
i < NumDims - 2; ++
i) {
210 const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[
i], patchIndices[1] / m_patchStrides[
i]};
211 patchIndices[0] -= patchIdx[0] * m_patchStrides[
i];
212 patchIndices[1] -= patchIdx[1] * m_patchStrides[
i];
214 const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[
i + 1], patchOffsets[1] / m_outputStrides[
i + 1]};
215 patchOffsets[0] -= offsetIdx[0] * m_outputStrides[
i + 1];
216 patchOffsets[1] -= offsetIdx[1] * m_outputStrides[
i + 1];
218 inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[
i];
219 inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[
i];
222 inputIndices[0] += (patchIndices[0] + patchOffsets[0]);
223 inputIndices[1] += (patchIndices[1] + patchOffsets[1]);
225 if (inputIndices[1] - inputIndices[0] ==
PacketSize - 1) {
230 values[0] = m_impl.coeff(inputIndices[0]);
231 values[
PacketSize - 1] = m_impl.coeff(inputIndices[1]);
242 const double compute_cost = NumDims * (TensorOpCost::DivCost<Index>() + TensorOpCost::MulCost<Index>() +
243 2 * TensorOpCost::AddCost<Index>());
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_UNROLL_LOOP
Definition: Macros.h:1298
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
SCALAR Scalar
Definition: bench_gemm.cpp:45
Generic expression where a coefficient-wise binary operator is applied to two expressions.
Definition: CwiseBinaryOp.h:79
The tensor base class.
Definition: TensorBase.h:1026
Definition: TensorCostModel.h:28
Definition: TensorPatch.h:52
Eigen::internal::traits< TensorPatchOp >::StorageKind StorageKind
Definition: TensorPatch.h:58
EIGEN_DEVICE_FUNC const internal::remove_all_t< typename XprType::Nested > & expression() const
Definition: TensorPatch.h:66
const PatchDim m_patch_dims
Definition: TensorPatch.h:70
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPatchOp(const XprType &expr, const PatchDim &patch_dims)
Definition: TensorPatch.h:61
Eigen::internal::traits< TensorPatchOp >::Index Index
Definition: TensorPatch.h:59
Eigen::internal::traits< TensorPatchOp >::Scalar Scalar
Definition: TensorPatch.h:54
Eigen::NumTraits< Scalar >::Real RealScalar
Definition: TensorPatch.h:55
XprType::CoeffReturnType CoeffReturnType
Definition: TensorPatch.h:56
Eigen::internal::nested< TensorPatchOp >::type Nested
Definition: TensorPatch.h:57
XprType::Nested m_xpr
Definition: TensorPatch.h:69
EIGEN_DEVICE_FUNC const PatchDim & patch_dims() const
Definition: TensorPatch.h:64
Definition: TensorBlock.h:566
@ ColMajor
Definition: Constants.h:318
char char * op
Definition: level2_impl.h:374
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Definition: Eigen_Colamd.h:49
Definition: Constants.h:519
T Real
Definition: NumTraits.h:183
Definition: TensorMeta.h:47
Definition: TensorForwardDeclarations.h:42
Dimensions m_dimensions
Definition: TensorPatch.h:250
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const
Definition: TensorPatch.h:241
EIGEN_STRONG_INLINE void cleanup()
Definition: TensorPatch.h:149
DSizes< Index, NumDims > Dimensions
Definition: TensorPatch.h:79
TensorEvaluator< ArgType, Device > m_impl
Definition: TensorPatch.h:255
XprType::Scalar Scalar
Definition: TensorPatch.h:80
EIGEN_STRONG_INLINE TensorEvaluator(const XprType &op, const Device &device)
Definition: TensorPatch.h:101
TensorPatchOp< PatchDim, ArgType > XprType
Definition: TensorPatch.h:76
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
Definition: TensorPatch.h:182
array< Index, NumDims > m_outputStrides
Definition: TensorPatch.h:251
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorPatch.h:82
XprType::Index Index
Definition: TensorPatch.h:77
XprType::CoeffReturnType CoeffReturnType
Definition: TensorPatch.h:81
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorPatch.h:151
array< Index, NumDims - 1 > m_patchStrides
Definition: TensorPatch.h:253
StorageMemory< CoeffReturnType, Device > Storage
Definition: TensorPatch.h:84
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorPatch.h:142
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType)
Definition: TensorPatch.h:144
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const
Definition: TensorPatch.h:247
Storage::Type EvaluatorPointerType
Definition: TensorPatch.h:85
internal::TensorBlockNotImplemented TensorBlock
Definition: TensorPatch.h:98
array< Index, NumDims - 1 > m_inputStrides
Definition: TensorPatch.h:252
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:31
static constexpr int Layout
Definition: TensorEvaluator.h:46
@ PacketAccess
Definition: TensorEvaluator.h:50
@ IsAligned
Definition: TensorEvaluator.h:49
static constexpr int PacketSize
Definition: TensorEvaluator.h:38
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
Definition: TensorEvaluator.h:89
Derived::Index Index
Definition: TensorEvaluator.h:32
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: TensorEvaluator.h:69
const TensorPatchOp< PatchDim, XprType > & type
Definition: TensorPatch.h:41
Definition: XprHelper.h:427
TensorPatchOp< PatchDim, XprType > type
Definition: TensorPatch.h:46
Definition: TensorTraits.h:152
ref_selector< T >::type type
Definition: TensorTraits.h:153
XprType::Scalar Scalar
Definition: TensorPatch.h:28
XprTraits::StorageKind StorageKind
Definition: TensorPatch.h:30
XprTraits::Index Index
Definition: TensorPatch.h:31
XprType::Nested Nested
Definition: TensorPatch.h:32
XprTraits::PointerType PointerType
Definition: TensorPatch.h:36
traits< XprType > XprTraits
Definition: TensorPatch.h:29
std::remove_reference_t< Nested > Nested_
Definition: TensorPatch.h:33
Definition: ForwardDeclarations.h:21