cxx11_tensor_striding.cpp File Reference
#include "main.h"
#include <Eigen/CXX11/Tensor>

Functions

template<int DataLayout>
static void test_simple_striding ()
 
template<int DataLayout>
static void test_striding_as_lvalue ()
 
 EIGEN_DECLARE_TEST (cxx11_tensor_striding)
 

Function Documentation

◆ EIGEN_DECLARE_TEST()

EIGEN_DECLARE_TEST ( cxx11_tensor_striding  )
109  {
110  CALL_SUBTEST(test_simple_striding<ColMajor>());
111  CALL_SUBTEST(test_simple_striding<RowMajor>());
112  CALL_SUBTEST(test_striding_as_lvalue<ColMajor>());
113  CALL_SUBTEST(test_striding_as_lvalue<RowMajor>());
114 }
#define CALL_SUBTEST(FUNC)
Definition: main.h:382

References CALL_SUBTEST.

◆ test_simple_striding()

template<int DataLayout>
static void test_simple_striding ( )
static
17  {
18  Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
19  tensor.setRandom();
21  strides[0] = 1;
22  strides[1] = 1;
23  strides[2] = 1;
24  strides[3] = 1;
25 
27  no_stride = tensor.stride(strides);
28 
29  VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
30  VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
31  VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
32  VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
33 
34  for (int i = 0; i < 2; ++i) {
35  for (int j = 0; j < 3; ++j) {
36  for (int k = 0; k < 5; ++k) {
37  for (int l = 0; l < 7; ++l) {
38  VERIFY_IS_EQUAL(tensor(i, j, k, l), no_stride(i, j, k, l));
39  }
40  }
41  }
42  }
43 
44  strides[0] = 2;
45  strides[1] = 4;
46  strides[2] = 2;
47  strides[3] = 3;
49  stride = tensor.stride(strides);
50 
51  VERIFY_IS_EQUAL(stride.dimension(0), 1);
52  VERIFY_IS_EQUAL(stride.dimension(1), 1);
53  VERIFY_IS_EQUAL(stride.dimension(2), 3);
54  VERIFY_IS_EQUAL(stride.dimension(3), 3);
55 
56  for (int i = 0; i < 1; ++i) {
57  for (int j = 0; j < 1; ++j) {
58  for (int k = 0; k < 3; ++k) {
59  for (int l = 0; l < 3; ++l) {
60  VERIFY_IS_EQUAL(tensor(2 * i, 4 * j, 2 * k, 3 * l), stride(i, j, k, l));
61  }
62  }
63  }
64  }
65 }
int i
Definition: BiCGSTAB_step_by_step.cpp:9
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorStridingOp< const Strides, const Derived > stride(const Strides &strides) const
Definition: TensorBase.h:1198
The tensor class.
Definition: Tensor.h:68
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:99
char char char int int * k
Definition: level2_impl.h:374
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:367
EIGEN_ALWAYS_INLINE DSizes< IndexType, NumDims > strides(const DSizes< IndexType, NumDims > &dimensions)
Definition: TensorBlock.h:29
std::array< T, N > array
Definition: EmulateArray.h:231
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2

References Eigen::Tensor< Scalar_, NumIndices_, Options_, IndexType_ >::dimension(), i, j, k, Eigen::TensorBase< Derived, AccessLevel >::setRandom(), Eigen::TensorBase< Derived, AccessLevel >::stride(), Eigen::internal::strides(), and VERIFY_IS_EQUAL.

◆ test_striding_as_lvalue()

template<int DataLayout>
static void test_striding_as_lvalue ( )
static
68  {
69  Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
70  tensor.setRandom();
72  strides[0] = 2;
73  strides[1] = 4;
74  strides[2] = 2;
75  strides[3] = 3;
76 
77  Tensor<float, 4, DataLayout> result(3, 12, 10, 21);
78  result.stride(strides) = tensor;
79 
80  for (int i = 0; i < 2; ++i) {
81  for (int j = 0; j < 3; ++j) {
82  for (int k = 0; k < 5; ++k) {
83  for (int l = 0; l < 7; ++l) {
84  VERIFY_IS_EQUAL(tensor(i, j, k, l), result(2 * i, 4 * j, 2 * k, 3 * l));
85  }
86  }
87  }
88  }
89 
90  array<ptrdiff_t, 4> no_strides;
91  no_strides[0] = 1;
92  no_strides[1] = 1;
93  no_strides[2] = 1;
94  no_strides[3] = 1;
95  Tensor<float, 4, DataLayout> result2(3, 12, 10, 21);
96  result2.stride(strides) = tensor.stride(no_strides);
97 
98  for (int i = 0; i < 2; ++i) {
99  for (int j = 0; j < 3; ++j) {
100  for (int k = 0; k < 5; ++k) {
101  for (int l = 0; l < 7; ++l) {
102  VERIFY_IS_EQUAL(tensor(i, j, k, l), result2(2 * i, 4 * j, 2 * k, 3 * l));
103  }
104  }
105  }
106  }
107 }

References i, j, k, Eigen::TensorBase< Derived, AccessLevel >::setRandom(), Eigen::TensorBase< Derived, AccessLevel >::stride(), Eigen::internal::strides(), and VERIFY_IS_EQUAL.