10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
14 #ifdef EIGEN_USE_THREADS
17 #include "./InternalHeaderCheck.h"
21 template <
typename Indices,
typename LeftArgType,
typename RightArgType,
typename OutputKernelType>
22 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>,
24 :
public TensorContractionEvaluatorBase<TensorEvaluator<
25 const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice>> {
26 typedef ThreadPoolDevice Device;
28 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
29 typedef TensorContractionEvaluatorBase<Self> Base;
31 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>
XprType;
32 typedef std::remove_const_t<typename XprType::Scalar>
Scalar;
43 typedef std::conditional_t<static_cast<int>(
Layout) ==
static_cast<int>(
ColMajor), LeftArgType, RightArgType>
45 typedef std::conditional_t<static_cast<int>(
Layout) ==
static_cast<int>(
ColMajor), RightArgType, LeftArgType>
48 static constexpr
int LDims =
49 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>
::value;
50 static constexpr
int RDims =
51 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>
::value;
54 typedef array<Index, LDims> left_dim_mapper_t;
55 typedef array<Index, RDims> right_dim_mapper_t;
57 typedef array<Index, ContractDims> contract_t;
58 typedef array<
Index, LDims - ContractDims> left_nocontract_t;
59 typedef array<
Index, RDims - ContractDims> right_nocontract_t;
61 static constexpr
int NumDims = LDims + RDims - 2 * ContractDims;
66 typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
67 typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
68 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
70 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
71 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
75 template <
int Alignment>
76 void evalProduct(
Scalar* buffer)
const {
77 evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
80 template <
typename EvalToCallback,
int Alignment>
81 void evalProductAsync(
Scalar* buffer, EvalToCallback done)
const {
82 evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
85 template <
typename DoneCallback,
int Alignment>
86 void evalProductImpl(
Scalar* buffer, DoneCallback done)
const {
104 const Index m = this->m_i_size;
105 const Index n = this->m_j_size;
106 const Index k = this->m_k_size;
107 if (
m == 0 ||
n == 0 ||
k == 0)
return;
132 bool shard_by_col = shardByCol(
m,
n, 2);
138 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(
k,
m,
n,
144 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(
k,
m,
n,
155 const TensorOpCost cost = contractionCost(
m,
n, bm, bn, bk, shard_by_col,
false);
158 int num_threads_by_k = numThreadsInnerDim(
m,
n,
k);
159 if (shardByInnerDim(
m,
n,
k, num_threads, num_threads_by_k)) {
162 if (IsEvalInSyncMode) {
163 EvalShardedByInnerDimContext<DoneCallback> ctx(
this, num_threads_by_k, buffer,
m,
n,
k, std::move(done));
164 ctx.template run<Alignment>();
167 new EvalShardedByInnerDimContext<DoneCallback>(
this, num_threads_by_k, buffer,
m,
n,
k, std::move(done));
168 ctx->template runAsync<Alignment>();
176 if (
n == 1) num_threads = 1;
178 if (num_threads == 1) {
180 if (!IsEvalInSyncMode) done();
185 shard_by_col = shardByCol(
m,
n, num_threads);
187 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(
188 k,
m,
n, num_threads);
193 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(
194 k,
m,
n, num_threads);
214 gm = coarsenM(
m,
n, bm, bn, bk, gn, num_threads, shard_by_col);
215 gn = coarsenN(
m,
n, bm, bn, bk, gm, num_threads, shard_by_col);
217 gn = coarsenN(
m,
n, bm, bn, bk, gm, num_threads, shard_by_col);
218 gm = coarsenM(
m,
n, bm, bn, bk, gn, num_threads, shard_by_col);
228 const Index sharding_dim_tasks = shard_by_col ? nn : nm;
229 const int num_worker_threads = this->
m_device.numThreadsInPool();
234 const float oversharding_factor = num_worker_threads <= 4 ? 8.0
235 : num_worker_threads <= 8 ? 4.0
236 : num_worker_threads <= 16 ? 2.0
237 : num_worker_threads <= 32 ? 1.0
238 : num_worker_threads <= 64 ? 0.8
241 const bool parallelize_by_sharding_dim_only = sharding_dim_tasks >= oversharding_factor * num_worker_threads;
250 bool parallel_pack = num_threads >= nm * nn;
253 parallel_pack =
true;
256 if ((shard_by_col ? nm : nn) == 1) parallel_pack =
false;
259 if (parallelize_by_sharding_dim_only) parallel_pack =
false;
262 if (IsEvalInSyncMode) {
263 #define CONTEXT_ARGS \
264 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
265 parallelize_by_sharding_dim_only, NoCallback()) \
271 #define CONTEXT_ARGS \
272 (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, shard_by_col, parallel_pack, \
273 parallelize_by_sharding_dim_only, std::move(done))
284 void operator()() {
eigen_assert(
false &&
"NoCallback should never be called"); }
289 template <
typename DoneCallback,
typename Context>
290 class EvalParallelNotification;
293 template <
typename Context>
294 class EvalParallelNotification<NoCallback, Context> {
296 EvalParallelNotification(Context*, NoCallback) {}
297 void Notify() { done_.Notify(); }
298 void Wait() { done_.Wait(); }
305 template <
typename DoneCallback,
typename Context>
306 class EvalParallelNotification {
308 EvalParallelNotification(Context* ctx, DoneCallback done) : ctx_(ctx), done_(std::move(done)) {}
314 DoneCallback done_copy = std::move(done_);
334 template <
typename DoneCallback,
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
335 bool rhs_inner_dim_reordered,
int Alignment>
336 class EvalParallelContext {
338 typedef internal::TensorContractionInputMapper<LhsScalar,
Index,
internal::Lhs, LeftEvaluator, left_nocontract_t,
340 lhs_inner_dim_contiguous,
false,
Unaligned>
342 typedef internal::TensorContractionInputMapper<RhsScalar,
Index,
internal::Rhs, RightEvaluator, right_nocontract_t,
344 rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
Unaligned>
347 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
349 typedef internal::TensorContractionKernel<Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
350 TensorContractionKernel;
352 typedef typename TensorContractionKernel::LhsBlock LhsBlock;
353 typedef typename TensorContractionKernel::RhsBlock RhsBlock;
354 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
358 bool shard_by_col,
bool parallel_pack,
bool parallelize_by_sharding_dim_only, DoneCallback done)
359 : created_by_thread_id_(std::this_thread::get_id()),
360 done_(this, std::move(done)),
362 lhs_(self->m_leftImpl, self->m_left_nocontract_strides, self->m_i_strides, self->m_left_contracting_strides,
364 rhs_(self->m_rightImpl, self->m_right_nocontract_strides, self->m_j_strides,
365 self->m_right_contracting_strides, self->m_k_strides),
368 output_kernel_(self->m_output_kernel),
369 tensor_contraction_params_(self->m_tensor_contraction_params),
370 num_threads_(num_threads),
371 shard_by_col_(shard_by_col),
372 parallel_pack_(parallel_pack),
373 parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
387 kernel_(m_, k_, n_, bm_, bk_, bn_),
388 num_thread_local_allocations_(0),
392 thread_local_capacity(2 * (parallelize_by_sharding_dim_only_ ? device_.numThreadsInPool() : 0)),
395 lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity, {*
this}, {*
this}),
396 rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0, {*
this}, {*
this}) {
398 eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
406 x == 0 ? 1 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) + (
x ==
P - 1 ? nm_ * nn_ : 0);
407 state_packing_ready_[
x] = parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
408 state_kernel_[
x] =
new std::atomic<uint8_t>*[nm_];
410 state_kernel_[
x][
m] =
new std::atomic<uint8_t>[nn_];
415 state_kernel_[
x][
m][
n].store((
x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1), std::memory_order_relaxed);
420 packed_mem_ = kernel_.allocateSlices(
424 std::min<Index>(nk_,
P - 1),
425 packed_lhs_, packed_rhs_);
427 if (parallelize_by_sharding_dim_only_) {
428 const int num_worker_threads = device_.numThreadsInPool();
431 can_use_thread_local_packed_ =
new std::atomic<bool>[nn_];
432 for (
int i = 0;
i < nn_; ++
i) can_use_thread_local_packed_[
i].store(
true, std::memory_order_relaxed);
434 Index num_blocks = num_worker_threads * gn_;
435 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
440 nullptr, &rhs_thread_local_pre_allocated_);
443 can_use_thread_local_packed_ =
new std::atomic<bool>[nm_];
444 for (
int i = 0;
i < nm_; ++
i) can_use_thread_local_packed_[
i].store(
true, std::memory_order_relaxed);
446 Index num_blocks = num_worker_threads * gm_;
447 thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
451 1, &lhs_thread_local_pre_allocated_,
457 ~EvalParallelContext() {
459 for (
Index m = 0;
m < nm_;
m++)
delete[] state_kernel_[
x][
m];
460 delete[] state_kernel_[
x];
462 kernel_.deallocate(device_, packed_mem_);
463 if (parallelize_by_sharding_dim_only_) {
464 kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
465 delete[] can_use_thread_local_packed_;
490 std::thread::id created_by_thread_id_;
494 EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
496 const Device& device_;
500 OutputMapper output_;
501 OutputKernelType output_kernel_;
502 TensorContractionParams tensor_contraction_params_;
503 const int num_threads_;
504 const bool shard_by_col_;
505 const bool parallel_pack_;
506 const bool parallelize_by_sharding_dim_only_;
527 TensorContractionKernel kernel_;
563 static constexpr
Index P = 3;
566 BlockMemHandle packed_mem_;
567 std::vector<LhsBlock> packed_lhs_[
P - 1];
568 std::vector<RhsBlock> packed_rhs_[
P - 1];
588 BlockMemHandle thread_local_pre_alocated_mem_;
592 std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
593 std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
596 std::atomic<int> num_thread_local_allocations_;
597 const int thread_local_capacity;
605 template <
typename BlockType>
606 class ThreadLocalBlocks {
608 ThreadLocalBlocks() =
default;
610 ThreadLocalBlocks(BlockType* base,
size_t grain_size)
611 : is_pre_allocated_(true), thread_local_pre_allocated_base_(base), grain_size_(grain_size) {}
613 ThreadLocalBlocks(BlockMemHandle mem_handle, std::vector<BlockType> blocks)
614 : is_pre_allocated_(false), mem_handle_(std::move(mem_handle)), blocks_(std::move(blocks)) {}
616 BlockType&
block(
int grain_index) {
619 return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index] : blocks_[grain_index];
622 void Release(EvalParallelContext& ctx)
const {
623 if (!is_pre_allocated_) {
624 ctx.kernel_.deallocate(ctx.device_, mem_handle_);
628 size_t size()
const {
return is_pre_allocated_ ? grain_size_ : blocks_.size(); }
631 bool is_pre_allocated_;
634 BlockType* thread_local_pre_allocated_base_ =
nullptr;
635 size_t grain_size_ = 0;
638 BlockMemHandle mem_handle_{};
639 std::vector<BlockType> blocks_;
648 template <
typename BlockType,
bool is_rhs>
649 class ThreadLocalBlocksInitialize {
652 static_assert(kIsLhs || kIsRhs,
"Unknown block type");
654 using Blocks = ThreadLocalBlocks<BlockType>;
657 ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
658 : ctx_(ctx), num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
660 void operator()(Blocks& blocks) {
661 const int n = ctx_.num_thread_local_allocations_.fetch_add(1, std::memory_order_relaxed);
663 if (
n >= num_worker_threads_) {
664 ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
666 ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_,
n, blocks);
675 template <
bool pack_rhs,
typename EvalCtx = EvalParallelContext>
676 struct ThreadLocalBlocksAllocator;
678 template <
typename EvalCtx>
679 struct ThreadLocalBlocksAllocator<true, EvalCtx> {
680 static void allocate(EvalCtx& ctx, Blocks& blocks) {
681 std::vector<RhsBlock> rhs_blocks;
682 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(ctx.device_,
686 nullptr, &rhs_blocks);
688 blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle), std::move(rhs_blocks));
691 static void reuse(EvalCtx& ctx,
int index, Blocks& blocks) {
692 RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
693 blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
697 template <
typename EvalCtx>
698 struct ThreadLocalBlocksAllocator<false, EvalCtx> {
699 static void allocate(EvalCtx& ctx, Blocks& blocks) {
700 std::vector<LhsBlock> lhs_blocks;
701 BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(ctx.device_,
705 &lhs_blocks,
nullptr);
707 blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle), std::move(lhs_blocks));
710 static void reuse(EvalCtx& ctx,
int index, Blocks& blocks) {
711 LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
712 blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
716 EvalParallelContext& ctx_;
717 const int num_worker_threads_;
720 template <
typename BlockType>
721 class ThreadLocalBlocksRelease {
723 using Blocks = ThreadLocalBlocks<BlockType>;
724 ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
725 void operator()(Blocks& blocks) { blocks.Release(ctx_); }
728 EvalParallelContext& ctx_;
732 using ThreadLocalLhsInit = ThreadLocalBlocksInitialize<LhsBlock,
false>;
733 using ThreadLocalRhsInit = ThreadLocalBlocksInitialize<RhsBlock,
true>;
736 using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
737 using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
749 std::atomic<bool>* can_use_thread_local_packed_;
751 std::atomic<uint8_t>** state_kernel_[
P];
756 std::atomic<Index> state_packing_ready_[
P];
757 std::atomic<Index> state_switch_[
P];
760 if (use_thread_local) {
762 ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.
local();
766 internal::convert_index<int>(grain_index));
768 return packed_lhs_[
k % (
P - 1)][
m1];
773 if (use_thread_local) {
775 ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.
local();
777 Index grain_index = n1 -
n * gn_;
779 internal::convert_index<int>(grain_index));
781 return packed_rhs_[
k % (
P - 1)][n1];
796 bool use_thread_local =
false;
798 if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
799 can_use_thread_local_packed_[
m].
load(std::memory_order_relaxed)) {
800 if (state_kernel_[
k %
P][
m][0].
load(std::memory_order_relaxed) == 1) {
801 use_thread_local =
true;
807 can_use_thread_local_packed_[
m].store(
false, std::memory_order_relaxed);
811 const Index mend =
m * gm_ + gm(
m);
813 kernel_.packLhs(&packed_lhs(
m,
k,
m1, use_thread_local), lhs_.getSubMapper(
m1 * bm_,
k * bk_), bk(
k), bm(
m1));
815 if (!parallel_pack_ && shard_by_col_) {
819 signal_switch(
k + 1);
820 for (
Index n = nn_ - 1;
n >= 0;
n--) {
821 bool sync = parallelize_by_sharding_dim_only_ ||
n == 0;
822 signal_kernel(
m,
n,
k, sync, use_thread_local);
828 bool use_thread_local =
false;
830 if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
831 can_use_thread_local_packed_[
n].
load(std::memory_order_relaxed)) {
832 if (state_kernel_[
k %
P][0][
n].
load(std::memory_order_relaxed) == 1) {
833 use_thread_local =
true;
839 can_use_thread_local_packed_[
n].store(
false, std::memory_order_relaxed);
843 const Index nend =
n * gn_ + gn(
n);
844 for (
Index n1 =
n * gn_; n1 < nend; n1++) {
845 if (!TensorContractionKernel::HasBeta &&
k == 0) {
855 std::fill_n(buffer_ + n1 * bn_ * m_, bn(n1) * m_,
Scalar(0));
857 kernel_.packRhs(&packed_rhs(
n,
k, n1, use_thread_local), rhs_.getSubMapper(
k * bk_, n1 * bn_), bk(
k), bn(n1));
860 if (parallel_pack_ || shard_by_col_) {
861 signal_switch(
k + 1);
862 for (
Index m = nm_ - 1;
m >= 0;
m--) {
863 bool sync = parallelize_by_sharding_dim_only_ ||
m == 0;
864 signal_kernel(
m,
n,
k, sync, use_thread_local);
876 const Index nend =
n * gn_ + gn(
n);
877 const Index mend =
m * gm_ + gm(
m);
884 for (
Index n1 =
n * gn_; n1 < nend; n1++) {
886 const auto output_mapper = output_.getSubMapper(
m1 * bm_, n1 * bn_);
887 kernel_.invoke(output_mapper, packed_lhs(
m,
k,
m1, !shard_by_col_ && use_thread_local),
888 packed_rhs(
n,
k, n1, shard_by_col_ && use_thread_local), bm(
m1), bk(
k), bn(n1),
alpha,
beta);
892 output_kernel_(output_mapper, tensor_contraction_params_,
m1 * bm_, n1 * bn_, bm(
m1), bn(n1));
898 for (
Index n1 =
n * gn_; n1 < nend; n1++) {
899 const auto output_mapper = output_.getSubMapper(
m1 * bm_, n1 * bn_);
900 kernel_.invoke(output_mapper, packed_lhs(
m,
k,
m1, !shard_by_col_ && use_thread_local),
901 packed_rhs(
n,
k, n1, shard_by_col_ && use_thread_local), bm(
m1), bk(
k), bn(n1),
alpha,
beta);
905 output_kernel_(output_mapper, tensor_contraction_params_,
m1 * bm_, n1 * bn_, bm(
m1), bn(n1));
909 signal_kernel(
m,
n,
k + 1,
false,
false);
910 signal_switch(
k + 2);
913 void signal_packing(
Index k) {
915 Index s = state_packing_ready_[
k %
P].fetch_sub(1);
918 state_packing_ready_[
k %
P] = shard_by_col_ ? nm_ : nn_;
919 enqueue_packing(
k, shard_by_col_);
923 std::atomic<uint8_t>* state = &state_kernel_[
k %
P][
m][
n];
926 if (
s != 1 && state->fetch_sub(1) != 1) {
930 state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
932 kernel(
m,
n,
k, use_thread_local);
935 device_.enqueueNoNotification([
this,
m,
n,
k, use_thread_local]() {
936 kernel(
m,
n,
k, use_thread_local);
942 Index s = state_switch_[
k %
P].fetch_sub(
v);
948 state_switch_[
k %
P] = (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) + nm_ * nn_;
952 if (parallel_pack_) {
953 enqueue_packing(
k, !shard_by_col_);
954 enqueue_packing(
k, shard_by_col_);
955 }
else if (shard_by_col_) {
956 enqueue_packing(
k,
false);
958 enqueue_packing(
k,
true);
966 }
else if (
k == nk_) {
967 signal_switch(
k + 1, parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
974 void enqueue_packing(
Index k,
bool rhs) { enqueue_packing_helper(0, rhs ? nn_ : nm_,
k, rhs); }
985 device_.enqueueNoNotification([
this, mid,
end,
k, rhs]() {
986 enqueue_packing_helper(mid,
end,
k, rhs);
999 bool pack_async = (
start == 0) && (parallelize_by_sharding_dim_only_ && shard_by_col_ == rhs) &&
1000 (
k > 0 || std::this_thread::get_id() == created_by_thread_id_);
1003 device_.enqueueNoNotification([
this,
start,
end,
k, rhs]() {
1004 enqueue_packing_helper(
start,
end,
k, rhs);
1007 enqueue_packing_helper(
start,
end,
k, rhs);
1013 Index bm(
Index m)
const {
return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
1014 Index bn(
Index n)
const {
return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
1015 Index bk(
Index k)
const {
return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
1017 Index gm(
Index m)
const {
return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
1018 Index gn(
Index n)
const {
return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
1020 EvalParallelContext(
const EvalParallelContext&) =
delete;
1021 void operator=(
const EvalParallelContext&) =
delete;
1024 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
1025 using SyncEvalParallelContext = EvalParallelContext<NoCallback, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
1026 rhs_inner_dim_reordered, Alignment>;
1035 template <
typename DoneCallback>
1036 struct EvalShardedByInnerDimContext {
1037 EvalShardedByInnerDimContext(
const Self*
self,
int num_threads,
Scalar* result_buffer,
Index m_size,
Index n_size,
1038 Index k_size, DoneCallback done_callback)
1040 m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
1041 m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
1042 m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
1043 result(result_buffer),
1047 done(std::move(done_callback)),
1048 buffer_size_bytes(
m *
n * sizeof(
Scalar)),
1049 block_size(blockSize(
k, num_threads)),
1053 l0_state(l0_ranges),
1054 block_buffers(num_blocks) {
1056 for (
int i = 0;
i < l0_ranges; ++
i) {
1057 const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size,
i);
1058 l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
1062 for (
Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
1063 Scalar* buf = block_idx == 0 ? result :
static_cast<Scalar*
>(evaluator->m_device.allocate(buffer_size_bytes));
1064 block_buffers.emplace_back(buf);
1068 ~EvalShardedByInnerDimContext() {
1069 for (
Index i = 1;
i < num_blocks; ++
i) {
1070 evaluator->m_device.deallocate(block_buffers[
i]);
1074 template <
int Alignment>
1076 Barrier barrier(internal::convert_index<int>(num_blocks));
1077 eval<Alignment>(barrier, 0, num_blocks);
1081 aggregateL0Blocks<Alignment>();
1084 applyOutputKernel();
1087 template <
int Alignment>
1089 evalAsync<Alignment>(0, num_blocks);
1097 const Self* evaluator;
1100 bool m_lhs_inner_dim_contiguous;
1101 bool m_rhs_inner_dim_contiguous;
1102 bool m_rhs_inner_dim_reordered;
1116 Index buffer_size_bytes;
1122 std::atomic<int> num_pending_blocks;
1140 static const Index l0_size = 4;
1144 MaxSizeVector<std::atomic<int>> l0_state;
1147 MaxSizeVector<Scalar*> block_buffers;
1149 template <
int Alignment>
1151 Scalar* buf = block_buffers[block_idx];
1155 internal::convert_index<int>(num_blocks)));
1158 const Index l0_index = block_idx / l0_size;
1159 const int v = l0_state[l0_index].fetch_sub(1);
1165 const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
1166 const Index dst_block_idx = l0_index * l0_size;
1168 if (rng_size == l0_size) {
1169 addAllToBuffer<Alignment>(
m *
n,
1170 block_buffers[dst_block_idx + 1],
1171 block_buffers[dst_block_idx + 2],
1172 block_buffers[dst_block_idx + 3],
1173 block_buffers[dst_block_idx]);
1176 for (
int i = 1;
i < rng_size; ++
i) {
1177 addToBuffer<Alignment>(
m *
n,
1178 block_buffers[dst_block_idx +
i],
1179 block_buffers[dst_block_idx]);
1186 template <
int Alignment>
1187 void aggregateL0Blocks()
const {
1190 for (; l0_index + 2 < l0_ranges; l0_index += 3) {
1191 addAllToBuffer<Alignment>(
m *
n,
1192 block_buffers[(l0_index + 0) * l0_size],
1193 block_buffers[(l0_index + 1) * l0_size],
1194 block_buffers[(l0_index + 2) * l0_size],
1198 for (; l0_index < l0_ranges; ++l0_index) {
1199 addToBuffer<Alignment>(
m *
n, block_buffers[l0_index * l0_size], block_buffers[0]);
1203 void applyOutputKernel()
const {
1204 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
1205 evaluator->m_output_kernel(OutputMapper(result,
m), evaluator->m_tensor_contraction_params,
1210 Index actualBlockSize(
Index block_idx)
const {
1211 return block_idx + 1 < num_blocks ? block_size :
k + block_size - block_size * num_blocks;
1217 return range_idx + 1 < num_ranges ? range_size : num_blocks + range_size - range_size * num_ranges;
1220 template <
int Alignment>
1224 const size_t num_packets =
n / output_packet_size;
1225 for (;
i < output_packet_size * num_packets;
i += output_packet_size) {
1226 const PacketReturnType src_val = internal::pload<PacketReturnType>(src_buf +
i);
1227 const PacketReturnType tgt_val = internal::ploadt<PacketReturnType, Alignment>(tgt_buf +
i);
1229 internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf +
i, sum);
1231 for (;
i <
n; ++
i) {
1232 tgt_buf[
i] += src_buf[
i];
1236 template <
int Alignment>
1247 const size_t num_packets =
n / output_packet_size;
1248 for (;
i < output_packet_size * num_packets;
i += output_packet_size) {
1249 const auto src_val0 = pload<PacketReturnType>(src_buf0 +
i);
1250 const auto src_val1 = pload<PacketReturnType>(src_buf1 +
i);
1251 const auto src_val2 = pload<PacketReturnType>(src_buf2 +
i);
1253 const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf +
i);
1254 const auto sum =
padd(
padd(dst_val, src_val0),
padd(src_val1, src_val2));
1256 pstoret<Scalar, PacketReturnType, Alignment>(dst_buf +
i, sum);
1258 for (;
i <
n; ++
i) {
1259 dst_buf[
i] += src_buf0[
i] + src_buf1[
i] + src_buf2[
i];
1263 template <
int Alignment>
1264 void eval(Barrier& barrier,
Index start_block_idx,
Index end_block_idx) {
1265 while (end_block_idx - start_block_idx > 1) {
1266 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1267 evaluator->m_device.enqueueNoNotification([
this, &barrier, mid_block_idx, end_block_idx]() {
1268 eval<Alignment>(barrier, mid_block_idx, end_block_idx);
1270 end_block_idx = mid_block_idx;
1273 Index block_idx = start_block_idx;
1274 Index block_start = block_idx * block_size;
1275 Index block_end = block_start + actualBlockSize(block_idx);
1277 processBlock<Alignment>(block_idx, block_start, block_end);
1281 template <
int Alignment>
1282 void evalAsync(
Index start_block_idx,
Index end_block_idx) {
1283 while (end_block_idx - start_block_idx > 1) {
1284 Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
1285 evaluator->m_device.enqueueNoNotification(
1286 [
this, mid_block_idx, end_block_idx]() {
1287 evalAsync<Alignment>(mid_block_idx, end_block_idx);
1289 end_block_idx = mid_block_idx;
1292 Index block_idx = start_block_idx;
1294 Index block_start = block_idx * block_size;
1295 Index block_end = block_start + actualBlockSize(block_idx);
1297 processBlock<Alignment>(block_idx, block_start, block_end);
1299 int v = num_pending_blocks.fetch_sub(1);
1304 aggregateL0Blocks<Alignment>();
1307 applyOutputKernel();
1314 DoneCallback done_copy = std::move(done);
1327 static Index blockSize(
Index k,
int num_threads) {
1328 const auto round_up = [=](
Index index) ->
Index {
1329 const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
1330 return numext::div_ceil<Index>(index, kmultiple) * kmultiple;
1333 const Index target_block_size = round_up(numext::div_ceil<Index>(
k, num_threads));
1334 const Index desired_min_block_size = 12 * packet_size;
1336 return numext::mini<Index>(
k, numext::maxi<Index>(desired_min_block_size, target_block_size));
1339 EvalShardedByInnerDimContext(
const EvalShardedByInnerDimContext&) =
delete;
1340 void operator=(
const EvalShardedByInnerDimContext&) =
delete;
1356 if (
m / num_threads >= Traits::nr &&
1358 (
n / num_threads < Traits::nr ||
1361 (
n / num_threads < 4 * Traits::nr && (
n % (num_threads * Traits::nr)) != 0 &&
1363 ((
m % (num_threads * Traits::nr)) == 0 ||
1371 if (
n / num_threads < 16 * Traits::nr && m >
n * 32)
return false;
1385 if (gm1 > nm0)
break;
1387 int res = checkGrain(
m,
n, bm, bn, bk, gm1, gn, gm, gn, num_threads, shard_by_col);
1390 if (
res == 0)
continue;
1404 if (gn1 > nn0)
break;
1405 int res = checkGrain(
m,
n, bm, bn, bk, gm, gn1, gm, gn, num_threads, shard_by_col);
1408 if (
res == 0)
continue;
1417 int num_threads,
bool shard_by_col)
const {
1418 const TensorOpCost cost = contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col,
true);
1422 if (taskSize < 1)
return 1;
1424 if (taskSize > 2)
return -1;
1434 double new_parallelism =
1435 static_cast<double>(new_tasks) / (numext::div_ceil<Index>(new_tasks, num_threads) * num_threads);
1437 double old_parallelism =
1438 static_cast<double>(old_tasks) / (numext::div_ceil<Index>(old_tasks, num_threads) * num_threads);
1439 if (new_parallelism > old_parallelism || new_parallelism == 1)
return 1;
1444 bool prepacked)
const {
1447 const double kd =
static_cast<double>(bk);
1448 double compute_bandwidth = computeBandwidth(
false, bm, bn, bk);
1450 TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth,
true, packed_size);
1452 cost += TensorOpCost(0,
sizeof(
CoeffReturnType), 0,
true, output_packet_size);
1460 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) * (kd /
n);
1461 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) * (kd /
m);
1465 lhsCost.dropMemoryCost();
1467 rhsCost.dropMemoryCost();
1468 return cost + lhsCost + rhsCost;
1473 static bool shardByInnerDim(
Index m,
Index n,
Index k,
int num_threads,
int num_threads_by_k) {
1474 std::ptrdiff_t bufsize =
m *
n *
sizeof(
Scalar);
1475 bool shard_by_k =
false;
1477 num_threads_by_k < 2 ||
1478 num_threads_by_k < num_threads ||
1481 k / num_threads_by_k < 2 * Traits::nr) {
1485 (
k / num_threads_by_k > 8 * Traits::nr &&
1488 (
numext::mini(
m,
n) < 2 * Traits::nr || num_threads_by_k > num_threads))) {
1497 TensorOpCost cost(0, 0, (computeBandwidth(
true,
m,
n,
k) *
m) *
n,
true, output_packet_size);
1499 cost += TensorOpCost(0,
sizeof(
CoeffReturnType), 0,
true, output_packet_size);
1500 TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(
true) *
m;
1501 TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(
true) *
n;
1504 lhsCost.dropMemoryCost();
1505 return cost + lhsCost + rhsCost;
1510 TensorOpCost cost = contractionCostPerInnerDim(
m,
n,
k);
1514 double reduction_cost =
1516 int num_threads = 1;
1517 double min_cost = total_parallel_cost;
1518 double kPerThreadOverHead = 3000;
1519 double kFixedOverHead = 100000;
1520 for (
int nt = 2; nt <= this->
m_device.numThreads(); nt += 2) {
1521 double sequential_cost = kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
1522 double parallel_cost = total_parallel_cost / nt + sequential_cost;
1523 if (parallel_cost < min_cost) {
1525 min_cost = parallel_cost;
1531 double computeBandwidth(
bool shard_by_col,
Index bm,
Index bn,
Index bk)
const {
1535 double computeBandwidth = bk == 1 ? 4.0
1536 : (shard_by_col ? bn : bm) < Traits::nr || (shard_by_col ? bm : bn) < Traits::mr ? 2.0
1538 #ifndef EIGEN_VECTORIZE_FMA
1543 if (computeBandwidth == 0.5) computeBandwidth = 1.0;
1545 return computeBandwidth;
Array< int, Dynamic, 1 > v
Definition: Array_initializer_list_vector_cxx11.cpp:1
int i
Definition: BiCGSTAB_step_by_step.cpp:9
const unsigned n
Definition: CG3DPackingUnitTest.cpp:11
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
EIGEN_ALWAYS_INLINE Packet2cf padd(Packet2cf &a, std::complex< float > &b)
Definition: MatrixVectorProduct.h:1277
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
void load(Archive &ar, ParticleHandler &handl)
Definition: Particles.h:21
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS)
Definition: TensorContraction.h:607
#define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN)
Definition: TensorContraction.h:640
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int numThreads(double output_size, const TensorOpCost &cost_per_coeff, int max_threads)
Definition: TensorCostModel.h:154
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double taskSize(double output_size, const TensorOpCost &cost_per_coeff)
Definition: TensorCostModel.h:166
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double totalCost(double output_size, const TensorOpCost &cost_per_coeff)
Definition: TensorCostModel.h:170
Definition: ThreadLocal.h:112
T & local()
Definition: ThreadLocal.h:137
static constexpr lastp1_t end
Definition: IndexedViewHelper.h:79
@ Unaligned
Definition: Constants.h:235
@ ColMajor
Definition: Constants.h:318
RealScalar s
Definition: level1_cplx_impl.h:130
RealScalar alpha
Definition: level1_cplx_impl.h:151
int * m
Definition: level2_cplx_impl.h:294
Scalar beta
Definition: level2_cplx_impl.h:36
char char char int int * k
Definition: level2_impl.h:374
char char * op
Definition: level2_impl.h:374
EIGEN_DEVICE_FUNC Packet padd(const Packet &a, const Packet &b)
Definition: GenericPacketMath.h:318
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
EIGEN_DEVICE_FUNC IndexDest convert_index(const IndexSrc &idx)
Definition: XprHelper.h:63
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:1355
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet ploadt(const typename unpacket_traits< Packet >::type *from)
Definition: GenericPacketMath.h:1334
EIGEN_DEVICE_FUNC Packet pload(const typename unpacket_traits< Packet >::type *from)
Definition: GenericPacketMath.h:752
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
Definition: MathFunctions.h:926
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_CONSTEXPR T div_ceil(T a, T b)
Definition: MathFunctions.h:1251
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
auto run(Kernel kernel, Args &&... args) -> decltype(kernel(args...))
Definition: gpu_test_helper.h:414
std::array< T, N > array
Definition: EmulateArray.h:231
squared absolute value
Definition: GlobalFunctions.h:87
std::ptrdiff_t l2CacheSize()
Definition: products/GeneralBlockPanelKernel.h:3127
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
std::ptrdiff_t l3CacheSize()
Definition: products/GeneralBlockPanelKernel.h:3135
double P
Uniform pressure.
Definition: TwenteMeshGluing.cpp:77
Definition: Eigen_Colamd.h:49
void start(const unsigned &i)
(Re-)start i-th timer
Definition: oomph_utilities.cc:243
list x
Definition: plotDoE.py:28
CwiseBinaryOp< internal::scalar_sum_op< double, double >, const CpyMatrixXd, const CpyMatrixXd > XprType
Definition: nestbyvalue.cpp:15
internal::nested_eval< T, 1 >::type eval(const T &xpr)
Definition: sparse_permutations.cpp:47
internal::packet_traits< Scalar >::type type
Definition: TensorMeta.h:48
static constexpr int Layout
Definition: TensorEvaluator.h:46
Derived::Scalar Scalar
Definition: TensorEvaluator.h:33
const Device EIGEN_DEVICE_REF m_device
Definition: TensorEvaluator.h:170
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived &m, const Device &device)
Definition: TensorEvaluator.h:66
Derived::Scalar CoeffReturnType
Definition: TensorEvaluator.h:34
Derived XprType
Definition: TensorEvaluator.h:37
Derived::Index Index
Definition: TensorEvaluator.h:32
PacketType< CoeffReturnType, Device >::type PacketReturnType
Definition: TensorEvaluator.h:35
Derived::Dimensions Dimensions
Definition: TensorEvaluator.h:36
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc &desc, TensorBlockScratch &scratch, bool=false) const
Definition: TensorEvaluator.h:147
static constexpr Index value
Definition: Meta.h:306
@ size
Definition: GenericPacketMath.h:113
@ size
Definition: GenericPacketMath.h:139