aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/BFloat16.hpp
blob: 52344db8101546ad81672ac5ba8fd027e90f6e87 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <ostream>
#include <cmath>
#include <stdint.h>

namespace armnn
{
class BFloat16
{
public:
    BFloat16()
    : m_Value(0)
    {}

    BFloat16(const BFloat16& v) = default;

    explicit BFloat16(uint16_t v)
    : m_Value(v)
    {}

    explicit BFloat16(float v)
    {
        m_Value = Float32ToBFloat16(v).Val();
    }

    operator float() const
    {
        return ToFloat32();
    }

    BFloat16& operator=(const BFloat16& other) = default;

    BFloat16& operator=(float v)
    {
        m_Value = Float32ToBFloat16(v).Val();
        return *this;
    }

    bool operator==(const BFloat16& r) const
    {
        return m_Value == r.Val();
    }

    static BFloat16 Float32ToBFloat16(const float v)
    {
        if (std::isnan(v))
        {
            return Nan();
        }
        else
        {
            // Round value to the nearest even
            // Float32
            // S EEEEEEEE MMMMMMLRMMMMMMMMMMMMMMM
            // BFloat16
            // S EEEEEEEE MMMMMML
            // LSB (L): Least significat bit of BFloat16 (last bit of the Mantissa of BFloat16)
            // R: Rounding bit
            // LSB = 0, R = 0 -> round down
            // LSB = 1, R = 0 -> round down
            // LSB = 0, R = 1, all the rest = 0 -> round down
            // LSB = 1, R = 1 -> round up
            // LSB = 0, R = 1 -> round up
            const uint32_t* u32 = reinterpret_cast<const uint32_t*>(&v);
            uint16_t u16 = static_cast<uint16_t>(*u32 >> 16u);
            // Mark the LSB
            const uint16_t lsb = u16 & 0x0001;
            // Mark the error to be truncate (the rest of 16 bits of FP32)
            const uint16_t error = static_cast<uint16_t>((*u32 & 0x0000FFFF));
            if ((error > 0x8000 || (error == 0x8000 && lsb == 1)))
            {
                u16++;
            }
            BFloat16 b(u16);
            return b;
        }
    }

    float ToFloat32() const
    {
        const uint32_t u32 = static_cast<uint32_t>(m_Value << 16u);
        const float* f32 = reinterpret_cast<const float*>(&u32);
        return *f32;
    }

    uint16_t Val() const
    {
        return m_Value;
    }

    static BFloat16 Max()
    {
        uint16_t max = 0x7F7F;
        return BFloat16(max);
    }

    static BFloat16 Nan()
    {
        uint16_t nan = 0x7FC0;
        return BFloat16(nan);
    }

    static BFloat16 Inf()
    {
        uint16_t infVal = 0x7F80;
        return BFloat16(infVal);
    }

private:
    uint16_t m_Value;
};

inline std::ostream& operator<<(std::ostream& os, const BFloat16& b)
{
    os << b.ToFloat32() << "(0x" << std::hex << b.Val() << ")";
    return os;
}

} //namespace armnn