aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BaseIterator.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp42
1 files changed, 41 insertions, 1 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 26b0179e71..5583fe79ad 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -8,6 +8,8 @@
#include <armnn/ArmNN.hpp>
#include <ResolveType.hpp>
+#include <boost/assert.hpp>
+
namespace armnn
{
@@ -35,6 +37,8 @@ public:
virtual ~Decoder() {}
+ virtual void Reset(void*) = 0;
+
virtual IType Get() const = 0;
};
@@ -46,6 +50,8 @@ public:
virtual ~Encoder() {}
+ virtual void Reset(void*) = 0;
+
virtual void Set(IType right) = 0;
virtual IType Get() const = 0;
@@ -55,30 +61,40 @@ template<typename T, typename Base>
class TypedIterator : public Base
{
public:
- TypedIterator(T* data)
+ TypedIterator(T* data = nullptr)
: m_Iterator(data), m_Start(data)
{}
+ void Reset(void* data) override
+ {
+ m_Iterator = reinterpret_cast<T*>(data);
+ m_Start = m_Iterator;
+ }
+
TypedIterator& operator++() override
{
+ BOOST_ASSERT(m_Iterator);
++m_Iterator;
return *this;
}
TypedIterator& operator+=(const unsigned int increment) override
{
+ BOOST_ASSERT(m_Iterator);
m_Iterator += increment;
return *this;
}
TypedIterator& operator-=(const unsigned int increment) override
{
+ BOOST_ASSERT(m_Iterator);
m_Iterator -= increment;
return *this;
}
TypedIterator& operator[](const unsigned int index) override
{
+ BOOST_ASSERT(m_Iterator);
m_Iterator = m_Start + index;
return *this;
}
@@ -94,6 +110,9 @@ public:
QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
: TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+ QASymm8Decoder(const float scale, const int32_t offset)
+ : QASymm8Decoder(nullptr, scale, offset) {}
+
float Get() const override
{
return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
@@ -110,6 +129,9 @@ public:
QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset)
: TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+ QSymm16Decoder(const float scale, const int32_t offset)
+ : QSymm16Decoder(nullptr, scale, offset) {}
+
float Get() const override
{
return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
@@ -126,6 +148,9 @@ public:
FloatDecoder(const float* data)
: TypedIterator(data) {}
+ FloatDecoder()
+ : FloatDecoder(nullptr) {}
+
float Get() const override
{
return *m_Iterator;
@@ -138,6 +163,9 @@ public:
ScaledInt32Decoder(const int32_t* data, const float scale)
: TypedIterator(data), m_Scale(scale) {}
+ ScaledInt32Decoder(const float scale)
+ : ScaledInt32Decoder(nullptr, scale) {}
+
float Get() const override
{
return static_cast<float>(*m_Iterator) * m_Scale;
@@ -153,6 +181,9 @@ public:
QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
: TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+ QASymm8Encoder(const float scale, const int32_t offset)
+ : QASymm8Encoder(nullptr, scale, offset) {}
+
void Set(float right) override
{
*m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
@@ -174,6 +205,9 @@ public:
QSymm16Encoder(int16_t* data, const float scale, const int32_t offset)
: TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+ QSymm16Encoder(const float scale, const int32_t offset)
+ : QSymm16Encoder(nullptr, scale, offset) {}
+
void Set(float right) override
{
*m_Iterator = armnn::Quantize<int16_t>(right, m_Scale, m_Offset);
@@ -195,6 +229,9 @@ public:
FloatEncoder(float* data)
: TypedIterator(data) {}
+ FloatEncoder()
+ : FloatEncoder(nullptr) {}
+
void Set(float right) override
{
*m_Iterator = right;
@@ -212,6 +249,9 @@ public:
BooleanEncoder(uint8_t* data)
: TypedIterator(data) {}
+ BooleanEncoder()
+ : BooleanEncoder(nullptr) {}
+
void Set(bool right) override
{
*m_Iterator = right;