aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp')
-rw-r--r--src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp80
1 files changed, 80 insertions, 0 deletions
diff --git a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
new file mode 100644
index 0000000000..a4df05c18a
--- /dev/null
+++ b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
@@ -0,0 +1,80 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "Optimization.hpp"
+#include "NetworkUtils.hpp"
+
+namespace armnn
+{
+namespace optimizations
+{
+
+class ConvertFp32NetworkToFp16Impl
+{
+public:
+
+ void Run(Graph& graph, Layer& layer) const
+ {
+ if(layer.GetType() == LayerType::Input)
+ {
+ // if the outputs of this layer are DataType::Float32
+ // add a ConvertFloat32ToFloat16 layer after each of the outputs
+ if (layer.GetDataType() == DataType::Float32)
+ {
+ InsertConvertFp32ToFp16LayersAfter(graph, layer);
+ }
+ }
+ else if (layer.GetType() == LayerType::Output)
+ {
+ // if the inputs of this layer are DataType::Float32
+ // add a ConvertFloat16ToFloat32 layer before each of the inputs
+ if (layer.GetDataType() == DataType::Float32)
+ {
+ InsertConvertFp16ToFp32LayersBefore(graph, layer);
+ }
+ }
+ else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32)
+ {
+ // if the inputs/outputs of this layer are DataType::Float32
+ // change the data type for all inputs and outputs to DataType::Float16
+ for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input)
+ {
+ // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection
+ // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer
+ Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer();
+ if (base.GetType() != LayerType::Input)
+ {
+ TensorInfo convertInfo = input->GetConnection()->GetTensorInfo();
+ if (convertInfo.GetDataType() == DataType::Float32)
+ {
+ convertInfo.SetDataType(DataType::Float16);
+ input->GetConnection()->SetTensorInfo(convertInfo);
+ }
+ }
+ }
+
+ // change outputs to DataType::Float16
+ for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output)
+ {
+ TensorInfo convertInfo = output->GetTensorInfo();
+ if (convertInfo.GetDataType() == DataType::Float32)
+ {
+ convertInfo.SetDataType(DataType::Float16);
+ output->SetTensorInfo(convertInfo);
+ }
+ }
+ }
+ }
+
+protected:
+ ConvertFp32NetworkToFp16Impl() = default;
+ ~ConvertFp32NetworkToFp16Impl() = default;
+};
+
+using Fp32NetworkToFp16Converter = OptimizeForType<Layer, ConvertFp32NetworkToFp16Impl>;
+
+} // namespace optimizations
+} // namespace armnn