11 #ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
12 #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
15 #include "../InternalHeaderCheck.h"
21 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStride,
30 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStride,
39 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStride,
47 TriMapper
tri(_tri, triStride);
48 OtherMapper other(_other, otherStride, otherIncr);
58 Index s = TriStorageOrder ==
RowMajor ? (IsLower ? 0 :
i + 1) : IsLower ?
i + 1 :
i - rs;
61 for (
Index j = 0;
j < otherSize; ++
j) {
65 typename OtherMapper::LinearMapper
r = other.getLinearMapper(
s,
j);
66 for (
Index i3 = 0; i3 <
k; ++i3)
b +=
conj(l[i3]) *
r(i3);
68 other(
i,
j) = (other(
i,
j) -
b) *
a;
73 typename OtherMapper::LinearMapper
r = other.getLinearMapper(
s,
j);
74 typename TriMapper::LinearMapper l =
tri.getLinearMapper(
s,
i);
75 for (
Index i3 = 0; i3 < rs; ++i3)
r(i3) -=
b *
conj(l(i3));
81 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStride,
90 LhsMapper lhs(_other, otherStride, otherIncr);
91 RhsMapper rhs(_tri, triStride);
93 enum { RhsStorageOrder = TriStorageOrder, IsLower = (Mode &
Lower) ==
Lower };
99 typename LhsMapper::LinearMapper
r = lhs.getLinearMapper(0,
j);
100 for (
Index k3 = 0; k3 <
k; ++k3) {
102 typename LhsMapper::LinearMapper
a = lhs.getLinearMapper(0, IsLower ?
j + 1 + k3 : k3);
107 for (
Index i = 0;
i < otherSize; ++
i)
r(
i) *= inv_rjj;
114 int OtherInnerStride>
121 OtherInnerStride>::
run(
size,
cols,
tri, triStride, _other, otherIncr, otherStride, blocking);
127 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
133 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
141 std::ptrdiff_t l1, l2, l3;
144 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_L_KERNELS && EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS
151 if (
size < avx512_trsm_cutoff<Scalar>(l2,
cols, L2Cap)) {
153 size,
cols, _tri, triStride, _other, 1, otherStride);
161 TriMapper
tri(_tri, triStride);
162 OtherMapper other(_other, otherStride, otherIncr);
171 std::size_t sizeA = kc * mc;
172 std::size_t sizeB = kc *
cols;
185 Index subcols =
cols > 0 ? l2 / (4 *
sizeof(
Scalar) * std::max<Index>(otherStride,
size)) : 0;
186 subcols = std::max<Index>((subcols / Traits::nr) * Traits::nr, Traits::nr);
188 for (
Index k2 = IsLower ? 0 :
size; IsLower ? k2 < size : k2 > 0; IsLower ? k2 += kc : k2 -= kc) {
204 for (
Index j2 = 0; j2 <
cols; j2 += subcols) {
207 for (
Index k1 = 0; k1 < actual_kc; k1 += SmallPanelWidth) {
208 Index actualPanelWidth = std::min<Index>(actual_kc - k1, SmallPanelWidth);
211 Index i = IsLower ? k2 + k1 : k2 - k1;
212 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_L_KERNELS
215 i = IsLower ? k2 + k1 : k2 - k1 - actualPanelWidth;
219 actualPanelWidth, actual_cols, _tri +
i + (
i)*triStride, triStride,
220 _other +
i * OtherInnerStride + j2 * otherStride, otherIncr, otherStride);
223 Index lengthTarget = actual_kc - k1 - actualPanelWidth;
224 Index startBlock = IsLower ? k2 + k1 : k2 - k1 - actualPanelWidth;
225 Index blockBOffset = IsLower ? k1 : lengthTarget;
228 pack_rhs(blockB + actual_kc * j2, other.getSubMapper(startBlock, j2), actualPanelWidth, actual_cols, actual_kc,
232 if (lengthTarget > 0) {
233 Index startTarget = IsLower ? k2 + k1 + actualPanelWidth : k2 - actual_kc;
235 pack_lhs(blockA,
tri.getSubMapper(startTarget, startBlock), actualPanelWidth, lengthTarget);
237 gebp_kernel(other.getSubMapper(startTarget, j2), blockA, blockB + actual_kc * j2, lengthTarget,
238 actualPanelWidth, actual_cols,
Scalar(-1), actualPanelWidth, actual_kc, 0, blockBOffset);
250 pack_lhs(blockA,
tri.getSubMapper(i2, IsLower ? k2 : k2 - kc), actual_kc, actual_mc);
252 gebp_kernel(other.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
cols,
Scalar(-1), -1, -1, 0, 0);
261 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
268 template <
typename Scalar,
typename Index,
int Mode,
bool Conjugate,
int TriStorageOrder,
int OtherInnerStr
ide>
276 #if defined(EIGEN_VECTORIZE_AVX512) && EIGEN_USE_AVX512_TRSM_R_KERNELS && EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS
280 std::ptrdiff_t l1, l2, l3;
283 if (
size < avx512_trsm_cutoff<Scalar>(l2,
rows, L2Cap)) {
285 size,
rows, _tri, triStride, _other, 1, otherStride);
293 LhsMapper lhs(_other, otherStride, otherIncr);
294 RhsMapper rhs(_tri, triStride);
298 RhsStorageOrder = TriStorageOrder,
306 std::size_t sizeA = kc * mc;
307 std::size_t sizeB = kc *
size;
319 for (
Index k2 = IsLower ?
size : 0; IsLower ? k2 > 0 : k2 <
size; IsLower ? k2 -= kc : k2 += kc) {
321 Index actual_k2 = IsLower ? k2 - actual_kc : k2;
323 Index startPanel = IsLower ? 0 : k2 + actual_kc;
324 Index rs = IsLower ? actual_k2 :
size - actual_k2 - actual_kc;
325 Scalar* geb = blockB + actual_kc * actual_kc;
327 if (rs > 0) pack_rhs(geb, rhs.getSubMapper(actual_k2, startPanel), actual_kc, rs);
332 for (
Index j2 = 0; j2 < actual_kc; j2 += SmallPanelWidth) {
333 Index actualPanelWidth = std::min<Index>(actual_kc - j2, SmallPanelWidth);
334 Index actual_j2 = actual_k2 + j2;
335 Index panelOffset = IsLower ? j2 + actualPanelWidth : 0;
336 Index panelLength = IsLower ? actual_kc - j2 - actualPanelWidth : j2;
339 pack_rhs_panel(blockB + j2 * actual_kc, rhs.getSubMapper(actual_k2 + panelOffset, actual_j2), panelLength,
340 actualPanelWidth, actual_kc, panelOffset);
344 for (
Index i2 = 0; i2 <
rows; i2 += mc) {
350 for (
Index j2 = IsLower ? (actual_kc - ((actual_kc % SmallPanelWidth) ?
Index(actual_kc % SmallPanelWidth)
351 :
Index(SmallPanelWidth)))
353 IsLower ? j2 >= 0 : j2 < actual_kc; IsLower ? j2 -= SmallPanelWidth : j2 += SmallPanelWidth) {
354 Index actualPanelWidth = std::min<Index>(actual_kc - j2, SmallPanelWidth);
355 Index absolute_j2 = actual_k2 + j2;
356 Index panelOffset = IsLower ? j2 + actualPanelWidth : 0;
357 Index panelLength = IsLower ? actual_kc - j2 - actualPanelWidth : j2;
360 if (panelLength > 0) {
361 gebp_kernel(lhs.getSubMapper(i2, absolute_j2), blockA, blockB + j2 * actual_kc, actual_mc, panelLength,
362 actualPanelWidth,
Scalar(-1), actual_kc, actual_kc,
363 panelOffset, panelOffset);
369 true>::kernel(actualPanelWidth, actual_mc,
370 _tri + absolute_j2 + absolute_j2 * triStride, triStride,
371 _other + i2 * OtherInnerStride + absolute_j2 * otherStride,
372 otherIncr, otherStride);
375 pack_lhs_panel(blockA, lhs.getSubMapper(i2, absolute_j2), actualPanelWidth, actual_mc, actual_kc, j2);
380 gebp_kernel(lhs.getSubMapper(i2, startPanel), blockA, geb, actual_mc, actual_kc, rs,
Scalar(-1), -1, -1, 0, 0);
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define EIGEN_DONT_INLINE
Definition: Macros.h:853
#define EIGEN_IF_CONSTEXPR(X)
Definition: Macros.h:1306
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:806
Tridiagonalization< MatrixXf > tri
Definition: Tridiagonalization_compute.cpp:1
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:46
Definition: ForwardDeclarations.h:102
Definition: BlasUtil.h:304
Definition: BlasUtil.h:443
Definition: products/GeneralBlockPanelKernel.h:397
Definition: GeneralMatrixMatrix.h:226
RhsScalar * blockB()
Definition: GeneralMatrixMatrix.h:246
LhsScalar * blockA()
Definition: GeneralMatrixMatrix.h:245
Index mc() const
Definition: GeneralMatrixMatrix.h:241
Index kc() const
Definition: GeneralMatrixMatrix.h:243
#define min(a, b)
Definition: datatypes.h:22
static constexpr lastp1_t end
Definition: IndexedViewHelper.h:79
@ UnitDiag
Definition: Constants.h:215
@ Lower
Definition: Constants.h:211
@ Upper
Definition: Constants.h:213
@ GetAction
Definition: Constants.h:516
@ Specialized
Definition: Constants.h:311
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
@ OnTheLeft
Definition: Constants.h:331
@ OnTheRight
Definition: Constants.h:333
RealScalar s
Definition: level1_cplx_impl.h:130
const Scalar * a
Definition: level2_cplx_impl.h:32
char char char int int * k
Definition: level2_impl.h:374
constexpr int plain_enum_max(A a, B b)
Definition: Meta.h:656
void manage_caching_sizes(Action action, std::ptrdiff_t *l1, std::ptrdiff_t *l2, std::ptrdiff_t *l3)
Definition: products/GeneralBlockPanelKernel.h:86
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
auto run(Kernel kernel, Args &&... args) -> decltype(kernel(args...))
Definition: gpu_test_helper.h:414
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
const AutoDiffScalar< DerType > & conj(const AutoDiffScalar< DerType > &x)
Definition: AutoDiffScalar.h:482
Definition: Eigen_Colamd.h:49
void start(const unsigned &i)
(Re-)start i-th timer
Definition: oomph_utilities.cc:243
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
Definition: NumTraits.h:217
Definition: ConjHelper.h:42
Definition: products/GeneralBlockPanelKernel.h:960
Definition: BlasUtil.h:34
Definition: BlasUtil.h:30
static void run(Index size, Index cols, const Scalar *tri, Index triStride, Scalar *_other, Index otherIncr, Index otherStride, level3_blocking< Scalar, Scalar > &blocking)
Definition: TriangularSolverMatrix.h:116
Definition: SolveTriangular.h:27
Definition: TriangularSolverMatrix.h:23
static void kernel(Index size, Index otherSize, const Scalar *_tri, Index triStride, Scalar *_other, Index otherIncr, Index otherStride)
Definition: TriangularSolverMatrix.h:42
Definition: TriangularSolverMatrix.h:32
static void kernel(Index size, Index otherSize, const Scalar *_tri, Index triStride, Scalar *_other, Index otherIncr, Index otherStride)
Definition: TriangularSolverMatrix.h:84
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2