diff options
Diffstat (limited to 'src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp')
-rw-r--r-- | src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp new file mode 100644 index 0000000000..a4df05c18a --- /dev/null +++ b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp @@ -0,0 +1,80 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "Optimization.hpp" +#include "NetworkUtils.hpp" + +namespace armnn +{ +namespace optimizations +{ + +class ConvertFp32NetworkToFp16Impl +{ +public: + + void Run(Graph& graph, Layer& layer) const + { + if(layer.GetType() == LayerType::Input) + { + // if the outputs of this layer are DataType::Float32 + // add a ConvertFloat32ToFloat16 layer after each of the outputs + if (layer.GetDataType() == DataType::Float32) + { + InsertConvertFp32ToFp16LayersAfter(graph, layer); + } + } + else if (layer.GetType() == LayerType::Output) + { + // if the inputs of this layer are DataType::Float32 + // add a ConvertFloat16ToFloat32 layer before each of the inputs + if (layer.GetDataType() == DataType::Float32) + { + InsertConvertFp16ToFp32LayersBefore(graph, layer); + } + } + else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32) + { + // if the inputs/outputs of this layer are DataType::Float32 + // change the data type for all inputs and outputs to DataType::Float16 + for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input) + { + // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection + // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer + Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer(); + if (base.GetType() != LayerType::Input) + { + TensorInfo convertInfo = input->GetConnection()->GetTensorInfo(); + if (convertInfo.GetDataType() == DataType::Float32) + { + convertInfo.SetDataType(DataType::Float16); + input->GetConnection()->SetTensorInfo(convertInfo); + } + } + } + + // change outputs to DataType::Float16 + for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output) + { + TensorInfo convertInfo = output->GetTensorInfo(); + if (convertInfo.GetDataType() == DataType::Float32) + { + convertInfo.SetDataType(DataType::Float16); + output->SetTensorInfo(convertInfo); + } + } + } + } + +protected: + ConvertFp32NetworkToFp16Impl() = default; + ~ConvertFp32NetworkToFp16Impl() = default; +}; + +using Fp32NetworkToFp16Converter = OptimizeForType<Layer, ConvertFp32NetworkToFp16Impl>; + +} // namespace optimizations +} // namespace armnn |