diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
commit | c577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch) | |
tree | bd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp | |
parent | 4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff) | |
download | armnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz |
Release 18.08
Diffstat (limited to 'src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp')
-rw-r--r-- | src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp new file mode 100644 index 0000000000..a4df05c18a --- /dev/null +++ b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp @@ -0,0 +1,80 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "Optimization.hpp" +#include "NetworkUtils.hpp" + +namespace armnn +{ +namespace optimizations +{ + +class ConvertFp32NetworkToFp16Impl +{ +public: + + void Run(Graph& graph, Layer& layer) const + { + if(layer.GetType() == LayerType::Input) + { + // if the outputs of this layer are DataType::Float32 + // add a ConvertFloat32ToFloat16 layer after each of the outputs + if (layer.GetDataType() == DataType::Float32) + { + InsertConvertFp32ToFp16LayersAfter(graph, layer); + } + } + else if (layer.GetType() == LayerType::Output) + { + // if the inputs of this layer are DataType::Float32 + // add a ConvertFloat16ToFloat32 layer before each of the inputs + if (layer.GetDataType() == DataType::Float32) + { + InsertConvertFp16ToFp32LayersBefore(graph, layer); + } + } + else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32) + { + // if the inputs/outputs of this layer are DataType::Float32 + // change the data type for all inputs and outputs to DataType::Float16 + for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input) + { + // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection + // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer + Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer(); + if (base.GetType() != LayerType::Input) + { + TensorInfo convertInfo = input->GetConnection()->GetTensorInfo(); + if (convertInfo.GetDataType() == DataType::Float32) + { + convertInfo.SetDataType(DataType::Float16); + input->GetConnection()->SetTensorInfo(convertInfo); + } + } + } + + // change outputs to DataType::Float16 + for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output) + { + TensorInfo convertInfo = output->GetTensorInfo(); + if (convertInfo.GetDataType() == DataType::Float32) + { + convertInfo.SetDataType(DataType::Float16); + output->SetTensorInfo(convertInfo); + } + } + } + } + +protected: + ConvertFp32NetworkToFp16Impl() = default; + ~ConvertFp32NetworkToFp16Impl() = default; +}; + +using Fp32NetworkToFp16Converter = OptimizeForType<Layer, ConvertFp32NetworkToFp16Impl>; + +} // namespace optimizations +} // namespace armnn |