33 #ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
37 #include "../InternalHeaderCheck.h"
49 template <
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
54 #define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
55 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
56 struct triangular_matrix_vector_product<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, ColMajor, Specialized> { \
57 static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_, Index rhsIncr, \
58 Scalar* res_, Index resIncr, Scalar alpha) { \
59 triangular_matrix_vector_product_trmv<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, ColMajor>::run( \
60 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
63 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
64 struct triangular_matrix_vector_product<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, RowMajor, Specialized> { \
65 static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_, Index rhsIncr, \
66 Scalar* res_, Index resIncr, Scalar alpha) { \
67 triangular_matrix_vector_product_trmv<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, RowMajor>::run( \
68 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
78 #define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
79 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
80 struct triangular_matrix_vector_product_trmv<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor> { \
82 IsLower = (Mode & Lower) == Lower, \
83 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
84 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
85 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
86 LowUp = IsLower ? Lower : Upper \
88 static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
89 Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
90 if (rows_ == 0 || cols_ == 0) return; \
91 if (ConjLhs || IsZeroDiag) { \
92 triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor, BuiltIn>::run( \
93 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
96 Index size = (std::min)(rows_, cols_); \
97 Index rows = IsLower ? rows_ : size; \
98 Index cols = IsLower ? size : cols_; \
100 typedef VectorX##EIGPREFIX VectorRhs; \
104 Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_, cols, InnerStride<>(rhsIncr)); \
107 x_tmp = rhs.conjugate(); \
114 char trans, uplo, diag; \
115 BlasIndex m, n, lda, incx, incy; \
120 n = convert_index<BlasIndex>(size); \
121 lda = convert_index<BlasIndex>(lhsStride); \
123 incy = convert_index<BlasIndex>(resIncr); \
127 uplo = IsLower ? 'L' : 'U'; \
128 diag = IsUnitDiag ? 'U' : 'N'; \
131 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \
134 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)x, &incx, \
135 (BLASTYPE*)res_, &incy); \
137 if (size < (std::max)(rows, cols)) { \
139 x_tmp = rhs.conjugate(); \
144 y = res_ + size * resIncr; \
146 m = convert_index<BlasIndex>(rows - size); \
147 n = convert_index<BlasIndex>(size); \
151 a = lhs_ + size * lda; \
152 m = convert_index<BlasIndex>(size); \
153 n = convert_index<BlasIndex>(cols - size); \
155 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
156 &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), \
157 (BLASTYPE*)y, &incy); \
175 #define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
176 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
177 struct triangular_matrix_vector_product_trmv<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor> { \
179 IsLower = (Mode & Lower) == Lower, \
180 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
181 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
182 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
183 LowUp = IsLower ? Lower : Upper \
185 static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
186 Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
187 if (rows_ == 0 || cols_ == 0) return; \
189 triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor, BuiltIn>::run( \
190 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
193 Index size = (std::min)(rows_, cols_); \
194 Index rows = IsLower ? rows_ : size; \
195 Index cols = IsLower ? size : cols_; \
197 typedef VectorX##EIGPREFIX VectorRhs; \
201 Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_, cols, InnerStride<>(rhsIncr)); \
204 x_tmp = rhs.conjugate(); \
211 char trans, uplo, diag; \
212 BlasIndex m, n, lda, incx, incy; \
217 n = convert_index<BlasIndex>(size); \
218 lda = convert_index<BlasIndex>(lhsStride); \
220 incy = convert_index<BlasIndex>(resIncr); \
223 trans = ConjLhs ? 'C' : 'T'; \
224 uplo = IsLower ? 'U' : 'L'; \
225 diag = IsUnitDiag ? 'U' : 'N'; \
228 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \
231 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)x, &incx, \
232 (BLASTYPE*)res_, &incy); \
234 if (size < (std::max)(rows, cols)) { \
236 x_tmp = rhs.conjugate(); \
241 y = res_ + size * resIncr; \
242 a = lhs_ + size * lda; \
243 m = convert_index<BlasIndex>(rows - size); \
244 n = convert_index<BlasIndex>(size); \
249 m = convert_index<BlasIndex>(size); \
250 n = convert_index<BlasIndex>(cols - size); \
252 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
253 &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), \
254 (BLASTYPE*)y, &incy); \
#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX)
Definition: TriangularMatrixVector_BLAS.h:175
#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar)
Definition: TriangularMatrixVector_BLAS.h:54
#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX)
Definition: TriangularMatrixVector_BLAS.h:78
#define _(A, B)
Definition: cfortran.h:132
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
Definition: cxx11_tensor_map.cpp:237
RealScalar s
Definition: level1_cplx_impl.h:130
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::complex< double > dcomplex
Definition: MKL_support.h:128
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
std::complex< float > scomplex
Definition: MKL_support.h:129
int c
Definition: calibrate.py:100
Definition: Eigen_Colamd.h:49
Definition: TriangularMatrixVector_BLAS.h:52
Definition: TriangularMatrixVector.h:22