TensorIO.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_IO_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_IO_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 struct TensorIOFormat;
19 
20 namespace internal {
21 template <typename Tensor, std::size_t rank, typename Format, typename EnableIf = void>
22 struct TensorPrinter;
23 }
24 
25 template <typename Derived_>
27  using Derived = Derived_;
28  TensorIOFormatBase(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
29  const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
30  const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
33  prefix(prefix),
34  suffix(suffix),
36  fill(fill),
38  flags(flags) {
39  init_spacer();
40  }
41 
42  void init_spacer() {
43  if ((flags & DontAlignCols)) return;
44  spacer.resize(prefix.size());
45  spacer[0] = "";
46  int i = int(tenPrefix.length()) - 1;
47  while (i >= 0 && tenPrefix[i] != '\n') {
48  spacer[0] += ' ';
49  i--;
50  }
51 
52  for (std::size_t k = 1; k < prefix.size(); k++) {
53  int j = int(prefix[k].length()) - 1;
54  while (j >= 0 && prefix[k][j] != '\n') {
55  spacer[k] += ' ';
56  j--;
57  }
58  }
59  }
60 
63  std::vector<std::string> prefix;
64  std::vector<std::string> suffix;
65  std::vector<std::string> separator;
66  char fill;
67  int precision;
68  int flags;
69  std::vector<std::string> spacer{};
70 };
71 
72 struct TensorIOFormatNumpy : public TensorIOFormatBase<TensorIOFormatNumpy> {
75  : Base(/*separator=*/{" ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
76  /*flags=*/0, /*tenPrefix=*/"[", /*tenSuffix=*/"]") {}
77 };
78 
79 struct TensorIOFormatNative : public TensorIOFormatBase<TensorIOFormatNative> {
82  : Base(/*separator=*/{", ", ",\n", "\n"}, /*prefix=*/{"", "{"}, /*suffix=*/{"", "}"},
83  /*precision=*/StreamPrecision, /*flags=*/0, /*tenPrefix=*/"{", /*tenSuffix=*/"}") {}
84 };
85 
86 struct TensorIOFormatPlain : public TensorIOFormatBase<TensorIOFormatPlain> {
89  : Base(/*separator=*/{" ", "\n", "\n", ""}, /*prefix=*/{""}, /*suffix=*/{""}, /*precision=*/StreamPrecision,
90  /*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
91 };
92 
93 struct TensorIOFormatLegacy : public TensorIOFormatBase<TensorIOFormatLegacy> {
96  : Base(/*separator=*/{", ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
97  /*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
98 };
99 
100 struct TensorIOFormat : public TensorIOFormatBase<TensorIOFormat> {
102  TensorIOFormat(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
103  const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
104  const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
106 
107  static inline const TensorIOFormatNumpy Numpy() { return TensorIOFormatNumpy{}; }
108 
109  static inline const TensorIOFormatPlain Plain() { return TensorIOFormatPlain{}; }
110 
111  static inline const TensorIOFormatNative Native() { return TensorIOFormatNative{}; }
112 
113  static inline const TensorIOFormatLegacy Legacy() { return TensorIOFormatLegacy{}; }
114 };
115 
116 template <typename T, int Layout, int rank, typename Format>
118 // specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
119 template <typename T, int rank, typename Format>
120 class TensorWithFormat<T, RowMajor, rank, Format> {
121  public:
122  TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
123 
124  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank, Format>& wf) {
125  // Evaluate the expression if needed
128  Evaluator tensor(eval, DefaultDevice());
129  tensor.evalSubExprsIfNeeded(NULL);
131  // Cleanup.
132  tensor.cleanup();
133  return os;
134  }
135 
136  protected:
138  Format t_format;
139 };
140 
141 template <typename T, int rank, typename Format>
142 class TensorWithFormat<T, ColMajor, rank, Format> {
143  public:
144  TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
145 
146  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank, Format>& wf) {
147  // Switch to RowMajor storage and print afterwards
148  typedef typename T::Index IndexType;
149  std::array<IndexType, rank> shuffle;
150  std::array<IndexType, rank> id;
151  std::iota(id.begin(), id.end(), IndexType(0));
152  std::copy(id.begin(), id.end(), shuffle.rbegin());
153  auto tensor_row_major = wf.t_tensor.swap_layout().shuffle(shuffle);
154 
155  // Evaluate the expression if needed
156  typedef TensorEvaluator<const TensorForcedEvalOp<const decltype(tensor_row_major)>, DefaultDevice> Evaluator;
157  TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
158  Evaluator tensor(eval, DefaultDevice());
159  tensor.evalSubExprsIfNeeded(NULL);
161  // Cleanup.
162  tensor.cleanup();
163  return os;
164  }
165 
166  protected:
168  Format t_format;
169 };
170 
171 template <typename T, typename Format>
172 class TensorWithFormat<T, ColMajor, 0, Format> {
173  public:
174  TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
175 
176  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0, Format>& wf) {
177  // Evaluate the expression if needed
180  Evaluator tensor(eval, DefaultDevice());
181  tensor.evalSubExprsIfNeeded(NULL);
183  // Cleanup.
184  tensor.cleanup();
185  return os;
186  }
187 
188  protected:
190  Format t_format;
191 };
192 
193 namespace internal {
194 
195 // Default scalar printer.
196 template <typename Scalar, typename Format, typename EnableIf = void>
198  static void run(std::ostream& stream, const Scalar& scalar, const Format&) { stream << scalar; }
199 };
200 
201 template <typename Scalar>
202 struct ScalarPrinter<Scalar, TensorIOFormatNumpy, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
203  static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNumpy&) {
204  stream << numext::real(scalar) << "+" << numext::imag(scalar) << "j";
205  }
206 };
207 
208 template <typename Scalar>
209 struct ScalarPrinter<Scalar, TensorIOFormatNative, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
210  static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNative&) {
211  stream << "{" << numext::real(scalar) << ", " << numext::imag(scalar) << "}";
212  }
213 };
214 
215 template <typename Tensor, std::size_t rank, typename Format, typename EnableIf>
217  using Scalar = std::remove_const_t<typename Tensor::Scalar>;
218 
219  static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
220  typedef typename Tensor::Index IndexType;
221 
225  int,
226  std::conditional_t<is_same<Scalar, std::complex<char>>::value ||
230  std::complex<int>, const Scalar&>>
231  PrintType;
232 
233  const IndexType total_size = array_prod(tensor.dimensions());
234 
235  std::streamsize explicit_precision;
236  if (fmt.precision == StreamPrecision) {
237  explicit_precision = 0;
238  } else if (fmt.precision == FullPrecision) {
240  explicit_precision = 0;
241  } else {
242  explicit_precision = significant_decimals_impl<Scalar>::run();
243  }
244  } else {
245  explicit_precision = fmt.precision;
246  }
247 
248  std::streamsize old_precision = 0;
249  if (explicit_precision) old_precision = s.precision(explicit_precision);
250 
251  IndexType width = 0;
252  bool align_cols = !(fmt.flags & DontAlignCols);
253  if (align_cols) {
254  // compute the largest width
255  for (IndexType i = 0; i < total_size; i++) {
256  std::stringstream sstr;
257  sstr.copyfmt(s);
258  ScalarPrinter<Scalar, Format>::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
259  width = std::max<IndexType>(width, IndexType(sstr.str().length()));
260  }
261  }
262  s << fmt.tenPrefix;
263  for (IndexType i = 0; i < total_size; i++) {
264  std::array<bool, rank> is_at_end{};
265  std::array<bool, rank> is_at_begin{};
266 
267  // is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
268  for (std::size_t k = 0; k < rank; k++) {
269  if ((i + 1) % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
270  std::multiplies<IndexType>())) ==
271  0) {
272  is_at_end[k] = true;
273  }
274  }
275 
276  // is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
277  for (std::size_t k = 0; k < rank; k++) {
278  if (i % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
279  std::multiplies<IndexType>())) ==
280  0) {
281  is_at_begin[k] = true;
282  }
283  }
284 
285  // do we have a line break?
286  bool is_at_begin_after_newline = false;
287  for (std::size_t k = 0; k < rank; k++) {
288  if (is_at_begin[k]) {
289  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
290  if (fmt.separator[separator_index].find('\n') != std::string::npos) {
291  is_at_begin_after_newline = true;
292  }
293  }
294  }
295 
296  bool is_at_end_before_newline = false;
297  for (std::size_t k = 0; k < rank; k++) {
298  if (is_at_end[k]) {
299  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
300  if (fmt.separator[separator_index].find('\n') != std::string::npos) {
301  is_at_end_before_newline = true;
302  }
303  }
304  }
305 
306  std::stringstream suffix, prefix, separator;
307  for (std::size_t k = 0; k < rank; k++) {
308  std::size_t suffix_index = (k < fmt.suffix.size()) ? k : fmt.suffix.size() - 1;
309  if (is_at_end[k]) {
310  suffix << fmt.suffix[suffix_index];
311  }
312  }
313  for (std::size_t k = 0; k < rank; k++) {
314  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
315  if (is_at_end[k] &&
316  (!is_at_end_before_newline || fmt.separator[separator_index].find('\n') != std::string::npos)) {
317  separator << fmt.separator[separator_index];
318  }
319  }
320  for (std::size_t k = 0; k < rank; k++) {
321  std::size_t spacer_index = (k < fmt.spacer.size()) ? k : fmt.spacer.size() - 1;
322  if (i != 0 && is_at_begin_after_newline && (!is_at_begin[k] || k == 0)) {
323  prefix << fmt.spacer[spacer_index];
324  }
325  }
326  for (int k = rank - 1; k >= 0; k--) {
327  std::size_t prefix_index = (static_cast<std::size_t>(k) < fmt.prefix.size()) ? k : fmt.prefix.size() - 1;
328  if (is_at_begin[k]) {
329  prefix << fmt.prefix[prefix_index];
330  }
331  }
332 
333  s << prefix.str();
334  // So we don't mess around with formatting, output scalar to a string stream, and adjust the width/fill manually.
335  std::stringstream sstr;
336  sstr.copyfmt(s);
337  ScalarPrinter<Scalar, Format>::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
338  std::string scalar_str = sstr.str();
339  IndexType scalar_width = scalar_str.length();
340  if (width && scalar_width < width) {
341  std::string filler;
342  for (IndexType j = scalar_width; j < width; ++j) {
343  filler.push_back(fmt.fill);
344  }
345  s << filler;
346  }
347  s << scalar_str;
348  s << suffix.str();
349  if (i < total_size - 1) {
350  s << separator.str();
351  }
352  }
353  s << fmt.tenSuffix;
354  if (explicit_precision) s.precision(old_precision);
355  }
356 };
357 
358 template <typename Tensor, std::size_t rank>
359 struct TensorPrinter<Tensor, rank, TensorIOFormatLegacy, std::enable_if_t<rank != 0>> {
361  using Scalar = std::remove_const_t<typename Tensor::Scalar>;
362 
363  static void run(std::ostream& s, const Tensor& tensor, const Format&) {
364  typedef typename Tensor::Index IndexType;
365  // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
366  // (dim(1)*dim(2)*...*dim(rank-1)).
367  const IndexType total_size = internal::array_prod(tensor.dimensions());
368  if (total_size > 0) {
369  const IndexType first_dim = Eigen::internal::array_get<0>(tensor.dimensions());
371  total_size / first_dim);
372  s << matrix;
373  return;
374  }
375  }
376 };
377 
378 template <typename Tensor, typename Format>
379 struct TensorPrinter<Tensor, 0, Format> {
380  static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
381  using Scalar = std::remove_const_t<typename Tensor::Scalar>;
382 
383  std::streamsize explicit_precision;
384  if (fmt.precision == StreamPrecision) {
385  explicit_precision = 0;
386  } else if (fmt.precision == FullPrecision) {
388  explicit_precision = 0;
389  } else {
390  explicit_precision = significant_decimals_impl<Scalar>::run();
391  }
392  } else {
393  explicit_precision = fmt.precision;
394  }
395 
396  std::streamsize old_precision = 0;
397  if (explicit_precision) old_precision = s.precision(explicit_precision);
398  s << fmt.tenPrefix;
399  ScalarPrinter<Scalar, Format>::run(s, tensor.coeff(0), fmt);
400  s << fmt.tenSuffix;
401  if (explicit_precision) s.precision(old_precision);
402  }
403 };
404 
405 } // end namespace internal
406 template <typename T>
407 std::ostream& operator<<(std::ostream& s, const TensorBase<T, ReadOnlyAccessors>& t) {
408  s << t.format(TensorIOFormat::Plain());
409  return s;
410 }
411 } // end namespace Eigen
412 
413 #endif // EIGEN_CXX11_TENSOR_TENSOR_IO_H
AnnoyingScalar imag(const AnnoyingScalar &)
Definition: AnnoyingScalar.h:132
int i
Definition: BiCGSTAB_step_by_step.cpp:9
#define eigen_assert(x)
Definition: Macros.h:910
SCALAR Scalar
Definition: bench_gemm.cpp:45
A matrix or vector expression mapping an existing array of data.
Definition: Map.h:96
The tensor base class.
Definition: TensorBase.h:1026
Definition: TensorForcedEval.h:57
friend std::ostream & operator<<(std::ostream &os, const TensorWithFormat< T, ColMajor, 0, Format > &wf)
Definition: TensorIO.h:176
TensorWithFormat(const T &tensor, const Format &format)
Definition: TensorIO.h:174
Format t_format
Definition: TensorIO.h:190
friend std::ostream & operator<<(std::ostream &os, const TensorWithFormat< T, ColMajor, rank, Format > &wf)
Definition: TensorIO.h:146
TensorWithFormat(const T &tensor, const Format &format)
Definition: TensorIO.h:144
friend std::ostream & operator<<(std::ostream &os, const TensorWithFormat< T, RowMajor, rank, Format > &wf)
Definition: TensorIO.h:124
TensorWithFormat(const T &tensor, const Format &format)
Definition: TensorIO.h:122
Definition: TensorIO.h:117
The tensor class.
Definition: Tensor.h:68
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
Definition: Tensor.h:100
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar & coeff(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) const
Definition: Tensor.h:112
static constexpr int Layout
Definition: Tensor.h:81
internal::traits< Self >::Index Index
Definition: Tensor.h:74
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data()
Definition: Tensor.h:102
Eigen::Map< Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor >, 0, Eigen::OuterStride<> > matrix(T *data, int rows, int cols, int stride)
Definition: common.h:85
float real
Definition: datatypes.h:10
static constexpr lastp1_t end
Definition: IndexedViewHelper.h:79
@ ColMajor
Definition: Constants.h:318
@ RowMajor
Definition: Constants.h:320
RealScalar s
Definition: level1_cplx_impl.h:130
return int(ret)+1
EIGEN_BLAS_FUNC() copy(int *n, RealScalar *px, int *incx, RealScalar *py, int *incy)
Definition: level1_impl.h:32
char char char int int * k
Definition: level2_impl.h:374
EIGEN_STRONG_INLINE Packet2d shuffle(const Packet2d &m, const Packet2d &n, int mask)
Definition: LSX/PacketMath.h:150
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE auto array_prod(const array< T, N > &arr) -> decltype(array_reduce< product_op, T, N >(arr, static_cast< T >(1)))
Definition: MoreMeta.h:497
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:70
squared absolute value
Definition: GlobalFunctions.h:87
std::ostream & operator<<(std::ostream &s, const DiagonalBase< Derived > &m)
Definition: IO.h:227
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:83
@ StreamPrecision
Definition: IO.h:20
@ FullPrecision
Definition: IO.h:20
@ DontAlignCols
Definition: IO.h:19
Definition: Eigen_Colamd.h:49
std::string string(const unsigned &i)
Definition: oomph_definitions.cc:286
t
Definition: plotPSD.py:36
std::string format(const std::string &str, const std::vector< std::string > &find, const std::vector< std::string > &replace)
Definition: openglsupport.cpp:217
internal::nested_eval< T, 1 >::type eval(const T &xpr)
Definition: sparse_permutations.cpp:47
Definition: TensorDeviceDefault.h:19
Holds information about the various numeric (i.e. scalar) types allowed by Eigen.
Definition: NumTraits.h:217
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:31
Definition: TensorIO.h:26
char fill
Definition: TensorIO.h:66
Derived_ Derived
Definition: TensorIO.h:27
std::vector< std::string > suffix
Definition: TensorIO.h:64
TensorIOFormatBase(const std::vector< std::string > &separator, const std::vector< std::string > &prefix, const std::vector< std::string > &suffix, int precision=StreamPrecision, int flags=0, const std::string &tenPrefix="", const std::string &tenSuffix="", const char fill=' ')
Definition: TensorIO.h:28
std::string tenPrefix
Definition: TensorIO.h:61
std::vector< std::string > spacer
Definition: TensorIO.h:69
std::string tenSuffix
Definition: TensorIO.h:62
int precision
Definition: TensorIO.h:67
void init_spacer()
Definition: TensorIO.h:42
std::vector< std::string > separator
Definition: TensorIO.h:65
int flags
Definition: TensorIO.h:68
std::vector< std::string > prefix
Definition: TensorIO.h:63
Definition: TensorIO.h:93
TensorIOFormatLegacy()
Definition: TensorIO.h:95
Definition: TensorIO.h:79
TensorIOFormatNative()
Definition: TensorIO.h:81
Definition: TensorIO.h:72
TensorIOFormatNumpy()
Definition: TensorIO.h:74
Definition: TensorIO.h:86
TensorIOFormatPlain()
Definition: TensorIO.h:88
Definition: TensorIO.h:100
static const TensorIOFormatPlain Plain()
Definition: TensorIO.h:109
static const TensorIOFormatNumpy Numpy()
Definition: TensorIO.h:107
static const TensorIOFormatNative Native()
Definition: TensorIO.h:111
static const TensorIOFormatLegacy Legacy()
Definition: TensorIO.h:113
TensorIOFormat(const std::vector< std::string > &separator, const std::vector< std::string > &prefix, const std::vector< std::string > &suffix, int precision=StreamPrecision, int flags=0, const std::string &tenPrefix="", const std::string &tenSuffix="", const char fill=' ')
Definition: TensorIO.h:102
static void run(std::ostream &stream, const Scalar &scalar, const TensorIOFormatNative &)
Definition: TensorIO.h:210
static void run(std::ostream &stream, const Scalar &scalar, const TensorIOFormatNumpy &)
Definition: TensorIO.h:203
Definition: TensorIO.h:197
static void run(std::ostream &stream, const Scalar &scalar, const Format &)
Definition: TensorIO.h:198
static void run(std::ostream &s, const Tensor &tensor, const Format &fmt)
Definition: TensorIO.h:380
static void run(std::ostream &s, const Tensor &tensor, const Format &)
Definition: TensorIO.h:363
std::remove_const_t< typename Tensor::Scalar > Scalar
Definition: TensorIO.h:361
Definition: TensorIO.h:216
std::remove_const_t< typename Tensor::Scalar > Scalar
Definition: TensorIO.h:217
static void run(std::ostream &s, const Tensor &tensor, const Format &fmt)
Definition: TensorIO.h:219
Definition: Meta.h:205
static int run()
Definition: IO.h:121
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2