diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2019-04-03 17:48:18 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2019-04-08 15:48:28 +0000 |
commit | 2e6dc3a1c5d47825535db7993ba77eb1596ae99b (patch) | |
tree | 48e73fa1862d17534804d1699bedb76120e88c9f /src/backends/reference/workloads/BaseIterator.hpp | |
parent | 0324f48e64edb99a5c8d819394545d97e0c2ae97 (diff) | |
download | armnn-2e6dc3a1c5d47825535db7993ba77eb1596ae99b.tar.gz |
IVGCVSW-2861 Refactor the Reference Elementwise workload
* Refactor Reference Comparison workload
* Removed templating based on the DataType
* Implemented BaseIterator to do decode/encode
Change-Id: I18f299f47ee23772f90152c1146b42f07465e105
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r-- | src/backends/reference/workloads/BaseIterator.hpp | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp new file mode 100644 index 0000000000..cfa8ce7e91 --- /dev/null +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -0,0 +1,155 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/ArmNN.hpp> +#include <TypeUtils.hpp> + +namespace armnn +{ + +class BaseIterator +{ +public: + BaseIterator() {} + + virtual ~BaseIterator() {} + + virtual BaseIterator& operator++() = 0; + + virtual BaseIterator& operator+=(const unsigned int increment) = 0; + + virtual BaseIterator& operator-=(const unsigned int increment) = 0; +}; + +class Decoder : public BaseIterator +{ +public: + Decoder() : BaseIterator() {} + + virtual ~Decoder() {} + + virtual float Get() const = 0; +}; + +class Encoder : public BaseIterator +{ +public: + Encoder() : BaseIterator() {} + + virtual ~Encoder() {} + + virtual void Set(const float& right) = 0; +}; + +class ComparisonEncoder : public BaseIterator +{ +public: + ComparisonEncoder() : BaseIterator() {} + + virtual ~ComparisonEncoder() {} + + virtual void Set(bool right) = 0; +}; + +template<typename T, typename Base> +class TypedIterator : public Base +{ +public: + TypedIterator(T* data) + : m_Iterator(data) + {} + + TypedIterator& operator++() override + { + ++m_Iterator; + return *this; + } + + TypedIterator& operator+=(const unsigned int increment) override + { + m_Iterator += increment; + return *this; + } + + TypedIterator& operator-=(const unsigned int increment) override + { + m_Iterator -= increment; + return *this; + } + + T* m_Iterator; +}; + +class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder> +{ +public: + QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset) + : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); + } + +private: + const float m_Scale; + const int32_t m_Offset; +}; + +class FloatDecoder : public TypedIterator<const float, Decoder> +{ +public: + FloatDecoder(const float* data) + : TypedIterator(data) {} + + float Get() const override + { + return *m_Iterator; + } +}; + +class FloatEncoder : public TypedIterator<float, Encoder> +{ +public: + FloatEncoder(float* data) + : TypedIterator(data) {} + + void Set(const float& right) override + { + *m_Iterator = right; + } +}; + +class QASymm8Encoder : public TypedIterator<uint8_t, Encoder> +{ +public: + QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset) + : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + + void Set(const float& right) override + { + *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset); + } + +private: + const float m_Scale; + const int32_t m_Offset; +}; + +class BooleanEncoder : public TypedIterator<uint8_t, ComparisonEncoder> +{ +public: + BooleanEncoder(uint8_t* data) + : TypedIterator(data) {} + + void Set(bool right) override + { + *m_Iterator = right; + } +}; + +} //namespace armnn
\ No newline at end of file |