// // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "Optimization.hpp" #include #include namespace armnn { namespace optimizations { class ConvertConstPermuteLayersToConstLayers { public: void Run(Graph& graph, InputSlot& connection) const { Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); Layer& child = connection.GetOwningLayer(); ARMNN_ASSERT(base.GetType() == LayerType::Constant); ARMNN_ASSERT(child.GetType() == LayerType::Permute); if (base.GetDataType() == child.GetDataType()) { switch (base.GetDataType()) { case DataType::Float16: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::Float32: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::QAsymmU8: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::Signed32: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::QSymmS16: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::QSymmS8: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::QAsymmS8: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::BFloat16: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::Signed64: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; case DataType::Boolean: ReplaceConstPermuteLayer(graph, PolymorphicDowncast(&base), PolymorphicDowncast(&child)); break; } } } protected: ConvertConstPermuteLayersToConstLayers() = default; ~ConvertConstPermuteLayersToConstLayers() = default; private: template> static void ReplaceConstPermuteLayer(Graph& graph, ConstantLayer* constantLayer, PermuteLayer* permuteLayer) { IgnoreUnused(graph); /** * This optimisation is to find situations where a constant set of inputs is being provided to a Permute * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we * want to Permute them once and store them in a Const layer to be used everytime as they will not change. */ TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo(); std::vector newValues(outputPermuteInfo.GetNumElements()); armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(), constantLayer->m_LayerOutput->Map(true), newValues.data(), GetDataTypeSize(outputPermuteInfo.GetDataType())); TensorInfo newInfo = outputPermuteInfo; newInfo.SetConstant(true); ConstTensor newInput(newInfo, newValues); constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput)); // Moves connections in permute output to the constant layer. // Permute layer will be removed if left unconnected. permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot()); // Updating the output tensor constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo); ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true); } }; using FusePermuteIntoConstLayer = OptimizeForConnection; } // namespace optimizations } // namespace armnn