aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/cl/ClBackend.cpp13
-rw-r--r--src/backends/cl/test/ClEndToEndTests.cpp15
-rw-r--r--src/backends/neon/NeonBackend.cpp5
3 files changed, 28 insertions, 5 deletions
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp
index 1d8ae21e5d..6d191a594b 100644
--- a/src/backends/cl/ClBackend.cpp
+++ b/src/backends/cl/ClBackend.cpp
@@ -673,6 +673,19 @@ OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
}
}
+ // Remove Reshape where possible
+ if (base.GetType() == LayerType::Reshape)
+ {
+ ReshapeLayer* baseLayer = PolymorphicDowncast<ReshapeLayer*>(&base);
+
+ // Cannot remove a Reshape if it's connected to any layer that has an NCHW layout
+ if (ConnectedToLayerWithNCHW(baseLayer))
+ {
+ continue;
+ }
+ RemoveReshapeLayer(baseLayer, untouched, optimizationViews);
+ }
+
// Special case to fuse padding into average pooling 2d for quantized datatype.
// Required to be done as a backend specific optimization as Neon does not support this special case.
if (base.GetType() == LayerType::Pooling2d)
diff --git a/src/backends/cl/test/ClEndToEndTests.cpp b/src/backends/cl/test/ClEndToEndTests.cpp
index 2436a8223f..78d2dea90d 100644
--- a/src/backends/cl/test/ClEndToEndTests.cpp
+++ b/src/backends/cl/test/ClEndToEndTests.cpp
@@ -673,4 +673,19 @@ TEST_CASE("ClForceImportWithMisalignedInputAndOutputBuffersEndToEndTest")
ForceImportWithMisalignedInputAndOutputBuffersEndToEndTest(clDefaultBackends);
}
+TEST_CASE("ClReshapeRemovalSimpleCaseEndToEnd")
+{
+ ReshapeRemovalEndToEnd<armnn::DataType::Float32>(clDefaultBackends);
+}
+
+TEST_CASE("ClReshapeRemovalNCHWFirstEndToEnd")
+{
+ ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(clDefaultBackends, false, true);
+}
+
+TEST_CASE("ClReshapeRemovalNCHWSecondEndToEnd")
+{
+ ReshapeRemovalNCHWEndToEnd<armnn::DataType::Float32>(clDefaultBackends, false, false);
+}
+
}
diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp
index 7311098631..ebce1a69de 100644
--- a/src/backends/neon/NeonBackend.cpp
+++ b/src/backends/neon/NeonBackend.cpp
@@ -519,11 +519,6 @@ OptimizationViews NeonBackend::OptimizeSubgraphView(const SubgraphView& subgraph
{
continue;
}
- // Cannot remove a Reshape if it's connected to a SplitterLayer
- if (ConnectedToLayerType(baseLayer, LayerType::Splitter))
- {
- continue;
- }
RemoveReshapeLayer(baseLayer, untouched, optimizationViews);
}