SparsePermutation.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2012 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSE_PERMUTATION_H
11 #define EIGEN_SPARSE_PERMUTATION_H
12 
13 // This file implements sparse * permutation products
14 
15 // IWYU pragma: private
16 #include "./InternalHeaderCheck.h"
17 
18 namespace Eigen {
19 
20 namespace internal {
21 
22 template <typename ExpressionType, typename PlainObjectType,
24 struct XprHelper {
25  XprHelper(const ExpressionType& xpr) : m_xpr(xpr) {}
26  inline const PlainObjectType& xpr() const { return m_xpr; }
27  // this is a new PlainObjectType initialized by xpr
28  const PlainObjectType m_xpr;
29 };
30 template <typename ExpressionType, typename PlainObjectType>
31 struct XprHelper<ExpressionType, PlainObjectType, false> {
32  XprHelper(const ExpressionType& xpr) : m_xpr(xpr) {}
33  inline const PlainObjectType& xpr() const { return m_xpr; }
34  // this is a reference to xpr
35  const PlainObjectType& m_xpr;
36 };
37 
38 template <typename PermDerived, bool NeedInverseEval>
39 struct PermHelper {
40  using IndicesType = typename PermDerived::IndicesType;
43  PermHelper(const PermDerived& perm) : m_perm(perm.inverse()) {}
44  inline const type& perm() const { return m_perm; }
45  // this is a new PermutationMatrix initialized by perm.inverse()
46  const type m_perm;
47 };
48 template <typename PermDerived>
49 struct PermHelper<PermDerived, false> {
50  using type = PermDerived;
51  PermHelper(const PermDerived& perm) : m_perm(perm) {}
52  inline const type& perm() const { return m_perm; }
53  // this is a reference to perm
54  const type& m_perm;
55 };
56 
57 template <typename ExpressionType, int Side, bool Transposed>
58 struct permutation_matrix_product<ExpressionType, Side, Transposed, SparseShape> {
61 
63  using StorageIndex = typename MatrixTypeCleaned::StorageIndex;
64 
65  // the actual "return type" is `Dest`. this is a temporary type
68 
69  static constexpr bool NeedOuterPermutation = ExpressionType::IsRowMajor ? Side == OnTheLeft : Side == OnTheRight;
70  static constexpr bool NeedInversePermutation = Transposed ? Side == OnTheLeft : Side == OnTheRight;
71 
72  template <typename Dest, typename PermutationType>
73  static inline void permute_outer(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
74  // if ExpressionType is not ReturnType, evaluate `xpr` (allocation)
75  // otherwise, just reference `xpr`
76  // TODO: handle trivial expressions such as CwiseBinaryOp without temporary
77  const TmpHelper tmpHelper(xpr);
78  const ReturnType& tmp = tmpHelper.xpr();
79 
80  ReturnType result(tmp.rows(), tmp.cols());
81 
82  for (Index j = 0; j < tmp.outerSize(); j++) {
83  Index jp = perm.indices().coeff(j);
84  Index jsrc = NeedInversePermutation ? jp : j;
85  Index jdst = NeedInversePermutation ? j : jp;
86  Index begin = tmp.outerIndexPtr()[jsrc];
87  Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
88  result.outerIndexPtr()[jdst + 1] += end - begin;
89  }
90 
91  std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
92  result.resizeNonZeros(result.nonZeros());
93 
94  for (Index j = 0; j < tmp.outerSize(); j++) {
95  Index jp = perm.indices().coeff(j);
96  Index jsrc = NeedInversePermutation ? jp : j;
97  Index jdst = NeedInversePermutation ? j : jp;
98  Index begin = tmp.outerIndexPtr()[jsrc];
99  Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[jsrc + 1] : begin + tmp.innerNonZeroPtr()[jsrc];
100  Index target = result.outerIndexPtr()[jdst];
101  smart_copy(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() + end, result.innerIndexPtr() + target);
102  smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() + end, result.valuePtr() + target);
103  }
104  dst = std::move(result);
105  }
106 
107  template <typename Dest, typename PermutationType>
108  static inline void permute_inner(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
110  using InnerPermType = typename InnerPermHelper::type;
111 
112  // if ExpressionType is not ReturnType, evaluate `xpr` (allocation)
113  // otherwise, just reference `xpr`
114  // TODO: handle trivial expressions such as CwiseBinaryOp without temporary
115  const TmpHelper tmpHelper(xpr);
116  const ReturnType& tmp = tmpHelper.xpr();
117 
118  // if inverse permutation of inner indices is requested, calculate perm.inverse() (allocation)
119  // otherwise, just reference `perm`
120  const InnerPermHelper permHelper(perm);
121  const InnerPermType& innerPerm = permHelper.perm();
122 
123  ReturnType result(tmp.rows(), tmp.cols());
124 
125  for (Index j = 0; j < tmp.outerSize(); j++) {
126  Index begin = tmp.outerIndexPtr()[j];
127  Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
128  result.outerIndexPtr()[j + 1] += end - begin;
129  }
130 
131  std::partial_sum(result.outerIndexPtr(), result.outerIndexPtr() + result.outerSize() + 1, result.outerIndexPtr());
132  result.resizeNonZeros(result.nonZeros());
133 
134  for (Index j = 0; j < tmp.outerSize(); j++) {
135  Index begin = tmp.outerIndexPtr()[j];
136  Index end = tmp.isCompressed() ? tmp.outerIndexPtr()[j + 1] : begin + tmp.innerNonZeroPtr()[j];
137  Index target = result.outerIndexPtr()[j];
138  std::transform(tmp.innerIndexPtr() + begin, tmp.innerIndexPtr() + end, result.innerIndexPtr() + target,
139  [&innerPerm](StorageIndex i) { return innerPerm.indices().coeff(i); });
140  smart_copy(tmp.valuePtr() + begin, tmp.valuePtr() + end, result.valuePtr() + target);
141  }
142  // the inner indices were permuted, and must be sorted
143  result.sortInnerIndices();
144  dst = std::move(result);
145  }
146 
147  template <typename Dest, typename PermutationType, bool DoOuter = NeedOuterPermutation,
148  std::enable_if_t<DoOuter, int> = 0>
149  static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
150  permute_outer(dst, perm, xpr);
151  }
152 
153  template <typename Dest, typename PermutationType, bool DoOuter = NeedOuterPermutation,
154  std::enable_if_t<!DoOuter, int> = 0>
155  static inline void run(Dest& dst, const PermutationType& perm, const ExpressionType& xpr) {
156  permute_inner(dst, perm, xpr);
157  }
158 };
159 
160 } // namespace internal
161 
162 namespace internal {
163 
164 template <int ProductTag>
166  typedef Sparse ret;
167 };
168 template <int ProductTag>
170  typedef Sparse ret;
171 };
172 
173 // TODO, the following two overloads are only needed to define the right temporary type through
174 // typename traits<permutation_sparse_matrix_product<Rhs,Lhs,OnTheRight,false> >::ReturnType
175 // whereas it should be correctly handled by traits<Product<> >::PlainObject
176 
177 template <typename Lhs, typename Rhs, int ProductTag>
179  : public evaluator<typename permutation_matrix_product<Rhs, OnTheLeft, false, SparseShape>::ReturnType> {
183 
184  enum { Flags = Base::Flags | EvalBeforeNestingBit };
185 
186  explicit product_evaluator(const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
187  internal::construct_at<Base>(this, m_result);
189  }
190 
191  protected:
193 };
194 
195 template <typename Lhs, typename Rhs, int ProductTag>
197  : public evaluator<typename permutation_matrix_product<Lhs, OnTheRight, false, SparseShape>::ReturnType> {
201 
202  enum { Flags = Base::Flags | EvalBeforeNestingBit };
203 
204  explicit product_evaluator(const XprType& xpr) : m_result(xpr.rows(), xpr.cols()) {
205  ::new (static_cast<Base*>(this)) Base(m_result);
207  }
208 
209  protected:
211 };
212 
213 } // end namespace internal
214 
217 template <typename SparseDerived, typename PermDerived>
221 }
222 
225 template <typename SparseDerived, typename PermDerived>
229 }
230 
233 template <typename SparseDerived, typename PermutationType>
237 }
238 
241 template <typename SparseDerived, typename PermutationType>
244  return Product<Inverse<PermutationType>, SparseDerived, AliasFreeProduct>(tperm.derived(), matrix.derived());
245 }
246 
247 } // end namespace Eigen
248 
249 #endif // EIGEN_SPARSE_SELFADJOINTVIEW_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
Side
Definition: Side.h:9
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
SCALAR Scalar
Definition: bench_gemm.cpp:45
Base class for permutations.
Definition: PermutationMatrix.h:49
constexpr EIGEN_DEVICE_FUNC Derived & derived()
Definition: EigenBase.h:49
Permutation matrix.
Definition: PermutationMatrix.h:280
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
Definition: PlainObjectBase.h:192
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT
Definition: PlainObjectBase.h:191
Expression of the product of two arbitrary matrices or vectors.
Definition: Product.h:202
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned & lhs() const
Definition: Product.h:230
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned & rhs() const
Definition: Product.h:231
void sortInnerIndices(Index begin, Index end)
Definition: SparseCompressedBase.h:144
Base class of any sparse matrices or sparse expressions.
Definition: SparseMatrixBase.h:30
A versatible sparse matrix representation.
Definition: SparseMatrix.h:121
Index nonZeros() const
Definition: SparseCompressedBase.h:64
Index outerSize() const
Definition: SparseMatrix.h:166
const Scalar * valuePtr() const
Definition: SparseMatrix.h:171
void resizeNonZeros(Index size)
Definition: SparseMatrix.h:754
const StorageIndex * outerIndexPtr() const
Definition: SparseMatrix.h:189
const StorageIndex * innerIndexPtr() const
Definition: SparseMatrix.h:180
Eigen::Map< Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor >, 0, Eigen::OuterStride<> > matrix(T *data, int rows, int cols, int stride)
Definition: common.h:85
EIGEN_DONT_INLINE void transform(const Transformation &t, Data &data)
Definition: geometry.cpp:25
static constexpr lastp1_t end
Definition: IndexedViewHelper.h:79
@ AliasFreeProduct
Definition: Constants.h:505
@ OnTheLeft
Definition: Constants.h:331
@ OnTheRight
Definition: Constants.h:333
const unsigned int EvalBeforeNestingBit
Definition: Constants.h:74
void inverse(const MatrixType &m)
Definition: inverse.cpp:64
Eigen::Matrix< Scalar, Dynamic, Dynamic, ColMajor > tmp
Definition: level3_impl.h:365
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
EIGEN_DEVICE_FUNC void smart_copy(const T *start, const T *end, T *target)
Definition: Memory.h:569
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
EIGEN_DEVICE_FUNC const Product< MatrixDerived, PermutationDerived, AliasFreeProduct > operator*(const MatrixBase< MatrixDerived > &matrix, const PermutationBase< PermutationDerived > &permutation)
Definition: PermutationMatrix.h:471
Extend namespace for flags.
Definition: fsi_chan_precond_driver.cc:56
type
Definition: compute_granudrum_aor.py:141
Definition: Eigen_Colamd.h:49
constexpr EIGEN_DEVICE_FUNC Derived & derived()
Definition: EigenBase.h:49
Definition: Constants.h:564
Definition: Constants.h:528
Definition: Constants.h:570
Definition: Constants.h:522
PermHelper(const PermDerived &perm)
Definition: SparsePermutation.h:51
const type & m_perm
Definition: SparsePermutation.h:54
const type & perm() const
Definition: SparsePermutation.h:52
PermDerived type
Definition: SparsePermutation.h:50
Definition: SparsePermutation.h:39
PermHelper(const PermDerived &perm)
Definition: SparsePermutation.h:43
const type & perm() const
Definition: SparsePermutation.h:44
typename PermDerived::IndicesType IndicesType
Definition: SparsePermutation.h:40
typename IndicesType::Scalar PermutationIndex
Definition: SparsePermutation.h:41
const type m_perm
Definition: SparsePermutation.h:46
XprHelper(const ExpressionType &xpr)
Definition: SparsePermutation.h:32
const PlainObjectType & m_xpr
Definition: SparsePermutation.h:35
const PlainObjectType & xpr() const
Definition: SparsePermutation.h:33
Definition: SparsePermutation.h:24
const PlainObjectType & xpr() const
Definition: SparsePermutation.h:26
const PlainObjectType m_xpr
Definition: SparsePermutation.h:28
XprHelper(const ExpressionType &xpr)
Definition: SparsePermutation.h:25
Definition: CoreEvaluators.h:104
Definition: ProductEvaluators.h:78
@ value
Definition: Meta.h:206
std::conditional_t< Evaluate, PlainObject, typename ref_selector< T >::type > type
Definition: XprHelper.h:549
static void permute_inner(Dest &dst, const PermutationType &perm, const ExpressionType &xpr)
Definition: SparsePermutation.h:108
typename MatrixTypeCleaned::Scalar Scalar
Definition: SparsePermutation.h:62
remove_all_t< MatrixType > MatrixTypeCleaned
Definition: SparsePermutation.h:60
typename MatrixTypeCleaned::StorageIndex StorageIndex
Definition: SparsePermutation.h:63
static void run(Dest &dst, const PermutationType &perm, const ExpressionType &xpr)
Definition: SparsePermutation.h:149
static void permute_outer(Dest &dst, const PermutationType &perm, const ExpressionType &xpr)
Definition: SparsePermutation.h:73
typename nested_eval< ExpressionType, 1 >::type MatrixType
Definition: SparsePermutation.h:59
Definition: ProductEvaluators.h:965
permutation_matrix_product< Lhs, OnTheRight, false, SparseShape >::ReturnType PlainObject
Definition: SparsePermutation.h:199
permutation_matrix_product< Rhs, OnTheLeft, false, SparseShape >::ReturnType PlainObject
Definition: SparsePermutation.h:181
Definition: ForwardDeclarations.h:221
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2