23.05
|
A BatchMatMulDescriptor for the BatchMatMul operator. More...
#include <Descriptors.hpp>
Public Member Functions | |
BatchMatMulDescriptor (bool transposeX=false, bool transposeY=false, bool adjointX=false, bool adjointY=false, DataLayout dataLayoutX=DataLayout::NCHW, DataLayout dataLayoutY=DataLayout::NCHW) | |
bool | operator== (const BatchMatMulDescriptor &rhs) const |
Public Member Functions inherited from BaseDescriptor | |
virtual bool | IsNull () const |
virtual | ~BaseDescriptor ()=default |
Static Public Member Functions | |
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > | GetAxesToMul (const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape) |
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > | GetAxesNotMul (const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape) |
static std::pair< unsigned int, unsigned int > | GetAxesToMul (DataLayout dataLayout, const TensorShape &tensorShape) |
Static helper to get the two axes (for each input) for multiplication. More... | |
static std::vector< unsigned int > | GetAxesNotMul (DataLayout dataLayout, const TensorShape &tensorShape) |
Static helper to get the axes (for each input) that will not be multiplied together. More... | |
static PermutationVector | GetPermuteVec (DataLayout dataLayout, const TensorShape &tensorShape) |
Static helper to get the axes which will be transposed. More... | |
Public Attributes | |
bool | m_TransposeX |
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More... | |
bool | m_TransposeY |
bool | m_AdjointX |
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More... | |
bool | m_AdjointY |
DataLayout | m_DataLayoutX |
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) More... | |
DataLayout | m_DataLayoutY |
A BatchMatMulDescriptor for the BatchMatMul operator.
Definition at line 1551 of file Descriptors.hpp.
|
inline |
Definition at line 1553 of file Descriptors.hpp.
|
static |
Definition at line 467 of file Descriptors.cpp.
References BatchMatMulDescriptor::m_DataLayoutX, and BatchMatMulDescriptor::m_DataLayoutY.
Referenced by BatchMatMulQueueDescriptor::Validate().
|
static |
Static helper to get the axes (for each input) that will not be multiplied together.
Definition at line 497 of file Descriptors.cpp.
References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().
|
static |
Definition at line 459 of file Descriptors.cpp.
References BatchMatMulDescriptor::m_DataLayoutX, and BatchMatMulDescriptor::m_DataLayoutY.
Referenced by BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetPermuteVec(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().
|
static |
Static helper to get the two axes (for each input) for multiplication.
Definition at line 476 of file Descriptors.cpp.
References TensorShape::GetNumDimensions(), armnn::NCDHW, armnn::NCHW, armnn::NDHWC, and armnn::NHWC.
|
static |
Static helper to get the axes which will be transposed.
Definition at line 514 of file Descriptors.cpp.
References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().
Referenced by BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().
|
inline |
bool m_AdjointX |
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time.
Definition at line 1584 of file Descriptors.hpp.
Referenced by armnn::ClBatchMatMulValidate(), ClBatchMatMulWorkload::ClBatchMatMulWorkload(), BatchMatMulLayer::InferOutputShapes(), armnn::NeonBatchMatMulValidate(), NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(), BatchMatMulDescriptor::operator==(), StringifyLayerParameters< BatchMatMulDescriptor >::Serialize(), and BatchMatMulQueueDescriptor::Validate().
bool m_AdjointY |
Definition at line 1585 of file Descriptors.hpp.
Referenced by armnn::ClBatchMatMulValidate(), ClBatchMatMulWorkload::ClBatchMatMulWorkload(), BatchMatMulLayer::InferOutputShapes(), armnn::NeonBatchMatMulValidate(), NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(), BatchMatMulDescriptor::operator==(), StringifyLayerParameters< BatchMatMulDescriptor >::Serialize(), and BatchMatMulQueueDescriptor::Validate().
DataLayout m_DataLayoutX |
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition at line 1588 of file Descriptors.hpp.
Referenced by armnn::ClBatchMatMulValidate(), ClBatchMatMulWorkload::ClBatchMatMulWorkload(), BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetAxesToMul(), BatchMatMulLayer::InferOutputShapes(), armnn::NeonBatchMatMulValidate(), NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(), BatchMatMulDescriptor::operator==(), and BatchMatMulQueueDescriptor::Validate().
DataLayout m_DataLayoutY |
Definition at line 1589 of file Descriptors.hpp.
Referenced by armnn::ClBatchMatMulValidate(), ClBatchMatMulWorkload::ClBatchMatMulWorkload(), BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetAxesToMul(), BatchMatMulLayer::InferOutputShapes(), armnn::NeonBatchMatMulValidate(), NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(), BatchMatMulDescriptor::operator==(), and BatchMatMulQueueDescriptor::Validate().
bool m_TransposeX |
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time.
Definition at line 1579 of file Descriptors.hpp.
Referenced by ClBatchMatMulWorkload::ClBatchMatMulWorkload(), BatchMatMulLayer::InferOutputShapes(), NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(), BatchMatMulDescriptor::operator==(), StringifyLayerParameters< BatchMatMulDescriptor >::Serialize(), and BatchMatMulQueueDescriptor::Validate().
bool m_TransposeY |
Definition at line 1580 of file Descriptors.hpp.
Referenced by ClBatchMatMulWorkload::ClBatchMatMulWorkload(), BatchMatMulLayer::InferOutputShapes(), NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(), BatchMatMulDescriptor::operator==(), StringifyLayerParameters< BatchMatMulDescriptor >::Serialize(), and BatchMatMulQueueDescriptor::Validate().