diff options
author | Keith Davis <keith.davis@arm.com> | 2019-11-04 08:58:33 +0000 |
---|---|---|
committer | Keith Davis <keith.davis@arm.com> | 2019-11-04 16:46:35 +0000 |
commit | 5236e1d6bcff6ebec7ec10d7d416cc6ead5482dd (patch) | |
tree | 4152c5fcd6b9c11848a02dfa4ff8705a2cfae0a5 /src/backends/reference/workloads/BaseIterator.hpp | |
parent | f71079328ae72a65c91e410b2bd35eabb67cb6d1 (diff) | |
download | armnn-5236e1d6bcff6ebec7ec10d7d416cc6ead5482dd.tar.gz |
IVGCVSW-3835 Create Encoder and Decoder for QSymm8PerAxis
* Add QuantizedSymm8PerAxis to armnn DataType (types.hpp) and
* Add Quantize and Dequantize template for int8 in TypeUtils to be able to compute QSymm8 of the weight
* Create PerAxisIterator for per-axis quantization
* Create QSymm8PerAxisDecoder
* Create QSymm8PerAxisEncoder
Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: Ibcfe0288a197b7ee50b543bdbd77b7edb8a547c2
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r-- | src/backends/reference/workloads/BaseIterator.hpp | 114 |
1 files changed, 113 insertions, 1 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index 18270faf46..9fe3f15f9b 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -339,4 +339,116 @@ public: } }; -} //namespace armnn +// PerAxisIterator for per-axis quantization +template<typename T, typename Base> +class PerAxisIterator : public Base +{ +public: + // axisFactor is used to calculate axisIndex + PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0) + : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor) + {} + + // This should be called to set index for per-axis Encoder/Decoder + PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) + { + BOOST_ASSERT(m_Iterator); + m_Iterator = m_Start + index; + m_AxisIndex = axisIndex; + return *this; + } + + void Reset(void* data) override + { + m_Iterator = reinterpret_cast<T*>(data); + m_Start = m_Iterator; + m_AxisIndex = 0; + } + + PerAxisIterator& operator++() override + { + BOOST_ASSERT(m_Iterator); + ++m_Iterator; + m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor; + return *this; + } + + PerAxisIterator& operator+=(const unsigned int increment) override + { + BOOST_ASSERT(m_Iterator); + m_Iterator += increment; + m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor; + return *this; + } + + PerAxisIterator& operator-=(const unsigned int decrement) override + { + BOOST_ASSERT(m_Iterator); + m_Iterator -= decrement; + m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor; + return *this; + } + + PerAxisIterator& operator[](const unsigned int index) override + { + BOOST_ASSERT(m_Iterator); + m_Iterator = m_Start + index; + m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor; + return *this; + } + + protected: + T* m_Iterator; + T* m_Start; + unsigned int m_AxisIndex; + unsigned int m_AxisFactor; +}; + +class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>> +{ +public: + QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor) + : PerAxisIterator(data, axisFactor), m_Scale(scale) {} + + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0); + } + + // Get scale of the current value + float GetScale() const + { + return m_Scale[m_AxisIndex]; + } + +private: + std::vector<float> m_Scale; +}; + +class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>> +{ +public: + QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor) + : PerAxisIterator(data, axisFactor), m_Scale(scale) {} + + void Set(float right) + { + *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0); + } + + float Get() const + { + return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0); + } + + // Get scale of the current value + float GetScale() const + { + return m_Scale[m_AxisIndex]; + } + +private: + std::vector<float> m_Scale; +}; + +} //namespace armnn
\ No newline at end of file |