aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BaseIterator.hpp
blob: cfa8ce7e91d3bc4a21134140b362e3b75c0a4e9d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/ArmNN.hpp>
#include <TypeUtils.hpp>

namespace armnn
{

class BaseIterator
{
public:
    BaseIterator() {}

    virtual ~BaseIterator() {}

    virtual BaseIterator& operator++() = 0;

    virtual BaseIterator& operator+=(const unsigned int increment) = 0;

    virtual BaseIterator& operator-=(const unsigned int increment) = 0;
};

class Decoder : public BaseIterator
{
public:
    Decoder() : BaseIterator() {}

    virtual ~Decoder() {}

    virtual float Get() const = 0;
};

class Encoder : public BaseIterator
{
public:
    Encoder() : BaseIterator() {}

    virtual ~Encoder() {}

    virtual void Set(const float& right) = 0;
};

class ComparisonEncoder : public BaseIterator
{
public:
    ComparisonEncoder() : BaseIterator() {}

    virtual ~ComparisonEncoder() {}

    virtual void Set(bool right) = 0;
};

template<typename T, typename Base>
class TypedIterator : public Base
{
public:
    TypedIterator(T* data)
        : m_Iterator(data)
    {}

    TypedIterator& operator++() override
    {
        ++m_Iterator;
        return *this;
    }

    TypedIterator& operator+=(const unsigned int increment) override
    {
        m_Iterator += increment;
        return *this;
    }

    TypedIterator& operator-=(const unsigned int increment) override
    {
        m_Iterator -= increment;
        return *this;
    }

    T* m_Iterator;
};

class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder>
{
public:
    QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
        : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}

    float Get() const override
    {
        return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
    }

private:
    const float m_Scale;
    const int32_t m_Offset;
};

class FloatDecoder : public TypedIterator<const float, Decoder>
{
public:
    FloatDecoder(const float* data)
        : TypedIterator(data) {}

    float Get() const override
    {
        return *m_Iterator;
    }
};

class FloatEncoder : public TypedIterator<float, Encoder>
{
public:
    FloatEncoder(float* data)
        : TypedIterator(data) {}

    void Set(const float& right) override
    {
        *m_Iterator = right;
    }
};

class QASymm8Encoder : public TypedIterator<uint8_t, Encoder>
{
public:
    QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
        : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}

    void Set(const float& right) override
    {
        *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
    }

private:
    const float m_Scale;
    const int32_t m_Offset;
};

class BooleanEncoder : public TypedIterator<uint8_t, ComparisonEncoder>
{
public:
    BooleanEncoder(uint8_t* data)
        : TypedIterator(data) {}

    void Set(bool right) override
    {
        *m_Iterator = right;
    }
};

} //namespace armnn