From 98e383eadf4e670d057ad725c7fe7924fea8e36b Mon Sep 17 00:00:00 2001 From: Idriss Chaouch Date: Mon, 28 Aug 2023 14:28:31 +0100 Subject: IVGCVSW-7525 Add broadcast_to operator Signed-off-by: Idriss Chaouch Signed-off-by: Narumol Prangnawarat Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8 --- include/armnn/BackendHelper.hpp | 6 ++++++ include/armnn/Descriptors.hpp | 19 +++++++++++++++++++ include/armnn/DescriptorsFwd.hpp | 1 + include/armnn/INetwork.hpp | 7 +++++++ include/armnn/Types.hpp | 3 ++- include/armnn/backends/WorkloadData.hpp | 5 +++++ 6 files changed, 40 insertions(+), 1 deletion(-) (limited to 'include') diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp index 986f854636..b61f010b0f 100644 --- a/include/armnn/BackendHelper.hpp +++ b/include/armnn/BackendHelper.hpp @@ -73,6 +73,12 @@ public: const BatchToSpaceNdDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()); + + bool IsBroadcastToSupported(const TensorInfo& input, + const TensorInfo& output, + const BroadcastToDescriptor& descriptor, + Optional reasonIfUnsupported); + bool IsCastSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()); diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 30eaefd83b..bf40b35ae9 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1656,4 +1656,23 @@ struct TileDescriptor : BaseDescriptor std::vector m_Multiples; }; +struct BroadcastToDescriptor : BaseDescriptor +{ + BroadcastToDescriptor() + : m_BroadcastToShape() + {} + + explicit BroadcastToDescriptor(const TensorShape& shape) + : m_BroadcastToShape(shape) + {} + + bool operator ==(const BroadcastToDescriptor& rhs) const + { + return m_BroadcastToShape == rhs.m_BroadcastToShape; + } + + /// Target shape value. + TensorShape m_BroadcastToShape; +}; + } // namespace armnn diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index 4b9a3e5060..4b0b70c2d3 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -14,6 +14,7 @@ struct ArgMinMaxDescriptor; struct BatchMatMulDescriptor; struct BatchNormalizationDescriptor; struct BatchToSpaceNdDescriptor; +struct BroadcastToDescriptor; struct ChannelShuffleDescriptor; struct ComparisonDescriptor; struct Convolution2dDescriptor; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index c2c76e3d97..64fdab6bd0 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -857,6 +857,13 @@ public: IConnectableLayer* AddTileLayer(const TileDescriptor& descriptor, const char* name = nullptr); + /// Add a BroadcastTo layer to the network + /// @param descriptor - Parameters for the BroadcastTo operation + /// @param name - Optional name for the layer + /// @return - Interface for configuring the layer + IConnectableLayer* AddBroadcastToLayer(const BroadcastToDescriptor& descriptor, + const char* name = nullptr); + void ExecuteStrategy(IStrategy& strategy) const; protected: diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 7cb3a859c7..933e7da412 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -481,6 +481,7 @@ using InferenceTimingPair = std::pair; X(ReverseV2) \ X(Tile) \ X(Fused) \ + X(BroadcastTo) \ // New layers should be added at last position to minimize instability. @@ -492,7 +493,7 @@ enum class LayerType LIST_OF_LAYER_TYPE #undef X FirstLayer = Activation, - LastLayer = Fused + LastLayer = BroadcastTo }; const char* GetLayerTypeAsCString(LayerType type); diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp index 86796cbcc0..a90a1abd65 100644 --- a/include/armnn/backends/WorkloadData.hpp +++ b/include/armnn/backends/WorkloadData.hpp @@ -765,4 +765,9 @@ struct TileQueueDescriptor : QueueDescriptorWithParameters void Validate(const WorkloadInfo& workloadInfo) const; }; +struct BroadcastToQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } // namespace armnn -- cgit v1.2.1