From dc8ed9d75e54e914a970e137900930fa64a0782b Mon Sep 17 00:00:00 2001 From: Samuel Yap Date: Mon, 8 Aug 2022 14:07:42 +0100 Subject: IVGCVSW-7105: BatchMatMul Optional Parameter Support * Added transpose parameters to pre-transpose each input tensor's slices * Added adjoint parameters to pre-adjoint each input tensor's slices * Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor) * Updated input validation and output shape inference for parameters * Additional layer unit tests for parameters added * Versionings incremented Signed-off-by: Samuel Yap Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667 --- include/armnn/Descriptors.hpp | 71 +++++++++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 26 deletions(-) (limited to 'include/armnn/Descriptors.hpp') diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 38e3c61500..493ce65976 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -1553,55 +1553,74 @@ struct ChannelShuffleDescriptor : BaseDescriptor /// A BatchMatMulDescriptor for the BatchMatMul operator struct BatchMatMulDescriptor : BaseDescriptor { - BatchMatMulDescriptor(Optional dataLayoutX = EmptyOptional(), - Optional dataLayoutY = EmptyOptional(), - std::vector transposeX = {}, - std::vector transposeY = {}, - std::vector adjointX = {}, - std::vector adjointY = {}) - : m_DataLayoutX(dataLayoutX) - , m_DataLayoutY(dataLayoutY) - , m_TransposeX(transposeX) + BatchMatMulDescriptor(bool transposeX = false, + bool transposeY = false, + bool adjointX = false, + bool adjointY = false, + DataLayout dataLayoutX = DataLayout::NCHW, + DataLayout dataLayoutY = DataLayout::NCHW) + : m_TransposeX(transposeX) , m_TransposeY(transposeY) , m_AdjointX(adjointX) , m_AdjointY(adjointY) + , m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) {} bool operator ==(const BatchMatMulDescriptor &rhs) const { - return m_DataLayoutX == rhs.m_DataLayoutX && - m_DataLayoutY == rhs.m_DataLayoutY && - m_TransposeX == rhs.m_TransposeX && + return m_TransposeX == rhs.m_TransposeX && m_TransposeY == rhs.m_TransposeY && m_AdjointX == rhs.m_AdjointX && - m_AdjointY == rhs.m_AdjointY; + m_AdjointY == rhs.m_AdjointY && + m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY; } - /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout) - Optional m_DataLayoutX; - Optional m_DataLayoutY; - - /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing) + /// Transpose the slices of each input tensor /// Transpose and Adjoint can not both be set to true for the same tensor at the same time - std::vector m_TransposeX; - std::vector m_TransposeY; + bool m_TransposeX; + bool m_TransposeY; - /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) + /// Adjoint the slices of each input tensor /// Transpose and Adjoint can not both be set to true for the same tensor at the same time - std::vector m_AdjointX; - std::vector m_AdjointY; + bool m_AdjointX; + bool m_AdjointY; - /// Static helper to get the two axes (for each input) for multiplication + /// Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) + DataLayout m_DataLayoutX; + DataLayout m_DataLayoutY; + + ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable " + "GetAxesToMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.", + "23.05") static std::pair, std::pair> GetAxesToMul( const BatchMatMulDescriptor& desc, const TensorShape& tensorXShape, const TensorShape& tensorYShape); - /// Static helper to get the axes (for each input) that will not be multiplied together + ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable " + "GetAxesNotMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.", + "23.05") static std::pair, std::vector> GetAxesNotMul( const BatchMatMulDescriptor& desc, const TensorShape& inputXShape, const TensorShape& inputYShape); + + /// Static helper to get the two axes (for each input) for multiplication + static std::pair GetAxesToMul( + DataLayout dataLayout, + const TensorShape& tensorShape); + + /// Static helper to get the axes (for each input) that will not be multiplied together + static std::vector GetAxesNotMul( + DataLayout dataLayout, + const TensorShape& tensorShape); + + /// Static helper to get the axes which will be transposed + static PermutationVector GetPermuteVec( + DataLayout dataLayout, + const TensorShape& tensorShape); }; } // namespace armnn -- cgit v1.2.1