diff options
Diffstat (limited to 'src/armnn/optimizations/MoveTransposeUp.hpp')
-rw-r--r-- | src/armnn/optimizations/MoveTransposeUp.hpp | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/src/armnn/optimizations/MoveTransposeUp.hpp b/src/armnn/optimizations/MoveTransposeUp.hpp new file mode 100644 index 0000000000..66543069c8 --- /dev/null +++ b/src/armnn/optimizations/MoveTransposeUp.hpp @@ -0,0 +1,83 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "Optimization.hpp" + +#include <armnnUtils/Transpose.hpp> + +namespace armnn +{ +namespace optimizations +{ +class MoveTransposeUpImpl +{ +public: + /// Run for every connection between a base Layer (any) and a child TransposeLayer. If the type + /// of the base layer allows it, it moves the permutation to the inputs of the base layer. + /// I.e., adds equivalent permutations before the inputs of the base layer and moves the + /// connections in the output of the child transpose layer to the output of the base layer. + void Run(Graph& graph, InputSlot& connection) const + { + OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); + + if (baseOutput.GetNumConnections() == 1U) + { + Layer& base = baseOutput.GetOwningLayer(); + + if (CanMoveTransposeToInputs(base)) + { + auto transpose = boost::polymorphic_downcast<TransposeLayer*>(&connection.GetOwningLayer()); + const PermutationVector& perm = transpose->GetPermutation(); + + // Inserts an equivalent transpose before every input of the base layer. + for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput) + { + // Inserts a new transpose layer. + const std::string name = std::string("moved_up-") + transpose->GetName(); + TransposeLayer& permLayer = *graph.InsertNewLayer<TransposeLayer>(*baseInput, perm, name.c_str()); + + // Sets output tensor info for the new layer. + OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot(); + const TensorInfo permOutInfo = armnnUtils::TransposeTensorShape(parentOutput.GetTensorInfo(), perm); + permLayer.GetOutputHandler().SetTensorInfo(permOutInfo); + } + + // Sets transposed output tensor info + const TensorInfo& childOutInfo = transpose->GetOutputHandler().GetTensorInfo(); + base.GetOutputHandler().SetTensorInfo(childOutInfo); + + // Bypasses transpose. It will be removed as it's left unconnected. + transpose->GetOutputSlot().MoveAllConnections(base.GetOutputSlot()); + } + } + } + +protected: + MoveTransposeUpImpl() = default; + ~MoveTransposeUpImpl() = default; + +private: + static bool CanMoveTransposeToInputs(const Layer& base) + { + switch (base.GetType()) + { + case LayerType::Activation: + case LayerType::Addition: + case LayerType::FakeQuantization: + case LayerType::Floor: + case LayerType::MemCopy: + case LayerType::Multiplication: + return true; + default: + return false; + } + } +}; + +using MoveTransposeUp = OptimizeForConnection<Layer, TransposeLayer, MoveTransposeUpImpl>; + +} // namespace optimizations +} // namespace armnn |