aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp39
-rw-r--r--src/backends/reference/workloads/Decoders.hpp4
-rw-r--r--src/backends/reference/workloads/Encoders.hpp4
3 files changed, 47 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index c48201837b..3f0144670f 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -194,6 +194,23 @@ private:
const int32_t m_Offset;
};
+class BFloat16Decoder : public TypedIterator<const BFloat16, Decoder<float>>
+{
+public:
+ BFloat16Decoder(const BFloat16* data)
+ : TypedIterator(data) {}
+
+ BFloat16Decoder()
+ : BFloat16Decoder(nullptr) {}
+
+ float Get() const override
+ {
+ float val = 0.f;
+ armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
+ return val;
+ }
+};
+
class Float16Decoder : public TypedIterator<const Half, Decoder<float>>
{
public:
@@ -355,6 +372,28 @@ private:
const int32_t m_Offset;
};
+class BFloat16Encoder : public TypedIterator<armnn::BFloat16, Encoder<float>>
+{
+public:
+ BFloat16Encoder(armnn::BFloat16* data)
+ : TypedIterator(data) {}
+
+ BFloat16Encoder()
+ : BFloat16Encoder(nullptr) {}
+
+ void Set(float right) override
+ {
+ armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(&right, 1, m_Iterator);
+ }
+
+ float Get() const override
+ {
+ float val = 0.f;
+ armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val);
+ return val;
+ }
+};
+
class Float16Encoder : public TypedIterator<Half, Encoder<float>>
{
public:
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index 6a8c756048..83c57c1169 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -102,6 +102,10 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
info.GetQuantizationScale(),
info.GetQuantizationOffset());
}
+ case DataType::BFloat16:
+ {
+ return std::make_unique<BFloat16Decoder>(static_cast<const BFloat16*>(data));
+ }
case DataType::Float16:
{
return std::make_unique<Float16Decoder>(static_cast<const Half*>(data));
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
index f52297602f..e93987da31 100644
--- a/src/backends/reference/workloads/Encoders.hpp
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -75,6 +75,10 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void*
{
return std::make_unique<Int32Encoder>(static_cast<int32_t*>(data));
}
+ case armnn::DataType::BFloat16:
+ {
+ return std::make_unique<BFloat16Encoder>(static_cast<armnn::BFloat16*>(data));
+ }
case armnn::DataType::Float16:
{
return std::make_unique<Float16Encoder>(static_cast<Half*>(data));