// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "Optimization.hpp" #include namespace armnn { namespace optimizations { class MovePermuteUpImpl { public: /// Run for every connection between a base Layer (any) and a child PermuteLayer. If the type /// of the base layer allows it, it moves the permutation to the inputs of the base layer. /// I.e., adds equivalent permutations before the inputs of the base layer and moves the /// connections in the output of the child permute layer to the output of the base layer. void Run(Graph& graph, InputSlot& connection) const { OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); if (baseOutput.GetNumConnections() == 1U) { Layer& base = baseOutput.GetOwningLayer(); if (CanMovePermuteToInputs(base)) { auto permute = boost::polymorphic_downcast(&connection.GetOwningLayer()); const PermutationVector& perm = permute->GetPermutation(); // Inserts an equivalent permute before every input of the base layer. for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput) { // Inserts a new permute layer. const std::string name = std::string("moved_up-") + permute->GetName(); PermuteLayer& permLayer = *graph.InsertNewLayer(*baseInput, perm, name.c_str()); // Sets output tensor info for the new layer. OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot(); const TensorInfo permOutInfo = armnnUtils::Permuted(parentOutput.GetTensorInfo(), perm); permLayer.GetOutputHandler().SetTensorInfo(permOutInfo); } // Sets permuted output tensor info const TensorInfo& childOutInfo = permute->GetOutputHandler().GetTensorInfo(); base.GetOutputHandler().SetTensorInfo(childOutInfo); // Bypasses permute. It will be removed as it's left unconnected. permute->GetOutputSlot().MoveAllConnections(base.GetOutputSlot()); } } } protected: MovePermuteUpImpl() = default; ~MovePermuteUpImpl() = default; private: static bool CanMovePermuteToInputs(const Layer& base) { switch (base.GetType()) { case LayerType::Activation: case LayerType::Addition: case LayerType::FakeQuantization: case LayerType::Floor: case LayerType::MemCopy: case LayerType::Multiplication: return true; default: return false; } } }; using MovePermuteUp = OptimizeForConnection; } // namespace optimizations } // namespace armnn