aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-12-20 12:48:02 +0000
committermike.kelly <mike.kelly@arm.com>2023-12-20 19:40:53 +0000
commita7bd3fa77431f0e99698f24e325b85df459d3b10 (patch)
treee6af0301ea73f0514fa599a174801a5aa7ae0d77
parent399d1001eba374d266c30e1e3239d6321e731339 (diff)
downloadarmnn-a7bd3fa77431f0e99698f24e325b85df459d3b10.tar.gz
IVGCVSW-7830 Remove Reshape where possible
* Remove reshape on ClBackend * Remove unnecessary restriction on NeonBackend remove Reshape Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I79940c9f8609d19b79f2fbe55225ffc8f0d90c25
-rw-r--r--src/backends/cl/ClBackend.cpp13
-rw-r--r--src/backends/cl/test/ClEndToEndTests.cpp15
-rw-r--r--src/backends/neon/NeonBackend.cpp5
3 files changed, 28 insertions, 5 deletions
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp
index 1d8ae21e5d..6d191a594b 100644
--- a/src/backends/cl/ClBackend.cpp
+++ b/src/backends/cl/ClBackend.cpp
@@ -673,6 +673,19 @@ OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
}
}
+ // Remove Reshape where possible
+ if (base.GetType() == LayerType::Reshape)
+ {
+ ReshapeLayer* baseLayer = PolymorphicDowncast<ReshapeLayer*>(&base);
+
+ // Cannot remove a Reshape if it's connected to any layer that has an NCHW layout
+ if (ConnectedToLayerWithNCHW(baseLayer))
+ {
+ continue;
+ }
+ RemoveReshapeLayer(baseLayer, untouched, optimizationViews);
+ }
+
// Special case to fuse padding into average pooling 2d for quantized datatype.
// Required to be done as a backend specific optimization as Neon does not support this special case.
if (base.GetType() == LayerType::Pooling2d)
diff --git a/src/backends/cl/test/ClEndToEndTests.cpp b/src/backends/cl/test/ClEndToEndTests.cpp
index 2436a8223f..78d2dea90d 100644
--- a/src/backends/cl/test/ClEndToEndTests.cpp
+++ b/src/backends/cl/test/ClEndToEndTests.cpp
@@ -673,4 +673,19 @@ TEST_CASE("ClForceImportWithMisalignedInputAndOutputBuffersEndToEndTest")
ForceImportWithMisalignedInputAndOutputBuffersEndToEndTest(clDefaultBackends);
}
+TEST_CASE("ClReshapeRemovalSimpleCaseEndToEnd")
+{
+ ReshapeRemovalEndToEnd<armnn::DataType::Float32>(clDefaultBackends);
+}
+
+TEST_CASE("ClReshapeRemovalNCHWFirstEndToEnd")
+{
+ ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(clDefaultBackends, false, true);
+}
+
+TEST_CASE("ClReshapeRemovalNCHWSecondEndToEnd")
+{
+ ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(clDefaultBackends, false, false);
+}
+
}
diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp
index 7311098631..ebce1a69de 100644
--- a/src/backends/neon/NeonBackend.cpp
+++ b/src/backends/neon/NeonBackend.cpp
@@ -519,11 +519,6 @@ OptimizationViews NeonBackend::OptimizeSubgraphView(const SubgraphView& subgraph
{
continue;
}
- // Cannot remove a Reshape if it's connected to a SplitterLayer
- if (ConnectedToLayerType(baseLayer, LayerType::Splitter))
- {
- continue;
- }
RemoveReshapeLayer(baseLayer, untouched, optimizationViews);
}