MathFunctionsImpl.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com)
5 // Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_MATHFUNCTIONSIMPL_H
12 #define EIGEN_MATHFUNCTIONSIMPL_H
13 
14 // IWYU pragma: private
15 #include "./InternalHeaderCheck.h"
16 
17 namespace Eigen {
18 
19 namespace internal {
20 
35 template <typename Packet, int Steps>
37  static_assert(Steps > 0, "Steps must be at least 1.");
38  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_a_recip) {
39  using Scalar = typename unpacket_traits<Packet>::type;
40  const Packet two = pset1<Packet>(Scalar(2));
41  // Refine the approximation using one Newton-Raphson step:
42  // x_{i} = x_{i-1} * (2 - a * x_{i-1})
44  const Packet tmp = pnmadd(a, x, two);
45  // If tmp is NaN, it means that a is either +/-0 or +/-Inf.
46  // In this case return the approximation directly.
47  const Packet is_not_nan = pcmp_eq(tmp, tmp);
48  return pselect(is_not_nan, pmul(x, tmp), x);
49  }
50 };
51 
52 template <typename Packet>
54  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& /*unused*/, const Packet& approx_rsqrt) {
55  return approx_rsqrt;
56  }
57 };
58 
74 template <typename Packet, int Steps>
76  static_assert(Steps > 0, "Steps must be at least 1.");
78  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_rsqrt) {
79  constexpr Scalar kMinusHalf = Scalar(-1) / Scalar(2);
80  const Packet cst_minus_half = pset1<Packet>(kMinusHalf);
81  const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
82 
83  Packet inv_sqrt = approx_rsqrt;
84  for (int step = 0; step < Steps; ++step) {
85  // Refine the approximation using one Newton-Raphson step:
86  // h_n = (x * inv_sqrt) * inv_sqrt - 1 (so that h_n is nearly 0).
87  // inv_sqrt = inv_sqrt - 0.5 * inv_sqrt * h_n
88  Packet r2 = pmul(a, inv_sqrt);
89  Packet half_r = pmul(inv_sqrt, cst_minus_half);
90  Packet h_n = pmadd(r2, inv_sqrt, cst_minus_one);
91  inv_sqrt = pmadd(half_r, h_n, inv_sqrt);
92  }
93 
94  // If x is NaN, then either:
95  // 1) the input is NaN
96  // 2) zero and infinity were multiplied
97  // In either of these cases, return approx_rsqrt
98  return pselect(pisnan(inv_sqrt), approx_rsqrt, inv_sqrt);
99  }
100 };
101 
102 template <typename Packet>
104  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& /*unused*/, const Packet& approx_rsqrt) {
105  return approx_rsqrt;
106  }
107 };
108 
124 template <typename Packet, int Steps = 1>
126  static_assert(Steps > 0, "Steps must be at least 1.");
127 
128  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_rsqrt) {
129  using Scalar = typename unpacket_traits<Packet>::type;
130  const Packet one_point_five = pset1<Packet>(Scalar(1.5));
131  const Packet minus_half = pset1<Packet>(Scalar(-0.5));
132  // If a is inf or zero, return a directly.
133  const Packet inf_mask = pcmp_eq(a, pset1<Packet>(NumTraits<Scalar>::infinity()));
134  const Packet return_a = por(pcmp_eq(a, pzero(a)), inf_mask);
135  // Do a single step of Newton's iteration for reciprocal square root:
136  // x_{n+1} = x_n * (1.5 + (-0.5 * x_n) * (a * x_n))).
137  // The Newton's step is computed this way to avoid over/under-flows.
138  Packet rsqrt = pmul(approx_rsqrt, pmadd(pmul(minus_half, approx_rsqrt), pmul(a, approx_rsqrt), one_point_five));
139  for (int step = 1; step < Steps; ++step) {
140  rsqrt = pmul(rsqrt, pmadd(pmul(minus_half, rsqrt), pmul(a, rsqrt), one_point_five));
141  }
142 
143  // Return sqrt(x) = x * rsqrt(x) for non-zero finite positive arguments.
144  // Return a itself for 0 or +inf, NaN for negative arguments.
145  return pselect(return_a, a, pmul(a, rsqrt));
146  }
147 };
148 
149 template <typename RealScalar>
151  // IEEE IEC 6059 special cases.
154 
156  RealScalar p, qp;
157  p = numext::maxi(x, y);
158  if (numext::is_exactly_zero(p)) return RealScalar(0);
159  qp = numext::mini(y, x) / p;
160  return p * sqrt(RealScalar(1) + qp * qp);
161 }
162 
163 template <typename Scalar>
164 struct hypot_impl {
166  static EIGEN_DEVICE_FUNC inline RealScalar run(const Scalar& x, const Scalar& y) {
168  return positive_real_hypot<RealScalar>(abs(x), abs(y));
169  }
170 };
171 
172 // Generic complex sqrt implementation that correctly handles corner cases
173 // according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt
174 template <typename T>
175 EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) {
176  // Computes the principal sqrt of the input.
177  //
178  // For a complex square root of the number x + i*y. We want to find real
179  // numbers u and v such that
180  // (u + i*v)^2 = x + i*y <=>
181  // u^2 - v^2 + i*2*u*v = x + i*v.
182  // By equating the real and imaginary parts we get:
183  // u^2 - v^2 = x
184  // 2*u*v = y.
185  //
186  // For x >= 0, this has the numerically stable solution
187  // u = sqrt(0.5 * (x + sqrt(x^2 + y^2)))
188  // v = y / (2 * u)
189  // and for x < 0,
190  // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2)))
191  // u = y / (2 * v)
192  //
193  // Letting w = sqrt(0.5 * (|x| + |z|)),
194  // if x == 0: u = w, v = sign(y) * w
195  // if x > 0: u = w, v = y / (2 * w)
196  // if x < 0: u = |y| / (2 * w), v = sign(y) * w
197 
198  const T x = numext::real(z);
199  const T y = numext::imag(z);
200  const T zero = T(0);
201  const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y)));
202 
203  return (numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y)
204  : numext::is_exactly_zero(x) ? std::complex<T>(w, y < zero ? -w : w)
205  : x > zero ? std::complex<T>(w, y / (2 * w))
206  : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w);
207 }
208 
209 // Generic complex rsqrt implementation.
210 template <typename T>
211 EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) {
212  // Computes the principal reciprocal sqrt of the input.
213  //
214  // For a complex reciprocal square root of the number z = x + i*y. We want to
215  // find real numbers u and v such that
216  // (u + i*v)^2 = 1 / (x + i*y) <=>
217  // u^2 - v^2 + i*2*u*v = x/|z|^2 - i*v/|z|^2.
218  // By equating the real and imaginary parts we get:
219  // u^2 - v^2 = x/|z|^2
220  // 2*u*v = y/|z|^2.
221  //
222  // For x >= 0, this has the numerically stable solution
223  // u = sqrt(0.5 * (x + |z|)) / |z|
224  // v = -y / (2 * u * |z|)
225  // and for x < 0,
226  // v = -sign(y) * sqrt(0.5 * (-x + |z|)) / |z|
227  // u = -y / (2 * v * |z|)
228  //
229  // Letting w = sqrt(0.5 * (|x| + |z|)),
230  // if x == 0: u = w / |z|, v = -sign(y) * w / |z|
231  // if x > 0: u = w / |z|, v = -y / (2 * w * |z|)
232  // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z|
233 
234  const T x = numext::real(z);
235  const T y = numext::imag(z);
236  const T zero = T(0);
237 
238  const T abs_z = numext::hypot(x, y);
239  const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z));
240  const T woz = w / abs_z;
241  // Corner cases consistent with 1/sqrt(z) on gcc/clang.
242  return numext::is_exactly_zero(abs_z) ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN())
243  : ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero)
244  : numext::is_exactly_zero(x) ? std::complex<T>(woz, y < zero ? woz : -woz)
245  : x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z))
246  : std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz);
247 }
248 
249 template <typename T>
250 EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) {
251  // Computes complex log.
252  T a = numext::abs(z);
254  T b = atan2(z.imag(), z.real());
255  return std::complex<T>(numext::log(a), b);
256 }
257 
258 } // end namespace internal
259 
260 } // end namespace Eigen
261 
262 #endif // EIGEN_MATHFUNCTIONSIMPL_H
AnnoyingScalar abs(const AnnoyingScalar &x)
Definition: AnnoyingScalar.h:135
AnnoyingScalar imag(const AnnoyingScalar &)
Definition: AnnoyingScalar.h:132
AnnoyingScalar sqrt(const AnnoyingScalar &x)
Definition: AnnoyingScalar.h:134
Eigen::Triplet< double > T
Definition: EigenUnitTest.cpp:11
#define EIGEN_USING_STD(FUNC)
Definition: Macros.h:1090
#define EIGEN_DEVICE_FUNC
Definition: Macros.h:892
#define EIGEN_STRONG_INLINE
Definition: Macros.h:834
RowVector3d w
Definition: Matrix_resize_int.cpp:3
float * p
Definition: Tutorial_Map_using.cpp:9
Scalar * b
Definition: benchVecAdd.cpp:17
SCALAR Scalar
Definition: bench_gemm.cpp:45
NumTraits< Scalar >::Real RealScalar
Definition: bench_gemm.cpp:46
float real
Definition: datatypes.h:10
const Scalar * a
Definition: level2_cplx_impl.h:32
Eigen::Matrix< Scalar, Dynamic, Dynamic, ColMajor > tmp
Definition: level3_impl.h:365
EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f &)
Definition: AVX/PacketMath.h:774
EIGEN_STRONG_INLINE Packet8f pisnan(const Packet8f &a)
Definition: AVX/PacketMath.h:1034
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RealScalar positive_real_hypot(const RealScalar &x, const RealScalar &y)
Definition: MathFunctionsImpl.h:150
const Scalar & y
Definition: RandomImpl.h:36
EIGEN_STRONG_INLINE Packet8h por(const Packet8h &a, const Packet8h &b)
Definition: AVX/PacketMath.h:2309
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: AltiVec/PacketMath.h:1218
EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf &a, const Packet4cf &b)
Definition: AVX/Complex.h:88
EIGEN_DEVICE_FUNC std::complex< T > complex_log(const std::complex< T > &z)
Definition: MathFunctionsImpl.h:250
EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf &a, const Packet2cf &b)
Definition: AltiVec/Complex.h:353
EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f &a, const Packet4f &b, const Packet4f &c)
Definition: LSX/PacketMath.h:827
EIGEN_DEVICE_FUNC std::complex< T > complex_sqrt(const std::complex< T > &a_x)
Definition: MathFunctionsImpl.h:175
EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f &mask, const Packet4f &a, const Packet4f &b)
Definition: AltiVec/PacketMath.h:1474
EIGEN_DEVICE_FUNC std::complex< T > complex_rsqrt(const std::complex< T > &a_x)
Definition: MathFunctionsImpl.h:211
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T log(const T &x)
Definition: MathFunctions.h:1332
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool() isinf(const Eigen::bfloat16 &h)
Definition: BFloat16.h:747
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
Definition: MathFunctions.h:926
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::enable_if_t< NumTraits< T >::IsSigned||NumTraits< T >::IsComplex, typename NumTraits< T >::Real > abs(const T &x)
Definition: MathFunctions.h:1355
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T rsqrt(const T &x)
Definition: MathFunctions.h:1327
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool() isnan(const Eigen::bfloat16 &h)
Definition: BFloat16.h:742
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool is_exactly_zero(const X &x)
Definition: Meta.h:592
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float sqrt(const float &x)
Definition: arch/SSE/MathFunctions.h:69
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
Definition: MathFunctions.h:920
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
AutoDiffScalar< Matrix< typename internal::traits< internal::remove_all_t< DerTypeA > >::Scalar, Dynamic, 1 > > atan2(const AutoDiffScalar< DerTypeA > &a, const AutoDiffScalar< DerTypeB > &b)
Definition: AutoDiffScalar.h:558
Definition: Eigen_Colamd.h:49
list x
Definition: plotDoE.py:28
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
Definition: NumTraits.h:217
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet &, const Packet &approx_rsqrt)
Definition: MathFunctionsImpl.h:54
Definition: MathFunctionsImpl.h:36
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet &a, const Packet &approx_a_recip)
Definition: MathFunctionsImpl.h:38
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet &, const Packet &approx_rsqrt)
Definition: MathFunctionsImpl.h:104
Definition: MathFunctionsImpl.h:75
typename unpacket_traits< Packet >::type Scalar
Definition: MathFunctionsImpl.h:77
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet &a, const Packet &approx_rsqrt)
Definition: MathFunctionsImpl.h:78
Definition: MathFunctionsImpl.h:125
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet &a, const Packet &approx_rsqrt)
Definition: MathFunctionsImpl.h:128
Definition: MathFunctionsImpl.h:164
NumTraits< Scalar >::Real RealScalar
Definition: MathFunctionsImpl.h:165
static EIGEN_DEVICE_FUNC RealScalar run(const Scalar &x, const Scalar &y)
Definition: MathFunctionsImpl.h:166
EIGEN_DONT_INLINE Scalar zero()
Definition: svd_common.h:232
Definition: ZVector/PacketMath.h:50