33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
37 #include "../InternalHeaderCheck.h"
43 template <
typename Scalar,
typename Index,
int Mode,
bool LhsIsTriangular,
int LhsStorageOrder,
bool ConjugateLhs,
44 int RhsStorageOrder,
bool ConjugateRhs,
int ResStorageOrder>
47 RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
50 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
51 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
52 struct product_triangular_matrix_matrix<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs, \
53 RhsStorageOrder, ConjugateRhs, ColMajor, 1, Specialized> { \
54 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride, \
55 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, \
56 Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) { \
57 EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
58 eigen_assert(resIncr == 1); \
59 product_triangular_matrix_matrix_trmm<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs, \
60 RhsStorageOrder, ConjugateRhs, ColMajor>::run(_rows, _cols, _depth, _lhs, \
61 lhsStride, _rhs, rhsStride, \
62 res, resStride, alpha, \
77 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
78 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
79 struct product_triangular_matrix_matrix_trmm<EIGTYPE, Index, Mode, true, LhsStorageOrder, ConjugateLhs, \
80 RhsStorageOrder, ConjugateRhs, ColMajor> { \
82 IsLower = (Mode & Lower) == Lower, \
83 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
84 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
85 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
86 LowUp = IsLower ? Lower : Upper, \
87 conjA = ((LhsStorageOrder == ColMajor) && ConjugateLhs) ? 1 : 0 \
90 static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
91 Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
92 level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
93 if (_rows == 0 || _cols == 0 || _depth == 0) return; \
94 Index diagSize = (std::min)(_rows, _depth); \
95 Index rows = IsLower ? _rows : diagSize; \
96 Index depth = IsLower ? diagSize : _depth; \
99 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
100 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
103 if (rows != depth) { \
107 if (((nthr == 1) && (((std::max)(rows, depth) - diagSize) / (double)diagSize < 0.5))) { \
109 product_triangular_matrix_matrix<EIGTYPE, Index, Mode, true, LhsStorageOrder, ConjugateLhs, RhsStorageOrder, \
110 ConjugateRhs, ColMajor, 1, BuiltIn>::run(_rows, _cols, _depth, _lhs, \
111 lhsStride, _rhs, rhsStride, res, \
112 1, resStride, alpha, blocking); \
116 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs, rows, depth, OuterStride<>(lhsStride)); \
117 MatrixLhs aa_tmp = lhsMap.template triangularView<Mode>(); \
118 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
119 gemm_blocking_space<ColMajor, EIGTYPE, EIGTYPE, Dynamic, Dynamic, Dynamic> gemm_blocking(_rows, _cols, \
121 general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \
122 ConjugateRhs, ColMajor, 1>::run(rows, cols, depth, aa_tmp.data(), aStride, \
123 _rhs, rhsStride, res, 1, resStride, alpha, \
130 char side = 'L', transa, uplo, diag = 'N'; \
133 BlasIndex m, n, lda, ldb; \
136 m = convert_index<BlasIndex>(diagSize); \
137 n = convert_index<BlasIndex>(cols); \
140 transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
143 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs, depth, cols, OuterStride<>(rhsStride)); \
144 MatrixX##EIGPREFIX b_tmp; \
147 b_tmp = rhs.conjugate(); \
151 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
154 uplo = IsLower ? 'L' : 'U'; \
155 if (LhsStorageOrder == RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
157 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs, rows, depth, OuterStride<>(lhsStride)); \
160 if ((conjA != 0) || (SetDiag == 0)) { \
162 a_tmp = lhs.conjugate(); \
166 a_tmp.diagonal().setZero(); \
167 else if (IsUnitDiag) \
168 a_tmp.diagonal().setOnes(); \
170 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
173 lda = convert_index<BlasIndex>(lhsStride); \
177 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
178 &lda, (BLASTYPE*)b, &ldb); \
181 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res, rows, cols, OuterStride<>(resStride)); \
182 res_tmp = res_tmp + b_tmp; \
199 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
200 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
201 struct product_triangular_matrix_matrix_trmm<EIGTYPE, Index, Mode, false, LhsStorageOrder, ConjugateLhs, \
202 RhsStorageOrder, ConjugateRhs, ColMajor> { \
204 IsLower = (Mode & Lower) == Lower, \
205 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
206 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
207 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
208 LowUp = IsLower ? Lower : Upper, \
209 conjA = ((RhsStorageOrder == ColMajor) && ConjugateRhs) ? 1 : 0 \
212 static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
213 Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
214 level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
215 if (_rows == 0 || _cols == 0 || _depth == 0) return; \
216 Index diagSize = (std::min)(_cols, _depth); \
217 Index rows = _rows; \
218 Index depth = IsLower ? _depth : diagSize; \
219 Index cols = IsLower ? diagSize : _cols; \
221 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
222 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
225 if (cols != depth) { \
228 if ((nthr == 1) && (((std::max)(cols, depth) - diagSize) / (double)diagSize < 0.5)) { \
230 product_triangular_matrix_matrix<EIGTYPE, Index, Mode, false, LhsStorageOrder, ConjugateLhs, \
231 RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run(_rows, _cols, \
240 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs, depth, cols, OuterStride<>(rhsStride)); \
241 MatrixRhs aa_tmp = rhsMap.template triangularView<Mode>(); \
242 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
243 gemm_blocking_space<ColMajor, EIGTYPE, EIGTYPE, Dynamic, Dynamic, Dynamic> gemm_blocking(_rows, _cols, \
245 general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \
246 ConjugateRhs, ColMajor, 1>::run(rows, cols, depth, _lhs, lhsStride, \
247 aa_tmp.data(), aStride, res, 1, resStride, \
248 alpha, gemm_blocking, 0); \
254 char side = 'R', transa, uplo, diag = 'N'; \
257 BlasIndex m, n, lda, ldb; \
260 m = convert_index<BlasIndex>(rows); \
261 n = convert_index<BlasIndex>(diagSize); \
264 transa = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
267 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs, rows, depth, OuterStride<>(lhsStride)); \
268 MatrixX##EIGPREFIX b_tmp; \
271 b_tmp = lhs.conjugate(); \
275 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
278 uplo = IsLower ? 'L' : 'U'; \
279 if (RhsStorageOrder == RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
281 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs, depth, cols, OuterStride<>(rhsStride)); \
284 if ((conjA != 0) || (SetDiag == 0)) { \
286 a_tmp = rhs.conjugate(); \
290 a_tmp.diagonal().setZero(); \
291 else if (IsUnitDiag) \
292 a_tmp.diagonal().setOnes(); \
294 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
297 lda = convert_index<BlasIndex>(rhsStride); \
301 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
302 &lda, (BLASTYPE*)b, &ldb); \
305 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res, rows, cols, OuterStride<>(resStride)); \
306 res_tmp = res_tmp + b_tmp; \
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
Definition: TriangularMatrixMatrix_BLAS.h:199
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
Definition: TriangularMatrixMatrix_BLAS.h:77
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)
Definition: TriangularMatrixMatrix_BLAS.h:50
int BLASFUNC() ztrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
int BLASFUNC() dtrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
int BLASFUNC() ctrmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
int BLASFUNC() strmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
SCALAR Scalar
Definition: bench_gemm.cpp:45
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
Definition: cxx11_tensor_map.cpp:237
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::complex< double > dcomplex
Definition: MKL_support.h:128
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
std::complex< float > scomplex
Definition: MKL_support.h:129
Definition: Eigen_Colamd.h:49
Definition: TriangularMatrixMatrix_BLAS.h:47
Definition: TriangularMatrixMatrix.h:49