10 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_H
11 #define EIGEN_GENERAL_MATRIX_MATRIX_H
14 #include "../InternalHeaderCheck.h"
20 template <
typename LhsScalar_,
typename RhsScalar_>
21 class level3_blocking;
24 template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
typename RhsScalar,
25 int RhsStorageOrder,
bool ConjugateRhs,
int ResInnerStride>
27 ConjugateRhs,
RowMajor, ResInnerStride> {
38 ResInnerStride>
::run(
cols,
rows, depth, rhs, rhsStride, lhs, lhsStride,
res, resIncr,
45 template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
typename RhsScalar,
46 int RhsStorageOrder,
bool ConjugateRhs,
int ResInnerStride>
48 ConjugateRhs,
ColMajor, ResInnerStride> {
58 LhsMapper lhs(lhs_, lhsStride);
59 RhsMapper rhs(rhs_, rhsStride);
60 ResMapper
res(res_, resStride, resIncr);
72 #if !defined(EIGEN_USE_BLAS) && (defined(EIGEN_HAS_OPENMP) || defined(EIGEN_GEMM_THREADPOOL))
75 int tid =
info->logical_thread_id;
76 int threads =
info->num_threads;
78 LhsScalar* blockA = blocking.
blockA();
81 std::size_t sizeB = kc * nc;
85 for (
Index k = 0;
k < depth;
k += kc) {
90 pack_rhs(blockB, rhs.getSubMapper(
k, 0), actual_kc, nc);
99 while (
info->task_info[tid].users != 0) {
100 std::this_thread::yield();
102 info->task_info[tid].users = threads;
104 pack_lhs(blockA +
info->task_info[tid].lhs_start * actual_kc,
105 lhs.getSubMapper(
info->task_info[tid].lhs_start,
k), actual_kc,
info->task_info[tid].lhs_length);
108 info->task_info[tid].sync =
k;
111 for (
int shift = 0; shift < threads; ++shift) {
112 int i = (tid + shift) % threads;
118 while (
info->task_info[
i].sync !=
k) {
119 std::this_thread::yield();
123 gebp(
res.getSubMapper(
info->task_info[
i].lhs_start, 0), blockA +
info->task_info[
i].lhs_start * actual_kc,
124 blockB,
info->task_info[
i].lhs_length, actual_kc, nc,
alpha);
132 pack_rhs(blockB, rhs.getSubMapper(
k,
j), actual_kc, actual_nc);
135 gebp(
res.getSubMapper(0,
j), blockA, blockB,
rows, actual_kc, actual_nc,
alpha);
140 for (
Index i = 0;
i < threads; ++
i)
info->task_info[
i].users -= 1;
148 std::size_t sizeA = kc * mc;
149 std::size_t sizeB = kc * nc;
154 const bool pack_rhs_once = mc !=
rows && kc == depth && nc ==
cols;
157 for (
Index i2 = 0; i2 <
rows; i2 += mc) {
160 for (
Index k2 = 0; k2 < depth; k2 += kc) {
167 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
170 for (
Index j2 = 0; j2 <
cols; j2 += nc) {
176 if ((!pack_rhs_once) || i2 == 0) pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc);
179 gebp(
res.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc,
alpha);
192 template <
typename Scalar,
typename Index,
typename Gemm,
typename Lhs,
typename Rhs,
typename Dest,
193 typename BlockingType>
221 template <
int StorageOrder,
typename LhsScalar,
typename RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
222 int KcFactor = 1,
bool FiniteAtCompileTime = MaxRows !=
Dynamic && MaxCols !=
Dynamic && MaxDepth !=
Dynamic>
225 template <
typename LhsScalar_,
typename RhsScalar_>
249 template <
int StorageOrder,
typename LhsScalar_,
typename RhsScalar_,
int MaxRows,
int MaxCols,
int MaxDepth,
253 :
public level3_blocking<std::conditional_t<StorageOrder == RowMajor, RhsScalar_, LhsScalar_>,
254 std::conditional_t<StorageOrder == RowMajor, LhsScalar_, RhsScalar_>> {
258 ActualCols =
Transpose ? MaxRows : MaxCols
260 typedef std::conditional_t<Transpose, RhsScalar_, LhsScalar_>
LhsScalar;
261 typedef std::conditional_t<Transpose, LhsScalar_, RhsScalar_>
RhsScalar;
262 enum { SizeA = ActualRows * MaxDepth, SizeB = ActualCols * MaxDepth };
264 #if EIGEN_MAX_STATIC_ALIGN_BYTES >= EIGEN_DEFAULT_ALIGN_BYTES
275 this->m_mc = ActualRows;
276 this->m_nc = ActualCols;
277 this->m_kc = MaxDepth;
278 #if EIGEN_MAX_STATIC_ALIGN_BYTES >= EIGEN_DEFAULT_ALIGN_BYTES
279 this->m_blockA = m_staticA;
280 this->m_blockB = m_staticB;
296 template <
int StorageOrder,
typename LhsScalar_,
typename RhsScalar_,
int MaxRows,
int MaxCols,
int MaxDepth,
298 class gemm_blocking_space<StorageOrder, LhsScalar_, RhsScalar_, MaxRows, MaxCols, MaxDepth, KcFactor, false>
299 :
public level3_blocking<std::conditional_t<StorageOrder == RowMajor, RhsScalar_, LhsScalar_>,
300 std::conditional_t<StorageOrder == RowMajor, LhsScalar_, RhsScalar_>> {
302 typedef std::conditional_t<Transpose, RhsScalar_, LhsScalar_>
LhsScalar;
303 typedef std::conditional_t<Transpose, LhsScalar_, RhsScalar_>
RhsScalar;
315 computeProductBlockingSizes<LhsScalar, RhsScalar, KcFactor>(this->m_kc, this->m_mc, this->m_nc, num_threads);
319 computeProductBlockingSizes<LhsScalar, RhsScalar, KcFactor>(this->m_kc, this->m_mc,
n, num_threads);
322 m_sizeA = this->m_mc * this->m_kc;
323 m_sizeB = this->m_kc * this->m_nc;
333 computeProductBlockingSizes<LhsScalar, RhsScalar, KcFactor>(this->m_kc,
m, this->m_nc, num_threads);
334 m_sizeA = this->m_mc * this->m_kc;
335 m_sizeB = this->m_kc * this->m_nc;
339 if (this->m_blockA == 0) this->m_blockA = aligned_new<LhsScalar>(m_sizeA);
343 if (this->m_blockB == 0) this->m_blockB = aligned_new<RhsScalar>(m_sizeB);
361 template <
typename Lhs,
typename Rhs>
380 template <
typename Dst>
392 scaleAndAddTo(dst, lhs, rhs,
Scalar(1));
396 template <
typename Dst>
401 scaleAndAddTo(dst, lhs, rhs,
Scalar(1));
404 template <
typename Dst>
409 scaleAndAddTo(dst, lhs, rhs,
Scalar(-1));
412 template <
typename Dest>
414 eigen_assert(dst.rows() == a_lhs.rows() && dst.cols() == a_rhs.cols());
415 if (a_lhs.cols() == 0 || a_lhs.rows() == 0 || a_rhs.cols() == 0)
return;
417 if (dst.cols() == 1) {
419 typename Dest::ColXpr dst_vec(dst.col(0));
422 }
else if (dst.rows() == 1) {
424 typename Dest::RowXpr dst_vec(dst.row(0));
435 Dest::MaxRowsAtCompileTime, Dest::MaxColsAtCompileTime, MaxDepthAtCompileTime>
442 bool(LhsBlasTraits::NeedToConjugate),
RhsScalar,
448 BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1,
true);
450 GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(),
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
#define EIGEN_GEMM_TO_COEFFBASED_THRESHOLD
Definition: GeneralProduct.h:28
#define eigen_internal_assert(x)
Definition: Macros.h:916
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER)
Definition: Memory.h:806
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
SCALAR Scalar
Definition: bench_gemm.cpp:45
Expression of the product of two arbitrary matrices or vectors.
Definition: Product.h:202
Expression of the transpose of a matrix.
Definition: Transpose.h:56
Definition: BlasUtil.h:304
Definition: BlasUtil.h:443
Definition: products/GeneralBlockPanelKernel.h:397
LhsPacket LhsPacket4Packing
Definition: products/GeneralBlockPanelKernel.h:440
gemm_blocking_space(Index, Index, Index, Index, bool)
Definition: GeneralMatrixMatrix.h:273
void allocateB()
Definition: GeneralMatrixMatrix.h:292
std::conditional_t< Transpose, LhsScalar_, RhsScalar_ > RhsScalar
Definition: GeneralMatrixMatrix.h:261
std::conditional_t< Transpose, RhsScalar_, LhsScalar_ > LhsScalar
Definition: GeneralMatrixMatrix.h:260
void allocateA()
Definition: GeneralMatrixMatrix.h:291
void allocateAll()
Definition: GeneralMatrixMatrix.h:293
void initParallel(Index, Index, Index, Index)
Definition: GeneralMatrixMatrix.h:289
Index m_sizeB
Definition: GeneralMatrixMatrix.h:306
gemm_blocking_space(Index rows, Index cols, Index depth, Index num_threads, bool l3_blocking)
Definition: GeneralMatrixMatrix.h:309
void allocateA()
Definition: GeneralMatrixMatrix.h:338
std::conditional_t< Transpose, RhsScalar_, LhsScalar_ > LhsScalar
Definition: GeneralMatrixMatrix.h:302
void allocateAll()
Definition: GeneralMatrixMatrix.h:346
~gemm_blocking_space()
Definition: GeneralMatrixMatrix.h:351
void allocateB()
Definition: GeneralMatrixMatrix.h:342
Index m_sizeA
Definition: GeneralMatrixMatrix.h:305
std::conditional_t< Transpose, LhsScalar_, RhsScalar_ > RhsScalar
Definition: GeneralMatrixMatrix.h:303
void initParallel(Index rows, Index cols, Index depth, Index num_threads)
Definition: GeneralMatrixMatrix.h:326
Definition: GeneralMatrixMatrix.h:223
Definition: GeneralMatrixMatrix.h:226
level3_blocking()
Definition: GeneralMatrixMatrix.h:239
RhsScalar * m_blockB
Definition: GeneralMatrixMatrix.h:232
Index nc() const
Definition: GeneralMatrixMatrix.h:242
Index m_nc
Definition: GeneralMatrixMatrix.h:235
LhsScalar * m_blockA
Definition: GeneralMatrixMatrix.h:231
Index m_kc
Definition: GeneralMatrixMatrix.h:236
Index m_mc
Definition: GeneralMatrixMatrix.h:234
RhsScalar * blockB()
Definition: GeneralMatrixMatrix.h:246
RhsScalar_ RhsScalar
Definition: GeneralMatrixMatrix.h:228
LhsScalar * blockA()
Definition: GeneralMatrixMatrix.h:245
Index mc() const
Definition: GeneralMatrixMatrix.h:241
LhsScalar_ LhsScalar
Definition: GeneralMatrixMatrix.h:227
Index kc() const
Definition: GeneralMatrixMatrix.h:243
#define min(a, b)
Definition: datatypes.h:22
@ GemvProduct
Definition: Constants.h:510
@ GemmProduct
Definition: Constants.h:511
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
const unsigned int RowMajorBit
Definition: Constants.h:70
RealScalar alpha
Definition: level1_cplx_impl.h:151
int * m
Definition: level2_cplx_impl.h:294
int info
Definition: level2_cplx_impl.h:39
char char char int int * k
Definition: level2_impl.h:374
EIGEN_DEVICE_FUNC void aligned_delete(T *ptr, std::size_t size)
Definition: Memory.h:430
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
EIGEN_STRONG_INLINE void parallelize_gemm(const Functor &func, Index rows, Index cols, Index, bool)
Definition: Parallelizer.h:108
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ResScalar combine_scalar_factors(const ResScalar &alpha, const Lhs &lhs, const Rhs &rhs)
Definition: BlasUtil.h:609
constexpr int min_size_prefer_fixed(A a, B b)
Definition: Meta.h:683
typename remove_all< T >::type remove_all_t
Definition: Meta.h:142
typename add_const_on_value_type< T >::type add_const_on_value_type_t
Definition: Meta.h:274
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
auto run(Kernel kernel, Args &&... args) -> decltype(kernel(args...))
Definition: gpu_test_helper.h:414
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
const int Dynamic
Definition: Constants.h:25
Definition: Eigen_Colamd.h:49
Definition: Constants.h:540
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:1043
Definition: Parallelizer.h:106
Template functor for scalar/packet assignment with addition.
Definition: AssignmentFunctors.h:52
Template functor for scalar/packet assignment.
Definition: AssignmentFunctors.h:25
Definition: BlasUtil.h:459
std::conditional_t< bool(HasUsableDirectAccess), ExtractType, typename ExtractType_::PlainObject > DirectLinearAccessType
Definition: BlasUtil.h:475
Definition: products/GeneralBlockPanelKernel.h:960
Definition: GeneralMatrixMatrix.h:194
Dest & m_dest
Definition: GeneralMatrixMatrix.h:216
void initParallelSession(Index num_threads) const
Definition: GeneralMatrixMatrix.h:198
const Rhs & m_rhs
Definition: GeneralMatrixMatrix.h:215
Gemm::Traits Traits
Definition: GeneralMatrixMatrix.h:211
Scalar m_actualAlpha
Definition: GeneralMatrixMatrix.h:217
BlockingType & m_blocking
Definition: GeneralMatrixMatrix.h:218
gemm_functor(const Lhs &lhs, const Rhs &rhs, Dest &dest, const Scalar &actualAlpha, BlockingType &blocking)
Definition: GeneralMatrixMatrix.h:195
const Lhs & m_lhs
Definition: GeneralMatrixMatrix.h:214
void operator()(Index row, Index rows, Index col=0, Index cols=-1, GemmParallelInfo< Index > *info=0) const
Definition: GeneralMatrixMatrix.h:203
Definition: BlasUtil.h:34
Definition: BlasUtil.h:30
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixMatrix.h:51
static void run(Index rows, Index cols, Index depth, const LhsScalar *lhs_, Index lhsStride, const RhsScalar *rhs_, Index rhsStride, ResScalar *res_, Index resIncr, Index resStride, ResScalar alpha, level3_blocking< LhsScalar, RhsScalar > &blocking, GemmParallelInfo< Index > *info=0)
Definition: GeneralMatrixMatrix.h:52
gebp_traits< LhsScalar, RhsScalar > Traits
Definition: GeneralMatrixMatrix.h:49
static EIGEN_STRONG_INLINE void run(Index rows, Index cols, Index depth, const LhsScalar *lhs, Index lhsStride, const RhsScalar *rhs, Index rhsStride, ResScalar *res, Index resIncr, Index resStride, ResScalar alpha, level3_blocking< RhsScalar, LhsScalar > &blocking, GemmParallelInfo< Index > *info=0)
Definition: GeneralMatrixMatrix.h:31
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixMatrix.h:30
gebp_traits< RhsScalar, LhsScalar > Traits
Definition: GeneralMatrixMatrix.h:28
Definition: BlasUtil.h:38
Definition: ProductEvaluators.h:394
Lhs::Scalar LhsScalar
Definition: GeneralMatrixMatrix.h:365
internal::blas_traits< Lhs > LhsBlasTraits
Definition: GeneralMatrixMatrix.h:368
RhsBlasTraits::DirectLinearAccessType ActualRhsType
Definition: GeneralMatrixMatrix.h:373
LhsBlasTraits::DirectLinearAccessType ActualLhsType
Definition: GeneralMatrixMatrix.h:369
static void subTo(Dst &dst, const Lhs &lhs, const Rhs &rhs)
Definition: GeneralMatrixMatrix.h:405
Product< Lhs, Rhs >::Scalar Scalar
Definition: GeneralMatrixMatrix.h:364
internal::remove_all_t< ActualRhsType > ActualRhsTypeCleaned
Definition: GeneralMatrixMatrix.h:374
Rhs::Scalar RhsScalar
Definition: GeneralMatrixMatrix.h:366
static void evalTo(Dst &dst, const Lhs &lhs, const Rhs &rhs)
Definition: GeneralMatrixMatrix.h:381
internal::blas_traits< Rhs > RhsBlasTraits
Definition: GeneralMatrixMatrix.h:372
static void addTo(Dst &dst, const Lhs &lhs, const Rhs &rhs)
Definition: GeneralMatrixMatrix.h:397
generic_product_impl< Lhs, Rhs, DenseShape, DenseShape, CoeffBasedProductMode > lazyproduct
Definition: GeneralMatrixMatrix.h:378
internal::remove_all_t< ActualLhsType > ActualLhsTypeCleaned
Definition: GeneralMatrixMatrix.h:370
static void scaleAndAddTo(Dest &dst, const Lhs &a_lhs, const Rhs &a_rhs, const Scalar &alpha)
Definition: GeneralMatrixMatrix.h:413
Definition: ProductEvaluators.h:341
Definition: ProductEvaluators.h:78
Template functor for scalar/packet assignment with subtraction.
Definition: AssignmentFunctors.h:73
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
void run(const string &dir_name, LinearSolver *linear_solver_pt, const unsigned nel_1d, bool mess_up_order)
Definition: two_d_poisson_compare_solvers.cc:317