KLUSupport.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Kyle Macfarlan <kyle.macfarlan@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_KLUSUPPORT_H
11 #define EIGEN_KLUSUPPORT_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 /* TODO extract L, extract U, compute det, etc... */
19 
36 inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[],
37  klu_common *Common, double) {
38  return klu_solve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), B,
39  Common);
40 }
41 
42 inline int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, std::complex<double> B[],
43  klu_common *Common, std::complex<double>) {
44  return klu_z_solve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs),
45  &numext::real_ref(B[0]), Common);
46 }
47 
48 inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[],
49  klu_common *Common, double) {
50  return klu_tsolve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs), B,
51  Common);
52 }
53 
54 inline int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, std::complex<double> B[],
55  klu_common *Common, std::complex<double>) {
56  return klu_z_tsolve(Symbolic, Numeric, internal::convert_index<int>(ldim), internal::convert_index<int>(nrhs),
57  &numext::real_ref(B[0]), 0, Common);
58 }
59 
60 inline klu_numeric *klu_factor(int Ap[], int Ai[], double Ax[], klu_symbolic *Symbolic, klu_common *Common, double) {
61  return klu_factor(Ap, Ai, Ax, Symbolic, Common);
62 }
63 
64 inline klu_numeric *klu_factor(int Ap[], int Ai[], std::complex<double> Ax[], klu_symbolic *Symbolic,
65  klu_common *Common, std::complex<double>) {
66  return klu_z_factor(Ap, Ai, &numext::real_ref(Ax[0]), Symbolic, Common);
67 }
68 
69 template <typename MatrixType_>
70 class KLU : public SparseSolverBase<KLU<MatrixType_> > {
71  protected:
74 
75  public:
76  using Base::_solve_impl;
77  typedef MatrixType_ MatrixType;
78  typedef typename MatrixType::Scalar Scalar;
80  typedef typename MatrixType::StorageIndex StorageIndex;
87  enum { ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime };
88 
89  public:
90  KLU() : m_dummy(0, 0), mp_matrix(m_dummy) { init(); }
91 
92  template <typename InputMatrixType>
93  explicit KLU(const InputMatrixType &matrix) : mp_matrix(matrix) {
94  init();
95  compute(matrix);
96  }
97 
98  ~KLU() {
99  if (m_symbolic) klu_free_symbolic(&m_symbolic, &m_common);
100  if (m_numeric) klu_free_numeric(&m_numeric, &m_common);
101  }
102 
103  EIGEN_CONSTEXPR inline Index rows() const EIGEN_NOEXCEPT { return mp_matrix.rows(); }
104  EIGEN_CONSTEXPR inline Index cols() const EIGEN_NOEXCEPT { return mp_matrix.cols(); }
105 
112  eigen_assert(m_isInitialized && "Decomposition is not initialized.");
113  return m_info;
114  }
115 #if 0 // not implemented yet
116  inline const LUMatrixType& matrixL() const
117  {
118  if (m_extractedDataAreDirty) extractData();
119  return m_l;
120  }
121 
122  inline const LUMatrixType& matrixU() const
123  {
124  if (m_extractedDataAreDirty) extractData();
125  return m_u;
126  }
127 
128  inline const IntColVectorType& permutationP() const
129  {
130  if (m_extractedDataAreDirty) extractData();
131  return m_p;
132  }
133 
134  inline const IntRowVectorType& permutationQ() const
135  {
136  if (m_extractedDataAreDirty) extractData();
137  return m_q;
138  }
139 #endif
144  template <typename InputMatrixType>
145  void compute(const InputMatrixType &matrix) {
146  if (m_symbolic) klu_free_symbolic(&m_symbolic, &m_common);
147  if (m_numeric) klu_free_numeric(&m_numeric, &m_common);
148  grab(matrix.derived());
150  factorize_impl();
151  }
152 
159  template <typename InputMatrixType>
160  void analyzePattern(const InputMatrixType &matrix) {
161  if (m_symbolic) klu_free_symbolic(&m_symbolic, &m_common);
162  if (m_numeric) klu_free_numeric(&m_numeric, &m_common);
163 
164  grab(matrix.derived());
165 
167  }
168 
173  inline const klu_common &kluCommon() const { return m_common; }
174 
181  inline klu_common &kluCommon() { return m_common; }
182 
189  template <typename InputMatrixType>
190  void factorize(const InputMatrixType &matrix) {
191  eigen_assert(m_analysisIsOk && "KLU: you must first call analyzePattern()");
192  if (m_numeric) klu_free_numeric(&m_numeric, &m_common);
193 
194  grab(matrix.derived());
195 
196  factorize_impl();
197  }
198 
200  template <typename BDerived, typename XDerived>
202 
203 #if 0 // not implemented yet
204  Scalar determinant() const;
205 
206  void extractData() const;
207 #endif
208 
209  protected:
210  void init() {
212  m_isInitialized = false;
213  m_numeric = 0;
214  m_symbolic = 0;
216 
217  klu_defaults(&m_common);
218  }
219 
222  m_analysisIsOk = false;
223  m_factorizationIsOk = false;
224  m_symbolic = klu_analyze(internal::convert_index<int>(mp_matrix.rows()),
225  const_cast<StorageIndex *>(mp_matrix.outerIndexPtr()),
226  const_cast<StorageIndex *>(mp_matrix.innerIndexPtr()), &m_common);
227  if (m_symbolic) {
228  m_isInitialized = true;
229  m_info = Success;
230  m_analysisIsOk = true;
232  }
233  }
234 
235  void factorize_impl() {
236  m_numeric = klu_factor(const_cast<StorageIndex *>(mp_matrix.outerIndexPtr()),
237  const_cast<StorageIndex *>(mp_matrix.innerIndexPtr()),
238  const_cast<Scalar *>(mp_matrix.valuePtr()), m_symbolic, &m_common, Scalar());
239 
241  m_factorizationIsOk = m_numeric ? 1 : 0;
243  }
244 
245  template <typename MatrixDerived>
248  internal::construct_at(&mp_matrix, A.derived());
249  }
250 
251  void grab(const KLUMatrixRef &A) {
252  if (&(A.derived()) != &mp_matrix) {
255  }
256  }
257 
258  // cached data to reduce reallocation, etc.
259 #if 0 // not implemented yet
260  mutable LUMatrixType m_l;
261  mutable LUMatrixType m_u;
262  mutable IntColVectorType m_p;
263  mutable IntRowVectorType m_q;
264 #endif
265 
268 
269  klu_numeric *m_numeric;
270  klu_symbolic *m_symbolic;
271  klu_common m_common;
276 
277  private:
278  KLU(const KLU &) {}
279 };
280 
281 #if 0 // not implemented yet
282 template<typename MatrixType>
283 void KLU<MatrixType>::extractData() const
284 {
285  if (m_extractedDataAreDirty)
286  {
287  eigen_assert(false && "KLU: extractData Not Yet Implemented");
288 
289  // get size of the data
290  int lnz, unz, rows, cols, nz_udiag;
291  umfpack_get_lunz(&lnz, &unz, &rows, &cols, &nz_udiag, m_numeric, Scalar());
292 
293  // allocate data
294  m_l.resize(rows,(std::min)(rows,cols));
295  m_l.resizeNonZeros(lnz);
296 
297  m_u.resize((std::min)(rows,cols),cols);
298  m_u.resizeNonZeros(unz);
299 
300  m_p.resize(rows);
301  m_q.resize(cols);
302 
303  // extract
304  umfpack_get_numeric(m_l.outerIndexPtr(), m_l.innerIndexPtr(), m_l.valuePtr(),
305  m_u.outerIndexPtr(), m_u.innerIndexPtr(), m_u.valuePtr(),
306  m_p.data(), m_q.data(), 0, 0, 0, m_numeric);
307 
308  m_extractedDataAreDirty = false;
309  }
310 }
311 
312 template<typename MatrixType>
314 {
315  eigen_assert(false && "KLU: extractData Not Yet Implemented");
316  return Scalar();
317 }
318 #endif
319 
320 template <typename MatrixType>
321 template <typename BDerived, typename XDerived>
323  Index rhsCols = b.cols();
324  EIGEN_STATIC_ASSERT((XDerived::Flags & RowMajorBit) == 0, THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
325  eigen_assert(m_factorizationIsOk &&
326  "The decomposition is not in a valid state for solving, you must first call either compute() or "
327  "analyzePattern()/factorize()");
328 
329  x = b;
330  int info = klu_solve(m_symbolic, m_numeric, b.rows(), rhsCols, x.const_cast_derived().data(),
331  const_cast<klu_common *>(&m_common), Scalar());
332 
333  m_info = info != 0 ? Success : NumericalIssue;
334  return true;
335 }
336 
337 } // end namespace Eigen
338 
339 #endif // EIGEN_KLUSUPPORT_H
#define EIGEN_NOEXCEPT
Definition: Macros.h:1267
#define EIGEN_CONSTEXPR
Definition: Macros.h:758
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STATIC_ASSERT(X, MSG)
Definition: StaticAssert.h:26
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:46
Definition: KLUSupport.h:70
MatrixType_ MatrixType
Definition: KLUSupport.h:77
klu_symbolic * m_symbolic
Definition: KLUSupport.h:270
EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT
Definition: KLUSupport.h:103
const klu_common & kluCommon() const
Definition: KLUSupport.h:173
void compute(const InputMatrixType &matrix)
Definition: KLUSupport.h:145
KLUMatrixRef mp_matrix
Definition: KLUSupport.h:267
void factorize(const InputMatrixType &matrix)
Definition: KLUSupport.h:190
void factorize_impl()
Definition: KLUSupport.h:235
int m_analysisIsOk
Definition: KLUSupport.h:274
SparseSolverBase< KLU< MatrixType_ > > Base
Definition: KLUSupport.h:72
~KLU()
Definition: KLUSupport.h:98
void grab(const EigenBase< MatrixDerived > &A)
Definition: KLUSupport.h:246
Ref< const KLUMatrixType, StandardCompressedFormat > KLUMatrixRef
Definition: KLUSupport.h:86
KLU()
Definition: KLUSupport.h:90
KLU(const KLU &)
Definition: KLUSupport.h:278
EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
Definition: KLUSupport.h:104
klu_numeric * m_numeric
Definition: KLUSupport.h:269
MatrixType::StorageIndex StorageIndex
Definition: KLUSupport.h:80
ComputationInfo info() const
Reports whether previous computation was successful.
Definition: KLUSupport.h:111
KLUMatrixType m_dummy
Definition: KLUSupport.h:266
KLU(const InputMatrixType &matrix)
Definition: KLUSupport.h:93
Matrix< Scalar, Dynamic, 1 > Vector
Definition: KLUSupport.h:81
bool _solve_impl(const MatrixBase< BDerived > &b, MatrixBase< XDerived > &x) const
Definition: KLUSupport.h:322
SparseMatrix< Scalar, ColMajor, int > KLUMatrixType
Definition: KLUSupport.h:85
MatrixType::RealScalar RealScalar
Definition: KLUSupport.h:79
@ ColsAtCompileTime
Definition: KLUSupport.h:87
@ MaxColsAtCompileTime
Definition: KLUSupport.h:87
bool m_extractedDataAreDirty
Definition: KLUSupport.h:275
void analyzePattern_impl()
Definition: KLUSupport.h:220
int m_factorizationIsOk
Definition: KLUSupport.h:273
Matrix< int, 1, MatrixType::ColsAtCompileTime > IntRowVectorType
Definition: KLUSupport.h:82
ComputationInfo m_info
Definition: KLUSupport.h:272
klu_common m_common
Definition: KLUSupport.h:271
klu_common & kluCommon()
Definition: KLUSupport.h:181
Matrix< int, MatrixType::RowsAtCompileTime, 1 > IntColVectorType
Definition: KLUSupport.h:83
MatrixType::Scalar Scalar
Definition: KLUSupport.h:78
SparseMatrix< Scalar > LUMatrixType
Definition: KLUSupport.h:84
void grab(const KLUMatrixRef &A)
Definition: KLUSupport.h:251
void init()
Definition: KLUSupport.h:210
void analyzePattern(const InputMatrixType &matrix)
Definition: KLUSupport.h:160
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:52
A base class for sparse solvers.
Definition: SparseSolverBase.h:67
void _solve_impl(const SparseMatrixBase< Rhs > &b, SparseMatrixBase< Dest > &dest) const
Definition: SparseSolverBase.h:104
bool m_isInitialized
Definition: SparseSolverBase.h:110
Definition: matrices.h:74
Eigen::Map< Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor >, 0, Eigen::OuterStride<> > matrix(T *data, int rows, int cols, int stride)
Definition: common.h:85
#define min(a, b)
Definition: datatypes.h:22
void determinant(const MatrixType &m)
Definition: determinant.cpp:15
ComputationInfo
Definition: Constants.h:438
@ NumericalIssue
Definition: Constants.h:442
@ InvalidInput
Definition: Constants.h:447
@ Success
Definition: Constants.h:440
const unsigned int RowMajorBit
Definition: Constants.h:70
int info
Definition: level2_cplx_impl.h:39
EIGEN_DEVICE_FUNC T * construct_at(T *p, Args &&... args)
Definition: Memory.h:1321
EIGEN_DEVICE_FUNC void destroy_at(T *p)
Definition: Memory.h:1335
EIGEN_DEVICE_FUNC internal::add_const_on_value_type_t< EIGEN_MATHFUNC_RETVAL(real_ref, Scalar)> real_ref(const Scalar &x)
Definition: MathFunctions.h:1051
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
int umfpack_get_numeric(int Lp[], int Lj[], double Lx[], int Up[], int Ui[], double Ux[], int P[], int Q[], double Dx[], int *do_recip, double Rs[], void *Numeric)
Definition: UmfPackSupport.h:232
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
klu_numeric * klu_factor(int Ap[], int Ai[], double Ax[], klu_symbolic *Symbolic, klu_common *Common, double)
Definition: KLUSupport.h:60
int klu_solve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[], klu_common *Common, double)
A sparse LU factorization and solver based on KLU.
Definition: KLUSupport.h:36
int klu_tsolve(klu_symbolic *Symbolic, klu_numeric *Numeric, Index ldim, Index nrhs, double B[], klu_common *Common, double)
Definition: KLUSupport.h:48
int umfpack_get_lunz(int *lnz, int *unz, int *n_row, int *n_col, int *nz_udiag, void *Numeric, double)
Definition: UmfPackSupport.h:211
list x
Definition: plotDoE.py:28
Definition: EigenBase.h:33