diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
commit | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch) | |
tree | c9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp | |
download | armnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz |
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp')
-rw-r--r-- | src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp new file mode 100644 index 0000000000..deb49c6884 --- /dev/null +++ b/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp @@ -0,0 +1,60 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "Optimization.hpp" + +namespace armnn +{ +namespace optimizations +{ + +class OptimizeConsecutiveReshapesImpl +{ +public: + /// Run for every connection between a base RashapeLayer and a child ReshapeLayer. + /// Inserts an equivalent ReshapeLayer that bypasses both for that connection. + void Run(Graph& graph, InputSlot& connection) const + { + auto& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); + auto& child = connection.GetOwningLayer(); + + BOOST_ASSERT(base.GetType() == LayerType::Reshape); + BOOST_ASSERT(child.GetType() == LayerType::Reshape); + + OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot(); + + const TensorInfo& inInfo = parentOut->GetTensorInfo(); + const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo(); + + if (inInfo.GetShape() != outInfo.GetShape()) + { + // Insert equivalent reshape before base layer + const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName(); + const ReshapeDescriptor descriptor{outInfo.GetShape()}; + auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str()); + // Set tensor info for new layer + newReshape.GetOutputHandler().SetTensorInfo(outInfo); + // Reconnect base with original parent + newReshape.GetOutputSlot().MoveAllConnections(*parentOut); + // Parent is now the new layer + parentOut = &newReshape.GetOutputSlot(); + } + + // Move connections in child output to parent layer. + // Child layer will be removed as it's left unconnected. + // Base layer will be removed if left unconnected. + child.GetOutputSlot().MoveAllConnections(*parentOut); + } + +protected: + OptimizeConsecutiveReshapesImpl() = default; + ~OptimizeConsecutiveReshapesImpl() = default; +}; + +using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>; + +} // namespace optimizations +} // namespace armnn |