aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorIdriss Chaouch <idriss.chaouch@arm.com>2023-08-28 14:28:31 +0100
committerIdriss Chaouch <idriss.chaouch@arm.com>2023-08-31 11:26:28 +0100
commit98e383eadf4e670d057ad725c7fe7924fea8e36b (patch)
tree35acac15aa69ab405887289cb9674d388f06f96b /include
parent2be039bce38a4fa436e8310dfe14ebfff20d57bd (diff)
downloadarmnn-98e383eadf4e670d057ad725c7fe7924fea8e36b.tar.gz
IVGCVSW-7525 Add broadcast_to operator
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com> Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8
Diffstat (limited to 'include')
-rw-r--r--include/armnn/BackendHelper.hpp6
-rw-r--r--include/armnn/Descriptors.hpp19
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/INetwork.hpp7
-rw-r--r--include/armnn/Types.hpp3
-rw-r--r--include/armnn/backends/WorkloadData.hpp5
6 files changed, 40 insertions, 1 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index 986f854636..b61f010b0f 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -73,6 +73,12 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+
+ bool IsBroadcastToSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const BroadcastToDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported);
+
bool IsCastSupported(const TensorInfo& input,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 30eaefd83b..bf40b35ae9 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1656,4 +1656,23 @@ struct TileDescriptor : BaseDescriptor
std::vector<uint32_t> m_Multiples;
};
+struct BroadcastToDescriptor : BaseDescriptor
+{
+ BroadcastToDescriptor()
+ : m_BroadcastToShape()
+ {}
+
+ explicit BroadcastToDescriptor(const TensorShape& shape)
+ : m_BroadcastToShape(shape)
+ {}
+
+ bool operator ==(const BroadcastToDescriptor& rhs) const
+ {
+ return m_BroadcastToShape == rhs.m_BroadcastToShape;
+ }
+
+ /// Target shape value.
+ TensorShape m_BroadcastToShape;
+};
+
} // namespace armnn
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 4b9a3e5060..4b0b70c2d3 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -14,6 +14,7 @@ struct ArgMinMaxDescriptor;
struct BatchMatMulDescriptor;
struct BatchNormalizationDescriptor;
struct BatchToSpaceNdDescriptor;
+struct BroadcastToDescriptor;
struct ChannelShuffleDescriptor;
struct ComparisonDescriptor;
struct Convolution2dDescriptor;
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index c2c76e3d97..64fdab6bd0 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -857,6 +857,13 @@ public:
IConnectableLayer* AddTileLayer(const TileDescriptor& descriptor,
const char* name = nullptr);
+ /// Add a BroadcastTo layer to the network
+ /// @param descriptor - Parameters for the BroadcastTo operation
+ /// @param name - Optional name for the layer
+ /// @return - Interface for configuring the layer
+ IConnectableLayer* AddBroadcastToLayer(const BroadcastToDescriptor& descriptor,
+ const char* name = nullptr);
+
void ExecuteStrategy(IStrategy& strategy) const;
protected:
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 7cb3a859c7..933e7da412 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -481,6 +481,7 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(ReverseV2) \
X(Tile) \
X(Fused) \
+ X(BroadcastTo) \
// New layers should be added at last position to minimize instability.
@@ -492,7 +493,7 @@ enum class LayerType
LIST_OF_LAYER_TYPE
#undef X
FirstLayer = Activation,
- LastLayer = Fused
+ LastLayer = BroadcastTo
};
const char* GetLayerTypeAsCString(LayerType type);
diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp
index 86796cbcc0..a90a1abd65 100644
--- a/include/armnn/backends/WorkloadData.hpp
+++ b/include/armnn/backends/WorkloadData.hpp
@@ -765,4 +765,9 @@ struct TileQueueDescriptor : QueueDescriptorWithParameters<TileDescriptor>
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct BroadcastToQueueDescriptor : QueueDescriptorWithParameters<BroadcastToDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} // namespace armnn