diff options
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r-- | include/armnn/Descriptors.hpp | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 628d045529..38e3c61500 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1550,4 +1550,58 @@ struct ChannelShuffleDescriptor : BaseDescriptor uint32_t m_Axis; }; +/// A BatchMatMulDescriptor for the BatchMatMul operator +struct BatchMatMulDescriptor : BaseDescriptor +{ + BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(), + Optional<DataLayout> dataLayoutY = EmptyOptional(), + std::vector<unsigned int> transposeX = {}, + std::vector<unsigned int> transposeY = {}, + std::vector<unsigned int> adjointX = {}, + std::vector<unsigned int> adjointY = {}) + : m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) + , m_TransposeX(transposeX) + , m_TransposeY(transposeY) + , m_AdjointX(adjointX) + , m_AdjointY(adjointY) + {} + + bool operator ==(const BatchMatMulDescriptor &rhs) const + { + return m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY && + m_TransposeX == rhs.m_TransposeX && + m_TransposeY == rhs.m_TransposeY && + m_AdjointX == rhs.m_AdjointX && + m_AdjointY == rhs.m_AdjointY; + } + + /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout) + Optional<DataLayout> m_DataLayoutX; + Optional<DataLayout> m_DataLayoutY; + + /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing) + /// Transpose and Adjoint can not both be set to true for the same tensor at the same time + std::vector<unsigned int> m_TransposeX; + std::vector<unsigned int> m_TransposeY; + + /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) + /// Transpose and Adjoint can not both be set to true for the same tensor at the same time + std::vector<unsigned int> m_AdjointX; + std::vector<unsigned int> m_AdjointY; + + /// Static helper to get the two axes (for each input) for multiplication + static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul( + const BatchMatMulDescriptor& desc, + const TensorShape& tensorXShape, + const TensorShape& tensorYShape); + + /// Static helper to get the axes (for each input) that will not be multiplied together + static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul( + const BatchMatMulDescriptor& desc, + const TensorShape& inputXShape, + const TensorShape& inputYShape); +}; + } // namespace armnn |