aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BaseIterator.hpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-03-06 14:45:57 +0000
committerJan Eilers <jan.eilers@arm.com>2020-03-09 16:13:56 +0000
commit8832522f47b701f5f042069e7bf8deae9b75d449 (patch)
treef217ab7fbda860a947eba88c9508eb1ac1b1d670 /src/backends/reference/workloads/BaseIterator.hpp
parent97da5e2e6c8aaaf4249af60e8305431315226f15 (diff)
downloadarmnn-8832522f47b701f5f042069e7bf8deae9b75d449.tar.gz
IVGCVSW-4517 Implement BFloat16 Encoder and Decoder
* Add ConvertFloat32ToBFloat16 * Add ConvertBFloat16ToFloat32 * Add BFloat16Encoder * Add BFloat16Decoder * Unit tests Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I198888384c923aba28cfbed09a02edc6f8194b3e
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp39
1 files changed, 39 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index c48201837b..3f0144670f 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -194,6 +194,23 @@ private:
const int32_t m_Offset;
};
+class BFloat16Decoder : public TypedIterator<const BFloat16, Decoder<float>>
+{
+public:
+ BFloat16Decoder(const BFloat16* data)
+ : TypedIterator(data) {}
+
+ BFloat16Decoder()
+ : BFloat16Decoder(nullptr) {}
+
+ float Get() const override
+ {
+ float val = 0.f;
+ armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
+ return val;
+ }
+};
+
class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
{
public:
@@ -355,6 +372,28 @@ private:
const int32_t m_Offset;
};
+class BFloat16Encoder : public TypedIterator<armnn::BFloat16, Encoder<float>>
+{
+public:
+ BFloat16Encoder(armnn::BFloat16* data)
+ : TypedIterator(data) {}
+
+ BFloat16Encoder()
+ : BFloat16Encoder(nullptr) {}
+
+ void Set(float right) override
+ {
+ armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(&right, 1, m_Iterator);
+ }
+
+ float Get() const override
+ {
+ float val = 0.f;
+ armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
+ return val;
+ }
+};
+
class Float16Encoder : public TypedIterator<Half, Encoder<float>>
{
public: