cxx11_tensor_patch.cpp File Reference
#include "main.h"
#include <Eigen/CXX11/Tensor>

Functions

template<int DataLayout>
static void test_simple_patch ()
 
 EIGEN_DECLARE_TEST (cxx11_tensor_patch)
 

Function Documentation

◆ EIGEN_DECLARE_TEST()

EIGEN_DECLARE_TEST ( cxx11_tensor_patch  )
166  {
167  CALL_SUBTEST(test_simple_patch<ColMajor>());
168  CALL_SUBTEST(test_simple_patch<RowMajor>());
169  // CALL_SUBTEST(test_expr_shuffling());
170 }
#define CALL_SUBTEST(FUNC)
Definition: main.h:382

References CALL_SUBTEST.

◆ test_simple_patch()

template<int DataLayout>
static void test_simple_patch ( )
static
17  {
18  Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
19  tensor.setRandom();
20  array<ptrdiff_t, 4> patch_dims;
21 
22  patch_dims[0] = 1;
23  patch_dims[1] = 1;
24  patch_dims[2] = 1;
25  patch_dims[3] = 1;
26 
28  no_patch = tensor.extract_patches(patch_dims);
29 
30  if (DataLayout == ColMajor) {
31  VERIFY_IS_EQUAL(no_patch.dimension(0), 1);
32  VERIFY_IS_EQUAL(no_patch.dimension(1), 1);
33  VERIFY_IS_EQUAL(no_patch.dimension(2), 1);
34  VERIFY_IS_EQUAL(no_patch.dimension(3), 1);
35  VERIFY_IS_EQUAL(no_patch.dimension(4), tensor.size());
36  } else {
37  VERIFY_IS_EQUAL(no_patch.dimension(0), tensor.size());
38  VERIFY_IS_EQUAL(no_patch.dimension(1), 1);
39  VERIFY_IS_EQUAL(no_patch.dimension(2), 1);
40  VERIFY_IS_EQUAL(no_patch.dimension(3), 1);
41  VERIFY_IS_EQUAL(no_patch.dimension(4), 1);
42  }
43 
44  for (int i = 0; i < tensor.size(); ++i) {
45  VERIFY_IS_EQUAL(tensor.data()[i], no_patch.data()[i]);
46  }
47 
48  patch_dims[0] = 2;
49  patch_dims[1] = 3;
50  patch_dims[2] = 5;
51  patch_dims[3] = 7;
52  Tensor<float, 5, DataLayout> single_patch;
53  single_patch = tensor.extract_patches(patch_dims);
54 
55  if (DataLayout == ColMajor) {
56  VERIFY_IS_EQUAL(single_patch.dimension(0), 2);
57  VERIFY_IS_EQUAL(single_patch.dimension(1), 3);
58  VERIFY_IS_EQUAL(single_patch.dimension(2), 5);
59  VERIFY_IS_EQUAL(single_patch.dimension(3), 7);
60  VERIFY_IS_EQUAL(single_patch.dimension(4), 1);
61  } else {
62  VERIFY_IS_EQUAL(single_patch.dimension(0), 1);
63  VERIFY_IS_EQUAL(single_patch.dimension(1), 2);
64  VERIFY_IS_EQUAL(single_patch.dimension(2), 3);
65  VERIFY_IS_EQUAL(single_patch.dimension(3), 5);
66  VERIFY_IS_EQUAL(single_patch.dimension(4), 7);
67  }
68 
69  for (int i = 0; i < tensor.size(); ++i) {
70  VERIFY_IS_EQUAL(tensor.data()[i], single_patch.data()[i]);
71  }
72 
73  patch_dims[0] = 1;
74  patch_dims[1] = 2;
75  patch_dims[2] = 2;
76  patch_dims[3] = 1;
78  twod_patch = tensor.extract_patches(patch_dims);
79 
80  if (DataLayout == ColMajor) {
81  VERIFY_IS_EQUAL(twod_patch.dimension(0), 1);
82  VERIFY_IS_EQUAL(twod_patch.dimension(1), 2);
83  VERIFY_IS_EQUAL(twod_patch.dimension(2), 2);
84  VERIFY_IS_EQUAL(twod_patch.dimension(3), 1);
85  VERIFY_IS_EQUAL(twod_patch.dimension(4), 2 * 2 * 4 * 7);
86  } else {
87  VERIFY_IS_EQUAL(twod_patch.dimension(0), 2 * 2 * 4 * 7);
88  VERIFY_IS_EQUAL(twod_patch.dimension(1), 1);
89  VERIFY_IS_EQUAL(twod_patch.dimension(2), 2);
90  VERIFY_IS_EQUAL(twod_patch.dimension(3), 2);
91  VERIFY_IS_EQUAL(twod_patch.dimension(4), 1);
92  }
93 
94  for (int i = 0; i < 2; ++i) {
95  for (int j = 0; j < 2; ++j) {
96  for (int k = 0; k < 4; ++k) {
97  for (int l = 0; l < 7; ++l) {
98  int patch_loc;
99  if (DataLayout == ColMajor) {
100  patch_loc = i + 2 * (j + 2 * (k + 4 * l));
101  } else {
102  patch_loc = l + 7 * (k + 4 * (j + 2 * i));
103  }
104  for (int x = 0; x < 2; ++x) {
105  for (int y = 0; y < 2; ++y) {
106  if (DataLayout == ColMajor) {
107  VERIFY_IS_EQUAL(tensor(i, j + x, k + y, l), twod_patch(0, x, y, 0, patch_loc));
108  } else {
109  VERIFY_IS_EQUAL(tensor(i, j + x, k + y, l), twod_patch(patch_loc, 0, x, y, 0));
110  }
111  }
112  }
113  }
114  }
115  }
116  }
117 
118  patch_dims[0] = 1;
119  patch_dims[1] = 2;
120  patch_dims[2] = 3;
121  patch_dims[3] = 5;
122  Tensor<float, 5, DataLayout> threed_patch;
123  threed_patch = tensor.extract_patches(patch_dims);
124 
125  if (DataLayout == ColMajor) {
126  VERIFY_IS_EQUAL(threed_patch.dimension(0), 1);
127  VERIFY_IS_EQUAL(threed_patch.dimension(1), 2);
128  VERIFY_IS_EQUAL(threed_patch.dimension(2), 3);
129  VERIFY_IS_EQUAL(threed_patch.dimension(3), 5);
130  VERIFY_IS_EQUAL(threed_patch.dimension(4), 2 * 2 * 3 * 3);
131  } else {
132  VERIFY_IS_EQUAL(threed_patch.dimension(0), 2 * 2 * 3 * 3);
133  VERIFY_IS_EQUAL(threed_patch.dimension(1), 1);
134  VERIFY_IS_EQUAL(threed_patch.dimension(2), 2);
135  VERIFY_IS_EQUAL(threed_patch.dimension(3), 3);
136  VERIFY_IS_EQUAL(threed_patch.dimension(4), 5);
137  }
138 
139  for (int i = 0; i < 2; ++i) {
140  for (int j = 0; j < 2; ++j) {
141  for (int k = 0; k < 3; ++k) {
142  for (int l = 0; l < 3; ++l) {
143  int patch_loc;
144  if (DataLayout == ColMajor) {
145  patch_loc = i + 2 * (j + 2 * (k + 3 * l));
146  } else {
147  patch_loc = l + 3 * (k + 3 * (j + 2 * i));
148  }
149  for (int x = 0; x < 2; ++x) {
150  for (int y = 0; y < 3; ++y) {
151  for (int z = 0; z < 5; ++z) {
152  if (DataLayout == ColMajor) {
153  VERIFY_IS_EQUAL(tensor(i, j + x, k + y, l + z), threed_patch(0, x, y, z, patch_loc));
154  } else {
155  VERIFY_IS_EQUAL(tensor(i, j + x, k + y, l + z), threed_patch(patch_loc, 0, x, y, z));
156  }
157  }
158  }
159  }
160  }
161  }
162  }
163  }
164 }
int i
Definition: BiCGSTAB_step_by_step.cpp:9
The tensor class.
Definition: Tensor.h:68
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:99
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data()
Definition: Tensor.h:102
static const int DataLayout
Definition: cxx11_tensor_image_patch_sycl.cpp:24
@ ColMajor
Definition: Constants.h:318
Scalar * y
Definition: level1_cplx_impl.h:128
char char char int int * k
Definition: level2_impl.h:374
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:367
std::array< T, N > array
Definition: EmulateArray.h:231
list x
Definition: plotDoE.py:28
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2

References Eigen::ColMajor, Eigen::Tensor< Scalar_, NumIndices_, Options_, IndexType_ >::data(), DataLayout, Eigen::Tensor< Scalar_, NumIndices_, Options_, IndexType_ >::dimension(), i, j, k, Eigen::TensorBase< Derived, AccessLevel >::setRandom(), Eigen::Tensor< Scalar_, NumIndices_, Options_, IndexType_ >::size(), VERIFY_IS_EQUAL, plotDoE::x, and y.