diff options
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/BaseIterator.hpp | 57 | ||||
-rw-r--r-- | src/backends/reference/workloads/Decoders.hpp | 15 | ||||
-rw-r--r-- | src/backends/reference/workloads/Encoders.hpp | 6 |
3 files changed, 64 insertions, 14 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index c9fd773d5e..18270faf46 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -5,6 +5,8 @@ #pragma once +#include "FloatingPointConverter.hpp" + #include <armnn/ArmNN.hpp> #include <ResolveType.hpp> @@ -142,14 +144,31 @@ private: const int32_t m_Offset; }; -class FloatDecoder : public TypedIterator<const float, Decoder<float>> +class Float16Decoder : public TypedIterator<const Half, Decoder<float>> { public: - FloatDecoder(const float* data) + Float16Decoder(const Half* data) : TypedIterator(data) {} - FloatDecoder() - : FloatDecoder(nullptr) {} + Float16Decoder() + : Float16Decoder(nullptr) {} + + float Get() const override + { + float val = 0.f; + armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val); + return val; + } +}; + +class Float32Decoder : public TypedIterator<const float, Decoder<float>> +{ +public: + Float32Decoder(const float* data) + : TypedIterator(data) {} + + Float32Decoder() + : Float32Decoder(nullptr) {} float Get() const override { @@ -238,14 +257,36 @@ private: const int32_t m_Offset; }; -class FloatEncoder : public TypedIterator<float, Encoder<float>> +class Float16Encoder : public TypedIterator<Half, Encoder<float>> +{ +public: + Float16Encoder(Half* data) + : TypedIterator(data) {} + + Float16Encoder() + : Float16Encoder(nullptr) {} + + void Set(float right) override + { + armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator); + } + + float Get() const override + { + float val = 0.f; + armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val); + return val; + } +}; + +class Float32Encoder : public TypedIterator<float, Encoder<float>> { public: - FloatEncoder(float* data) + Float32Encoder(float* data) : TypedIterator(data) {} - FloatEncoder() - : FloatEncoder(nullptr) {} + Float32Encoder() + : Float32Encoder(nullptr) {} void Set(float right) override { diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp index 0101789bec..328a5eb0f7 100644 --- a/src/backends/reference/workloads/Decoders.hpp +++ b/src/backends/reference/workloads/Decoders.hpp @@ -6,6 +6,7 @@ #pragma once #include "BaseIterator.hpp" +#include "FloatingPointConverter.hpp" #include <boost/assert.hpp> @@ -20,25 +21,29 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const { switch(info.GetDataType()) { - case armnn::DataType::QuantisedAsymm8: + case DataType::QuantisedAsymm8: { return std::make_unique<QASymm8Decoder>( static_cast<const uint8_t*>(data), info.GetQuantizationScale(), info.GetQuantizationOffset()); } - case armnn::DataType::QuantisedSymm16: + case DataType::QuantisedSymm16: { return std::make_unique<QSymm16Decoder>( static_cast<const int16_t*>(data), info.GetQuantizationScale(), info.GetQuantizationOffset()); } - case armnn::DataType::Float32: + case DataType::Float16: { - return std::make_unique<FloatDecoder>(static_cast<const float*>(data)); + return std::make_unique<Float16Decoder>(static_cast<const Half*>(data)); } - case armnn::DataType::Signed32: + case DataType::Float32: + { + return std::make_unique<Float32Decoder>(static_cast<const float*>(data)); + } + case DataType::Signed32: { const float scale = info.GetQuantizationScale(); if (scale == 0.f) diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index f0e40d224b..2b3a11af06 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -38,9 +38,13 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* { return std::make_unique<Int32Encoder>(static_cast<int32_t*>(data)); } + case armnn::DataType::Float16: + { + return std::make_unique<Float16Encoder>(static_cast<Half*>(data)); + } case armnn::DataType::Float32: { - return std::make_unique<FloatEncoder>(static_cast<float*>(data)); + return std::make_unique<Float32Encoder>(static_cast<float*>(data)); } default: { |