KroneckerTensorProduct.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011 Kolja Brix <brix@igpm.rwth-aachen.de>
5 // Copyright (C) 2011 Andreas Platen <andiplaten@gmx.de>
6 // Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef KRONECKER_TENSOR_PRODUCT_H
13 #define KRONECKER_TENSOR_PRODUCT_H
14 
15 // IWYU pragma: private
16 #include "./InternalHeaderCheck.h"
17 
18 namespace Eigen {
19 
27 template <typename Derived>
28 class KroneckerProductBase : public ReturnByValue<Derived> {
29  private:
31  typedef typename Traits::Scalar Scalar;
32 
33  protected:
34  typedef typename Traits::Lhs Lhs;
35  typedef typename Traits::Rhs Rhs;
36 
37  public:
39  KroneckerProductBase(const Lhs& A, const Rhs& B) : m_A(A), m_B(B) {}
40 
41  inline Index rows() const { return m_A.rows() * m_B.rows(); }
42  inline Index cols() const { return m_A.cols() * m_B.cols(); }
43 
49  return m_A.coeff(row / m_B.rows(), col / m_B.cols()) * m_B.coeff(row % m_B.rows(), col % m_B.cols());
50  }
51 
56  Scalar coeff(Index i) const {
58  return m_A.coeff(i / m_A.size()) * m_B.coeff(i % m_A.size());
59  }
60 
61  protected:
62  typename Lhs::Nested m_A;
63  typename Rhs::Nested m_B;
64 };
65 
78 template <typename Lhs, typename Rhs>
79 class KroneckerProduct : public KroneckerProductBase<KroneckerProduct<Lhs, Rhs> > {
80  private:
82  using Base::m_A;
83  using Base::m_B;
84 
85  public:
87  KroneckerProduct(const Lhs& A, const Rhs& B) : Base(A, B) {}
88 
90  template <typename Dest>
91  void evalTo(Dest& dst) const;
92 };
93 
109 template <typename Lhs, typename Rhs>
110 class KroneckerProductSparse : public KroneckerProductBase<KroneckerProductSparse<Lhs, Rhs> > {
111  private:
113  using Base::m_A;
114  using Base::m_B;
115 
116  public:
118  KroneckerProductSparse(const Lhs& A, const Rhs& B) : Base(A, B) {}
119 
121  template <typename Dest>
122  void evalTo(Dest& dst) const;
123 };
124 
125 template <typename Lhs, typename Rhs>
126 template <typename Dest>
127 void KroneckerProduct<Lhs, Rhs>::evalTo(Dest& dst) const {
128  const int BlockRows = Rhs::RowsAtCompileTime, BlockCols = Rhs::ColsAtCompileTime;
129  const Index Br = m_B.rows(), Bc = m_B.cols();
130  for (Index i = 0; i < m_A.rows(); ++i)
131  for (Index j = 0; j < m_A.cols(); ++j)
132  Block<Dest, BlockRows, BlockCols>(dst, i * Br, j * Bc, Br, Bc) = m_A.coeff(i, j) * m_B;
133 }
134 
135 template <typename Lhs, typename Rhs>
136 template <typename Dest>
138  Index Br = m_B.rows(), Bc = m_B.cols();
139  dst.resize(this->rows(), this->cols());
140  dst.resizeNonZeros(0);
141 
142  // 1 - evaluate the operands if needed:
143  typedef typename internal::nested_eval<Lhs, Dynamic>::type Lhs1;
144  typedef internal::remove_all_t<Lhs1> Lhs1Cleaned;
145  const Lhs1 lhs1(m_A);
146  typedef typename internal::nested_eval<Rhs, Dynamic>::type Rhs1;
147  typedef internal::remove_all_t<Rhs1> Rhs1Cleaned;
148  const Rhs1 rhs1(m_B);
149 
150  // 2 - construct respective iterators
151  typedef Eigen::InnerIterator<Lhs1Cleaned> LhsInnerIterator;
152  typedef Eigen::InnerIterator<Rhs1Cleaned> RhsInnerIterator;
153 
154  // compute number of non-zeros per innervectors of dst
155  {
156  // TODO VectorXi is not necessarily big enough!
157  VectorXi nnzA = VectorXi::Zero(Dest::IsRowMajor ? m_A.rows() : m_A.cols());
158  for (Index kA = 0; kA < m_A.outerSize(); ++kA)
159  for (LhsInnerIterator itA(lhs1, kA); itA; ++itA) nnzA(Dest::IsRowMajor ? itA.row() : itA.col())++;
160 
161  VectorXi nnzB = VectorXi::Zero(Dest::IsRowMajor ? m_B.rows() : m_B.cols());
162  for (Index kB = 0; kB < m_B.outerSize(); ++kB)
163  for (RhsInnerIterator itB(rhs1, kB); itB; ++itB) nnzB(Dest::IsRowMajor ? itB.row() : itB.col())++;
164 
165  Matrix<int, Dynamic, Dynamic, ColMajor> nnzAB = nnzB * nnzA.transpose();
166  dst.reserve(VectorXi::Map(nnzAB.data(), nnzAB.size()));
167  }
168 
169  for (Index kA = 0; kA < m_A.outerSize(); ++kA) {
170  for (Index kB = 0; kB < m_B.outerSize(); ++kB) {
171  for (LhsInnerIterator itA(lhs1, kA); itA; ++itA) {
172  for (RhsInnerIterator itB(rhs1, kB); itB; ++itB) {
173  Index i = itA.row() * Br + itB.row(), j = itA.col() * Bc + itB.col();
174  dst.insert(i, j) = itA.value() * itB.value();
175  }
176  }
177  }
178  }
179 }
180 
181 namespace internal {
182 
183 template <typename Lhs_, typename Rhs_>
184 struct traits<KroneckerProduct<Lhs_, Rhs_> > {
189 
190  enum {
195  };
196 
198 };
199 
200 template <typename Lhs_, typename Rhs_>
201 struct traits<KroneckerProductSparse<Lhs_, Rhs_> > {
210 
211  enum {
212  LhsFlags = Lhs::Flags,
213  RhsFlags = Rhs::Flags,
214 
219 
220  EvalToRowMajor = (int(LhsFlags) & int(RhsFlags) & RowMajorBit),
221  RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
222 
223  Flags = ((int(LhsFlags) | int(RhsFlags)) & HereditaryBits & RemovedBits) | EvalBeforeNestingBit,
224  CoeffReadCost = HugeCost
225  };
226 
228 };
229 
230 } // end namespace internal
231 
251 template <typename A, typename B>
253  return KroneckerProduct<A, B>(a.derived(), b.derived());
254 }
255 
277 template <typename A, typename B>
279  return KroneckerProductSparse<A, B>(a.derived(), b.derived());
280 }
281 
282 } // end namespace Eigen
283 
284 #endif // KRONECKER_TENSOR_PRODUCT_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
m col(1)
m row(1)
#define EIGEN_STATIC_ASSERT_VECTOR_ONLY(TYPE)
Definition: StaticAssert.h:36
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
Expression of a fixed-size or dynamic-size block.
Definition: Block.h:110
An InnerIterator allows to loop over the element of any matrix expression.
Definition: CoreIterators.h:37
The base class of dense and sparse Kronecker product.
Definition: KroneckerTensorProduct.h:28
Scalar coeff(Index row, Index col) const
Definition: KroneckerTensorProduct.h:48
KroneckerProductBase(const Lhs &A, const Rhs &B)
Constructor.
Definition: KroneckerTensorProduct.h:39
Traits::Rhs Rhs
Definition: KroneckerTensorProduct.h:35
Rhs::Nested m_B
Definition: KroneckerTensorProduct.h:63
Lhs::Nested m_A
Definition: KroneckerTensorProduct.h:62
Scalar coeff(Index i) const
Definition: KroneckerTensorProduct.h:56
Traits::Scalar Scalar
Definition: KroneckerTensorProduct.h:31
Traits::Lhs Lhs
Definition: KroneckerTensorProduct.h:34
internal::traits< Derived > Traits
Definition: KroneckerTensorProduct.h:30
Index cols() const
Definition: KroneckerTensorProduct.h:42
Index rows() const
Definition: KroneckerTensorProduct.h:41
Kronecker tensor product helper class for sparse matrices.
Definition: KroneckerTensorProduct.h:110
KroneckerProductBase< KroneckerProductSparse > Base
Definition: KroneckerTensorProduct.h:112
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
Definition: KroneckerTensorProduct.h:137
KroneckerProductSparse(const Lhs &A, const Rhs &B)
Constructor.
Definition: KroneckerTensorProduct.h:118
Kronecker tensor product helper class for dense matrices.
Definition: KroneckerTensorProduct.h:79
KroneckerProduct(const Lhs &A, const Rhs &B)
Constructor.
Definition: KroneckerTensorProduct.h:87
void evalTo(Dest &dst) const
Evaluate the Kronecker tensor product.
Definition: KroneckerTensorProduct.h:127
KroneckerProductBase< KroneckerProduct > Base
Definition: KroneckerTensorProduct.h:81
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:52
The matrix class, also used for vectors and row-vectors.
Definition: Eigen/Eigen/src/Core/Matrix.h:186
constexpr EIGEN_DEVICE_FUNC const Scalar * data() const
Definition: PlainObjectBase.h:273
Definition: ReturnByValue.h:50
internal::dense_xpr_base< ReturnByValue >::type Base
Definition: ReturnByValue.h:54
A versatible sparse matrix representation.
Definition: SparseMatrix.h:121
Definition: matrices.h:74
const unsigned int EvalBeforeNestingBit
Definition: Constants.h:74
const unsigned int RowMajorBit
Definition: Constants.h:70
return int(ret)+1
Eigen::DenseIndex ret
Definition: level1_cplx_impl.h:43
const Scalar * a
Definition: level2_cplx_impl.h:32
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
constexpr int size_at_compile_time(int rows, int cols)
Definition: XprHelper.h:373
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
const unsigned int HereditaryBits
Definition: Constants.h:198
const int HugeCost
Definition: Constants.h:48
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
KroneckerProduct< A, B > kroneckerProduct(const MatrixBase< A > &a, const MatrixBase< B > &b)
Definition: KroneckerTensorProduct.h:252
Extend namespace for flags.
Definition: fsi_chan_precond_driver.cc:56
Definition: Eigen_Colamd.h:49
double Zero
Definition: pseudosolid_node_update_elements.cc:35
Definition: EigenBase.h:33
Definition: Constants.h:534
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:1043
std::conditional_t< Evaluate, PlainObject, typename ref_selector< T >::type > type
Definition: XprHelper.h:549
std::conditional_t<(sizeof(I1)< sizeof(I2)), I2, I1 > type
Definition: XprHelper.h:146
Template functor to compute the product of two scalars.
Definition: BinaryFunctors.h:73
MatrixXpr XprKind
Definition: KroneckerTensorProduct.h:202
ScalarBinaryOpTraits< typename Lhs::Scalar, typename Rhs::Scalar >::ReturnType Scalar
Definition: KroneckerTensorProduct.h:205
remove_all_t< Lhs_ > Lhs
Definition: KroneckerTensorProduct.h:203
promote_index_type< typename Lhs::StorageIndex, typename Rhs::StorageIndex >::type StorageIndex
Definition: KroneckerTensorProduct.h:209
SparseMatrix< Scalar, 0, StorageIndex > ReturnType
Definition: KroneckerTensorProduct.h:227
remove_all_t< Rhs_ > Rhs
Definition: KroneckerTensorProduct.h:204
cwise_promote_storage_type< typename traits< Lhs >::StorageKind, typename traits< Rhs >::StorageKind, scalar_product_op< typename Lhs::Scalar, typename Rhs::Scalar > >::ret StorageKind
Definition: KroneckerTensorProduct.h:208
ScalarBinaryOpTraits< typename Lhs::Scalar, typename Rhs::Scalar >::ReturnType Scalar
Definition: KroneckerTensorProduct.h:187
Matrix< Scalar, Rows, Cols > ReturnType
Definition: KroneckerTensorProduct.h:197
remove_all_t< Rhs_ > Rhs
Definition: KroneckerTensorProduct.h:186
remove_all_t< Lhs_ > Lhs
Definition: KroneckerTensorProduct.h:185
promote_index_type< typename Lhs::StorageIndex, typename Rhs::StorageIndex >::type StorageIndex
Definition: KroneckerTensorProduct.h:188
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2