From bc7ffb5e9e5f4c86280b20c65772eb12d8bb140e Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 20 Mar 2020 15:01:01 +0000 Subject: IVGCVSW-4520 Implement BFloat16 Optimizer * Add ReduceFp32ToBf16 to OptimizerOptions * Add ConvertFp32NetworkToBf16 * Add utility functions to insert conversion layers * Add constant conversion BF16 <-> FP32 * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Iaff77e20c721400b052cb37eb9ef6fe16d7abaff --- .../optimizations/ConvertFp32NetworkToBf16.hpp | 81 ++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp (limited to 'src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp') diff --git a/src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp new file mode 100644 index 0000000000..d6350c3af3 --- /dev/null +++ b/src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp @@ -0,0 +1,81 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "Optimization.hpp" +#include "NetworkUtils.hpp" + +namespace armnn +{ +namespace optimizations +{ + +class ConvertFp32NetworkToBf16Impl +{ +public: + void Run(Graph& graph, Layer& layer) const + { + if(layer.GetType() == LayerType::Input) + { + // if the outputs of this layer are DataType::Float32 + // add a ConvertFloat32ToBFloat16 layer after each of the outputs + if (layer.GetDataType() == DataType::Float32) + { + InsertConvertFp32ToBf16LayersAfter(graph, layer); + } + } + else if (layer.GetType() == LayerType::Output) + { + // if the inputs of this layer are DataType::Float32 + // add a ConvertBFloat16ToFloat32 layer before each of the inputs + if (layer.GetDataType() == DataType::Float32) + { + // NOTE: We need to call InsertConvertBf16ToFp32LayersBefore with expectCorrectInputType = false + // here, otherwise it will expect the inputs to be DataType::BFloat16 + InsertConvertBf16ToFp32LayersBefore(graph, layer, false); + } + } + else if (layer.GetType() != LayerType::ConvertFp32ToBf16 && layer.GetType() != LayerType::ConvertBf16ToFp32) + { + // if the inputs/outputs of this layer are DataType::Float32 + // change the data type for all inputs and outputs to DataType::BFloat16 + for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input) + { + // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection + // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer + Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer(); + if (base.GetType() != LayerType::Input) + { + TensorInfo convertInfo = input->GetConnection()->GetTensorInfo(); + if (convertInfo.GetDataType() == DataType::Float32) + { + convertInfo.SetDataType(DataType::BFloat16); + input->GetConnection()->SetTensorInfo(convertInfo); + } + } + } + + // change outputs to DataType::BFloat16 + for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output) + { + TensorInfo convertInfo = output->GetTensorInfo(); + if (convertInfo.GetDataType() == DataType::Float32) + { + convertInfo.SetDataType(DataType::BFloat16); + output->SetTensorInfo(convertInfo); + } + } + } + } + +protected: + ConvertFp32NetworkToBf16Impl() = default; + ~ConvertFp32NetworkToBf16Impl() = default; +}; + +using Fp32NetworkToBf16Converter = OptimizeForType; + +} // namespace optimizations +} // namespace armnn -- cgit v1.2.1