ei_imklfft_impl.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // This Source Code Form is subject to the terms of the Mozilla
5 // Public License v. 2.0. If a copy of the MPL was not distributed
6 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
7 
8 #include <mkl_dfti.h>
9 
10 // IWYU pragma: private
11 #include "./InternalHeaderCheck.h"
12 
13 #include <complex>
14 #include <memory>
15 
16 namespace Eigen {
17 namespace internal {
18 namespace imklfft {
19 
20 #define RUN_OR_ASSERT(EXPR, ERROR_MSG) \
21  { \
22  MKL_LONG status = (EXPR); \
23  eigen_assert(status == DFTI_NO_ERROR && (ERROR_MSG)); \
24  };
25 
26 inline MKL_Complex16* complex_cast(const std::complex<double>* p) {
27  return const_cast<MKL_Complex16*>(reinterpret_cast<const MKL_Complex16*>(p));
28 }
29 
30 inline MKL_Complex8* complex_cast(const std::complex<float>* p) {
31  return const_cast<MKL_Complex8*>(reinterpret_cast<const MKL_Complex8*>(p));
32 }
33 
34 /*
35  * Parameters:
36  * precision: enum, Precision of the transform: DFTI_SINGLE or DFTI_DOUBLE.
37  * forward_domain: enum, Forward domain of the transform: DFTI_COMPLEX or
38  * DFTI_REAL. dimension: MKL_LONG Dimension of the transform. sizes: MKL_LONG if
39  * dimension = 1.Length of the transform for a one-dimensional transform. sizes:
40  * Array of type MKL_LONG otherwise. Lengths of each dimension for a
41  * multi-dimensional transform.
42  */
43 inline void configure_descriptor(std::shared_ptr<DFTI_DESCRIPTOR>& handl, enum DFTI_CONFIG_VALUE precision,
44  enum DFTI_CONFIG_VALUE forward_domain, MKL_LONG dimension, MKL_LONG* sizes) {
45  eigen_assert(dimension == 1 || dimension == 2 && "Transformation dimension must be less than 3.");
46 
47  DFTI_DESCRIPTOR_HANDLE res = nullptr;
48  if (dimension == 1) {
49  RUN_OR_ASSERT(DftiCreateDescriptor(&res, precision, forward_domain, dimension, *sizes),
50  "DftiCreateDescriptor failed.")
51  handl.reset(res, [](DFTI_DESCRIPTOR_HANDLE handle) { DftiFreeDescriptor(&handle); });
52  if (forward_domain == DFTI_REAL) {
53  // Set CCE storage
54  RUN_OR_ASSERT(DftiSetValue(handl.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX),
55  "DftiSetValue failed.")
56  }
57  } else {
58  RUN_OR_ASSERT(DftiCreateDescriptor(&res, precision, DFTI_COMPLEX, dimension, sizes), "DftiCreateDescriptor failed.")
59  handl.reset(res, [](DFTI_DESCRIPTOR_HANDLE handle) { DftiFreeDescriptor(&handle); });
60  }
61 
62  RUN_OR_ASSERT(DftiSetValue(handl.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE), "DftiSetValue failed.")
63  RUN_OR_ASSERT(DftiCommitDescriptor(handl.get()), "DftiCommitDescriptor failed.")
64 }
65 
66 template <typename T>
67 struct plan {};
68 
69 template <>
70 struct plan<float> {
71  typedef float scalar_type;
72  typedef MKL_Complex8 complex_type;
73 
74  std::shared_ptr<DFTI_DESCRIPTOR> m_plan;
75 
76  plan() = default;
77 
78  enum DFTI_CONFIG_VALUE precision = DFTI_SINGLE;
79 
80  inline void forward(complex_type* dst, complex_type* src, MKL_LONG nfft) {
81  if (m_plan == 0) {
82  configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
83  }
84  RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
85  }
86 
87  inline void inverse(complex_type* dst, complex_type* src, MKL_LONG nfft) {
88  if (m_plan == 0) {
89  configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
90  }
91  RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
92  }
93 
94  inline void forward(complex_type* dst, scalar_type* src, MKL_LONG nfft) {
95  if (m_plan == 0) {
96  configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
97  }
98  RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
99  }
100 
101  inline void inverse(scalar_type* dst, complex_type* src, MKL_LONG nfft) {
102  if (m_plan == 0) {
103  configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
104  }
105  RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
106  }
107 
108  inline void forward2(complex_type* dst, complex_type* src, int n0, int n1) {
109  if (m_plan == 0) {
110  MKL_LONG sizes[2] = {n0, n1};
111  configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
112  }
113  RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
114  }
115 
116  inline void inverse2(complex_type* dst, complex_type* src, int n0, int n1) {
117  if (m_plan == 0) {
118  MKL_LONG sizes[2] = {n0, n1};
119  configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
120  }
121  RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
122  }
123 };
124 
125 template <>
126 struct plan<double> {
127  typedef double scalar_type;
128  typedef MKL_Complex16 complex_type;
129 
130  std::shared_ptr<DFTI_DESCRIPTOR> m_plan;
131 
132  plan() = default;
133 
134  enum DFTI_CONFIG_VALUE precision = DFTI_DOUBLE;
135 
136  inline void forward(complex_type* dst, complex_type* src, MKL_LONG nfft) {
137  if (m_plan == 0) {
138  configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
139  }
140  RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
141  }
142 
143  inline void inverse(complex_type* dst, complex_type* src, MKL_LONG nfft) {
144  if (m_plan == 0) {
145  configure_descriptor(m_plan, precision, DFTI_COMPLEX, 1, &nfft);
146  }
147  RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
148  }
149 
150  inline void forward(complex_type* dst, scalar_type* src, MKL_LONG nfft) {
151  if (m_plan == 0) {
152  configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
153  }
154  RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
155  }
156 
157  inline void inverse(scalar_type* dst, complex_type* src, MKL_LONG nfft) {
158  if (m_plan == 0) {
159  configure_descriptor(m_plan, precision, DFTI_REAL, 1, &nfft);
160  }
161  RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
162  }
163 
164  inline void forward2(complex_type* dst, complex_type* src, int n0, int n1) {
165  if (m_plan == 0) {
166  MKL_LONG sizes[2] = {n0, n1};
167  configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
168  }
169  RUN_OR_ASSERT(DftiComputeForward(m_plan.get(), src, dst), "DftiComputeForward failed.")
170  }
171 
172  inline void inverse2(complex_type* dst, complex_type* src, int n0, int n1) {
173  if (m_plan == 0) {
174  MKL_LONG sizes[2] = {n0, n1};
175  configure_descriptor(m_plan, precision, DFTI_COMPLEX, 2, sizes);
176  }
177  RUN_OR_ASSERT(DftiComputeBackward(m_plan.get(), src, dst), "DftiComputeBackward failed.")
178  }
179 };
180 
181 template <typename Scalar_>
182 struct imklfft_impl {
183  typedef Scalar_ Scalar;
184  typedef std::complex<Scalar> Complex;
185 
186  inline void clear() { m_plans.clear(); }
187 
188  // complex-to-complex forward FFT
189  inline void fwd(Complex* dst, const Complex* src, int nfft) {
190  MKL_LONG size = nfft;
191  get_plan(nfft, dst, src).forward(complex_cast(dst), complex_cast(src), size);
192  }
193 
194  // real-to-complex forward FFT
195  inline void fwd(Complex* dst, const Scalar* src, int nfft) {
196  MKL_LONG size = nfft;
197  get_plan(nfft, dst, src).forward(complex_cast(dst), const_cast<Scalar*>(src), nfft);
198  }
199 
200  // 2-d complex-to-complex
201  inline void fwd2(Complex* dst, const Complex* src, int n0, int n1) {
202  get_plan(n0, n1, dst, src).forward2(complex_cast(dst), complex_cast(src), n0, n1);
203  }
204 
205  // inverse complex-to-complex
206  inline void inv(Complex* dst, const Complex* src, int nfft) {
207  MKL_LONG size = nfft;
208  get_plan(nfft, dst, src).inverse(complex_cast(dst), complex_cast(src), nfft);
209  }
210 
211  // half-complex to scalar
212  inline void inv(Scalar* dst, const Complex* src, int nfft) {
213  MKL_LONG size = nfft;
214  get_plan(nfft, dst, src).inverse(const_cast<Scalar*>(dst), complex_cast(src), nfft);
215  }
216 
217  // 2-d complex-to-complex
218  inline void inv2(Complex* dst, const Complex* src, int n0, int n1) {
219  get_plan(n0, n1, dst, src).inverse2(complex_cast(dst), complex_cast(src), n0, n1);
220  }
221 
222  private:
223  std::map<int64_t, plan<Scalar>> m_plans;
224 
225  inline plan<Scalar>& get_plan(int nfft, void* dst, const void* src) {
226  int inplace = dst == src ? 1 : 0;
227  int aligned = ((reinterpret_cast<size_t>(src) & 15) | (reinterpret_cast<size_t>(dst) & 15)) == 0 ? 1 : 0;
228  int64_t key = ((nfft << 2) | (inplace << 1) | aligned) << 1;
229 
230  // Create element if key does not exist.
231  return m_plans[key];
232  }
233 
234  inline plan<Scalar>& get_plan(int n0, int n1, void* dst, const void* src) {
235  int inplace = (dst == src) ? 1 : 0;
236  int aligned = ((reinterpret_cast<size_t>(src) & 15) | (reinterpret_cast<size_t>(dst) & 15)) == 0 ? 1 : 0;
237  int64_t key = (((((int64_t)n0) << 31) | (n1 << 2) | (inplace << 1) | aligned) << 1) + 1;
238 
239  // Create element if key does not exist.
240  return m_plans[key];
241  }
242 };
243 
244 #undef RUN_OR_ASSERT
245 
246 } // namespace imklfft
247 } // namespace internal
248 } // namespace Eigen
#define eigen_assert(x)
Definition: Macros.h:910
cout<< "Here is the matrix m:"<< endl<< m<< endl;Matrix< ptrdiff_t, 3, 1 > res
Definition: PartialRedux_count.cpp:3
float * p
Definition: Tutorial_Map_using.cpp:9
Scalar Scalar int size
Definition: benchVecAdd.cpp:17
std::vector< Array2i > sizes
Definition: dense_solvers.cpp:12
#define RUN_OR_ASSERT(EXPR, ERROR_MSG)
Definition: ei_imklfft_impl.h:20
void inplace(bool square=false, bool SPD=false)
Definition: inplace_decomposition.cpp:18
void configure_descriptor(std::shared_ptr< DFTI_DESCRIPTOR > &handl, enum DFTI_CONFIG_VALUE precision, enum DFTI_CONFIG_VALUE forward_domain, MKL_LONG dimension, MKL_LONG *sizes)
Definition: ei_imklfft_impl.h:43
MKL_Complex16 * complex_cast(const std::complex< double > *p)
Definition: ei_imklfft_impl.h:26
std::int64_t int64_t
Definition: Meta.h:43
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
Definition: Eigen_Colamd.h:49
Definition: ei_imklfft_impl.h:182
plan< Scalar > & get_plan(int n0, int n1, void *dst, const void *src)
Definition: ei_imklfft_impl.h:234
std::complex< Scalar > Complex
Definition: ei_imklfft_impl.h:184
void fwd2(Complex *dst, const Complex *src, int n0, int n1)
Definition: ei_imklfft_impl.h:201
void fwd(Complex *dst, const Complex *src, int nfft)
Definition: ei_imklfft_impl.h:189
std::map< int64_t, plan< Scalar > > m_plans
Definition: ei_imklfft_impl.h:223
void clear()
Definition: ei_imklfft_impl.h:186
void inv(Scalar *dst, const Complex *src, int nfft)
Definition: ei_imklfft_impl.h:212
void inv2(Complex *dst, const Complex *src, int n0, int n1)
Definition: ei_imklfft_impl.h:218
plan< Scalar > & get_plan(int nfft, void *dst, const void *src)
Definition: ei_imklfft_impl.h:225
void fwd(Complex *dst, const Scalar *src, int nfft)
Definition: ei_imklfft_impl.h:195
void inv(Complex *dst, const Complex *src, int nfft)
Definition: ei_imklfft_impl.h:206
Scalar_ Scalar
Definition: ei_imklfft_impl.h:183
void inverse2(complex_type *dst, complex_type *src, int n0, int n1)
Definition: ei_imklfft_impl.h:172
MKL_Complex16 complex_type
Definition: ei_imklfft_impl.h:128
std::shared_ptr< DFTI_DESCRIPTOR > m_plan
Definition: ei_imklfft_impl.h:130
void inverse(scalar_type *dst, complex_type *src, MKL_LONG nfft)
Definition: ei_imklfft_impl.h:157
void inverse(complex_type *dst, complex_type *src, MKL_LONG nfft)
Definition: ei_imklfft_impl.h:143
void forward2(complex_type *dst, complex_type *src, int n0, int n1)
Definition: ei_imklfft_impl.h:164
void forward(complex_type *dst, complex_type *src, MKL_LONG nfft)
Definition: ei_imklfft_impl.h:136
void forward(complex_type *dst, scalar_type *src, MKL_LONG nfft)
Definition: ei_imklfft_impl.h:150
double scalar_type
Definition: ei_imklfft_impl.h:127
std::shared_ptr< DFTI_DESCRIPTOR > m_plan
Definition: ei_imklfft_impl.h:74
void forward(complex_type *dst, scalar_type *src, MKL_LONG nfft)
Definition: ei_imklfft_impl.h:94
float scalar_type
Definition: ei_imklfft_impl.h:71
void inverse2(complex_type *dst, complex_type *src, int n0, int n1)
Definition: ei_imklfft_impl.h:116
void forward(complex_type *dst, complex_type *src, MKL_LONG nfft)
Definition: ei_imklfft_impl.h:80
MKL_Complex8 complex_type
Definition: ei_imklfft_impl.h:72
void inverse(complex_type *dst, complex_type *src, MKL_LONG nfft)
Definition: ei_imklfft_impl.h:87
void inverse(scalar_type *dst, complex_type *src, MKL_LONG nfft)
Definition: ei_imklfft_impl.h:101
void forward2(complex_type *dst, complex_type *src, int n0, int n1)
Definition: ei_imklfft_impl.h:108
Definition: ei_imklfft_impl.h:67