aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/MovePermuteUp.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
committertelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
commit4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch)
treec9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/optimizations/MovePermuteUp.hpp
downloadarmnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/optimizations/MovePermuteUp.hpp')
-rw-r--r--src/armnn/optimizations/MovePermuteUp.hpp82
1 files changed, 82 insertions, 0 deletions
diff --git a/src/armnn/optimizations/MovePermuteUp.hpp b/src/armnn/optimizations/MovePermuteUp.hpp
new file mode 100644
index 0000000000..8c59986762
--- /dev/null
+++ b/src/armnn/optimizations/MovePermuteUp.hpp
@@ -0,0 +1,82 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "Optimization.hpp"
+#include "Permute.hpp"
+
+namespace armnn
+{
+namespace optimizations
+{
+class MovePermuteUpImpl
+{
+public:
+ /// Run for every connection between a base Layer (any) and a child PermuteLayer. If the type
+ /// of the base layer allows it, it moves the permutation to the inputs of the base layer.
+ /// I.e., adds equivalent permutations before the inputs of the base layer and moves the
+ /// connections in the output of the child permute layer to the output of the base layer.
+ void Run(Graph& graph, InputSlot& connection) const
+ {
+ OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
+
+ if (baseOutput.GetNumConnections() == 1U)
+ {
+ Layer& base = baseOutput.GetOwningLayer();
+
+ if (CanMovePermuteToInputs(base))
+ {
+ auto permute = boost::polymorphic_downcast<PermuteLayer*>(&connection.GetOwningLayer());
+ const PermutationVector& perm = permute->GetPermutation();
+
+ // Insert an equivalent permute before every input of the base layer.
+ for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput)
+ {
+ // Insert new permute layer.
+ const std::string name = std::string("moved_up-") + permute->GetName();
+ PermuteLayer& permLayer = *graph.InsertNewLayer<PermuteLayer>(*baseInput, perm, name.c_str());
+
+ // Set output tensor info for the new layer.
+ OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot();
+ const TensorInfo permOutInfo = armnnUtils::Permuted(parentOutput.GetTensorInfo(), perm);
+ permLayer.GetOutputHandler().SetTensorInfo(permOutInfo);
+ }
+
+ // Set permuted output tensor info
+ const TensorInfo& childOutInfo = permute->GetOutputHandler().GetTensorInfo();
+ base.GetOutputHandler().SetTensorInfo(childOutInfo);
+
+ // Bypass permute. It will be removed as it's left unconnected.
+ permute->GetOutputSlot().MoveAllConnections(base.GetOutputSlot());
+ }
+ }
+ }
+
+protected:
+ MovePermuteUpImpl() = default;
+ ~MovePermuteUpImpl() = default;
+
+private:
+ static bool CanMovePermuteToInputs(const Layer& base)
+ {
+ switch (base.GetType())
+ {
+ case LayerType::Activation:
+ case LayerType::Addition:
+ case LayerType::FakeQuantization:
+ case LayerType::Floor:
+ case LayerType::MemCopy:
+ case LayerType::Multiplication:
+ return true;
+ default:
+ return false;
+ }
+ }
+};
+
+using MovePermuteUp = OptimizeForConnection<Layer, PermuteLayer, MovePermuteUpImpl>;
+
+} // namespace optimizations
+} // namespace armnn