diff options
author | Samuel Yap <samuel.yap@arm.com> | 2022-07-06 15:36:03 +0100 |
---|---|---|
committer | Samuel Yap <samuel.yap@arm.com> | 2022-07-22 16:52:38 +0100 |
commit | 4b7a34dd92eb3f736e05ac6623fd147ecd8636b1 (patch) | |
tree | c33e5820f89e359c80d8773288e8adb075735039 /include/armnn/Descriptors.hpp | |
parent | 16929a2b432232f7a34fcbd1f1b0fe1212500206 (diff) | |
download | armnn-4b7a34dd92eb3f736e05ac6623fd147ecd8636b1.tar.gz |
IVGCVSW-7109: Add Batch MatMul front end support - Reference
* Descriptors added for BatchMatMul
* Layer definition added
* Input validation added (will likely change when opt. param support comes in)
* Ref workload implementation for BatchMatMul added (will also change with opt. param support)
* Ref layer tests made for BatchMatMul
* CMake and other build files updated
Signed-off-by: Samuel Yap <samuel.yap@arm.com>
Change-Id: Ic885301da543ee0fbe7922b85e7f9658c4efc617
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r-- | include/armnn/Descriptors.hpp | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 628d045529..38e3c61500 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1550,4 +1550,58 @@ struct ChannelShuffleDescriptor : BaseDescriptor uint32_t m_Axis; }; +/// A BatchMatMulDescriptor for the BatchMatMul operator +struct BatchMatMulDescriptor : BaseDescriptor +{ + BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(), + Optional<DataLayout> dataLayoutY = EmptyOptional(), + std::vector<unsigned int> transposeX = {}, + std::vector<unsigned int> transposeY = {}, + std::vector<unsigned int> adjointX = {}, + std::vector<unsigned int> adjointY = {}) + : m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) + , m_TransposeX(transposeX) + , m_TransposeY(transposeY) + , m_AdjointX(adjointX) + , m_AdjointY(adjointY) + {} + + bool operator ==(const BatchMatMulDescriptor &rhs) const + { + return m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY && + m_TransposeX == rhs.m_TransposeX && + m_TransposeY == rhs.m_TransposeY && + m_AdjointX == rhs.m_AdjointX && + m_AdjointY == rhs.m_AdjointY; + } + + /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout) + Optional<DataLayout> m_DataLayoutX; + Optional<DataLayout> m_DataLayoutY; + + /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing) + /// Transpose and Adjoint can not both be set to true for the same tensor at the same time + std::vector<unsigned int> m_TransposeX; + std::vector<unsigned int> m_TransposeY; + + /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) + /// Transpose and Adjoint can not both be set to true for the same tensor at the same time + std::vector<unsigned int> m_AdjointX; + std::vector<unsigned int> m_AdjointY; + + /// Static helper to get the two axes (for each input) for multiplication + static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul( + const BatchMatMulDescriptor& desc, + const TensorShape& tensorXShape, + const TensorShape& tensorYShape); + + /// Static helper to get the axes (for each input) that will not be multiplied together + static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul( + const BatchMatMulDescriptor& desc, + const TensorShape& inputXShape, + const TensorShape& inputYShape); +}; + } // namespace armnn |