TriangularMatrixMatrix_BLAS.h
Go to the documentation of this file.
1 /*
2  Copyright (c) 2011, Intel Corporation. All rights reserved.
3 
4  Redistribution and use in source and binary forms, with or without modification,
5  are permitted provided that the following conditions are met:
6 
7  * Redistributions of source code must retain the above copyright notice, this
8  list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright notice,
10  this list of conditions and the following disclaimer in the documentation
11  and/or other materials provided with the distribution.
12  * Neither the name of Intel Corporation nor the names of its contributors may
13  be used to endorse or promote products derived from this software without
14  specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23  ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 
27  ********************************************************************************
28  * Content : Eigen bindings to BLAS F77
29  * Triangular matrix * matrix product functionality based on ?TRMM.
30  ********************************************************************************
31 */
32 
33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
35 
36 // IWYU pragma: private
37 #include "../InternalHeaderCheck.h"
38 
39 namespace Eigen {
40 
41 namespace internal {
42 
43 template <typename Scalar, typename Index, int Mode, bool LhsIsTriangular, int LhsStorageOrder, bool ConjugateLhs,
44  int RhsStorageOrder, bool ConjugateRhs, int ResStorageOrder>
46  : product_triangular_matrix_matrix<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs,
47  RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
48 
49 // try to go to BLAS specialization
50 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
51  template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
52  struct product_triangular_matrix_matrix<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs, \
53  RhsStorageOrder, ConjugateRhs, ColMajor, 1, Specialized> { \
54  static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride, \
55  const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, \
56  Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) { \
57  EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
58  eigen_assert(resIncr == 1); \
59  product_triangular_matrix_matrix_trmm<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs, \
60  RhsStorageOrder, ConjugateRhs, ColMajor>::run(_rows, _cols, _depth, _lhs, \
61  lhsStride, _rhs, rhsStride, \
62  res, resStride, alpha, \
63  blocking); \
64  } \
65  };
66 
67 EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
68 EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
71 EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
72 EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
75 
76 // implements col-major += alpha * op(triangular) * op(general)
77 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
78  template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
79  struct product_triangular_matrix_matrix_trmm<EIGTYPE, Index, Mode, true, LhsStorageOrder, ConjugateLhs, \
80  RhsStorageOrder, ConjugateRhs, ColMajor> { \
81  enum { \
82  IsLower = (Mode & Lower) == Lower, \
83  SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
84  IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
85  IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
86  LowUp = IsLower ? Lower : Upper, \
87  conjA = ((LhsStorageOrder == ColMajor) && ConjugateLhs) ? 1 : 0 \
88  }; \
89  \
90  static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
91  Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
92  level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
93  if (_rows == 0 || _cols == 0 || _depth == 0) return; \
94  Index diagSize = (std::min)(_rows, _depth); \
95  Index rows = IsLower ? _rows : diagSize; \
96  Index depth = IsLower ? diagSize : _depth; \
97  Index cols = _cols; \
98  \
99  typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
100  typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
101  \
102  /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
103  if (rows != depth) { \
104  /* FIXME handle mkl_domain_get_max_threads */ \
105  /*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1; \
106  \
107  if (((nthr == 1) && (((std::max)(rows, depth) - diagSize) / (double)diagSize < 0.5))) { \
108  /* Most likely no benefit to call TRMM or GEMM from BLAS */ \
109  product_triangular_matrix_matrix<EIGTYPE, Index, Mode, true, LhsStorageOrder, ConjugateLhs, RhsStorageOrder, \
110  ConjugateRhs, ColMajor, 1, BuiltIn>::run(_rows, _cols, _depth, _lhs, \
111  lhsStride, _rhs, rhsStride, res, \
112  1, resStride, alpha, blocking); \
113  /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
114  } else { \
115  /* Make sense to call GEMM */ \
116  Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs, rows, depth, OuterStride<>(lhsStride)); \
117  MatrixLhs aa_tmp = lhsMap.template triangularView<Mode>(); \
118  BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
119  gemm_blocking_space<ColMajor, EIGTYPE, EIGTYPE, Dynamic, Dynamic, Dynamic> gemm_blocking(_rows, _cols, \
120  _depth, 1, true); \
121  general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \
122  ConjugateRhs, ColMajor, 1>::run(rows, cols, depth, aa_tmp.data(), aStride, \
123  _rhs, rhsStride, res, 1, resStride, alpha, \
124  gemm_blocking, 0); \
125  \
126  /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
127  } \
128  return; \
129  } \
130  char side = 'L', transa, uplo, diag = 'N'; \
131  EIGTYPE* b; \
132  const EIGTYPE* a; \
133  BlasIndex m, n, lda, ldb; \
134  \
135  /* Set m, n */ \
136  m = convert_index<BlasIndex>(diagSize); \
137  n = convert_index<BlasIndex>(cols); \
138  \
139  /* Set trans */ \
140  transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
141  \
142  /* Set b, ldb */ \
143  Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs, depth, cols, OuterStride<>(rhsStride)); \
144  MatrixX##EIGPREFIX b_tmp; \
145  \
146  if (ConjugateRhs) \
147  b_tmp = rhs.conjugate(); \
148  else \
149  b_tmp = rhs; \
150  b = b_tmp.data(); \
151  ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
152  \
153  /* Set uplo */ \
154  uplo = IsLower ? 'L' : 'U'; \
155  if (LhsStorageOrder == RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
156  /* Set a, lda */ \
157  Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs, rows, depth, OuterStride<>(lhsStride)); \
158  MatrixLhs a_tmp; \
159  \
160  if ((conjA != 0) || (SetDiag == 0)) { \
161  if (conjA) \
162  a_tmp = lhs.conjugate(); \
163  else \
164  a_tmp = lhs; \
165  if (IsZeroDiag) \
166  a_tmp.diagonal().setZero(); \
167  else if (IsUnitDiag) \
168  a_tmp.diagonal().setOnes(); \
169  a = a_tmp.data(); \
170  lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
171  } else { \
172  a = _lhs; \
173  lda = convert_index<BlasIndex>(lhsStride); \
174  } \
175  /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
176  /* call ?trmm*/ \
177  BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
178  &lda, (BLASTYPE*)b, &ldb); \
179  \
180  /* Add op(a_triangular)*b into res*/ \
181  Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res, rows, cols, OuterStride<>(resStride)); \
182  res_tmp = res_tmp + b_tmp; \
183  } \
184  };
185 
186 #ifdef EIGEN_USE_MKL
187 EIGEN_BLAS_TRMM_L(double, double, d, dtrmm)
188 EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm)
189 EIGEN_BLAS_TRMM_L(float, float, f, strmm)
190 EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm)
191 #else
192 EIGEN_BLAS_TRMM_L(double, double, d, dtrmm_)
193 EIGEN_BLAS_TRMM_L(dcomplex, double, cd, ztrmm_)
194 EIGEN_BLAS_TRMM_L(float, float, f, strmm_)
195 EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_)
196 #endif
197 
198 // implements col-major += alpha * op(general) * op(triangular)
199 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
200  template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
201  struct product_triangular_matrix_matrix_trmm<EIGTYPE, Index, Mode, false, LhsStorageOrder, ConjugateLhs, \
202  RhsStorageOrder, ConjugateRhs, ColMajor> { \
203  enum { \
204  IsLower = (Mode & Lower) == Lower, \
205  SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
206  IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
207  IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
208  LowUp = IsLower ? Lower : Upper, \
209  conjA = ((RhsStorageOrder == ColMajor) && ConjugateRhs) ? 1 : 0 \
210  }; \
211  \
212  static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
213  Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
214  level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
215  if (_rows == 0 || _cols == 0 || _depth == 0) return; \
216  Index diagSize = (std::min)(_cols, _depth); \
217  Index rows = _rows; \
218  Index depth = IsLower ? _depth : diagSize; \
219  Index cols = IsLower ? diagSize : _cols; \
220  \
221  typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
222  typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
223  \
224  /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
225  if (cols != depth) { \
226  int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
227  \
228  if ((nthr == 1) && (((std::max)(cols, depth) - diagSize) / (double)diagSize < 0.5)) { \
229  /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
230  product_triangular_matrix_matrix<EIGTYPE, Index, Mode, false, LhsStorageOrder, ConjugateLhs, \
231  RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run(_rows, _cols, \
232  _depth, _lhs, \
233  lhsStride, _rhs, \
234  rhsStride, res, \
235  1, resStride, \
236  alpha, blocking); \
237  /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
238  } else { \
239  /* Make sense to call GEMM */ \
240  Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs, depth, cols, OuterStride<>(rhsStride)); \
241  MatrixRhs aa_tmp = rhsMap.template triangularView<Mode>(); \
242  BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
243  gemm_blocking_space<ColMajor, EIGTYPE, EIGTYPE, Dynamic, Dynamic, Dynamic> gemm_blocking(_rows, _cols, \
244  _depth, 1, true); \
245  general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \
246  ConjugateRhs, ColMajor, 1>::run(rows, cols, depth, _lhs, lhsStride, \
247  aa_tmp.data(), aStride, res, 1, resStride, \
248  alpha, gemm_blocking, 0); \
249  \
250  /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
251  } \
252  return; \
253  } \
254  char side = 'R', transa, uplo, diag = 'N'; \
255  EIGTYPE* b; \
256  const EIGTYPE* a; \
257  BlasIndex m, n, lda, ldb; \
258  \
259  /* Set m, n */ \
260  m = convert_index<BlasIndex>(rows); \
261  n = convert_index<BlasIndex>(diagSize); \
262  \
263  /* Set trans */ \
264  transa = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
265  \
266  /* Set b, ldb */ \
267  Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs, rows, depth, OuterStride<>(lhsStride)); \
268  MatrixX##EIGPREFIX b_tmp; \
269  \
270  if (ConjugateLhs) \
271  b_tmp = lhs.conjugate(); \
272  else \
273  b_tmp = lhs; \
274  b = b_tmp.data(); \
275  ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
276  \
277  /* Set uplo */ \
278  uplo = IsLower ? 'L' : 'U'; \
279  if (RhsStorageOrder == RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
280  /* Set a, lda */ \
281  Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs, depth, cols, OuterStride<>(rhsStride)); \
282  MatrixRhs a_tmp; \
283  \
284  if ((conjA != 0) || (SetDiag == 0)) { \
285  if (conjA) \
286  a_tmp = rhs.conjugate(); \
287  else \
288  a_tmp = rhs; \
289  if (IsZeroDiag) \
290  a_tmp.diagonal().setZero(); \
291  else if (IsUnitDiag) \
292  a_tmp.diagonal().setOnes(); \
293  a = a_tmp.data(); \
294  lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
295  } else { \
296  a = _rhs; \
297  lda = convert_index<BlasIndex>(rhsStride); \
298  } \
299  /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
300  /* call ?trmm*/ \
301  BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
302  &lda, (BLASTYPE*)b, &ldb); \
303  \
304  /* Add op(a_triangular)*b into res*/ \
305  Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res, rows, cols, OuterStride<>(resStride)); \
306  res_tmp = res_tmp + b_tmp; \
307  } \
308  };
309 
310 #ifdef EIGEN_USE_MKL
311 EIGEN_BLAS_TRMM_R(double, double, d, dtrmm)
312 EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm)
313 EIGEN_BLAS_TRMM_R(float, float, f, strmm)
314 EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm)
315 #else
316 EIGEN_BLAS_TRMM_R(double, double, d, dtrmm_)
317 EIGEN_BLAS_TRMM_R(dcomplex, double, cd, ztrmm_)
318 EIGEN_BLAS_TRMM_R(float, float, f, strmm_)
319 EIGEN_BLAS_TRMM_R(scomplex, float, cf, ctrmm_)
320 #endif
321 } // end namespace internal
322 
323 } // end namespace Eigen
324 
325 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
Definition: TriangularMatrixMatrix_BLAS.h:199
#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC)
Definition: TriangularMatrixMatrix_BLAS.h:77
#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular)
Definition: TriangularMatrixMatrix_BLAS.h:50
int BLASFUNC() ztrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
int BLASFUNC() dtrmm(char *, char *, char *, char *, int *, int *, double *, double *, int *, double *, int *)
int BLASFUNC() ctrmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
int BLASFUNC() strmm(char *, char *, char *, char *, int *, int *, float *, float *, int *, float *, int *)
SCALAR Scalar
Definition: bench_gemm.cpp:45
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
Definition: cxx11_tensor_map.cpp:237
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
std::complex< double > dcomplex
Definition: MKL_support.h:128
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
std::complex< float > scomplex
Definition: MKL_support.h:129
Definition: Eigen_Colamd.h:49
Definition: TriangularMatrixMatrix_BLAS.h:47
Definition: TriangularMatrixMatrix.h:49