diff options
author | Mike Kelly <mike.kelly@arm.com> | 2023-07-17 17:49:55 +0100 |
---|---|---|
committer | Mike Kelly <mike.kelly@arm.com> | 2023-07-17 22:55:36 +0100 |
commit | be06f10f79ccbdb9b110342c89c1d70238cc141c (patch) | |
tree | f3048ba9893d60daf8561c402c0fd1860e026b5f /src/backends/neon | |
parent | 02a22e7c84f007f742eb43d221cc37e2bd591edb (diff) | |
download | armnn-be06f10f79ccbdb9b110342c89c1d70238cc141c.tar.gz |
IVGCVSW-7891 Failure in Nightly tests
* Added check to ensure that Reshapes are not removed on Neon if they are
before or after a SplitterLayer and have more than 4 dimensions.
* Moved NCHW check to a function to reduce clutter.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I45d97634484e8dc0ca7675c23481caf84eb3fe90
Diffstat (limited to 'src/backends/neon')
-rw-r--r-- | src/backends/neon/NeonBackend.cpp | 21 |
1 files changed, 5 insertions, 16 deletions
diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp index 098b1ff109..60e25672ae 100644 --- a/src/backends/neon/NeonBackend.cpp +++ b/src/backends/neon/NeonBackend.cpp @@ -510,26 +510,15 @@ OptimizationViews NeonBackend::OptimizeSubgraphView(const SubgraphView& subgraph if (base.GetType() == LayerType::Reshape) { ReshapeLayer* baseLayer = PolymorphicDowncast<ReshapeLayer*>(&base); - Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer(); - // Cannot currently remove the Reshape if it's connected to any layer that has an NCHW layout - if (IsNCHW(parentLayer)) + // Cannot remove a Reshape if it's connected to any layer that has an NCHW layout + if (ConnectedToLayerWithNCHW(baseLayer)) { continue; } - bool isNCHW = false; - - for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i) - { - Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer(); - - if (IsNCHW(nextLayer)) - { - isNCHW = true; - break; - } - } - if (isNCHW) + // Cannot remove a Reshape if it's connected to a SplitterLayer through a Tensor that has more than + // 4 dimensions + if (ConnectedToSplitterWithMoreThan4Dims(baseLayer)) { continue; } |