cxx11_tensor_concatenation.cpp File Reference
#include "main.h"
#include <Eigen/CXX11/Tensor>

Functions

template<int DataLayout>
static void test_dimension_failures ()
 
template<int DataLayout>
static void test_static_dimension_failure ()
 
template<int DataLayout>
static void test_simple_concatenation ()
 
static void test_concatenation_as_lvalue ()
 
 EIGEN_DECLARE_TEST (cxx11_tensor_concatenation)
 

Function Documentation

◆ EIGEN_DECLARE_TEST()

EIGEN_DECLARE_TEST ( cxx11_tensor_concatenation  )
119  {
120  CALL_SUBTEST(test_dimension_failures<ColMajor>());
121  CALL_SUBTEST(test_dimension_failures<RowMajor>());
122  CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
123  CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
124  CALL_SUBTEST(test_simple_concatenation<ColMajor>());
125  CALL_SUBTEST(test_simple_concatenation<RowMajor>());
126  // CALL_SUBTEST(test_vectorized_concatenation());
128 }
static void test_concatenation_as_lvalue()
Definition: cxx11_tensor_concatenation.cpp:101
#define CALL_SUBTEST(FUNC)
Definition: main.h:382

References CALL_SUBTEST, and test_concatenation_as_lvalue().

◆ test_concatenation_as_lvalue()

static void test_concatenation_as_lvalue ( )
static
101  {
102  Tensor<int, 2> t1(2, 3);
103  Tensor<int, 2> t2(2, 3);
104  t1.setRandom();
105  t2.setRandom();
106 
107  Tensor<int, 2> result(4, 3);
108  result.setRandom();
109  t1.concatenate(t2, 0) = result;
110 
111  for (int i = 0; i < 2; ++i) {
112  for (int j = 0; j < 3; ++j) {
113  VERIFY_IS_EQUAL(t1(i, j), result(i, j));
114  VERIFY_IS_EQUAL(t2(i, j), result(i + 2, j));
115  }
116  }
117 }
int i
Definition: BiCGSTAB_step_by_step.cpp:9
The tensor class.
Definition: Tensor.h:68
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:367
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2

References Eigen::TensorBase< Derived, AccessLevel >::concatenate(), i, j, Eigen::TensorBase< Derived, AccessLevel >::setRandom(), and VERIFY_IS_EQUAL.

Referenced by EIGEN_DECLARE_TEST().

◆ test_dimension_failures()

template<int DataLayout>
static void test_dimension_failures ( )
static
17  {
18  Tensor<int, 3, DataLayout> left(2, 3, 1);
19  Tensor<int, 3, DataLayout> right(3, 3, 1);
20  left.setRandom();
21  right.setRandom();
22 
23  // Okay; other dimensions are equal.
24  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
25 
26  // Dimension mismatches.
27  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
28  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
29 
30  // Axis > NumDims or < 0.
31  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
32  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
33 }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorConcatenationOp< const Axis, const Derived, const OtherDerived > concatenate(const OtherDerived &other, const Axis &axis) const
Definition: TensorBase.h:1095
#define VERIFY_RAISES_ASSERT(a)
Definition: main.h:329

References Eigen::TensorBase< Derived, AccessLevel >::concatenate(), Eigen::TensorBase< Derived, AccessLevel >::setRandom(), and VERIFY_RAISES_ASSERT.

◆ test_simple_concatenation()

template<int DataLayout>
static void test_simple_concatenation ( )
static
54  {
55  Tensor<int, 3, DataLayout> left(2, 3, 1);
56  Tensor<int, 3, DataLayout> right(2, 3, 1);
57  left.setRandom();
58  right.setRandom();
59 
60  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
61  VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
62  VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
63  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
64  for (int j = 0; j < 3; ++j) {
65  for (int i = 0; i < 2; ++i) {
66  VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
67  }
68  for (int i = 2; i < 4; ++i) {
69  VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
70  }
71  }
72 
73  concatenation = left.concatenate(right, 1);
74  VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
75  VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
76  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
77  for (int i = 0; i < 2; ++i) {
78  for (int j = 0; j < 3; ++j) {
79  VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
80  }
81  for (int j = 3; j < 6; ++j) {
82  VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
83  }
84  }
85 
86  concatenation = left.concatenate(right, 2);
87  VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
88  VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
89  VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
90  for (int i = 0; i < 2; ++i) {
91  for (int j = 0; j < 3; ++j) {
92  VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
93  VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
94  }
95  }
96 }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:99

References Eigen::TensorBase< Derived, AccessLevel >::concatenate(), Eigen::Tensor< Scalar_, NumIndices_, Options_, IndexType_ >::dimension(), i, j, Eigen::TensorBase< Derived, AccessLevel >::setRandom(), and VERIFY_IS_EQUAL.

◆ test_static_dimension_failure()

template<int DataLayout>
static void test_static_dimension_failure ( )
static
36  {
37  Tensor<int, 2, DataLayout> left(2, 3);
38  Tensor<int, 3, DataLayout> right(2, 3, 1);
39  left.setRandom();
40  right.setRandom();
41 
42 #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
43  // Technically compatible, but we static assert that the inputs have same
44  // NumDims.
45  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
46 #endif
47 
48  // This can be worked around in this case.
49  Tensor<int, 3, DataLayout> concatenation = left.reshape(Tensor<int, 3>::Dimensions(2, 3, 1)).concatenate(right, 0);
50  Tensor<int, 2, DataLayout> alternative = left.concatenate(right.reshape(Tensor<int, 2>::Dimensions(2, 3)), 0);
51 }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorReshapingOp< const NewDimensions, const Derived > reshape(const NewDimensions &newDimensions) const
Definition: TensorBase.h:1106
Definition: TensorDimensions.h:161

References Eigen::TensorBase< Derived, AccessLevel >::concatenate(), Eigen::TensorBase< Derived, AccessLevel >::reshape(), and Eigen::TensorBase< Derived, AccessLevel >::setRandom().