diff options
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r-- | src/backends/reference/workloads/BaseIterator.hpp | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index 9fe3f15f9b..50475312a5 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -451,4 +451,25 @@ private: std::vector<float> m_Scale; }; -} //namespace armnn
\ No newline at end of file +class ScaledInt32PerAxisDecoder : public PerAxisIterator<const int32_t, Decoder<float>> +{ +public: + ScaledInt32PerAxisDecoder(const int32_t* data, const std::vector<float>& scales, unsigned int axisFactor) + : PerAxisIterator(data, axisFactor), m_Scales(scales) {} + + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scales[m_AxisIndex], 0); + } + + // Get scale of the current value + float GetScale() const + { + return m_Scales[m_AxisIndex]; + } + +private: + std::vector<float> m_Scales; +}; + +} // namespace armnn |