IterativeSolverBase.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_ITERATIVE_SOLVER_BASE_H
11 #define EIGEN_ITERATIVE_SOLVER_BASE_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 template <typename MatrixType>
22  private:
23  template <typename T0>
24  struct any_conversion {
25  template <typename T>
26  any_conversion(const volatile T&);
27  template <typename T>
29  };
30  struct yes {
31  int a[1];
32  };
33  struct no {
34  int a[2];
35  };
36 
37  template <typename T>
38  static yes test(const Ref<const T>&, int);
39  template <typename T>
40  static no test(any_conversion<T>, ...);
41 
42  public:
44  enum { value = sizeof(test<MatrixType>(ms_from, 0)) == sizeof(yes) };
45 };
46 
47 template <typename MatrixType>
50 };
51 
54 
55 // We have an explicit matrix at hand, compatible with Ref<>
56 template <typename MatrixType>
58  public:
60  template <int UpLo>
61  struct ConstSelfAdjointViewReturnType {
62  typedef typename ActualMatrixType::template ConstSelfAdjointViewReturnType<UpLo>::Type Type;
63  };
64 
65  enum { MatrixFree = false };
66 
67  generic_matrix_wrapper() : m_dummy(0, 0), m_matrix(m_dummy) {}
68 
69  template <typename InputType>
70  generic_matrix_wrapper(const InputType& mat) : m_matrix(mat) {}
71 
72  const ActualMatrixType& matrix() const { return m_matrix; }
73 
74  template <typename MatrixDerived>
76  internal::destroy_at(&m_matrix);
77  internal::construct_at(&m_matrix, mat.derived());
78  }
79 
81  if (&(mat.derived()) != &m_matrix) {
82  internal::destroy_at(&m_matrix);
83  internal::construct_at(&m_matrix, mat);
84  }
85  }
86 
87  protected:
88  MatrixType m_dummy; // used to default initialize the Ref<> object
90 };
91 
92 // MatrixType is not compatible with Ref<> -> matrix-free wrapper
93 template <typename MatrixType>
95  public:
97  template <int UpLo>
98  struct ConstSelfAdjointViewReturnType {
100  };
101 
102  enum { MatrixFree = true };
103 
104  generic_matrix_wrapper() : mp_matrix(0) {}
105 
106  generic_matrix_wrapper(const MatrixType& mat) : mp_matrix(&mat) {}
107 
108  const ActualMatrixType& matrix() const { return *mp_matrix; }
109 
110  void grab(const MatrixType& mat) { mp_matrix = &mat; }
111 
112  protected:
114 };
115 
116 } // namespace internal
117 
123 template <typename Derived>
124 class IterativeSolverBase : public SparseSolverBase<Derived> {
125  protected:
127  using Base::m_isInitialized;
128 
129  public:
132  typedef typename MatrixType::Scalar Scalar;
133  typedef typename MatrixType::StorageIndex StorageIndex;
135 
136  enum { ColsAtCompileTime = MatrixType::ColsAtCompileTime, MaxColsAtCompileTime = MatrixType::MaxColsAtCompileTime };
137 
138  public:
139  using Base::derived;
140 
143 
154  template <typename MatrixDerived>
156  init();
157  compute(matrix());
158  }
159 
161 
163 
169  template <typename MatrixDerived>
171  grab(A.derived());
172  m_preconditioner.analyzePattern(matrix());
173  m_isInitialized = true;
174  m_analysisIsOk = true;
175  m_info = m_preconditioner.info();
176  return derived();
177  }
178 
189  template <typename MatrixDerived>
191  eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
192  grab(A.derived());
193  m_preconditioner.factorize(matrix());
194  m_factorizationIsOk = true;
195  m_info = m_preconditioner.info();
196  return derived();
197  }
198 
209  template <typename MatrixDerived>
210  Derived& compute(const EigenBase<MatrixDerived>& A) {
211  grab(A.derived());
212  m_preconditioner.compute(matrix());
213  m_isInitialized = true;
214  m_analysisIsOk = true;
215  m_factorizationIsOk = true;
216  m_info = m_preconditioner.info();
217  return derived();
218  }
219 
221  EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT { return matrix().rows(); }
222 
224  EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT { return matrix().cols(); }
225 
229  RealScalar tolerance() const { return m_tolerance; }
230 
236  Derived& setTolerance(const RealScalar& tolerance) {
238  return derived();
239  }
240 
243 
245  const Preconditioner& preconditioner() const { return m_preconditioner; }
246 
251  Index maxIterations() const { return (m_maxIterations < 0) ? 2 * matrix().cols() : m_maxIterations; }
252 
256  Derived& setMaxIterations(Index maxIters) {
257  m_maxIterations = maxIters;
258  return derived();
259  }
260 
262  Index iterations() const {
263  eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
264  return m_iterations;
265  }
266 
270  RealScalar error() const {
271  eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
272  return m_error;
273  }
274 
280  template <typename Rhs, typename Guess>
281  inline const SolveWithGuess<Derived, Rhs, Guess> solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const {
282  eigen_assert(m_isInitialized && "Solver is not initialized.");
283  eigen_assert(derived().rows() == b.rows() && "solve(): invalid number of rows of the right hand side matrix b");
284  return SolveWithGuess<Derived, Rhs, Guess>(derived(), b.derived(), x0);
285  }
286 
289  eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
290  return m_info;
291  }
292 
294  template <typename Rhs, typename DestDerived>
296  eigen_assert(rows() == b.rows());
297 
298  Index rhsCols = b.cols();
299  Index size = b.rows();
300  DestDerived& dest(aDest.derived());
301  typedef typename DestDerived::Scalar DestScalar;
304  // We do not directly fill dest because sparse expressions have to be free of aliasing issue.
305  // For non square least-square problems, b and dest might not have the same size whereas they might alias
306  // each-other.
307  typename DestDerived::PlainObject tmp(cols(), rhsCols);
308  ComputationInfo global_info = Success;
309  for (Index k = 0; k < rhsCols; ++k) {
310  tb = b.col(k);
311  tx = dest.col(k);
312  derived()._solve_vector_with_guess_impl(tb, tx);
313  tmp.col(k) = tx.sparseView(0);
314 
315  // The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
316  // we need to restore it to the worst value.
317  if (m_info == NumericalIssue)
318  global_info = NumericalIssue;
319  else if (m_info == NoConvergence)
320  global_info = NoConvergence;
321  }
322  m_info = global_info;
323  dest.swap(tmp);
324  }
325 
326  template <typename Rhs, typename DestDerived>
327  std::enable_if_t<Rhs::ColsAtCompileTime != 1 && DestDerived::ColsAtCompileTime != 1> _solve_with_guess_impl(
328  const Rhs& b, MatrixBase<DestDerived>& aDest) const {
329  eigen_assert(rows() == b.rows());
330 
331  Index rhsCols = b.cols();
332  DestDerived& dest(aDest.derived());
333  ComputationInfo global_info = Success;
334  for (Index k = 0; k < rhsCols; ++k) {
335  typename DestDerived::ColXpr xk(dest, k);
336  typename Rhs::ConstColXpr bk(b, k);
337  derived()._solve_vector_with_guess_impl(bk, xk);
338 
339  // The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
340  // we need to restore it to the worst value.
341  if (m_info == NumericalIssue)
342  global_info = NumericalIssue;
343  else if (m_info == NoConvergence)
344  global_info = NoConvergence;
345  }
346  m_info = global_info;
347  }
348 
349  template <typename Rhs, typename DestDerived>
350  std::enable_if_t<Rhs::ColsAtCompileTime == 1 || DestDerived::ColsAtCompileTime == 1> _solve_with_guess_impl(
351  const Rhs& b, MatrixBase<DestDerived>& dest) const {
352  derived()._solve_vector_with_guess_impl(b, dest.derived());
353  }
354 
356  template <typename Rhs, typename Dest>
357  void _solve_impl(const Rhs& b, Dest& x) const {
358  x.setZero();
359  derived()._solve_with_guess_impl(b, x);
360  }
361 
362  protected:
363  void init() {
364  m_isInitialized = false;
365  m_analysisIsOk = false;
366  m_factorizationIsOk = false;
367  m_maxIterations = -1;
369  }
370 
372  typedef typename MatrixWrapper::ActualMatrixType ActualMatrixType;
373 
374  const ActualMatrixType& matrix() const { return m_matrixWrapper.matrix(); }
375 
376  template <typename InputType>
377  void grab(const InputType& A) {
378  m_matrixWrapper.grab(A);
379  }
380 
383 
386 
391 };
392 
393 } // end namespace Eigen
394 
395 #endif // EIGEN_ITERATIVE_SOLVER_BASE_H
Eigen::SparseMatrix< double > mat
Definition: EigenUnitTest.cpp:10
#define EIGEN_NOEXCEPT
Definition: Macros.h:1267
#define EIGEN_CONSTEXPR
Definition: Macros.h:758
#define eigen_assert(x)
Definition: Macros.h:910
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:46
MatrixXf MatrixType
Definition: benchmark-blocking-sizes.cpp:52
Base class for linear iterative solvers.
Definition: IterativeSolverBase.h:124
internal::generic_matrix_wrapper< MatrixType > MatrixWrapper
Definition: IterativeSolverBase.h:371
bool m_analysisIsOk
Definition: IterativeSolverBase.h:390
IterativeSolverBase()
Definition: IterativeSolverBase.h:142
ComputationInfo info() const
Definition: IterativeSolverBase.h:288
bool m_factorizationIsOk
Definition: IterativeSolverBase.h:390
std::enable_if_t< Rhs::ColsAtCompileTime==1||DestDerived::ColsAtCompileTime==1 > _solve_with_guess_impl(const Rhs &b, MatrixBase< DestDerived > &dest) const
Definition: IterativeSolverBase.h:350
RealScalar error() const
Definition: IterativeSolverBase.h:270
Derived & factorize(const EigenBase< MatrixDerived > &A)
Definition: IterativeSolverBase.h:190
Index maxIterations() const
Definition: IterativeSolverBase.h:251
IterativeSolverBase(IterativeSolverBase &&)=default
internal::traits< Derived >::MatrixType MatrixType
Definition: IterativeSolverBase.h:130
ComputationInfo m_info
Definition: IterativeSolverBase.h:389
@ MaxColsAtCompileTime
Definition: IterativeSolverBase.h:136
@ ColsAtCompileTime
Definition: IterativeSolverBase.h:136
IterativeSolverBase(const EigenBase< MatrixDerived > &A)
Definition: IterativeSolverBase.h:155
EIGEN_CONSTEXPR Index rows() const EIGEN_NOEXCEPT
Definition: IterativeSolverBase.h:221
Derived & analyzePattern(const EigenBase< MatrixDerived > &A)
Definition: IterativeSolverBase.h:170
MatrixWrapper::ActualMatrixType ActualMatrixType
Definition: IterativeSolverBase.h:372
MatrixType::RealScalar RealScalar
Definition: IterativeSolverBase.h:134
void grab(const InputType &A)
Definition: IterativeSolverBase.h:377
Preconditioner & preconditioner()
Definition: IterativeSolverBase.h:242
MatrixType::Scalar Scalar
Definition: IterativeSolverBase.h:132
MatrixWrapper m_matrixWrapper
Definition: IterativeSolverBase.h:381
const Preconditioner & preconditioner() const
Definition: IterativeSolverBase.h:245
internal::traits< Derived >::Preconditioner Preconditioner
Definition: IterativeSolverBase.h:131
EIGEN_CONSTEXPR Index cols() const EIGEN_NOEXCEPT
Definition: IterativeSolverBase.h:224
Derived & compute(const EigenBase< MatrixDerived > &A)
Definition: IterativeSolverBase.h:210
void init()
Definition: IterativeSolverBase.h:363
RealScalar m_error
Definition: IterativeSolverBase.h:387
void _solve_impl(const Rhs &b, Dest &x) const
Definition: IterativeSolverBase.h:357
Preconditioner m_preconditioner
Definition: IterativeSolverBase.h:382
MatrixType::StorageIndex StorageIndex
Definition: IterativeSolverBase.h:133
void _solve_with_guess_impl(const Rhs &b, SparseMatrixBase< DestDerived > &aDest) const
Definition: IterativeSolverBase.h:295
~IterativeSolverBase()
Definition: IterativeSolverBase.h:162
SparseSolverBase< Derived > Base
Definition: IterativeSolverBase.h:126
Index m_iterations
Definition: IterativeSolverBase.h:388
Index m_maxIterations
Definition: IterativeSolverBase.h:384
Derived & setTolerance(const RealScalar &tolerance)
Definition: IterativeSolverBase.h:236
bool m_isInitialized
Definition: SparseSolverBase.h:110
std::enable_if_t< Rhs::ColsAtCompileTime !=1 &&DestDerived::ColsAtCompileTime !=1 > _solve_with_guess_impl(const Rhs &b, MatrixBase< DestDerived > &aDest) const
Definition: IterativeSolverBase.h:327
RealScalar tolerance() const
Definition: IterativeSolverBase.h:229
RealScalar m_tolerance
Definition: IterativeSolverBase.h:385
Derived & derived()
Definition: SparseSolverBase.h:76
const SolveWithGuess< Derived, Rhs, Guess > solveWithGuess(const MatrixBase< Rhs > &b, const Guess &x0) const
Definition: IterativeSolverBase.h:281
Index iterations() const
Definition: IterativeSolverBase.h:262
const ActualMatrixType & matrix() const
Definition: IterativeSolverBase.h:374
Derived & setMaxIterations(Index maxIters)
Definition: IterativeSolverBase.h:256
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:52
The matrix class, also used for vectors and row-vectors.
Definition: Eigen/Eigen/src/Core/Matrix.h:186
A matrix or vector expression mapping an existing expression.
Definition: Ref.h:264
Pseudo expression representing a solving operation.
Definition: SolveWithGuess.h:19
Base class of any sparse matrices or sparse expressions.
Definition: SparseMatrixBase.h:30
const Derived & derived() const
Definition: SparseMatrixBase.h:144
A base class for sparse solvers.
Definition: SparseSolverBase.h:67
bool m_isInitialized
Definition: SparseSolverBase.h:110
Derived & derived()
Definition: SparseSolverBase.h:76
generic_matrix_wrapper()
Definition: IterativeSolverBase.h:67
generic_matrix_wrapper(const InputType &mat)
Definition: IterativeSolverBase.h:70
Ref< const MatrixType > ActualMatrixType
Definition: IterativeSolverBase.h:59
void grab(const EigenBase< MatrixDerived > &mat)
Definition: IterativeSolverBase.h:75
MatrixType m_dummy
Definition: IterativeSolverBase.h:88
void grab(const Ref< const MatrixType > &mat)
Definition: IterativeSolverBase.h:80
const ActualMatrixType & matrix() const
Definition: IterativeSolverBase.h:72
ActualMatrixType m_matrix
Definition: IterativeSolverBase.h:89
const ActualMatrixType & matrix() const
Definition: IterativeSolverBase.h:108
void grab(const MatrixType &mat)
Definition: IterativeSolverBase.h:110
generic_matrix_wrapper(const MatrixType &mat)
Definition: IterativeSolverBase.h:106
MatrixType ActualMatrixType
Definition: IterativeSolverBase.h:96
const ActualMatrixType * mp_matrix
Definition: IterativeSolverBase.h:113
generic_matrix_wrapper()
Definition: IterativeSolverBase.h:104
Definition: IterativeSolverBase.h:53
ComputationInfo
Definition: Constants.h:438
@ NumericalIssue
Definition: Constants.h:442
@ Success
Definition: Constants.h:440
@ NoConvergence
Definition: Constants.h:444
char char char int int * k
Definition: level2_impl.h:374
Eigen::Matrix< Scalar, Dynamic, Dynamic, ColMajor > tmp
Definition: level3_impl.h:365
@ Rhs
Definition: TensorContractionMapper.h:20
EIGEN_DEVICE_FUNC T * construct_at(T *p, Args &&... args)
Definition: Memory.h:1321
EIGEN_DEVICE_FUNC void destroy_at(T *p)
Definition: Memory.h:1335
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
Vector< double > x0(2, 0.0)
Definition: Eigen_Colamd.h:49
double epsilon
Definition: osc_ring_sarah_asymptotics.h:43
list x
Definition: plotDoE.py:28
Definition: EigenBase.h:33
ActualMatrixType::template ConstSelfAdjointViewReturnType< UpLo >::Type Type
Definition: IterativeSolverBase.h:62
Definition: IterativeSolverBase.h:33
int a[2]
Definition: IterativeSolverBase.h:34
Definition: IterativeSolverBase.h:30
int a[1]
Definition: IterativeSolverBase.h:31
Definition: IterativeSolverBase.h:21
@ value
Definition: IterativeSolverBase.h:44
static yes test(const Ref< const T > &, int)
static no test(any_conversion< T >,...)
static MatrixType ms_from
Definition: IterativeSolverBase.h:43
Definition: IterativeSolverBase.h:48
Definition: ForwardDeclarations.h:21