diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/workloads/BaseIterator.hpp | 39 | ||||
-rw-r--r-- | src/backends/reference/workloads/Decoders.hpp | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/Encoders.hpp | 4 |
3 files changed, 47 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index c48201837b..3f0144670f 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -194,6 +194,23 @@ private: const int32_t m_Offset; }; +class BFloat16Decoder : public TypedIterator<const BFloat16, Decoder<float>> +{ +public: + BFloat16Decoder(const BFloat16* data) + : TypedIterator(data) {} + + BFloat16Decoder() + : BFloat16Decoder(nullptr) {} + + float Get() const override + { + float val = 0.f; + armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val); + return val; + } +}; + class Float16Decoder : public TypedIterator<const Half, Decoder<float>> { public: @@ -355,6 +372,28 @@ private: const int32_t m_Offset; }; +class BFloat16Encoder : public TypedIterator<armnn::BFloat16, Encoder<float>> +{ +public: + BFloat16Encoder(armnn::BFloat16* data) + : TypedIterator(data) {} + + BFloat16Encoder() + : BFloat16Encoder(nullptr) {} + + void Set(float right) override + { + armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(&right, 1, m_Iterator); + } + + float Get() const override + { + float val = 0.f; + armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(m_Iterator, 1, &val); + return val; + } +}; + class Float16Encoder : public TypedIterator<Half, Encoder<float>> { public: diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp index 6a8c756048..83c57c1169 100644 --- a/src/backends/reference/workloads/Decoders.hpp +++ b/src/backends/reference/workloads/Decoders.hpp @@ -102,6 +102,10 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const info.GetQuantizationScale(), info.GetQuantizationOffset()); } + case DataType::BFloat16: + { + return std::make_unique<BFloat16Decoder>(static_cast<const BFloat16*>(data)); + } case DataType::Float16: { return std::make_unique<Float16Decoder>(static_cast<const Half*>(data)); diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index f52297602f..e93987da31 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -75,6 +75,10 @@ inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* { return std::make_unique<Int32Encoder>(static_cast<int32_t*>(data)); } + case armnn::DataType::BFloat16: + { + return std::make_unique<BFloat16Encoder>(static_cast<armnn::BFloat16*>(data)); + } case armnn::DataType::Float16: { return std::make_unique<Float16Encoder>(static_cast<Half*>(data)); |