level3_impl.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 #include <iostream>
10 #include "common.h"
11 
13 (const char *opa, const char *opb, const int *m, const int *n, const int *k, const RealScalar *palpha,
14  const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc,
15  const int *ldc) {
16  // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " <<
17  // *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
18  using Eigen::ColMajor;
19  using Eigen::DenseIndex;
20  using Eigen::Dynamic;
21  using Eigen::RowMajor;
22  typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex,
25  static const functype func[12] = {
26  // array index: NOTR | (NOTR << 2)
28  ColMajor, 1>::run),
29  // array index: TR | (NOTR << 2)
31  ColMajor, 1>::run),
32  // array index: ADJ | (NOTR << 2)
34  ColMajor, 1>::run),
35  0,
36  // array index: NOTR | (TR << 2)
38  ColMajor, 1>::run),
39  // array index: TR | (TR << 2)
41  ColMajor, 1>::run),
42  // array index: ADJ | (TR << 2)
44  ColMajor, 1>::run),
45  0,
46  // array index: NOTR | (ADJ << 2)
48  ColMajor, 1>::run),
49  // array index: TR | (ADJ << 2)
51  ColMajor, 1>::run),
52  // array index: ADJ | (ADJ << 2)
54  ColMajor, 1>::run),
55  0};
56 
57  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
58  const Scalar *b = reinterpret_cast<const Scalar *>(pb);
59  Scalar *c = reinterpret_cast<Scalar *>(pc);
60  Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
61  Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
62 
63  int info = 0;
64  if (OP(*opa) == INVALID)
65  info = 1;
66  else if (OP(*opb) == INVALID)
67  info = 2;
68  else if (*m < 0)
69  info = 3;
70  else if (*n < 0)
71  info = 4;
72  else if (*k < 0)
73  info = 5;
74  else if (*lda < std::max(1, (OP(*opa) == NOTR) ? *m : *k))
75  info = 8;
76  else if (*ldb < std::max(1, (OP(*opb) == NOTR) ? *k : *n))
77  info = 10;
78  else if (*ldc < std::max(1, *m))
79  info = 13;
80  if (info) return xerbla_(SCALAR_SUFFIX_UP "GEMM ", &info);
81 
82  if (*m == 0 || *n == 0) return;
83 
84  if (beta != Scalar(1)) {
85  if (beta == Scalar(0))
86  matrix(c, *m, *n, *ldc).setZero();
87  else
88  matrix(c, *m, *n, *ldc) *= beta;
89  }
90 
91  if (*k == 0) return;
92 
94  true);
95 
96  int code = OP(*opa) | (OP(*opb) << 2);
97  func[code](*m, *n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking, 0);
98 }
99 
100 EIGEN_BLAS_FUNC(trsm)
101 (const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
102  const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) {
103  // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " "
104  // << *palpha << " " << *lda << " " << *ldb<< "\n";
105  using Eigen::ColMajor;
106  using Eigen::DenseIndex;
107  using Eigen::Dynamic;
108  using Eigen::Lower;
109  using Eigen::OnTheLeft;
110  using Eigen::OnTheRight;
111  using Eigen::RowMajor;
112  using Eigen::UnitDiag;
113  using Eigen::Upper;
114  typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex,
116  static const functype func[32] = {
117  // array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
119  1>::run),
120  // array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
122  1>::run),
123  // array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
125  1>::run),
126  0,
127  // array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
129  1>::run),
130  // array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
132  1>::run),
133  // array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
135  1>::run),
136  0,
137  // array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
139  1>::run),
140  // array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
142  1>::run),
143  // array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
145  1>::run),
146  0,
147  // array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
149  1>::run),
150  // array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
152  1>::run),
153  // array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
155  1>::run),
156  0,
157  // array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
159  ColMajor, 1>::run),
160  // array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
162  ColMajor, 1>::run),
163  // array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)
165  ColMajor, 1>::run),
166  0,
167  // array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
169  ColMajor, 1>::run),
170  // array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
172  ColMajor, 1>::run),
173  // array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
175  ColMajor, 1>::run),
176  0,
177  // array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
179  ColMajor, 1>::run),
180  // array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
182  ColMajor, 1>::run),
183  // array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)
185  ColMajor, 1>::run),
186  0,
187  // array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
189  ColMajor, 1>::run),
190  // array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
192  ColMajor, 1>::run),
193  // array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
195  ColMajor, 1>::run),
196  0};
197 
198  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
199  Scalar *b = reinterpret_cast<Scalar *>(pb);
200  Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
201 
202  int info = 0;
203  if (SIDE(*side) == INVALID)
204  info = 1;
205  else if (UPLO(*uplo) == INVALID)
206  info = 2;
207  else if (OP(*opa) == INVALID)
208  info = 3;
209  else if (DIAG(*diag) == INVALID)
210  info = 4;
211  else if (*m < 0)
212  info = 5;
213  else if (*n < 0)
214  info = 6;
215  else if (*lda < std::max(1, (SIDE(*side) == LEFT) ? *m : *n))
216  info = 9;
217  else if (*ldb < std::max(1, *m))
218  info = 11;
219  if (info) return xerbla_(SCALAR_SUFFIX_UP "TRSM ", &info);
220 
221  if (*m == 0 || *n == 0) return;
222 
223  int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
224 
225  if (SIDE(*side) == LEFT) {
227  false);
228  func[code](*m, *n, a, *lda, b, 1, *ldb, blocking);
229  } else {
231  false);
232  func[code](*n, *m, a, *lda, b, 1, *ldb, blocking);
233  }
234 
235  if (alpha != Scalar(1)) matrix(b, *m, *n, *ldb) *= alpha;
236 }
237 
238 // b = alpha*op(a)*b for side = 'L'or'l'
239 // b = alpha*b*op(a) for side = 'R'or'r'
241 (const char *side, const char *uplo, const char *opa, const char *diag, const int *m, const int *n,
242  const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb) {
243  // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " "
244  // << *lda << " " << *ldb << " " << *palpha << "\n";
245  using Eigen::ColMajor;
246  using Eigen::DenseIndex;
247  using Eigen::Dynamic;
248  using Eigen::Lower;
249  using Eigen::RowMajor;
250  using Eigen::UnitDiag;
251  using Eigen::Upper;
252  typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex,
253  Scalar *, DenseIndex, DenseIndex, const Scalar &,
255  static const functype func[32] = {
256  // array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
258  false, ColMajor, 1>::run),
259  // array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
261  false, ColMajor, 1>::run),
262  // array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
264  false, ColMajor, 1>::run),
265  0,
266  // array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
268  ColMajor, false, ColMajor, 1>::run),
269  // array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
271  RowMajor, false, ColMajor, 1>::run),
272  // array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
274  RowMajor, Conj, ColMajor, 1>::run),
275  0,
276  // array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
278  false, ColMajor, 1>::run),
279  // array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
281  false, ColMajor, 1>::run),
282  // array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
284  false, ColMajor, 1>::run),
285  0,
286  // array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
288  ColMajor, false, ColMajor, 1>::run),
289  // array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
291  RowMajor, false, ColMajor, 1>::run),
292  // array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
294  RowMajor, Conj, ColMajor, 1>::run),
295  0,
296  // array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
298  ColMajor, false, ColMajor, 1>::run),
299  // array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
301  ColMajor, false, ColMajor, 1>::run),
302  // array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)
304  ColMajor, false, ColMajor, 1>::run),
305  0,
306  // array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
308  ColMajor, false, ColMajor, 1>::run),
309  // array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
311  RowMajor, false, ColMajor, 1>::run),
312  // array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
314  RowMajor, Conj, ColMajor, 1>::run),
315  0,
316  // array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
318  ColMajor, false, ColMajor, 1>::run),
319  // array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
321  ColMajor, false, ColMajor, 1>::run),
322  // array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)
324  ColMajor, false, ColMajor, 1>::run),
325  0,
326  // array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
328  ColMajor, false, ColMajor, 1>::run),
329  // array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
331  RowMajor, false, ColMajor, 1>::run),
332  // array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
334  RowMajor, Conj, ColMajor, 1>::run),
335  0};
336 
337  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
338  Scalar *b = reinterpret_cast<Scalar *>(pb);
339  Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
340 
341  int info = 0;
342  if (SIDE(*side) == INVALID)
343  info = 1;
344  else if (UPLO(*uplo) == INVALID)
345  info = 2;
346  else if (OP(*opa) == INVALID)
347  info = 3;
348  else if (DIAG(*diag) == INVALID)
349  info = 4;
350  else if (*m < 0)
351  info = 5;
352  else if (*n < 0)
353  info = 6;
354  else if (*lda < std::max(1, (SIDE(*side) == LEFT) ? *m : *n))
355  info = 9;
356  else if (*ldb < std::max(1, *m))
357  info = 11;
358  if (info) return xerbla_(SCALAR_SUFFIX_UP "TRMM ", &info);
359 
360  int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
361 
362  if (*m == 0 || *n == 0) return;
363 
364  // FIXME find a way to avoid this copy
366  matrix(b, *m, *n, *ldb).setZero();
367 
368  if (SIDE(*side) == LEFT) {
370  false);
371  func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, 1, *ldb, alpha, blocking);
372  } else {
374  false);
375  func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, 1, *ldb, alpha, blocking);
376  }
377 }
378 
379 // c = alpha*a*b + beta*c for side = 'L'or'l'
380 // c = alpha*b*a + beta*c for side = 'R'or'r
382 (const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa,
383  const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
384  // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb
385  // << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
386  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
387  const Scalar *b = reinterpret_cast<const Scalar *>(pb);
388  Scalar *c = reinterpret_cast<Scalar *>(pc);
389  Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
390  Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
391 
392  int info = 0;
393  if (SIDE(*side) == INVALID)
394  info = 1;
395  else if (UPLO(*uplo) == INVALID)
396  info = 2;
397  else if (*m < 0)
398  info = 3;
399  else if (*n < 0)
400  info = 4;
401  else if (*lda < std::max(1, (SIDE(*side) == LEFT) ? *m : *n))
402  info = 7;
403  else if (*ldb < std::max(1, *m))
404  info = 9;
405  else if (*ldc < std::max(1, *m))
406  info = 12;
407  if (info) return xerbla_(SCALAR_SUFFIX_UP "SYMM ", &info);
408 
409  if (beta != Scalar(1)) {
410  if (beta == Scalar(0))
411  matrix(c, *m, *n, *ldc).setZero();
412  else
413  matrix(c, *m, *n, *ldc) *= beta;
414  }
415 
416  if (*m == 0 || *n == 0) return;
417 
418  int size = (SIDE(*side) == LEFT) ? (*m) : (*n);
419  using Eigen::ColMajor;
420  using Eigen::DenseIndex;
421  using Eigen::Dynamic;
422  using Eigen::Lower;
423  using Eigen::RowMajor;
424  using Eigen::Upper;
425 #if ISCOMPLEX
426  // FIXME add support for symmetric complex matrix
428  if (UPLO(*uplo) == UP) {
429  matA.triangularView<Upper>() = matrix(a, size, size, *lda);
430  matA.triangularView<Lower>() = matrix(a, size, size, *lda).transpose();
431  } else if (UPLO(*uplo) == LO) {
432  matA.triangularView<Lower>() = matrix(a, size, size, *lda);
433  matA.triangularView<Upper>() = matrix(a, size, size, *lda).transpose();
434  }
435  if (SIDE(*side) == LEFT)
436  matrix(c, *m, *n, *ldc) += alpha * matA * matrix(b, *m, *n, *ldb);
437  else if (SIDE(*side) == RIGHT)
438  matrix(c, *m, *n, *ldc) += alpha * matrix(b, *m, *n, *ldb) * matA;
439 #else
441  false);
442 
443  if (SIDE(*side) == LEFT)
444  if (UPLO(*uplo) == UP)
446  ColMajor, 1>::run(*m, *n, a, *lda, b, *ldb, c, 1, *ldc, alpha,
447  blocking);
448  else if (UPLO(*uplo) == LO)
450  ColMajor, 1>::run(*m, *n, a, *lda, b, *ldb, c, 1, *ldc, alpha,
451  blocking);
452  else
453  return;
454  else if (SIDE(*side) == RIGHT)
455  if (UPLO(*uplo) == UP)
457  ColMajor, 1>::run(*m, *n, b, *ldb, a, *lda, c, 1, *ldc, alpha,
458  blocking);
459  else if (UPLO(*uplo) == LO)
461  ColMajor, 1>::run(*m, *n, b, *ldb, a, *lda, c, 1, *ldc, alpha,
462  blocking);
463  else
464  return;
465  else
466  return;
467 #endif
468 }
469 
470 // c = alpha*a*a' + beta*c for op = 'N'or'n'
471 // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
473 (const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa,
474  const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
475  // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " "
476  // << *pbeta << " " << *ldc << "\n";
477  using Eigen::ColMajor;
478  using Eigen::DenseIndex;
479  using Eigen::Dynamic;
480  using Eigen::Lower;
481  using Eigen::RowMajor;
482  using Eigen::Upper;
483 #if !ISCOMPLEX
484  typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *,
486  static const functype func[8] = {
487  // array index: NOTR | (UP << 2)
489  Conj, ColMajor, 1, Upper>::run),
490  // array index: TR | (UP << 2)
492  Conj, ColMajor, 1, Upper>::run),
493  // array index: ADJ | (UP << 2)
495  false, ColMajor, 1, Upper>::run),
496  0,
497  // array index: NOTR | (LO << 2)
499  Conj, ColMajor, 1, Lower>::run),
500  // array index: TR | (LO << 2)
502  Conj, ColMajor, 1, Lower>::run),
503  // array index: ADJ | (LO << 2)
505  false, ColMajor, 1, Lower>::run),
506  0};
507 #endif
508 
509  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
510  Scalar *c = reinterpret_cast<Scalar *>(pc);
511  Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
512  Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
513 
514  int info = 0;
515  if (UPLO(*uplo) == INVALID)
516  info = 1;
517  else if (OP(*op) == INVALID || (ISCOMPLEX && OP(*op) == ADJ))
518  info = 2;
519  else if (*n < 0)
520  info = 3;
521  else if (*k < 0)
522  info = 4;
523  else if (*lda < std::max(1, (OP(*op) == NOTR) ? *n : *k))
524  info = 7;
525  else if (*ldc < std::max(1, *n))
526  info = 10;
527  if (info) return xerbla_(SCALAR_SUFFIX_UP "SYRK ", &info);
528 
529  if (beta != Scalar(1)) {
530  if (UPLO(*uplo) == UP)
531  if (beta == Scalar(0))
532  matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
533  else
534  matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
535  else if (beta == Scalar(0))
536  matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
537  else
538  matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
539  }
540 
541  if (*n == 0 || *k == 0) return;
542 
543 #if ISCOMPLEX
544  // FIXME add support for symmetric complex matrix
545  if (UPLO(*uplo) == UP) {
546  if (OP(*op) == NOTR)
547  matrix(c, *n, *n, *ldc).triangularView<Upper>() +=
548  alpha * matrix(a, *n, *k, *lda) * matrix(a, *n, *k, *lda).transpose();
549  else
550  matrix(c, *n, *n, *ldc).triangularView<Upper>() +=
551  alpha * matrix(a, *k, *n, *lda).transpose() * matrix(a, *k, *n, *lda);
552  } else {
553  if (OP(*op) == NOTR)
554  matrix(c, *n, *n, *ldc).triangularView<Lower>() +=
555  alpha * matrix(a, *n, *k, *lda) * matrix(a, *n, *k, *lda).transpose();
556  else
557  matrix(c, *n, *n, *ldc).triangularView<Lower>() +=
558  alpha * matrix(a, *k, *n, *lda).transpose() * matrix(a, *k, *n, *lda);
559  }
560 #else
562  false);
563 
564  int code = OP(*op) | (UPLO(*uplo) << 2);
565  func[code](*n, *k, a, *lda, a, *lda, c, 1, *ldc, alpha, blocking);
566 #endif
567 }
568 
569 // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
570 // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
571 EIGEN_BLAS_FUNC(syr2k)
572 (const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa,
573  const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
574  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
575  const Scalar *b = reinterpret_cast<const Scalar *>(pb);
576  Scalar *c = reinterpret_cast<Scalar *>(pc);
577  Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
578  Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
579 
580  // std::cerr << "in syr2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " "
581  // << *ldb << " " << beta << " " << *ldc << "\n";
582 
583  int info = 0;
584  if (UPLO(*uplo) == INVALID)
585  info = 1;
586  else if (OP(*op) == INVALID || (ISCOMPLEX && OP(*op) == ADJ))
587  info = 2;
588  else if (*n < 0)
589  info = 3;
590  else if (*k < 0)
591  info = 4;
592  else if (*lda < std::max(1, (OP(*op) == NOTR) ? *n : *k))
593  info = 7;
594  else if (*ldb < std::max(1, (OP(*op) == NOTR) ? *n : *k))
595  info = 9;
596  else if (*ldc < std::max(1, *n))
597  info = 12;
598  if (info) return xerbla_(SCALAR_SUFFIX_UP "SYR2K", &info);
599 
600  using Eigen::Lower;
601  using Eigen::Upper;
602  if (beta != Scalar(1)) {
603  if (UPLO(*uplo) == UP)
604  if (beta == Scalar(0))
605  matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
606  else
607  matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
608  else if (beta == Scalar(0))
609  matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
610  else
611  matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
612  }
613 
614  if (*k == 0) return;
615 
616  if (OP(*op) == NOTR) {
617  if (UPLO(*uplo) == UP) {
618  matrix(c, *n, *n, *ldc).triangularView<Upper>() +=
619  alpha * matrix(a, *n, *k, *lda) * matrix(b, *n, *k, *ldb).transpose() +
620  alpha * matrix(b, *n, *k, *ldb) * matrix(a, *n, *k, *lda).transpose();
621  } else if (UPLO(*uplo) == LO)
622  matrix(c, *n, *n, *ldc).triangularView<Lower>() +=
623  alpha * matrix(a, *n, *k, *lda) * matrix(b, *n, *k, *ldb).transpose() +
624  alpha * matrix(b, *n, *k, *ldb) * matrix(a, *n, *k, *lda).transpose();
625  } else if (OP(*op) == TR || OP(*op) == ADJ) {
626  if (UPLO(*uplo) == UP)
627  matrix(c, *n, *n, *ldc).triangularView<Upper>() +=
628  alpha * matrix(a, *k, *n, *lda).transpose() * matrix(b, *k, *n, *ldb) +
629  alpha * matrix(b, *k, *n, *ldb).transpose() * matrix(a, *k, *n, *lda);
630  else if (UPLO(*uplo) == LO)
631  matrix(c, *n, *n, *ldc).triangularView<Lower>() +=
632  alpha * matrix(a, *k, *n, *lda).transpose() * matrix(b, *k, *n, *ldb) +
633  alpha * matrix(b, *k, *n, *ldb).transpose() * matrix(a, *k, *n, *lda);
634  }
635 }
636 
637 #if ISCOMPLEX
638 
639 // c = alpha*a*b + beta*c for side = 'L'or'l'
640 // c = alpha*b*a + beta*c for side = 'R'or'r
641 EIGEN_BLAS_FUNC(hemm)
642 (const char *side, const char *uplo, const int *m, const int *n, const RealScalar *palpha, const RealScalar *pa,
643  const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
644  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
645  const Scalar *b = reinterpret_cast<const Scalar *>(pb);
646  Scalar *c = reinterpret_cast<Scalar *>(pc);
647  Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
648  Scalar beta = *reinterpret_cast<const Scalar *>(pbeta);
649 
650  // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " "
651  // << beta << " " << *ldc << "\n";
652 
653  int info = 0;
654  if (SIDE(*side) == INVALID)
655  info = 1;
656  else if (UPLO(*uplo) == INVALID)
657  info = 2;
658  else if (*m < 0)
659  info = 3;
660  else if (*n < 0)
661  info = 4;
662  else if (*lda < std::max(1, (SIDE(*side) == LEFT) ? *m : *n))
663  info = 7;
664  else if (*ldb < std::max(1, *m))
665  info = 9;
666  else if (*ldc < std::max(1, *m))
667  info = 12;
668  if (info) return xerbla_(SCALAR_SUFFIX_UP "HEMM ", &info);
669 
670  if (beta == Scalar(0))
671  matrix(c, *m, *n, *ldc).setZero();
672  else if (beta != Scalar(1))
673  matrix(c, *m, *n, *ldc) *= beta;
674 
675  if (*m == 0 || *n == 0) return;
676 
677  using Eigen::ColMajor;
678  using Eigen::DenseIndex;
679  using Eigen::Dynamic;
680  using Eigen::RowMajor;
681  using Eigen::Upper;
682 
683  int size = (SIDE(*side) == LEFT) ? (*m) : (*n);
685  false);
686 
687  if (SIDE(*side) == LEFT) {
688  if (UPLO(*uplo) == UP)
690  ColMajor, 1>::run(*m, *n, a, *lda, b, *ldb, c, 1, *ldc, alpha,
691  blocking);
692  else if (UPLO(*uplo) == LO)
694  ColMajor, 1>::run(*m, *n, a, *lda, b, *ldb, c, 1, *ldc, alpha,
695  blocking);
696  else
697  return;
698  } else if (SIDE(*side) == RIGHT) {
699  if (UPLO(*uplo) == UP)
700  matrix(c, *m, *n, *ldc) +=
701  alpha * matrix(b, *m, *n, *ldb) *
702  matrix(a, *n, *n, *lda)
703  .selfadjointView<Upper>(); /*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false,
704 RowMajor,true,Conj, ColMajor, 1>
705 ::run(*m, *n, b, *ldb, a, *lda, c, 1, *ldc, alpha, blocking);*/
706  else if (UPLO(*uplo) == LO)
708  ColMajor, 1>::run(*m, *n, b, *ldb, a, *lda, c, 1, *ldc, alpha,
709  blocking);
710  else
711  return;
712  } else {
713  return;
714  }
715 }
716 
717 // c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
718 // c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
719 EIGEN_BLAS_FUNC(herk)
720 (const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa,
721  const int *lda, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
722  // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " "
723  // << *pbeta << " " << *ldc << "\n";
724  using Eigen::ColMajor;
725  using Eigen::DenseIndex;
726  using Eigen::Dynamic;
727  using Eigen::Lower;
728  using Eigen::RowMajor;
729  using Eigen::StrictlyLower;
730  using Eigen::StrictlyUpper;
731  using Eigen::Upper;
732  typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *,
734  static const functype func[8] = {
735  // array index: NOTR | (UP << 2)
737  Conj, ColMajor, 1, Upper>::run),
738  0,
739  // array index: ADJ | (UP << 2)
741  false, ColMajor, 1, Upper>::run),
742  0,
743  // array index: NOTR | (LO << 2)
745  Conj, ColMajor, 1, Lower>::run),
746  0,
747  // array index: ADJ | (LO << 2)
749  false, ColMajor, 1, Lower>::run),
750  0};
751 
752  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
753  Scalar *c = reinterpret_cast<Scalar *>(pc);
755  RealScalar beta = *pbeta;
756 
757  // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " <<
758  // beta << " " << *ldc << "\n";
759 
760  int info = 0;
761  if (UPLO(*uplo) == INVALID)
762  info = 1;
763  else if ((OP(*op) == INVALID) || (OP(*op) == TR))
764  info = 2;
765  else if (*n < 0)
766  info = 3;
767  else if (*k < 0)
768  info = 4;
769  else if (*lda < std::max(1, (OP(*op) == NOTR) ? *n : *k))
770  info = 7;
771  else if (*ldc < std::max(1, *n))
772  info = 10;
773  if (info) return xerbla_(SCALAR_SUFFIX_UP "HERK ", &info);
774 
775  int code = OP(*op) | (UPLO(*uplo) << 2);
776 
777  if (beta != RealScalar(1)) {
778  if (UPLO(*uplo) == UP)
779  if (beta == Scalar(0))
780  matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
781  else
782  matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
783  else if (beta == Scalar(0))
784  matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
785  else
786  matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
787 
788  if (beta != Scalar(0)) {
789  matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
790  matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
791  }
792  }
793 
794  if (*k > 0 && alpha != RealScalar(0)) {
796  false);
797  func[code](*n, *k, a, *lda, a, *lda, c, 1, *ldc, alpha, blocking);
798  matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
799  }
800 }
801 
802 // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
803 // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
804 EIGEN_BLAS_FUNC(her2k)
805 (const char *uplo, const char *op, const int *n, const int *k, const RealScalar *palpha, const RealScalar *pa,
806  const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) {
807  const Scalar *a = reinterpret_cast<const Scalar *>(pa);
808  const Scalar *b = reinterpret_cast<const Scalar *>(pb);
809  Scalar *c = reinterpret_cast<Scalar *>(pc);
810  Scalar alpha = *reinterpret_cast<const Scalar *>(palpha);
811  RealScalar beta = *pbeta;
812 
813  // std::cerr << "in her2k " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " "
814  // << *ldb << " " << beta << " " << *ldc << "\n";
815 
816  int info = 0;
817  if (UPLO(*uplo) == INVALID)
818  info = 1;
819  else if ((OP(*op) == INVALID) || (OP(*op) == TR))
820  info = 2;
821  else if (*n < 0)
822  info = 3;
823  else if (*k < 0)
824  info = 4;
825  else if (*lda < std::max(1, (OP(*op) == NOTR) ? *n : *k))
826  info = 7;
827  else if (*ldb < std::max(1, (OP(*op) == NOTR) ? *n : *k))
828  info = 9;
829  else if (*ldc < std::max(1, *n))
830  info = 12;
831  if (info) return xerbla_(SCALAR_SUFFIX_UP "HER2K", &info);
832 
833  using Eigen::Lower;
834  using Eigen::StrictlyLower;
835  using Eigen::StrictlyUpper;
836  using Eigen::Upper;
837  if (beta != RealScalar(1)) {
838  if (UPLO(*uplo) == UP)
839  if (beta == Scalar(0))
840  matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
841  else
842  matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
843  else if (beta == Scalar(0))
844  matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
845  else
846  matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
847 
848  if (beta != Scalar(0)) {
849  matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
850  matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
851  }
852  } else if (*k > 0 && alpha != Scalar(0))
853  matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
854 
855  if (*k == 0) return;
856 
857  if (OP(*op) == NOTR) {
858  if (UPLO(*uplo) == UP) {
859  matrix(c, *n, *n, *ldc).triangularView<Upper>() +=
860  alpha * matrix(a, *n, *k, *lda) * matrix(b, *n, *k, *ldb).adjoint() +
861  Eigen::numext::conj(alpha) * matrix(b, *n, *k, *ldb) * matrix(a, *n, *k, *lda).adjoint();
862  } else if (UPLO(*uplo) == LO)
863  matrix(c, *n, *n, *ldc).triangularView<Lower>() +=
864  alpha * matrix(a, *n, *k, *lda) * matrix(b, *n, *k, *ldb).adjoint() +
865  Eigen::numext::conj(alpha) * matrix(b, *n, *k, *ldb) * matrix(a, *n, *k, *lda).adjoint();
866  } else if (OP(*op) == ADJ) {
867  if (UPLO(*uplo) == UP)
868  matrix(c, *n, *n, *ldc).triangularView<Upper>() +=
869  alpha * matrix(a, *k, *n, *lda).adjoint() * matrix(b, *k, *n, *ldb) +
870  Eigen::numext::conj(alpha) * matrix(b, *k, *n, *ldb).adjoint() * matrix(a, *k, *n, *lda);
871  else if (UPLO(*uplo) == LO)
872  matrix(c, *n, *n, *ldc).triangularView<Lower>() +=
873  alpha * matrix(a, *k, *n, *lda).adjoint() * matrix(b, *k, *n, *ldb) +
874  Eigen::numext::conj(alpha) * matrix(b, *k, *n, *ldb).adjoint() * matrix(a, *k, *n, *lda);
875  }
876 }
877 
878 #endif // ISCOMPLEX
AnnoyingScalar conj(const AnnoyingScalar &x)
Definition: AnnoyingScalar.h:133
SCALAR Scalar
Definition: bench_gemm.cpp:45
EIGEN_DONT_INLINE void gemm(const A &a, const B &b, C &c)
Definition: bench_gemm.cpp:158
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:46
#define ISCOMPLEX
Definition: blas/complex_double.cpp:14
#define SCALAR_SUFFIX_UP
Definition: blas/complex_double.cpp:12
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR Index outerStride() const EIGEN_NOEXCEPT
Definition: Eigen/Eigen/src/Core/Matrix.h:390
constexpr EIGEN_DEVICE_FUNC const Scalar * data() const
Definition: PlainObjectBase.h:273
Definition: GeneralMatrixMatrix.h:223
Definition: GeneralMatrixMatrix.h:226
@ Conj
Definition: common.h:73
#define UP
Definition: common.h:46
#define LEFT
Definition: common.h:43
#define TR
Definition: common.h:40
#define SIDE(X)
Definition: common.h:57
#define OP(X)
Definition: common.h:54
#define NOTR
Definition: common.h:39
#define ADJ
Definition: common.h:41
#define INVALID
Definition: common.h:52
#define UPLO(X)
Definition: common.h:59
#define DIAG(X)
Definition: common.h:61
#define LO
Definition: common.h:47
#define max(a, b)
Definition: datatypes.h:23
@ StrictlyLower
Definition: Constants.h:223
@ UnitDiag
Definition: Constants.h:215
@ StrictlyUpper
Definition: Constants.h:225
@ Lower
Definition: Constants.h:211
@ Upper
Definition: Constants.h:213
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
@ OnTheLeft
Definition: Constants.h:331
@ OnTheRight
Definition: Constants.h:333
EIGEN_BLAS_FUNC(EIGEN_CAT(REAL_SCALAR_SUFFIX, scal))(int *n
*ldc setZero()
const char const char const int const int const int const RealScalar const RealScalar const int const RealScalar const int const RealScalar RealScalar const int * ldc
Definition: level3_impl.h:15
Eigen::Matrix< Scalar, Dynamic, Dynamic, ColMajor > tmp
Definition: level3_impl.h:365
const char * opa
Definition: level3_impl.h:13
const char const char const int const int const int const RealScalar * palpha
Definition: level3_impl.h:13
matrix(b, *m, *n, *ldb).setZero()
const char const char const int const int const int const RealScalar const RealScalar const int * lda
Definition: level3_impl.h:14
const char const char const int const int * n
Definition: level3_impl.h:13
int size
Definition: level3_impl.h:418
int code
Definition: level3_impl.h:96
const char const char const int * m
Definition: level3_impl.h:13
Eigen::Matrix< Scalar, Dynamic, Dynamic, ColMajor > matA(size, size)
const Scalar * b
Definition: level3_impl.h:58
const char const char const int const int const int const RealScalar const RealScalar const int const RealScalar const int const RealScalar RealScalar * pc
Definition: level3_impl.h:14
const char const char * op
Definition: level3_impl.h:473
const char const char * uplo
Definition: level3_impl.h:101
const char const char const int const int const int const RealScalar const RealScalar const int const RealScalar const int * ldb
Definition: level3_impl.h:14
int info
Definition: level3_impl.h:63
const char const char const int const int const int * k
Definition: level3_impl.h:13
const Scalar * a
Definition: level3_impl.h:57
void(* functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex, Scalar, Eigen::internal::level3_blocking< Scalar, Scalar > &, Eigen::internal::GemmParallelInfo< DenseIndex > *)
Definition: level3_impl.h:22
Scalar alpha
Definition: level3_impl.h:60
Scalar * c
Definition: level3_impl.h:59
const char const char const int const int const int const RealScalar const RealScalar const int const RealScalar * pb
Definition: level3_impl.h:14
const char const char const int const int const int const RealScalar const RealScalar const int const RealScalar const int const RealScalar * pbeta
Definition: level3_impl.h:14
const char * side
Definition: level3_impl.h:101
Scalar beta
Definition: level3_impl.h:61
const char const char const char const char * diag
Definition: level3_impl.h:101
const char const char * opb
Definition: level3_impl.h:13
const char const char const int const int const int const RealScalar const RealScalar * pa
Definition: level3_impl.h:14
EIGEN_DEFAULT_DENSE_INDEX_TYPE DenseIndex
Definition: Meta.h:75
const int Dynamic
Definition: Constants.h:25
void symm(int size=Size, int othersize=OtherSize)
Definition: product_symm.cpp:13
void syrk(const MatrixType &m)
Definition: product_syrk.cpp:13
void trmm(int rows=get_random_size< Scalar >(), int cols=get_random_size< Scalar >(), int otherCols=OtherCols==Dynamic ? get_random_size< Scalar >() :OtherCols)
Definition: product_trmm.cpp:20
Definition: Parallelizer.h:106
Definition: GeneralMatrixMatrixTriangular.h:39
Definition: SelfadjointMatrixMatrix.h:264
Definition: TriangularMatrixMatrix.h:49
Definition: SolveTriangular.h:27
Definition: benchGeometry.cpp:21
void run(const string &dir_name, LinearSolver *linear_solver_pt, const unsigned nel_1d, bool mess_up_order)
Definition: two_d_poisson_compare_solvers.cc:317
EIGEN_WEAK_LINKING void xerbla_(const char *msg, int *info)
Definition: xerbla.cpp:14