// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "Optimization.hpp" #include "NetworkUtils.hpp" namespace armnn { namespace optimizations { class ConvertFp32NetworkToFp16Impl { public: void Run(Graph& graph, Layer& layer) const { if(layer.GetType() == LayerType::Input) { // if the outputs of this layer are DataType::Float32 // add a ConvertFloat32ToFloat16 layer after each of the outputs if (layer.GetDataType() == DataType::Float32) { InsertConvertFp32ToFp16LayersAfter(graph, layer); } } else if (layer.GetType() == LayerType::Output) { // if the inputs of this layer are DataType::Float32 // add a ConvertFloat16ToFloat32 layer before each of the inputs if (layer.GetDataType() == DataType::Float32) { InsertConvertFp16ToFp32LayersBefore(graph, layer); } } else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32) { // if the inputs/outputs of this layer are DataType::Float32 // change the data type for all inputs and outputs to DataType::Float16 for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input) { // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer(); if (base.GetType() != LayerType::Input) { TensorInfo convertInfo = input->GetConnection()->GetTensorInfo(); if (convertInfo.GetDataType() == DataType::Float32) { convertInfo.SetDataType(DataType::Float16); input->GetConnection()->SetTensorInfo(convertInfo); } } } // change outputs to DataType::Float16 for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output) { TensorInfo convertInfo = output->GetTensorInfo(); if (convertInfo.GetDataType() == DataType::Float32) { convertInfo.SetDataType(DataType::Float16); output->SetTensorInfo(convertInfo); } } } } protected: ConvertFp32NetworkToFp16Impl() = default; ~ConvertFp32NetworkToFp16Impl() = default; }; using Fp32NetworkToFp16Converter = OptimizeForType; } // namespace optimizations } // namespace armnn