aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BaseIterator.hpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-04-03 17:48:18 +0100
committerSadik Armagan <sadik.armagan@arm.com>2019-04-08 15:48:28 +0000
commit2e6dc3a1c5d47825535db7993ba77eb1596ae99b (patch)
tree48e73fa1862d17534804d1699bedb76120e88c9f /src/backends/reference/workloads/BaseIterator.hpp
parent0324f48e64edb99a5c8d819394545d97e0c2ae97 (diff)
downloadarmnn-2e6dc3a1c5d47825535db7993ba77eb1596ae99b.tar.gz
IVGCVSW-2861 Refactor the Reference Elementwise workload
* Refactor Reference Comparison workload * Removed templating based on the DataType * Implemented BaseIterator to do decode/encode Change-Id: I18f299f47ee23772f90152c1146b42f07465e105 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Signed-off-by: Kevin May <kevin.may@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp155
1 files changed, 155 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
new file mode 100644
index 0000000000..cfa8ce7e91
--- /dev/null
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -0,0 +1,155 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/ArmNN.hpp>
+#include <TypeUtils.hpp>
+
+namespace armnn
+{
+
+class BaseIterator
+{
+public:
+ BaseIterator() {}
+
+ virtual ~BaseIterator() {}
+
+ virtual BaseIterator& operator++() = 0;
+
+ virtual BaseIterator& operator+=(const unsigned int increment) = 0;
+
+ virtual BaseIterator& operator-=(const unsigned int increment) = 0;
+};
+
+class Decoder : public BaseIterator
+{
+public:
+ Decoder() : BaseIterator() {}
+
+ virtual ~Decoder() {}
+
+ virtual float Get() const = 0;
+};
+
+class Encoder : public BaseIterator
+{
+public:
+ Encoder() : BaseIterator() {}
+
+ virtual ~Encoder() {}
+
+ virtual void Set(const float& right) = 0;
+};
+
+class ComparisonEncoder : public BaseIterator
+{
+public:
+ ComparisonEncoder() : BaseIterator() {}
+
+ virtual ~ComparisonEncoder() {}
+
+ virtual void Set(bool right) = 0;
+};
+
+template<typename T, typename Base>
+class TypedIterator : public Base
+{
+public:
+ TypedIterator(T* data)
+ : m_Iterator(data)
+ {}
+
+ TypedIterator& operator++() override
+ {
+ ++m_Iterator;
+ return *this;
+ }
+
+ TypedIterator& operator+=(const unsigned int increment) override
+ {
+ m_Iterator += increment;
+ return *this;
+ }
+
+ TypedIterator& operator-=(const unsigned int increment) override
+ {
+ m_Iterator -= increment;
+ return *this;
+ }
+
+ T* m_Iterator;
+};
+
+class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder>
+{
+public:
+ QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset)
+ : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+
+ float Get() const override
+ {
+ return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset);
+ }
+
+private:
+ const float m_Scale;
+ const int32_t m_Offset;
+};
+
+class FloatDecoder : public TypedIterator<const float, Decoder>
+{
+public:
+ FloatDecoder(const float* data)
+ : TypedIterator(data) {}
+
+ float Get() const override
+ {
+ return *m_Iterator;
+ }
+};
+
+class FloatEncoder : public TypedIterator<float, Encoder>
+{
+public:
+ FloatEncoder(float* data)
+ : TypedIterator(data) {}
+
+ void Set(const float& right) override
+ {
+ *m_Iterator = right;
+ }
+};
+
+class QASymm8Encoder : public TypedIterator<uint8_t, Encoder>
+{
+public:
+ QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset)
+ : TypedIterator(data), m_Scale(scale), m_Offset(offset) {}
+
+ void Set(const float& right) override
+ {
+ *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset);
+ }
+
+private:
+ const float m_Scale;
+ const int32_t m_Offset;
+};
+
+class BooleanEncoder : public TypedIterator<uint8_t, ComparisonEncoder>
+{
+public:
+ BooleanEncoder(uint8_t* data)
+ : TypedIterator(data) {}
+
+ void Set(bool right) override
+ {
+ *m_Iterator = right;
+ }
+};
+
+} //namespace armnn \ No newline at end of file