diff options
author | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-08-28 14:28:31 +0100 |
---|---|---|
committer | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-08-31 11:26:28 +0100 |
commit | 98e383eadf4e670d057ad725c7fe7924fea8e36b (patch) | |
tree | 35acac15aa69ab405887289cb9674d388f06f96b /include | |
parent | 2be039bce38a4fa436e8310dfe14ebfff20d57bd (diff) | |
download | armnn-98e383eadf4e670d057ad725c7fe7924fea8e36b.tar.gz |
IVGCVSW-7525 Add broadcast_to operator
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com>
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/BackendHelper.hpp | 6 | ||||
-rw-r--r-- | include/armnn/Descriptors.hpp | 19 | ||||
-rw-r--r-- | include/armnn/DescriptorsFwd.hpp | 1 | ||||
-rw-r--r-- | include/armnn/INetwork.hpp | 7 | ||||
-rw-r--r-- | include/armnn/Types.hpp | 3 | ||||
-rw-r--r-- | include/armnn/backends/WorkloadData.hpp | 5 |
6 files changed, 40 insertions, 1 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp index 986f854636..b61f010b0f 100644 --- a/include/armnn/BackendHelper.hpp +++ b/include/armnn/BackendHelper.hpp @@ -73,6 +73,12 @@ public: const BatchToSpaceNdDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()); + + bool IsBroadcastToSupported(const TensorInfo& input, + const TensorInfo& output, + const BroadcastToDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported); + bool IsCastSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()); diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 30eaefd83b..bf40b35ae9 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1656,4 +1656,23 @@ struct TileDescriptor : BaseDescriptor std::vector<uint32_t> m_Multiples; }; +struct BroadcastToDescriptor : BaseDescriptor +{ + BroadcastToDescriptor() + : m_BroadcastToShape() + {} + + explicit BroadcastToDescriptor(const TensorShape& shape) + : m_BroadcastToShape(shape) + {} + + bool operator ==(const BroadcastToDescriptor& rhs) const + { + return m_BroadcastToShape == rhs.m_BroadcastToShape; + } + + /// Target shape value. + TensorShape m_BroadcastToShape; +}; + } // namespace armnn diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index 4b9a3e5060..4b0b70c2d3 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -14,6 +14,7 @@ struct ArgMinMaxDescriptor; struct BatchMatMulDescriptor; struct BatchNormalizationDescriptor; struct BatchToSpaceNdDescriptor; +struct BroadcastToDescriptor; struct ChannelShuffleDescriptor; struct ComparisonDescriptor; struct Convolution2dDescriptor; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index c2c76e3d97..64fdab6bd0 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -857,6 +857,13 @@ public: IConnectableLayer* AddTileLayer(const TileDescriptor& descriptor, const char* name = nullptr); + /// Add a BroadcastTo layer to the network + /// @param descriptor - Parameters for the BroadcastTo operation + /// @param name - Optional name for the layer + /// @return - Interface for configuring the layer + IConnectableLayer* AddBroadcastToLayer(const BroadcastToDescriptor& descriptor, + const char* name = nullptr); + void ExecuteStrategy(IStrategy& strategy) const; protected: diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 7cb3a859c7..933e7da412 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -481,6 +481,7 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>; X(ReverseV2) \ X(Tile) \ X(Fused) \ + X(BroadcastTo) \ // New layers should be added at last position to minimize instability. @@ -492,7 +493,7 @@ enum class LayerType LIST_OF_LAYER_TYPE #undef X FirstLayer = Activation, - LastLayer = Fused + LastLayer = BroadcastTo }; const char* GetLayerTypeAsCString(LayerType type); diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp index 86796cbcc0..a90a1abd65 100644 --- a/include/armnn/backends/WorkloadData.hpp +++ b/include/armnn/backends/WorkloadData.hpp @@ -765,4 +765,9 @@ struct TileQueueDescriptor : QueueDescriptorWithParameters<TileDescriptor> void Validate(const WorkloadInfo& workloadInfo) const; }; +struct BroadcastToQueueDescriptor : QueueDescriptorWithParameters<BroadcastToDescriptor> +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } // namespace armnn |