GeneralMatrixVector.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11 #define EIGEN_GENERAL_MATRIX_VECTOR_H
12 
13 // IWYU pragma: private
14 #include "../InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
21 
22 template <int N, typename T1, typename T2, typename T3>
24  typedef T3 type;
25 };
26 
27 template <typename T1, typename T2, typename T3>
28 struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> {
29  typedef T1 type;
30 };
31 
32 template <typename T1, typename T2, typename T3>
33 struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> {
34  typedef T2 type;
35 };
36 
37 template <typename LhsScalar, typename RhsScalar, int PacketSize_ = GEMVPacketFull>
38 class gemv_traits {
40 
41 #define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
42  typedef typename gemv_packet_cond< \
43  packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
44  typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet##postfix
45 
46  PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
47  PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
48  PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
49 #undef PACKET_DECL_COND_POSTFIX
50 
51  public:
52  enum {
58  };
59 
60  typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
61  typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
62  typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
63 };
64 
65 /* Optimized col-major matrix * vector product:
66  * This algorithm processes the matrix per vertical panels,
67  * which are then processed horizontally per chunk of 8*PacketSize x 1 vertical segments.
68  *
69  * Mixing type logic: C += alpha * A * B
70  * | A | B |alpha| comments
71  * |real |cplx |cplx | no vectorization
72  * |real |cplx |real | alpha is converted to a cplx when calling the run function, no vectorization
73  * |cplx |real |cplx | invalid, the caller has to do tmp: = A * B; C += alpha*tmp
74  * |cplx |real |real | optimal case, vectorization possible via real-cplx mul
75  *
76  * The same reasoning apply for the transposed case.
77  */
78 template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
79  typename RhsMapper, bool ConjugateRhs, int Version>
80 struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper,
81  ConjugateRhs, Version> {
85 
87 
88  typedef typename Traits::LhsPacket LhsPacket;
89  typedef typename Traits::RhsPacket RhsPacket;
90  typedef typename Traits::ResPacket ResPacket;
91 
95 
99 
100  EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs,
101  const RhsMapper& rhs, ResScalar* res, Index resIncr,
102  RhsScalar alpha);
103 };
104 
105 template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
106  typename RhsMapper, bool ConjugateRhs, int Version>
108 general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
109  Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
110  ResScalar* res, Index resIncr, RhsScalar alpha) {
111  EIGEN_UNUSED_VARIABLE(resIncr);
112  eigen_internal_assert(resIncr == 1);
113 
114  // The following copy tells the compiler that lhs's attributes are not modified outside this function
115  // This helps GCC to generate proper code.
116  LhsMapper lhs(alhs);
117 
122 
123  const Index lhsStride = lhs.stride();
124  // TODO: for padded aligned inputs, we could enable aligned reads
125  enum {
126  LhsAlignment = Unaligned,
127  ResPacketSize = Traits::ResPacketSize,
128  ResPacketSizeHalf = HalfTraits::ResPacketSize,
129  ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
130  LhsPacketSize = Traits::LhsPacketSize,
131  HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
132  HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
133  };
134 
135  const Index n8 = rows - 8 * ResPacketSize + 1;
136  const Index n4 = rows - 4 * ResPacketSize + 1;
137  const Index n3 = rows - 3 * ResPacketSize + 1;
138  const Index n2 = rows - 2 * ResPacketSize + 1;
139  const Index n1 = rows - 1 * ResPacketSize + 1;
140  const Index n_half = rows - 1 * ResPacketSizeHalf + 1;
141  const Index n_quarter = rows - 1 * ResPacketSizeQuarter + 1;
142 
143  // TODO: improve the following heuristic:
144  const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(LhsScalar) < 32000 ? 16 : 4);
145  ResPacket palpha = pset1<ResPacket>(alpha);
146  ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
147  ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
148 
149  for (Index j2 = 0; j2 < cols; j2 += block_cols) {
150  Index jend = numext::mini(j2 + block_cols, cols);
151  Index i = 0;
152  for (; i < n8; i += ResPacketSize * 8) {
153  ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
154  c2 = pset1<ResPacket>(ResScalar(0)), c3 = pset1<ResPacket>(ResScalar(0)),
155  c4 = pset1<ResPacket>(ResScalar(0)), c5 = pset1<ResPacket>(ResScalar(0)),
156  c6 = pset1<ResPacket>(ResScalar(0)), c7 = pset1<ResPacket>(ResScalar(0));
157 
158  for (Index j = j2; j < jend; j += 1) {
159  RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
160  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
161  c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
162  c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
163  c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 3, j), b0, c3);
164  c4 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 4, j), b0, c4);
165  c5 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 5, j), b0, c5);
166  c6 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 6, j), b0, c6);
167  c7 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 7, j), b0, c7);
168  }
169  pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
170  pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
171  pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
172  pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 3)));
173  pstoreu(res + i + ResPacketSize * 4, pmadd(c4, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 4)));
174  pstoreu(res + i + ResPacketSize * 5, pmadd(c5, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 5)));
175  pstoreu(res + i + ResPacketSize * 6, pmadd(c6, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 6)));
176  pstoreu(res + i + ResPacketSize * 7, pmadd(c7, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 7)));
177  }
178  if (i < n4) {
179  ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
180  c2 = pset1<ResPacket>(ResScalar(0)), c3 = pset1<ResPacket>(ResScalar(0));
181 
182  for (Index j = j2; j < jend; j += 1) {
183  RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
184  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
185  c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
186  c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
187  c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 3, j), b0, c3);
188  }
189  pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
190  pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
191  pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
192  pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 3)));
193 
194  i += ResPacketSize * 4;
195  }
196  if (i < n3) {
197  ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
198  c2 = pset1<ResPacket>(ResScalar(0));
199 
200  for (Index j = j2; j < jend; j += 1) {
201  RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
202  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
203  c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
204  c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
205  }
206  pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
207  pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
208  pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
209 
210  i += ResPacketSize * 3;
211  }
212  if (i < n2) {
213  ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0));
214 
215  for (Index j = j2; j < jend; j += 1) {
216  RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
217  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
218  c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
219  }
220  pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
221  pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
222  i += ResPacketSize * 2;
223  }
224  if (i < n1) {
225  ResPacket c0 = pset1<ResPacket>(ResScalar(0));
226  for (Index j = j2; j < jend; j += 1) {
227  RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
228  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
229  }
230  pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
231  i += ResPacketSize;
232  }
233  if (HasHalf && i < n_half) {
234  ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0));
235  for (Index j = j2; j < jend; j += 1) {
236  RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j, 0));
237  c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf, LhsAlignment>(i + 0, j), b0, c0);
238  }
239  pstoreu(res + i + ResPacketSizeHalf * 0,
240  pmadd(c0, palpha_half, ploadu<ResPacketHalf>(res + i + ResPacketSizeHalf * 0)));
241  i += ResPacketSizeHalf;
242  }
243  if (HasQuarter && i < n_quarter) {
244  ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0));
245  for (Index j = j2; j < jend; j += 1) {
246  RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j, 0));
247  c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter, LhsAlignment>(i + 0, j), b0, c0);
248  }
249  pstoreu(res + i + ResPacketSizeQuarter * 0,
250  pmadd(c0, palpha_quarter, ploadu<ResPacketQuarter>(res + i + ResPacketSizeQuarter * 0)));
251  i += ResPacketSizeQuarter;
252  }
253  for (; i < rows; ++i) {
254  ResScalar c0(0);
255  for (Index j = j2; j < jend; j += 1) c0 += cj.pmul(lhs(i, j), rhs(j, 0));
256  res[i] += alpha * c0;
257  }
258  }
259 }
260 
261 /* Optimized row-major matrix * vector product:
262  * This algorithm processes 4 rows at once that allows to both reduce
263  * the number of load/stores of the result by a factor 4 and to reduce
264  * the instruction dependency. Moreover, we know that all bands have the
265  * same alignment pattern.
266  *
267  * Mixing type logic:
268  * - alpha is always a complex (or converted to a complex)
269  * - no vectorization
270  */
271 template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
272  typename RhsMapper, bool ConjugateRhs, int Version>
273 struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper,
274  ConjugateRhs, Version> {
278 
280 
281  typedef typename Traits::LhsPacket LhsPacket;
282  typedef typename Traits::RhsPacket RhsPacket;
283  typedef typename Traits::ResPacket ResPacket;
284 
288 
292 
293  EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs,
294  const RhsMapper& rhs, ResScalar* res, Index resIncr,
295  ResScalar alpha);
296 };
297 
298 template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
299  typename RhsMapper, bool ConjugateRhs, int Version>
301 general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
302  Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
303  ResScalar* res, Index resIncr, ResScalar alpha) {
304  // The following copy tells the compiler that lhs's attributes are not modified outside this function
305  // This helps GCC to generate proper code.
306  LhsMapper lhs(alhs);
307 
308  eigen_internal_assert(rhs.stride() == 1);
313 
314  // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
315  // processing 8 rows at once might be counter productive wrt cache.
316  const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? 0 : rows - 7;
317  const Index n4 = rows - 3;
318  const Index n2 = rows - 1;
319 
320  // TODO: for padded aligned inputs, we could enable aligned reads
321  enum {
322  LhsAlignment = Unaligned,
323  ResPacketSize = Traits::ResPacketSize,
324  ResPacketSizeHalf = HalfTraits::ResPacketSize,
325  ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
326  LhsPacketSize = Traits::LhsPacketSize,
327  LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
328  LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
329  HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
330  HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
331  };
332 
333  using UnsignedIndex = typename make_unsigned<Index>::type;
334  const Index fullColBlockEnd = LhsPacketSize * (UnsignedIndex(cols) / LhsPacketSize);
335  const Index halfColBlockEnd = LhsPacketSizeHalf * (UnsignedIndex(cols) / LhsPacketSizeHalf);
336  const Index quarterColBlockEnd = LhsPacketSizeQuarter * (UnsignedIndex(cols) / LhsPacketSizeQuarter);
337 
338  Index i = 0;
339  for (; i < n8; i += 8) {
340  ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
341  c2 = pset1<ResPacket>(ResScalar(0)), c3 = pset1<ResPacket>(ResScalar(0)),
342  c4 = pset1<ResPacket>(ResScalar(0)), c5 = pset1<ResPacket>(ResScalar(0)),
343  c6 = pset1<ResPacket>(ResScalar(0)), c7 = pset1<ResPacket>(ResScalar(0));
344 
345  for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
346  RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j, 0);
347 
348  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
349  c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 1, j), b0, c1);
350  c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 2, j), b0, c2);
351  c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 3, j), b0, c3);
352  c4 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 4, j), b0, c4);
353  c5 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 5, j), b0, c5);
354  c6 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 6, j), b0, c6);
355  c7 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 7, j), b0, c7);
356  }
357  ResScalar cc0 = predux(c0);
358  ResScalar cc1 = predux(c1);
359  ResScalar cc2 = predux(c2);
360  ResScalar cc3 = predux(c3);
361  ResScalar cc4 = predux(c4);
362  ResScalar cc5 = predux(c5);
363  ResScalar cc6 = predux(c6);
364  ResScalar cc7 = predux(c7);
365 
366  for (Index j = fullColBlockEnd; j < cols; ++j) {
367  RhsScalar b0 = rhs(j, 0);
368 
369  cc0 += cj.pmul(lhs(i + 0, j), b0);
370  cc1 += cj.pmul(lhs(i + 1, j), b0);
371  cc2 += cj.pmul(lhs(i + 2, j), b0);
372  cc3 += cj.pmul(lhs(i + 3, j), b0);
373  cc4 += cj.pmul(lhs(i + 4, j), b0);
374  cc5 += cj.pmul(lhs(i + 5, j), b0);
375  cc6 += cj.pmul(lhs(i + 6, j), b0);
376  cc7 += cj.pmul(lhs(i + 7, j), b0);
377  }
378  res[(i + 0) * resIncr] += alpha * cc0;
379  res[(i + 1) * resIncr] += alpha * cc1;
380  res[(i + 2) * resIncr] += alpha * cc2;
381  res[(i + 3) * resIncr] += alpha * cc3;
382  res[(i + 4) * resIncr] += alpha * cc4;
383  res[(i + 5) * resIncr] += alpha * cc5;
384  res[(i + 6) * resIncr] += alpha * cc6;
385  res[(i + 7) * resIncr] += alpha * cc7;
386  }
387  for (; i < n4; i += 4) {
388  ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
389  c2 = pset1<ResPacket>(ResScalar(0)), c3 = pset1<ResPacket>(ResScalar(0));
390 
391  for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
392  RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j, 0);
393 
394  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
395  c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 1, j), b0, c1);
396  c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 2, j), b0, c2);
397  c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 3, j), b0, c3);
398  }
399  ResScalar cc0 = predux(c0);
400  ResScalar cc1 = predux(c1);
401  ResScalar cc2 = predux(c2);
402  ResScalar cc3 = predux(c3);
403 
404  for (Index j = fullColBlockEnd; j < cols; ++j) {
405  RhsScalar b0 = rhs(j, 0);
406 
407  cc0 += cj.pmul(lhs(i + 0, j), b0);
408  cc1 += cj.pmul(lhs(i + 1, j), b0);
409  cc2 += cj.pmul(lhs(i + 2, j), b0);
410  cc3 += cj.pmul(lhs(i + 3, j), b0);
411  }
412  res[(i + 0) * resIncr] += alpha * cc0;
413  res[(i + 1) * resIncr] += alpha * cc1;
414  res[(i + 2) * resIncr] += alpha * cc2;
415  res[(i + 3) * resIncr] += alpha * cc3;
416  }
417  for (; i < n2; i += 2) {
418  ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0));
419 
420  for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
421  RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j, 0);
422 
423  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
424  c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 1, j), b0, c1);
425  }
426  ResScalar cc0 = predux(c0);
427  ResScalar cc1 = predux(c1);
428 
429  for (Index j = fullColBlockEnd; j < cols; ++j) {
430  RhsScalar b0 = rhs(j, 0);
431 
432  cc0 += cj.pmul(lhs(i + 0, j), b0);
433  cc1 += cj.pmul(lhs(i + 1, j), b0);
434  }
435  res[(i + 0) * resIncr] += alpha * cc0;
436  res[(i + 1) * resIncr] += alpha * cc1;
437  }
438  for (; i < rows; ++i) {
439  ResPacket c0 = pset1<ResPacket>(ResScalar(0));
440  ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0));
441  ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0));
442 
443  for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
444  RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j, 0);
445  c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i, j), b0, c0);
446  }
447  ResScalar cc0 = predux(c0);
448  if (HasHalf) {
449  for (Index j = fullColBlockEnd; j < halfColBlockEnd; j += LhsPacketSizeHalf) {
450  RhsPacketHalf b0 = rhs.template load<RhsPacketHalf, Unaligned>(j, 0);
451  c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf, LhsAlignment>(i, j), b0, c0_h);
452  }
453  cc0 += predux(c0_h);
454  }
455  if (HasQuarter) {
456  for (Index j = halfColBlockEnd; j < quarterColBlockEnd; j += LhsPacketSizeQuarter) {
457  RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter, Unaligned>(j, 0);
458  c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter, LhsAlignment>(i, j), b0, c0_q);
459  }
460  cc0 += predux(c0_q);
461  }
462  for (Index j = quarterColBlockEnd; j < cols; ++j) {
463  cc0 += cj.pmul(lhs(i, j), rhs(j, 0));
464  }
465  res[i * resIncr] += alpha * cc0;
466  }
467 }
468 
469 } // end namespace internal
470 
471 } // end namespace Eigen
472 
473 #endif // EIGEN_GENERAL_MATRIX_VECTOR_H
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define eigen_internal_assert(x)
Definition: Macros.h:916
#define EIGEN_UNUSED_VARIABLE(var)
Definition: Macros.h:966
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define EIGEN_DONT_INLINE
Definition: Macros.h:853
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
int rows
Definition: Tutorial_commainit_02.cpp:1
int cols
Definition: Tutorial_commainit_02.cpp:1
#define _(A, B)
Definition: cfortran.h:132
Definition: GeneralMatrixVector.h:38
PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_)
PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_)
std::conditional_t< Vectorizable, ResPacket_, ResScalar > ResPacket
Definition: GeneralMatrixVector.h:62
@ ResPacketSize
Definition: GeneralMatrixVector.h:57
@ RhsPacketSize
Definition: GeneralMatrixVector.h:56
@ Vectorizable
Definition: GeneralMatrixVector.h:53
@ LhsPacketSize
Definition: GeneralMatrixVector.h:55
ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType ResScalar
Definition: GeneralMatrixVector.h:39
PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_)
std::conditional_t< Vectorizable, LhsPacket_, LhsScalar > LhsPacket
Definition: GeneralMatrixVector.h:60
std::conditional_t< Vectorizable, RhsPacket_, RhsScalar > RhsPacket
Definition: GeneralMatrixVector.h:61
@ Unaligned
Definition: Constants.h:235
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
RealScalar * palpha
Definition: level1_cplx_impl.h:147
return int(ret)+1
RealScalar alpha
Definition: level1_cplx_impl.h:151
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218
GEMVPacketSizeType
Definition: GeneralMatrixVector.h:20
@ GEMVPacketFull
Definition: GeneralMatrixVector.h:20
@ GEMVPacketHalf
Definition: GeneralMatrixVector.h:20
@ GEMVPacketQuarter
Definition: GeneralMatrixVector.h:20
EIGEN_DEVICE_FUNC unpacket_traits< Packet >::type predux(const Packet &a)
Definition: GenericPacketMath.h:1232
EIGEN_DEVICE_FUNC void pstoreu(Scalar *to, const Packet &from)
Definition: GenericPacketMath.h:911
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
auto run(Kernel kernel, Args &&... args) -> decltype(kernel(args...))
Definition: gpu_test_helper.h:414
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
type
Definition: compute_granudrum_aor.py:141
Definition: Eigen_Colamd.h:49
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:1043
Definition: ConjHelper.h:71
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType &x, const RhsType &y, const ResultType &c) const
Definition: ConjHelper.h:74
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType &x, const RhsType &y) const
Definition: ConjHelper.h:79
Definition: GeneralMatrixVector.h:23
T3 type
Definition: GeneralMatrixVector.h:24
Definition: GenericPacketMath.h:134
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2