aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
authorIdriss Chaouch <idriss.chaouch@arm.com>2023-08-28 14:28:31 +0100
committerIdriss Chaouch <idriss.chaouch@arm.com>2023-08-31 11:26:28 +0100
commit98e383eadf4e670d057ad725c7fe7924fea8e36b (patch)
tree35acac15aa69ab405887289cb9674d388f06f96b /src/armnn/Network.cpp
parent2be039bce38a4fa436e8310dfe14ebfff20d57bd (diff)
downloadarmnn-98e383eadf4e670d057ad725c7fe7924fea8e36b.tar.gz
IVGCVSW-7525 Add broadcast_to operator
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com> Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp16
1 files changed, 15 insertions, 1 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 7f4ef6b1b6..d2b14cd045 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -657,6 +657,12 @@ IConnectableLayer* INetwork::AddTileLayer(const TileDescriptor &descriptor,
return pNetworkImpl->AddTileLayer(descriptor, name);
}
+IConnectableLayer* INetwork::AddBroadcastToLayer(const BroadcastToDescriptor& descriptor,
+ const char* name)
+{
+ return pNetworkImpl->AddBroadcastToLayer(descriptor, name);
+}
+
void INetwork::ExecuteStrategy(IStrategy& strategy) const
{
return pNetworkImpl->ExecuteStrategy(strategy);
@@ -1929,8 +1935,10 @@ IOptimizedNetworkPtr Optimize(const Graph& inGraph,
optGraph.InferTensorInfos();
}
- // Perform AddBroadcastReshapeLayer optimisation
+ // Perform BroadcastToOptimizationLayer and then AddBroadcastReshapeLayer optimisation
using namespace optimizations;
+ Optimizer::Pass(optGraph, MakeOptimizations(BroadcastToOptimizationLayer()));
+
Optimizer::Pass(optGraph, MakeOptimizations(AddBroadcastReshapeLayer()));
if(options.GetShapeInferenceMethod() == ShapeInferenceMethod::ValidateOnly)
@@ -1961,6 +1969,7 @@ IOptimizedNetworkPtr Optimize(const Graph& inGraph,
FoldPadIntoConvolution2d(),
FoldPadIntoDepthwiseConvolution2d(),
FoldPadIntoPooling2d(),
+ BroadcastToOptimizationLayer(),
PermuteAndBatchToSpaceAsDepthToSpace(),
TransposeAndBatchToSpaceAsDepthToSpace(),
FuseBatchNormIntoConvolution2DFloat32(),
@@ -3045,6 +3054,11 @@ IConnectableLayer* NetworkImpl::AddPrecompiledLayer(const PreCompiledDescriptor&
return layer;
}
+IConnectableLayer* NetworkImpl::AddBroadcastToLayer(const BroadcastToDescriptor &desc, const char *name)
+{
+ return m_Graph->AddLayer<BroadcastToLayer>(desc, name);
+}
+
void NetworkImpl::ExecuteStrategy(IStrategy& strategy) const
{
for (auto layer : GetGraph())