BatchMatMul< TensorType > Struct Template Reference

Public Member Functions

DSizes< DenseIndex, 3 > dimensions (const Tensor< float, 3 > &input1, const Tensor< float, 3 > &input2) const
 
template<typename Output , typename Device >
void eval (const Tensor< float, 3 > &input1, const Tensor< float, 3 > &input2, Output &output, const Device &device) const
 
DSizes< DenseIndex, 3 > dimensions (const TensorType &input1, const TensorType &input2) const
 
template<typename Output , typename Device >
void eval (const TensorType &input1, const TensorType &input2, Output &output, const Device &device) const
 

Member Function Documentation

◆ dimensions() [1/2]

template<typename TensorType >
DSizes<DenseIndex, 3> BatchMatMul< TensorType >::dimensions ( const Tensor< float, 3 > &  input1,
const Tensor< float, 3 > &  input2 
) const
inline
58  {
59  DSizes<DenseIndex, 3> result;
60  result[0] = input1.dimension(0);
61  result[1] = input2.dimension(1);
62  result[2] = input2.dimension(2);
63  return result;
64  }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:99
Definition: TensorDimensions.h:161

References Eigen::Tensor< Scalar_, NumIndices_, Options_, IndexType_ >::dimension().

◆ dimensions() [2/2]

template<typename TensorType >
DSizes<DenseIndex, 3> BatchMatMul< TensorType >::dimensions ( const TensorType &  input1,
const TensorType &  input2 
) const
inline
89  {
90  DSizes<DenseIndex, 3> result;
91  result[0] = input1.dimension(0);
92  result[1] = input2.dimension(1);
93  result[2] = input2.dimension(2);
94  return result;
95  }

◆ eval() [1/2]

template<typename TensorType >
template<typename Output , typename Device >
void BatchMatMul< TensorType >::eval ( const Tensor< float, 3 > &  input1,
const Tensor< float, 3 > &  input2,
Output &  output,
const Device &  device 
) const
inline
68  {
70  array<DimPair, 1> dims;
71  dims[0] = DimPair(1, 0);
72  for (int i = 0; i < output.dimension(2); ++i) {
73  output.template chip<2>(i).device(device) = input1.chip<2>(i).contract(input2.chip<2>(i), dims);
74  }
75  }
int i
Definition: BiCGSTAB_step_by_step.cpp:9
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorChippingOp< DimId, const Derived > chip(const Index offset) const
Definition: TensorBase.h:1141
The tensor class.
Definition: Tensor.h:68
Tensor< float, 1 >::DimensionPair DimPair
Definition: cxx11_tensor_contraction.cpp:17
std::array< T, N > array
Definition: EmulateArray.h:231
void output(std::ostream &outfile, const unsigned &nplot)
Overload output function.
Definition: overloaded_element_body.h:490

References Eigen::TensorBase< Derived, AccessLevel >::chip(), i, and output().

◆ eval() [2/2]

template<typename TensorType >
template<typename Output , typename Device >
void BatchMatMul< TensorType >::eval ( const TensorType &  input1,
const TensorType &  input2,
Output &  output,
const Device &  device 
) const
inline
98  {
99  typedef typename TensorType::DimensionPair DimPair;
100  array<DimPair, 1> dims;
101  dims[0] = DimPair(1, 0);
102  for (int64_t i = 0; i < output.dimension(2); ++i) {
103  output.template chip<2>(i).device(device) = input1.template chip<2>(i).contract(input2.template chip<2>(i), dims);
104  }
105  }
std::int64_t int64_t
Definition: Meta.h:43

References i, and output().


The documentation for this struct was generated from the following files: