diff options
author | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-08-28 14:28:31 +0100 |
---|---|---|
committer | Idriss Chaouch <idriss.chaouch@arm.com> | 2023-08-31 11:26:28 +0100 |
commit | 98e383eadf4e670d057ad725c7fe7924fea8e36b (patch) | |
tree | 35acac15aa69ab405887289cb9674d388f06f96b /src/armnn/Network.cpp | |
parent | 2be039bce38a4fa436e8310dfe14ebfff20d57bd (diff) | |
download | armnn-98e383eadf4e670d057ad725c7fe7924fea8e36b.tar.gz |
IVGCVSW-7525 Add broadcast_to operator
Signed-off-by: Idriss Chaouch <idriss.chaouch@arm.com>
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I94ec5f9120b2d736fdf98d00ec5137a4efd739b8
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 7f4ef6b1b6..d2b14cd045 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -657,6 +657,12 @@ IConnectableLayer* INetwork::AddTileLayer(const TileDescriptor &descriptor, return pNetworkImpl->AddTileLayer(descriptor, name); } +IConnectableLayer* INetwork::AddBroadcastToLayer(const BroadcastToDescriptor& descriptor, + const char* name) +{ + return pNetworkImpl->AddBroadcastToLayer(descriptor, name); +} + void INetwork::ExecuteStrategy(IStrategy& strategy) const { return pNetworkImpl->ExecuteStrategy(strategy); @@ -1929,8 +1935,10 @@ IOptimizedNetworkPtr Optimize(const Graph& inGraph, optGraph.InferTensorInfos(); } - // Perform AddBroadcastReshapeLayer optimisation + // Perform BroadcastToOptimizationLayer and then AddBroadcastReshapeLayer optimisation using namespace optimizations; + Optimizer::Pass(optGraph, MakeOptimizations(BroadcastToOptimizationLayer())); + Optimizer::Pass(optGraph, MakeOptimizations(AddBroadcastReshapeLayer())); if(options.GetShapeInferenceMethod() == ShapeInferenceMethod::ValidateOnly) @@ -1961,6 +1969,7 @@ IOptimizedNetworkPtr Optimize(const Graph& inGraph, FoldPadIntoConvolution2d(), FoldPadIntoDepthwiseConvolution2d(), FoldPadIntoPooling2d(), + BroadcastToOptimizationLayer(), PermuteAndBatchToSpaceAsDepthToSpace(), TransposeAndBatchToSpaceAsDepthToSpace(), FuseBatchNormIntoConvolution2DFloat32(), @@ -3045,6 +3054,11 @@ IConnectableLayer* NetworkImpl::AddPrecompiledLayer(const PreCompiledDescriptor& return layer; } +IConnectableLayer* NetworkImpl::AddBroadcastToLayer(const BroadcastToDescriptor &desc, const char *name) +{ + return m_Graph->AddLayer<BroadcastToLayer>(desc, name); +} + void NetworkImpl::ExecuteStrategy(IStrategy& strategy) const { for (auto layer : GetGraph()) |