diff options
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r-- | include/armnn/Descriptors.hpp | 71 |
1 files changed, 45 insertions, 26 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 38e3c61500..493ce65976 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -1553,55 +1553,74 @@ struct ChannelShuffleDescriptor : BaseDescriptor /// A BatchMatMulDescriptor for the BatchMatMul operator struct BatchMatMulDescriptor : BaseDescriptor { - BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(), - Optional<DataLayout> dataLayoutY = EmptyOptional(), - std::vector<unsigned int> transposeX = {}, - std::vector<unsigned int> transposeY = {}, - std::vector<unsigned int> adjointX = {}, - std::vector<unsigned int> adjointY = {}) - : m_DataLayoutX(dataLayoutX) - , m_DataLayoutY(dataLayoutY) - , m_TransposeX(transposeX) + BatchMatMulDescriptor(bool transposeX = false, + bool transposeY = false, + bool adjointX = false, + bool adjointY = false, + DataLayout dataLayoutX = DataLayout::NCHW, + DataLayout dataLayoutY = DataLayout::NCHW) + : m_TransposeX(transposeX) , m_TransposeY(transposeY) , m_AdjointX(adjointX) , m_AdjointY(adjointY) + , m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) {} bool operator ==(const BatchMatMulDescriptor &rhs) const { - return m_DataLayoutX == rhs.m_DataLayoutX && - m_DataLayoutY == rhs.m_DataLayoutY && - m_TransposeX == rhs.m_TransposeX && + return m_TransposeX == rhs.m_TransposeX && m_TransposeY == rhs.m_TransposeY && m_AdjointX == rhs.m_AdjointX && - m_AdjointY == rhs.m_AdjointY; + m_AdjointY == rhs.m_AdjointY && + m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY; } - /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout) - Optional<DataLayout> m_DataLayoutX; - Optional<DataLayout> m_DataLayoutY; - - /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing) + /// Transpose the slices of each input tensor /// Transpose and Adjoint can not both be set to true for the same tensor at the same time - std::vector<unsigned int> m_TransposeX; - std::vector<unsigned int> m_TransposeY; + bool m_TransposeX; + bool m_TransposeY; - /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) + /// Adjoint the slices of each input tensor /// Transpose and Adjoint can not both be set to true for the same tensor at the same time - std::vector<unsigned int> m_AdjointX; - std::vector<unsigned int> m_AdjointY; + bool m_AdjointX; + bool m_AdjointY; - /// Static helper to get the two axes (for each input) for multiplication + /// Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) + DataLayout m_DataLayoutX; + DataLayout m_DataLayoutY; + + ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable " + "GetAxesToMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.", + "23.05") static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul( const BatchMatMulDescriptor& desc, const TensorShape& tensorXShape, const TensorShape& tensorYShape); - /// Static helper to get the axes (for each input) that will not be multiplied together + ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable " + "GetAxesNotMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.", + "23.05") static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul( const BatchMatMulDescriptor& desc, const TensorShape& inputXShape, const TensorShape& inputYShape); + + /// Static helper to get the two axes (for each input) for multiplication + static std::pair<unsigned int, unsigned int> GetAxesToMul( + DataLayout dataLayout, + const TensorShape& tensorShape); + + /// Static helper to get the axes (for each input) that will not be multiplied together + static std::vector<unsigned int> GetAxesNotMul( + DataLayout dataLayout, + const TensorShape& tensorShape); + + /// Static helper to get the axes which will be transposed + static PermutationVector GetPermuteVec( + DataLayout dataLayout, + const TensorShape& tensorShape); }; } // namespace armnn |