10 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11 #define EIGEN_GENERAL_MATRIX_VECTOR_H
14 #include "../InternalHeaderCheck.h"
22 template <
int N,
typename T1,
typename T2,
typename T3>
27 template <
typename T1,
typename T2,
typename T3>
32 template <
typename T1,
typename T2,
typename T3>
37 template <
typename LhsScalar,
typename RhsScalar,
int PacketSize_ = GEMVPacketFull>
41 #define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
42 typedef typename gemv_packet_cond< \
43 packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
44 typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet##postfix
49 #undef PACKET_DECL_COND_POSTFIX
60 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar>
LhsPacket;
61 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar>
RhsPacket;
62 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar>
ResPacket;
78 template <
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
79 typename RhsMapper,
bool ConjugateRhs,
int Version>
81 ConjugateRhs, Version> {
105 template <
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
106 typename RhsMapper,
bool ConjugateRhs,
int Version>
123 const Index lhsStride = lhs.stride();
127 ResPacketSize = Traits::ResPacketSize,
128 ResPacketSizeHalf = HalfTraits::ResPacketSize,
129 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
130 LhsPacketSize = Traits::LhsPacketSize,
131 HasHalf = (
int)ResPacketSizeHalf < (
int)ResPacketSize,
132 HasQuarter = (
int)ResPacketSizeQuarter < (
int)ResPacketSizeHalf
135 const Index n8 =
rows - 8 * ResPacketSize + 1;
136 const Index n4 =
rows - 4 * ResPacketSize + 1;
137 const Index n3 =
rows - 3 * ResPacketSize + 1;
138 const Index n2 =
rows - 2 * ResPacketSize + 1;
139 const Index n1 =
rows - 1 * ResPacketSize + 1;
140 const Index n_half =
rows - 1 * ResPacketSizeHalf + 1;
141 const Index n_quarter =
rows - 1 * ResPacketSizeQuarter + 1;
144 const Index block_cols =
cols < 128 ?
cols : (lhsStride *
sizeof(LhsScalar) < 32000 ? 16 : 4);
149 for (
Index j2 = 0; j2 <
cols; j2 += block_cols) {
152 for (;
i < n8;
i += ResPacketSize * 8) {
158 for (
Index j = j2;
j < jend;
j += 1) {
160 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 0,
j), b0, c0);
161 c1 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 1,
j), b0, c1);
162 c2 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 2,
j), b0, c2);
163 c3 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 3,
j), b0, c3);
164 c4 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 4,
j), b0, c4);
165 c5 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 5,
j), b0, c5);
166 c6 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 6,
j), b0, c6);
167 c7 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 7,
j), b0, c7);
182 for (
Index j = j2;
j < jend;
j += 1) {
184 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 0,
j), b0, c0);
185 c1 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 1,
j), b0, c1);
186 c2 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 2,
j), b0, c2);
187 c3 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 3,
j), b0, c3);
194 i += ResPacketSize * 4;
200 for (
Index j = j2;
j < jend;
j += 1) {
202 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 0,
j), b0, c0);
203 c1 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 1,
j), b0, c1);
204 c2 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 2,
j), b0, c2);
210 i += ResPacketSize * 3;
215 for (
Index j = j2;
j < jend;
j += 1) {
217 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 0,
j), b0, c0);
218 c1 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + LhsPacketSize * 1,
j), b0, c1);
222 i += ResPacketSize * 2;
226 for (
Index j = j2;
j < jend;
j += 1) {
228 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 0,
j), b0, c0);
233 if (HasHalf &&
i < n_half) {
235 for (
Index j = j2;
j < jend;
j += 1) {
237 c0 = pcj_half.
pmadd(lhs.template load<LhsPacketHalf, LhsAlignment>(
i + 0,
j), b0, c0);
240 pmadd(c0, palpha_half, ploadu<ResPacketHalf>(
res +
i + ResPacketSizeHalf * 0)));
241 i += ResPacketSizeHalf;
243 if (HasQuarter &&
i < n_quarter) {
245 for (
Index j = j2;
j < jend;
j += 1) {
247 c0 = pcj_quarter.
pmadd(lhs.template load<LhsPacketQuarter, LhsAlignment>(
i + 0,
j), b0, c0);
250 pmadd(c0, palpha_quarter, ploadu<ResPacketQuarter>(
res +
i + ResPacketSizeQuarter * 0)));
251 i += ResPacketSizeQuarter;
255 for (
Index j = j2;
j < jend;
j += 1) c0 += cj.
pmul(lhs(
i,
j), rhs(
j, 0));
271 template <
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
272 typename RhsMapper,
bool ConjugateRhs,
int Version>
274 ConjugateRhs, Version> {
298 template <
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
299 typename RhsMapper,
bool ConjugateRhs,
int Version>
316 const Index n8 = lhs.stride() *
sizeof(LhsScalar) > 32000 ? 0 :
rows - 7;
323 ResPacketSize = Traits::ResPacketSize,
324 ResPacketSizeHalf = HalfTraits::ResPacketSize,
325 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
326 LhsPacketSize = Traits::LhsPacketSize,
327 LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
328 LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
329 HasHalf = (
int)ResPacketSizeHalf < (
int)ResPacketSize,
330 HasQuarter = (
int)ResPacketSizeQuarter < (
int)ResPacketSizeHalf
334 const Index fullColBlockEnd = LhsPacketSize * (UnsignedIndex(
cols) / LhsPacketSize);
335 const Index halfColBlockEnd = LhsPacketSizeHalf * (UnsignedIndex(
cols) / LhsPacketSizeHalf);
336 const Index quarterColBlockEnd = LhsPacketSizeQuarter * (UnsignedIndex(
cols) / LhsPacketSizeQuarter);
339 for (;
i < n8;
i += 8) {
345 for (
Index j = 0;
j < fullColBlockEnd;
j += LhsPacketSize) {
346 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j, 0);
348 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 0,
j), b0, c0);
349 c1 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 1,
j), b0, c1);
350 c2 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 2,
j), b0, c2);
351 c3 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 3,
j), b0, c3);
352 c4 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 4,
j), b0, c4);
353 c5 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 5,
j), b0, c5);
354 c6 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 6,
j), b0, c6);
355 c7 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 7,
j), b0, c7);
367 RhsScalar b0 = rhs(
j, 0);
369 cc0 += cj.
pmul(lhs(
i + 0,
j), b0);
370 cc1 += cj.
pmul(lhs(
i + 1,
j), b0);
371 cc2 += cj.
pmul(lhs(
i + 2,
j), b0);
372 cc3 += cj.
pmul(lhs(
i + 3,
j), b0);
373 cc4 += cj.
pmul(lhs(
i + 4,
j), b0);
374 cc5 += cj.
pmul(lhs(
i + 5,
j), b0);
375 cc6 += cj.
pmul(lhs(
i + 6,
j), b0);
376 cc7 += cj.
pmul(lhs(
i + 7,
j), b0);
387 for (;
i < n4;
i += 4) {
391 for (
Index j = 0;
j < fullColBlockEnd;
j += LhsPacketSize) {
392 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j, 0);
394 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 0,
j), b0, c0);
395 c1 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 1,
j), b0, c1);
396 c2 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 2,
j), b0, c2);
397 c3 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 3,
j), b0, c3);
405 RhsScalar b0 = rhs(
j, 0);
407 cc0 += cj.
pmul(lhs(
i + 0,
j), b0);
408 cc1 += cj.
pmul(lhs(
i + 1,
j), b0);
409 cc2 += cj.
pmul(lhs(
i + 2,
j), b0);
410 cc3 += cj.
pmul(lhs(
i + 3,
j), b0);
417 for (;
i < n2;
i += 2) {
420 for (
Index j = 0;
j < fullColBlockEnd;
j += LhsPacketSize) {
421 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j, 0);
423 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 0,
j), b0, c0);
424 c1 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i + 1,
j), b0, c1);
430 RhsScalar b0 = rhs(
j, 0);
432 cc0 += cj.
pmul(lhs(
i + 0,
j), b0);
433 cc1 += cj.
pmul(lhs(
i + 1,
j), b0);
443 for (
Index j = 0;
j < fullColBlockEnd;
j += LhsPacketSize) {
444 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(
j, 0);
445 c0 = pcj.
pmadd(lhs.template load<LhsPacket, LhsAlignment>(
i,
j), b0, c0);
449 for (
Index j = fullColBlockEnd;
j < halfColBlockEnd;
j += LhsPacketSizeHalf) {
450 RhsPacketHalf b0 = rhs.template load<RhsPacketHalf, Unaligned>(
j, 0);
451 c0_h = pcj_half.
pmadd(lhs.template load<LhsPacketHalf, LhsAlignment>(
i,
j), b0, c0_h);
456 for (
Index j = halfColBlockEnd;
j < quarterColBlockEnd;
j += LhsPacketSizeQuarter) {
458 c0_q = pcj_quarter.
pmadd(lhs.template load<LhsPacketQuarter, LhsAlignment>(
i,
j), b0, c0_q);
463 cc0 += cj.
pmul(lhs(
i,
j), rhs(
j, 0));
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define eigen_internal_assert(x)
Definition: Macros.h:916
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define EIGEN_DONT_INLINE
Definition: Macros.h:853
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
#define _(A, B)
Definition: cfortran.h:132
Definition: GeneralMatrixVector.h:38
PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_)
PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_)
std::conditional_t< Vectorizable, ResPacket_, ResScalar > ResPacket
Definition: GeneralMatrixVector.h:62
@ ResPacketSize
Definition: GeneralMatrixVector.h:57
@ RhsPacketSize
Definition: GeneralMatrixVector.h:56
@ Vectorizable
Definition: GeneralMatrixVector.h:53
@ LhsPacketSize
Definition: GeneralMatrixVector.h:55
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixVector.h:39
PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_)
std::conditional_t< Vectorizable, LhsPacket_, LhsScalar > LhsPacket
Definition: GeneralMatrixVector.h:60
std::conditional_t< Vectorizable, RhsPacket_, RhsScalar > RhsPacket
Definition: GeneralMatrixVector.h:61
@ Unaligned
Definition: Constants.h:235
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
RealScalar * palpha
Definition: level1_cplx_impl.h:147
RealScalar alpha
Definition: level1_cplx_impl.h:151
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218
GEMVPacketSizeType
Definition: GeneralMatrixVector.h:20
@ GEMVPacketFull
Definition: GeneralMatrixVector.h:20
@ GEMVPacketHalf
Definition: GeneralMatrixVector.h:20
@ GEMVPacketQuarter
Definition: GeneralMatrixVector.h:20
EIGEN_DEVICE_FUNC unpacket_traits< Packet >::type predux(const Packet &a)
Definition: GenericPacketMath.h:1232
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:911
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
auto run(Kernel kernel, Args &&... args) -> decltype(kernel(args...))
Definition: gpu_test_helper.h:414
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
type
Definition: compute_granudrum_aor.py:141
Definition: Eigen_Colamd.h:49
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:1043
Definition: ConjHelper.h:71
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType &x, const RhsType &y, const ResultType &c) const
Definition: ConjHelper.h:74
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType &x, const RhsType &y) const
Definition: ConjHelper.h:79
T1 type
Definition: GeneralMatrixVector.h:29
T2 type
Definition: GeneralMatrixVector.h:34
Definition: GeneralMatrixVector.h:23
T3 type
Definition: GeneralMatrixVector.h:24
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixVector.h:86
Traits::LhsPacket LhsPacket
Definition: GeneralMatrixVector.h:88
Traits::RhsPacket RhsPacket
Definition: GeneralMatrixVector.h:89
QuarterTraits::LhsPacket LhsPacketQuarter
Definition: GeneralMatrixVector.h:96
Traits::ResPacket ResPacket
Definition: GeneralMatrixVector.h:90
gemv_traits< LhsScalar, RhsScalar, GEMVPacketHalf > HalfTraits
Definition: GeneralMatrixVector.h:83
HalfTraits::ResPacket ResPacketHalf
Definition: GeneralMatrixVector.h:94
gemv_traits< LhsScalar, RhsScalar, GEMVPacketQuarter > QuarterTraits
Definition: GeneralMatrixVector.h:84
QuarterTraits::RhsPacket RhsPacketQuarter
Definition: GeneralMatrixVector.h:97
QuarterTraits::ResPacket ResPacketQuarter
Definition: GeneralMatrixVector.h:98
HalfTraits::LhsPacket LhsPacketHalf
Definition: GeneralMatrixVector.h:92
HalfTraits::RhsPacket RhsPacketHalf
Definition: GeneralMatrixVector.h:93
gemv_traits< LhsScalar, RhsScalar > Traits
Definition: GeneralMatrixVector.h:82
QuarterTraits::RhsPacket RhsPacketQuarter
Definition: GeneralMatrixVector.h:290
gemv_traits< LhsScalar, RhsScalar > Traits
Definition: GeneralMatrixVector.h:275
Traits::RhsPacket RhsPacket
Definition: GeneralMatrixVector.h:282
gemv_traits< LhsScalar, RhsScalar, GEMVPacketHalf > HalfTraits
Definition: GeneralMatrixVector.h:276
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixVector.h:279
HalfTraits::RhsPacket RhsPacketHalf
Definition: GeneralMatrixVector.h:286
Traits::ResPacket ResPacket
Definition: GeneralMatrixVector.h:283
Traits::LhsPacket LhsPacket
Definition: GeneralMatrixVector.h:281
HalfTraits::LhsPacket LhsPacketHalf
Definition: GeneralMatrixVector.h:285
HalfTraits::ResPacket ResPacketHalf
Definition: GeneralMatrixVector.h:287
gemv_traits< LhsScalar, RhsScalar, GEMVPacketQuarter > QuarterTraits
Definition: GeneralMatrixVector.h:277
QuarterTraits::ResPacket ResPacketQuarter
Definition: GeneralMatrixVector.h:291
QuarterTraits::LhsPacket LhsPacketQuarter
Definition: GeneralMatrixVector.h:289
Definition: BlasUtil.h:42
Definition: GenericPacketMath.h:134
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2