From 4fcda0101ec3d110c1d6d7bee5c83416b645528a Mon Sep 17 00:00:00 2001 From: telsoa01 Date: Fri, 9 Mar 2018 14:13:49 +0000 Subject: Release 18.02 Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6 --- .../optimizations/OptimizeConsecutiveReshapes.hpp | 60 ++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp (limited to 'src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp') diff --git a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp new file mode 100644 index 0000000000..deb49c6884 --- /dev/null +++ b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp @@ -0,0 +1,60 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "Optimization.hpp" + +namespace armnn +{ +namespace optimizations +{ + +class OptimizeConsecutiveReshapesImpl +{ +public: + /// Run for every connection between a base RashapeLayer and a child ReshapeLayer. + /// Inserts an equivalent ReshapeLayer that bypasses both for that connection. + void Run(Graph& graph, InputSlot& connection) const + { + auto& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); + auto& child = connection.GetOwningLayer(); + + BOOST_ASSERT(base.GetType() == LayerType::Reshape); + BOOST_ASSERT(child.GetType() == LayerType::Reshape); + + OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot(); + + const TensorInfo& inInfo = parentOut->GetTensorInfo(); + const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo(); + + if (inInfo.GetShape() != outInfo.GetShape()) + { + // Insert equivalent reshape before base layer + const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName(); + const ReshapeDescriptor descriptor{outInfo.GetShape()}; + auto& newReshape = *graph.InsertNewLayer(base.GetInputSlot(0), descriptor, name.c_str()); + // Set tensor info for new layer + newReshape.GetOutputHandler().SetTensorInfo(outInfo); + // Reconnect base with original parent + newReshape.GetOutputSlot().MoveAllConnections(*parentOut); + // Parent is now the new layer + parentOut = &newReshape.GetOutputSlot(); + } + + // Move connections in child output to parent layer. + // Child layer will be removed as it's left unconnected. + // Base layer will be removed if left unconnected. + child.GetOutputSlot().MoveAllConnections(*parentOut); + } + +protected: + OptimizeConsecutiveReshapesImpl() = default; + ~OptimizeConsecutiveReshapesImpl() = default; +}; + +using OptimizeConsecutiveReshapes = OptimizeForConnection; + +} // namespace optimizations +} // namespace armnn -- cgit v1.2.1