diff options
author | Samuel Yap <samuel.yap@arm.com> | 2022-08-08 14:07:42 +0100 |
---|---|---|
committer | Nikhil Raj <nikhil.raj@arm.com> | 2022-08-30 17:03:33 +0100 |
commit | dc8ed9d75e54e914a970e137900930fa64a0782b (patch) | |
tree | 8bcaedaae81a6afbdbe3c9a4e69e45840f18cdb4 /include/armnn | |
parent | 9c9d5b9d796d243d88bd7a7aebb2e7e6c467e3a4 (diff) | |
download | armnn-dc8ed9d75e54e914a970e137900930fa64a0782b.tar.gz |
IVGCVSW-7105: BatchMatMul Optional Parameter Support
* Added transpose parameters to pre-transpose each input tensor's slices
* Added adjoint parameters to pre-adjoint each input tensor's slices
* Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor)
* Updated input validation and output shape inference for parameters
* Additional layer unit tests for parameters added
* Versionings incremented
Signed-off-by: Samuel Yap <samuel.yap@arm.com>
Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667
Diffstat (limited to 'include/armnn')
-rw-r--r-- | include/armnn/Descriptors.hpp | 71 | ||||
-rw-r--r-- | include/armnn/Version.hpp | 2 |
2 files changed, 46 insertions, 27 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 38e3c61500..493ce65976 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -1553,55 +1553,74 @@ struct ChannelShuffleDescriptor : BaseDescriptor /// A BatchMatMulDescriptor for the BatchMatMul operator struct BatchMatMulDescriptor : BaseDescriptor { - BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(), - Optional<DataLayout> dataLayoutY = EmptyOptional(), - std::vector<unsigned int> transposeX = {}, - std::vector<unsigned int> transposeY = {}, - std::vector<unsigned int> adjointX = {}, - std::vector<unsigned int> adjointY = {}) - : m_DataLayoutX(dataLayoutX) - , m_DataLayoutY(dataLayoutY) - , m_TransposeX(transposeX) + BatchMatMulDescriptor(bool transposeX = false, + bool transposeY = false, + bool adjointX = false, + bool adjointY = false, + DataLayout dataLayoutX = DataLayout::NCHW, + DataLayout dataLayoutY = DataLayout::NCHW) + : m_TransposeX(transposeX) , m_TransposeY(transposeY) , m_AdjointX(adjointX) , m_AdjointY(adjointY) + , m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) {} bool operator ==(const BatchMatMulDescriptor &rhs) const { - return m_DataLayoutX == rhs.m_DataLayoutX && - m_DataLayoutY == rhs.m_DataLayoutY && - m_TransposeX == rhs.m_TransposeX && + return m_TransposeX == rhs.m_TransposeX && m_TransposeY == rhs.m_TransposeY && m_AdjointX == rhs.m_AdjointX && - m_AdjointY == rhs.m_AdjointY; + m_AdjointY == rhs.m_AdjointY && + m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY; } - /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout) - Optional<DataLayout> m_DataLayoutX; - Optional<DataLayout> m_DataLayoutY; - - /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing) + /// Transpose the slices of each input tensor /// Transpose and Adjoint can not both be set to true for the same tensor at the same time - std::vector<unsigned int> m_TransposeX; - std::vector<unsigned int> m_TransposeY; + bool m_TransposeX; + bool m_TransposeY; - /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) + /// Adjoint the slices of each input tensor /// Transpose and Adjoint can not both be set to true for the same tensor at the same time - std::vector<unsigned int> m_AdjointX; - std::vector<unsigned int> m_AdjointY; + bool m_AdjointX; + bool m_AdjointY; - /// Static helper to get the two axes (for each input) for multiplication + /// Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) + DataLayout m_DataLayoutX; + DataLayout m_DataLayoutY; + + ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable " + "GetAxesToMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.", + "23.05") static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul( const BatchMatMulDescriptor& desc, const TensorShape& tensorXShape, const TensorShape& tensorYShape); - /// Static helper to get the axes (for each input) that will not be multiplied together + ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable " + "GetAxesNotMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.", + "23.05") static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul( const BatchMatMulDescriptor& desc, const TensorShape& inputXShape, const TensorShape& inputYShape); + + /// Static helper to get the two axes (for each input) for multiplication + static std::pair<unsigned int, unsigned int> GetAxesToMul( + DataLayout dataLayout, + const TensorShape& tensorShape); + + /// Static helper to get the axes (for each input) that will not be multiplied together + static std::vector<unsigned int> GetAxesNotMul( + DataLayout dataLayout, + const TensorShape& tensorShape); + + /// Static helper to get the axes which will be transposed + static PermutationVector GetPermuteVec( + DataLayout dataLayout, + const TensorShape& tensorShape); }; } // namespace armnn diff --git a/include/armnn/Version.hpp b/include/armnn/Version.hpp index 7951eacf1d..7fdb20ade5 100644 --- a/include/armnn/Version.hpp +++ b/include/armnn/Version.hpp @@ -10,7 +10,7 @@ #define STRINGIFY_MACRO(s) #s // ArmNN version components -#define ARMNN_MAJOR_VERSION 30 +#define ARMNN_MAJOR_VERSION 31 #define ARMNN_MINOR_VERSION 0 #define ARMNN_PATCH_VERSION 0 |