InnerProduct.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2024 Charlie Schlosser <cs.schlosser@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_INNER_PRODUCT_EVAL_H
11 #define EIGEN_INNER_PRODUCT_EVAL_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 namespace internal {
19 
20 // recursively searches for the largest simd type that does not exceed Size, or the smallest if no such type exists
22  bool Stop =
25 
26 template <typename Scalar, int Size, typename Packet>
29 };
30 
31 template <typename Scalar, int Size, typename Packet>
33  using type = Packet;
34 };
35 
36 template <typename Scalar, int Size>
38 
39 template <typename Scalar>
42 };
43 
44 template <typename Lhs, typename Rhs>
49 #ifndef EIGEN_NO_DEBUG
50  static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, const Rhs& rhs) {
51  eigen_assert((lhs.size() == rhs.size()) && "Inner product: lhs and rhs vectors must have same size");
52  }
53 #else
54  static EIGEN_DEVICE_FUNC void run(const Lhs&, const Rhs&) {}
55 #endif
56 };
57 
58 template <typename Func, typename Lhs, typename Rhs>
61  SizeAtCompileTime = min_size_prefer_fixed(Lhs::SizeAtCompileTime, Rhs::SizeAtCompileTime),
63 
64  using Scalar = typename Func::result_type;
66 
67  static constexpr bool Vectorize =
68  bool(LhsFlags & RhsFlags & PacketAccessBit) && Func::PacketAccess &&
70 
72  Func func = Func())
73  : m_func(func), m_lhs(lhs), m_rhs(rhs), m_size(lhs.size()) {
75  }
76 
78 
80  return m_func.coeff(m_lhs.coeff(index), m_rhs.coeff(index));
81  }
82 
84  return m_func.coeff(value, m_lhs.coeff(index), m_rhs.coeff(index));
85  }
86 
87  template <typename PacketType, int LhsMode = LhsAlignment, int RhsMode = RhsAlignment>
89  return m_func.packet(m_lhs.template packet<LhsMode, PacketType>(index),
90  m_rhs.template packet<RhsMode, PacketType>(index));
91  }
92 
93  template <typename PacketType, int LhsMode = LhsAlignment, int RhsMode = RhsAlignment>
95  return m_func.packet(value, m_lhs.template packet<LhsMode, PacketType>(index),
96  m_rhs.template packet<RhsMode, PacketType>(index));
97  }
98 
99  const Func m_func;
103 };
104 
105 template <typename Evaluator, bool Vectorize = Evaluator::Vectorize>
107 
108 // scalar loop
109 template <typename Evaluator>
110 struct inner_product_impl<Evaluator, false> {
111  using Scalar = typename Evaluator::Scalar;
113  const Index size = eval.size();
114  if (size == 0) return Scalar(0);
115 
116  Scalar result = eval.coeff(0);
117  for (Index k = 1; k < size; k++) {
118  result = eval.coeff(result, k);
119  }
120 
121  return result;
122  }
123 };
124 
125 // vector loop
126 template <typename Evaluator>
127 struct inner_product_impl<Evaluator, true> {
128  using UnsignedIndex = std::make_unsigned_t<Index>;
129  using Scalar = typename Evaluator::Scalar;
130  using Packet = typename Evaluator::Packet;
131  static constexpr int PacketSize = unpacket_traits<Packet>::size;
133  const UnsignedIndex size = static_cast<UnsignedIndex>(eval.size());
134  if (size < PacketSize) return inner_product_impl<Evaluator, false>::run(eval);
135 
136  const UnsignedIndex packetEnd = numext::round_down(size, PacketSize);
137  const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize);
138  const UnsignedIndex numPackets = size / PacketSize;
139  const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize;
140 
141  Packet presult0, presult1, presult2, presult3;
142 
143  presult0 = eval.template packet<Packet>(0 * PacketSize);
144  if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize);
145  if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize);
146  if (numPackets >= 4) {
147  presult3 = eval.template packet<Packet>(3 * PacketSize);
148 
149  for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) {
150  presult0 = eval.packet(presult0, k + 0 * PacketSize);
151  presult1 = eval.packet(presult1, k + 1 * PacketSize);
152  presult2 = eval.packet(presult2, k + 2 * PacketSize);
153  presult3 = eval.packet(presult3, k + 3 * PacketSize);
154  }
155 
156  if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize);
157  if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize);
158  if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize);
159 
160  presult2 = padd(presult2, presult3);
161  }
162 
163  if (numPackets >= 3) presult1 = padd(presult1, presult2);
164  if (numPackets >= 2) presult0 = padd(presult0, presult1);
165 
166  Scalar result = predux(presult0);
167  for (UnsignedIndex k = packetEnd; k < size; k++) {
168  result = eval.coeff(result, k);
169  }
170 
171  return result;
172  }
173 };
174 
175 template <typename Scalar, bool Conj>
177 
178 template <typename Scalar>
179 struct conditional_conj<Scalar, true> {
181  template <typename Packet>
183  return pconj(a);
184  }
185 };
186 
187 template <typename Scalar>
188 struct conditional_conj<Scalar, false> {
190  template <typename Packet>
192  return a;
193  }
194 };
195 
196 template <typename LhsScalar, typename RhsScalar, bool Conj>
200  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const LhsScalar& a, const RhsScalar& b) const {
201  return (conj_helper::coeff(a) * b);
202  }
204  const RhsScalar& b) const {
205  return (conj_helper::coeff(a) * b) + accum;
206  }
207  static constexpr bool PacketAccess = false;
208 };
209 
210 template <typename Scalar, bool Conj>
215  return pmul(conj_helper::coeff(a), b);
216  }
217  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& accum, const Scalar& a, const Scalar& b) const {
218  return pmadd(conj_helper::coeff(a), b, accum);
219  }
220  template <typename Packet>
222  return pmul(conj_helper::packet(a), b);
223  }
224  template <typename Packet>
225  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& accum, const Packet& a, const Packet& b) const {
226  return pmadd(conj_helper::packet(a), b, accum);
227  }
229 };
230 
231 template <typename Lhs, typename Rhs, bool Conj>
233  using LhsScalar = typename traits<Lhs>::Scalar;
234  using RhsScalar = typename traits<Rhs>::Scalar;
237  using result_type = typename Evaluator::Scalar;
239  Evaluator eval(a.derived(), b.derived(), Op());
241  }
242 };
243 
244 template <typename Lhs, typename Rhs>
245 struct dot_impl : default_inner_product_impl<Lhs, Rhs, true> {};
246 
247 } // namespace internal
248 } // namespace Eigen
249 
250 #endif // EIGEN_INNER_PRODUCT_EVAL_H
AnnoyingScalar conj(const AnnoyingScalar &x)
Definition: AnnoyingScalar.h:133
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define eigen_assert(x)
Definition: Macros.h:910
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
#define EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(TYPE0, TYPE1)
Definition: StaticAssert.h:60
#define EIGEN_STATIC_ASSERT_VECTOR_ONLY(TYPE)
Definition: StaticAssert.h:36
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
internal::packet_traits< Scalar >::type Packet
Definition: benchmark-blocking-sizes.cpp:54
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:52
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR T value()
Definition: XprHelper.h:161
@ Conj
Definition: common.h:73
const unsigned int PacketAccessBit
Definition: Constants.h:97
const Scalar * a
Definition: level2_cplx_impl.h:32
char char char int int * k
Definition: level2_impl.h:374
EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf &a)
Definition: AltiVec/Complex.h:268
EIGEN_DEVICE_FUNC Packet padd(const Packet &a, const Packet &b)
Definition: GenericPacketMath.h:318
@ Lhs
Definition: TensorContractionMapper.h:20
@ Rhs
Definition: TensorContractionMapper.h:20
constexpr int min_size_prefer_fixed(A a, B b)
Definition: Meta.h:683
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218
EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf &a, const Packet4cf &b)
Definition: AVX/Complex.h:88
EIGEN_DEVICE_FUNC unpacket_traits< Packet >::type predux(const Packet &a)
Definition: GenericPacketMath.h:1232
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EIGEN_CONSTEXPR T round_down(T a, U b)
Definition: MathFunctions.h:1266
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
squared absolute value
Definition: GlobalFunctions.h:87
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
const int Dynamic
Definition: Constants.h:25
type
Definition: compute_granudrum_aor.py:141
Definition: Eigen_Colamd.h:49
internal::nested_eval< T, 1 >::type eval(const T &xpr)
Definition: sparse_permutations.cpp:47
Definition: TensorMeta.h:47
Determines whether the given binary operation of two numeric types is allowed and what the scalar ret...
Definition: XprHelper.h:1043
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar &a)
Definition: InnerProduct.h:189
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet &a)
Definition: InnerProduct.h:191
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet &a)
Definition: InnerProduct.h:182
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar &a)
Definition: InnerProduct.h:180
Definition: InnerProduct.h:176
Definition: InnerProduct.h:232
typename traits< Rhs >::Scalar RhsScalar
Definition: InnerProduct.h:234
typename Evaluator::Scalar result_type
Definition: InnerProduct.h:237
typename traits< Lhs >::Scalar LhsScalar
Definition: InnerProduct.h:233
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type run(const MatrixBase< Lhs > &a, const MatrixBase< Rhs > &b)
Definition: InnerProduct.h:238
Definition: InnerProduct.h:245
Definition: XprHelper.h:427
Definition: CoreEvaluators.h:104
typename packet_traits< Scalar >::type type
Definition: InnerProduct.h:41
typename find_inner_product_packet_helper< Scalar, Size, typename unpacket_traits< Packet >::half >::type type
Definition: InnerProduct.h:28
Definition: InnerProduct.h:37
Definition: InnerProduct.h:45
static EIGEN_DEVICE_FUNC void run(const Lhs &lhs, const Rhs &rhs)
Definition: InnerProduct.h:50
Definition: InnerProduct.h:59
static constexpr int LhsAlignment
Definition: InnerProduct.h:62
static constexpr int SizeAtCompileTime
Definition: InnerProduct.h:61
typename Func::result_type Scalar
Definition: InnerProduct.h:64
const evaluator< Rhs > m_rhs
Definition: InnerProduct.h:101
const Func m_func
Definition: InnerProduct.h:99
typename find_inner_product_packet< Scalar, SizeAtCompileTime >::type Packet
Definition: InnerProduct.h:65
const variable_if_dynamic< Index, SizeAtCompileTime > m_size
Definition: InnerProduct.h:102
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index index) const
Definition: InnerProduct.h:88
static constexpr int RhsAlignment
Definition: InnerProduct.h:62
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const
Definition: InnerProduct.h:77
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const
Definition: InnerProduct.h:79
static constexpr int RhsFlags
Definition: InnerProduct.h:60
static constexpr bool Vectorize
Definition: InnerProduct.h:67
static constexpr int LhsFlags
Definition: InnerProduct.h:60
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE inner_product_evaluator(const Lhs &lhs, const Rhs &rhs, Func func=Func())
Definition: InnerProduct.h:71
const evaluator< Lhs > m_lhs
Definition: InnerProduct.h:100
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(const PacketType &value, Index index) const
Definition: InnerProduct.h:94
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar &value, Index index) const
Definition: InnerProduct.h:83
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator &eval)
Definition: InnerProduct.h:112
typename Evaluator::Scalar Scalar
Definition: InnerProduct.h:111
typename Evaluator::Packet Packet
Definition: InnerProduct.h:130
typename Evaluator::Scalar Scalar
Definition: InnerProduct.h:129
std::make_unsigned_t< Index > UnsignedIndex
Definition: InnerProduct.h:128
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator &eval)
Definition: InnerProduct.h:132
Definition: InnerProduct.h:106
Definition: GenericPacketMath.h:108
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar &accum, const Scalar &a, const Scalar &b) const
Definition: InnerProduct.h:217
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar &a, const Scalar &b) const
Definition: InnerProduct.h:214
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet &a, const Packet &b) const
Definition: InnerProduct.h:221
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet &accum, const Packet &a, const Packet &b) const
Definition: InnerProduct.h:225
Definition: InnerProduct.h:197
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const result_type &accum, const LhsScalar &a, const RhsScalar &b) const
Definition: InnerProduct.h:203
typename ScalarBinaryOpTraits< LhsScalar, RhsScalar >::ReturnType result_type
Definition: InnerProduct.h:198
static constexpr bool PacketAccess
Definition: InnerProduct.h:207
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const LhsScalar &a, const RhsScalar &b) const
Definition: InnerProduct.h:200
Definition: ForwardDeclarations.h:21
Definition: GenericPacketMath.h:134
T half
Definition: GenericPacketMath.h:136
@ size
Definition: GenericPacketMath.h:139
Definition: benchGeometry.cpp:21
void run(const string &dir_name, LinearSolver *linear_solver_pt, const unsigned nel_1d, bool mess_up_order)
Definition: two_d_poisson_compare_solvers.cc:317
Definition: ZVector/PacketMath.h:50