aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
blob: 729b76ad6bf075616b99a329776669625650c0ff (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "Optimization.hpp"
#include "NetworkUtils.hpp"

namespace armnn
{
namespace optimizations
{

class ConvertFp32NetworkToFp16Impl
{
public:

    void Run(Graph& graph, Layer& layer) const
    {
        if(layer.GetType() == LayerType::Input)
        {
            // if the outputs of this layer are DataType::Float32
            // add a ConvertFloat32ToFloat16 layer after each of the outputs
            if (layer.GetDataType() == DataType::Float32)
            {
                InsertConvertFp32ToFp16LayersAfter(graph, layer);
            }
        }
        else if (layer.GetType() == LayerType::Output)
        {
            // if the inputs of this layer are DataType::Float32
            // add a ConvertFloat16ToFloat32 layer before each of the inputs
            if (layer.GetDataType() == DataType::Float32)
            {
                InsertConvertFp16ToFp32LayersBefore(graph, layer);
            }
        }
        else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32)
        {
            // if the inputs/outputs of this layer are DataType::Float32
            // change the data type for all inputs and outputs to DataType::Float16
            for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input)
            {
                // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection
                // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer
                Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer();
                if (base.GetType() != LayerType::Input)
                {
                    TensorInfo convertInfo = input->GetConnection()->GetTensorInfo();
                    if (convertInfo.GetDataType() == DataType::Float32)
                    {
                        convertInfo.SetDataType(DataType::Float16);
                        input->GetConnection()->SetTensorInfo(convertInfo);
                    }
                }
            }

            // change outputs to DataType::Float16
            for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output)
            {
                TensorInfo convertInfo = output->GetTensorInfo();
                if (convertInfo.GetDataType() == DataType::Float32)
                {
                    convertInfo.SetDataType(DataType::Float16);
                    output->SetTensorInfo(convertInfo);
                }
            }
        }
    }

protected:
    ConvertFp32NetworkToFp16Impl() = default;
    ~ConvertFp32NetworkToFp16Impl() = default;
};

using Fp32NetworkToFp16Converter = OptimizeForType<Layer, ConvertFp32NetworkToFp16Impl>;

} // namespace optimizations
} // namespace armnn