From e69c399dcee1e75ebf9b2b12f72f3ad628c4e104 Mon Sep 17 00:00:00 2001 From: Matthew Jackson Date: Mon, 9 Sep 2019 14:31:21 +0100 Subject: IVGCVSW-3824 Implement Float 16 Encoder and Decoder * Implement Float 16 Encoder and Decoder * Add Stack Float 16 layer and create workload tests Signed-off-by: Matthew Jackson Change-Id: Ice4678226f4d22c06ebcc6db3052d42ce0c1bd67 --- src/backends/reference/workloads/BaseIterator.hpp | 57 +++++++++++++++++++---- src/backends/reference/workloads/Decoders.hpp | 15 ++++-- src/backends/reference/workloads/Encoders.hpp | 6 ++- 3 files changed, 64 insertions(+), 14 deletions(-) (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index c9fd773d5e..18270faf46 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -5,6 +5,8 @@ #pragma once +#include "FloatingPointConverter.hpp" + #include #include @@ -142,14 +144,31 @@ private: const int32_t m_Offset; }; -class FloatDecoder : public TypedIterator> +class Float16Decoder : public TypedIterator> { public: - FloatDecoder(const float* data) + Float16Decoder(const Half* data) : TypedIterator(data) {} - FloatDecoder() - : FloatDecoder(nullptr) {} + Float16Decoder() + : Float16Decoder(nullptr) {} + + float Get() const override + { + float val = 0.f; + armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val); + return val; + } +}; + +class Float32Decoder : public TypedIterator> +{ +public: + Float32Decoder(const float* data) + : TypedIterator(data) {} + + Float32Decoder() + : Float32Decoder(nullptr) {} float Get() const override { @@ -238,14 +257,36 @@ private: const int32_t m_Offset; }; -class FloatEncoder : public TypedIterator> +class Float16Encoder : public TypedIterator> +{ +public: + Float16Encoder(Half* data) + : TypedIterator(data) {} + + Float16Encoder() + : Float16Encoder(nullptr) {} + + void Set(float right) override + { + armnnUtils::FloatingPointConverter::ConvertFloat32To16(&right, 1, m_Iterator); + } + + float Get() const override + { + float val = 0.f; + armnnUtils::FloatingPointConverter::ConvertFloat16To32(m_Iterator, 1, &val); + return val; + } +}; + +class Float32Encoder : public TypedIterator> { public: - FloatEncoder(float* data) + Float32Encoder(float* data) : TypedIterator(data) {} - FloatEncoder() - : FloatEncoder(nullptr) {} + Float32Encoder() + : Float32Encoder(nullptr) {} void Set(float right) override { diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp index 0101789bec..328a5eb0f7 100644 --- a/src/backends/reference/workloads/Decoders.hpp +++ b/src/backends/reference/workloads/Decoders.hpp @@ -6,6 +6,7 @@ #pragma once #include "BaseIterator.hpp" +#include "FloatingPointConverter.hpp" #include @@ -20,25 +21,29 @@ inline std::unique_ptr> MakeDecoder(const TensorInfo& info, const { switch(info.GetDataType()) { - case armnn::DataType::QuantisedAsymm8: + case DataType::QuantisedAsymm8: { return std::make_unique( static_cast(data), info.GetQuantizationScale(), info.GetQuantizationOffset()); } - case armnn::DataType::QuantisedSymm16: + case DataType::QuantisedSymm16: { return std::make_unique( static_cast(data), info.GetQuantizationScale(), info.GetQuantizationOffset()); } - case armnn::DataType::Float32: + case DataType::Float16: { - return std::make_unique(static_cast(data)); + return std::make_unique(static_cast(data)); } - case armnn::DataType::Signed32: + case DataType::Float32: + { + return std::make_unique(static_cast(data)); + } + case DataType::Signed32: { const float scale = info.GetQuantizationScale(); if (scale == 0.f) diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index f0e40d224b..2b3a11af06 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -38,9 +38,13 @@ inline std::unique_ptr> MakeEncoder(const TensorInfo& info, void* { return std::make_unique(static_cast(data)); } + case armnn::DataType::Float16: + { + return std::make_unique(static_cast(data)); + } case armnn::DataType::Float32: { - return std::make_unique(static_cast(data)); + return std::make_unique(static_cast(data)); } default: { -- cgit v1.2.1