aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp
blob: 6c80e740be720ca59a301b5d1701dac59724852d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include "NetworkUtils.hpp"
#include "Optimization.hpp"

#include <armnn/utility/PolymorphicDowncast.hpp>

namespace armnn
{
namespace optimizations
{

template <typename LayerT>
inline LayerT* ConvertWeight(Layer* l)
{
    LayerT* layer = PolymorphicDowncast<LayerT*>(l);
    if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected)
         && layer->m_Weight)
    {
        const TensorInfo& info = layer->m_Weight->GetTensorInfo();

        if (info.GetDataType() == DataType::Float32)
        {
            std::vector<BFloat16> newValues(info.GetNumElements());

            armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(
                    layer->m_Weight->template GetConstTensor<float>(),
                    info.GetNumElements(),
                    newValues.data());

            TensorInfo newInfo(info);
            newInfo.SetDataType(DataType::BFloat16);
            ConstTensor newInput(newInfo, newValues);
            layer->m_Weight.reset(new ScopedTensorHandle(newInput));
        }
    }
    return layer;
}

class ConvertFp32NetworkToBf16Impl
{
public:

    void Run(Graph& graph, Layer& layer) const
    {
        // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
        // And also convert weight data type from Float32 to Bfloat16.
        // Do not convert bias data type.
        if (layer.GetType() == LayerType::Convolution2d)
        {
            if (layer.GetDataType() == DataType::Float32)
            {
                InsertConvertFp32ToBf16LayersBefore(graph,layer);
                ConvertWeight<Convolution2dLayer>(&layer);
            }
        }
        else if (layer.GetType() == LayerType::FullyConnected)
        {
            if (layer.GetDataType() == DataType::Float32)
            {
                InsertConvertFp32ToBf16LayersBefore(graph,layer);
                ConvertWeight<FullyConnectedLayer>(&layer);
            }
        }
    }

protected:
    ConvertFp32NetworkToBf16Impl() = default;
    ~ConvertFp32NetworkToBf16Impl() = default;
};

using Fp32NetworkToBf16Converter = OptimizeForType<Layer, ConvertFp32NetworkToBf16Impl>;

} // namespace optimizations
} // namespace armnn