From 6b47809e7d6c55d20a05d863ce2f09159f381f85 Mon Sep 17 00:00:00 2001 From: Samuel Yap Date: Wed, 6 Jul 2022 15:36:03 +0100 Subject: IVGCVSW-7109: Add Batch MatMul front end support - Reference * Descriptors added for BatchMatMul * Layer definition added * Input validation added (will likely change when opt. param support comes in) * Ref workload implementation for BatchMatMul added (will also change with opt. param support) * Ref layer tests made for BatchMatMul * CMake and other build files updated Signed-off-by: Samuel Yap Change-Id: Ic885301da543ee0fbe7922b85e7f9658c4efc617 --- include/armnn/BackendHelper.hpp | 6 ++++ include/armnn/Descriptors.hpp | 54 +++++++++++++++++++++++++++++++++ include/armnn/DescriptorsFwd.hpp | 1 + include/armnn/INetwork.hpp | 6 ++++ include/armnn/Types.hpp | 3 +- include/armnn/backends/WorkloadData.hpp | 5 +++ 6 files changed, 74 insertions(+), 1 deletion(-) (limited to 'include/armnn') diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp index 09c7385d5c..f78b4f80b9 100644 --- a/include/armnn/BackendHelper.hpp +++ b/include/armnn/BackendHelper.hpp @@ -43,6 +43,12 @@ public: const ArgMinMaxDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()); + bool IsBatchMatMulSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const BatchMatMulDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()); + bool IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 628d045529..38e3c61500 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1550,4 +1550,58 @@ struct ChannelShuffleDescriptor : BaseDescriptor uint32_t m_Axis; }; +/// A BatchMatMulDescriptor for the BatchMatMul operator +struct BatchMatMulDescriptor : BaseDescriptor +{ + BatchMatMulDescriptor(Optional dataLayoutX = EmptyOptional(), + Optional dataLayoutY = EmptyOptional(), + std::vector transposeX = {}, + std::vector transposeY = {}, + std::vector adjointX = {}, + std::vector adjointY = {}) + : m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) + , m_TransposeX(transposeX) + , m_TransposeY(transposeY) + , m_AdjointX(adjointX) + , m_AdjointY(adjointY) + {} + + bool operator ==(const BatchMatMulDescriptor &rhs) const + { + return m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY && + m_TransposeX == rhs.m_TransposeX && + m_TransposeY == rhs.m_TransposeY && + m_AdjointX == rhs.m_AdjointX && + m_AdjointY == rhs.m_AdjointY; + } + + /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout) + Optional m_DataLayoutX; + Optional m_DataLayoutY; + + /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing) + /// Transpose and Adjoint can not both be set to true for the same tensor at the same time + std::vector m_TransposeX; + std::vector m_TransposeY; + + /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) + /// Transpose and Adjoint can not both be set to true for the same tensor at the same time + std::vector m_AdjointX; + std::vector m_AdjointY; + + /// Static helper to get the two axes (for each input) for multiplication + static std::pair, std::pair> GetAxesToMul( + const BatchMatMulDescriptor& desc, + const TensorShape& tensorXShape, + const TensorShape& tensorYShape); + + /// Static helper to get the axes (for each input) that will not be multiplied together + static std::pair, std::vector> GetAxesNotMul( + const BatchMatMulDescriptor& desc, + const TensorShape& inputXShape, + const TensorShape& inputYShape); +}; + } // namespace armnn diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index ab6c7d235a..c0c1cc238d 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -11,6 +11,7 @@ struct BaseDescriptor; struct ActivationDescriptor; struct ArgMinMaxDescriptor; +struct BatchMatMulDescriptor; struct BatchNormalizationDescriptor; struct BatchToSpaceNdDescriptor; struct ChannelShuffleDescriptor; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 3d4be1a7fa..349c7e87b5 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -752,6 +752,12 @@ public: IConnectableLayer* AddChannelShuffleLayer(const ChannelShuffleDescriptor& descriptor, const char* name = nullptr); + /// Add a BatchMatMul layer to the network + /// @param descriptor - Parameters for the BatchMatMul operation + /// @param name - Optional name for the layer + /// @return - Interface for configuring the layer + IConnectableLayer* AddBatchMatMulLayer(const BatchMatMulDescriptor& descriptor, + const char* name = nullptr); void ExecuteStrategy(IStrategy& strategy) const; diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index af75513638..98229df07f 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -458,7 +458,8 @@ using InferenceTimingPair = std::pair; X(ChannelShuffle) \ X(Convolution3d) \ X(Pooling3d) \ - X(GatherNd)\ + X(GatherNd) \ + X(BatchMatMul) \ // New layers should be added at last to minimize instability. diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp index 1a2f34e21f..00962ed52c 100644 --- a/include/armnn/backends/WorkloadData.hpp +++ b/include/armnn/backends/WorkloadData.hpp @@ -785,4 +785,9 @@ struct ChannelShuffleQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } // namespace armnn -- cgit v1.2.1