aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BaseIterator.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp57
1 files changed, 49 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index c9fd773d5e..18270faf46 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -5,6 +5,8 @@
#pragma once
+#include "FloatingPointConverter.hpp"
+
#include <armnn/ArmNN.hpp>
#include <ResolveType.hpp>
@@ -142,14 +144,31 @@ private:
const int32_t m_Offset;
};
-class FloatDecoder : public TypedIterator<const float, Decoder<float>>
+class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
{
public:
- FloatDecoder(const float* data)
+ Float16Decoder(const Half* data)
: TypedIterator(data) {}
- FloatDecoder()
- : FloatDecoder(nullptr) {}
+ Float16Decoder()
+ : Float16Decoder(nullptr) {}
+
+ float Get() const override
+ {
+ float val = 0.f;
+ armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
+ return val;
+ }
+};
+
+class Float32Decoder : public TypedIterator<const float, Decoder<float>>
+{
+public:
+ Float32Decoder(const float* data)
+ : TypedIterator(data) {}
+
+ Float32Decoder()
+ : Float32Decoder(nullptr) {}
float Get() const override
{
@@ -238,14 +257,36 @@ private:
const int32_t m_Offset;
};
-class FloatEncoder : public TypedIterator<float, Encoder<float>>
+class Float16Encoder : public TypedIterator<Half, Encoder<float>>
+{
+public:
+ Float16Encoder(Half* data)
+ : TypedIterator(data) {}
+
+ Float16Encoder()
+ : Float16Encoder(nullptr) {}
+
+ void Set(float right) override
+ {
+ armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator);
+ }
+
+ float Get() const override
+ {
+ float val = 0.f;
+ armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val);
+ return val;
+ }
+};
+
+class Float32Encoder : public TypedIterator<float, Encoder<float>>
{
public:
- FloatEncoder(float* data)
+ Float32Encoder(float* data)
: TypedIterator(data) {}
- FloatEncoder()
- : FloatEncoder(nullptr) {}
+ Float32Encoder()
+ : Float32Encoder(nullptr) {}
void Set(float right) override
{