diff options
Diffstat (limited to 'include/armnn')
-rw-r--r-- | include/armnn/BackendHelper.hpp | 6 | ||||
-rw-r--r-- | include/armnn/Descriptors.hpp | 54 | ||||
-rw-r--r-- | include/armnn/DescriptorsFwd.hpp | 1 | ||||
-rw-r--r-- | include/armnn/INetwork.hpp | 6 | ||||
-rw-r--r-- | include/armnn/Types.hpp | 3 | ||||
-rw-r--r-- | include/armnn/backends/WorkloadData.hpp | 5 |
6 files changed, 74 insertions, 1 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp index 09c7385d5c..f78b4f80b9 100644 --- a/include/armnn/BackendHelper.hpp +++ b/include/armnn/BackendHelper.hpp @@ -43,6 +43,12 @@ public: const ArgMinMaxDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()); + bool IsBatchMatMulSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const BatchMatMulDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()); + bool IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 628d045529..38e3c61500 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1550,4 +1550,58 @@ struct ChannelShuffleDescriptor : BaseDescriptor uint32_t m_Axis; }; +/// A BatchMatMulDescriptor for the BatchMatMul operator +struct BatchMatMulDescriptor : BaseDescriptor +{ + BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(), + Optional<DataLayout> dataLayoutY = EmptyOptional(), + std::vector<unsigned int> transposeX = {}, + std::vector<unsigned int> transposeY = {}, + std::vector<unsigned int> adjointX = {}, + std::vector<unsigned int> adjointY = {}) + : m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) + , m_TransposeX(transposeX) + , m_TransposeY(transposeY) + , m_AdjointX(adjointX) + , m_AdjointY(adjointY) + {} + + bool operator ==(const BatchMatMulDescriptor &rhs) const + { + return m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY && + m_TransposeX == rhs.m_TransposeX && + m_TransposeY == rhs.m_TransposeY && + m_AdjointX == rhs.m_AdjointX && + m_AdjointY == rhs.m_AdjointY; + } + + /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout) + Optional<DataLayout> m_DataLayoutX; + Optional<DataLayout> m_DataLayoutY; + + /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing) + /// Transpose and Adjoint can not both be set to true for the same tensor at the same time + std::vector<unsigned int> m_TransposeX; + std::vector<unsigned int> m_TransposeY; + + /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) + /// Transpose and Adjoint can not both be set to true for the same tensor at the same time + std::vector<unsigned int> m_AdjointX; + std::vector<unsigned int> m_AdjointY; + + /// Static helper to get the two axes (for each input) for multiplication + static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul( + const BatchMatMulDescriptor& desc, + const TensorShape& tensorXShape, + const TensorShape& tensorYShape); + + /// Static helper to get the axes (for each input) that will not be multiplied together + static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul( + const BatchMatMulDescriptor& desc, + const TensorShape& inputXShape, + const TensorShape& inputYShape); +}; + } // namespace armnn diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index ab6c7d235a..c0c1cc238d 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -11,6 +11,7 @@ struct BaseDescriptor; struct ActivationDescriptor; struct ArgMinMaxDescriptor; +struct BatchMatMulDescriptor; struct BatchNormalizationDescriptor; struct BatchToSpaceNdDescriptor; struct ChannelShuffleDescriptor; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 3d4be1a7fa..349c7e87b5 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -752,6 +752,12 @@ public: IConnectableLayer* AddChannelShuffleLayer(const ChannelShuffleDescriptor& descriptor, const char* name = nullptr); + /// Add a BatchMatMul layer to the network + /// @param descriptor - Parameters for the BatchMatMul operation + /// @param name - Optional name for the layer + /// @return - Interface for configuring the layer + IConnectableLayer* AddBatchMatMulLayer(const BatchMatMulDescriptor& descriptor, + const char* name = nullptr); void ExecuteStrategy(IStrategy& strategy) const; diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index af75513638..98229df07f 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -458,7 +458,8 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>; X(ChannelShuffle) \ X(Convolution3d) \ X(Pooling3d) \ - X(GatherNd)\ + X(GatherNd) \ + X(BatchMatMul) \ // New layers should be added at last to minimize instability. diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp index 1a2f34e21f..00962ed52c 100644 --- a/include/armnn/backends/WorkloadData.hpp +++ b/include/armnn/backends/WorkloadData.hpp @@ -785,4 +785,9 @@ struct ChannelShuffleQueueDescriptor : QueueDescriptorWithParameters<ChannelShuf void Validate(const WorkloadInfo& workloadInfo) const; }; +struct BatchMatMulQueueDescriptor : QueueDescriptorWithParameters<BatchMatMulDescriptor> +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } // namespace armnn |