aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
committertelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
commit4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch)
treec9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
downloadarmnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp')
-rw-r--r--src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp60
1 files changed, 60 insertions, 0 deletions
diff --git a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
new file mode 100644
index 0000000000..deb49c6884
--- /dev/null
+++ b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "Optimization.hpp"
+
+namespace armnn
+{
+namespace optimizations
+{
+
+class OptimizeConsecutiveReshapesImpl
+{
+public:
+ /// Run for every connection between a base RashapeLayer and a child ReshapeLayer.
+ /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
+ void Run(Graph& graph, InputSlot& connection) const
+ {
+ auto& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
+ auto& child = connection.GetOwningLayer();
+
+ BOOST_ASSERT(base.GetType() == LayerType::Reshape);
+ BOOST_ASSERT(child.GetType() == LayerType::Reshape);
+
+ OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
+
+ const TensorInfo& inInfo = parentOut->GetTensorInfo();
+ const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
+
+ if (inInfo.GetShape() != outInfo.GetShape())
+ {
+ // Insert equivalent reshape before base layer
+ const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
+ const ReshapeDescriptor descriptor{outInfo.GetShape()};
+ auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
+ // Set tensor info for new layer
+ newReshape.GetOutputHandler().SetTensorInfo(outInfo);
+ // Reconnect base with original parent
+ newReshape.GetOutputSlot().MoveAllConnections(*parentOut);
+ // Parent is now the new layer
+ parentOut = &newReshape.GetOutputSlot();
+ }
+
+ // Move connections in child output to parent layer.
+ // Child layer will be removed as it's left unconnected.
+ // Base layer will be removed if left unconnected.
+ child.GetOutputSlot().MoveAllConnections(*parentOut);
+ }
+
+protected:
+ OptimizeConsecutiveReshapesImpl() = default;
+ ~OptimizeConsecutiveReshapesImpl() = default;
+};
+
+using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>;
+
+} // namespace optimizations
+} // namespace armnn