aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/BaseIterator.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/BaseIterator.hpp')
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp114
1 files changed, 113 insertions, 1 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index 18270faf46..9fe3f15f9b 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -339,4 +339,116 @@ public:
}
};
-} //namespace armnn
+// PerAxisIterator for per-axis quantization
+template<typename T, typename Base>
+class PerAxisIterator : public Base
+{
+public:
+ // axisFactor is used to calculate axisIndex
+ PerAxisIterator(T* data = nullptr, unsigned int axisFactor = 0)
+ : m_Iterator(data), m_Start(data), m_AxisIndex(0), m_AxisFactor(axisFactor)
+ {}
+
+ // This should be called to set index for per-axis Encoder/Decoder
+ PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex)
+ {
+ BOOST_ASSERT(m_Iterator);
+ m_Iterator = m_Start + index;
+ m_AxisIndex = axisIndex;
+ return *this;
+ }
+
+ void Reset(void* data) override
+ {
+ m_Iterator = reinterpret_cast<T*>(data);
+ m_Start = m_Iterator;
+ m_AxisIndex = 0;
+ }
+
+ PerAxisIterator& operator++() override
+ {
+ BOOST_ASSERT(m_Iterator);
+ ++m_Iterator;
+ m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
+ return *this;
+ }
+
+ PerAxisIterator& operator+=(const unsigned int increment) override
+ {
+ BOOST_ASSERT(m_Iterator);
+ m_Iterator += increment;
+ m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
+ return *this;
+ }
+
+ PerAxisIterator& operator-=(const unsigned int decrement) override
+ {
+ BOOST_ASSERT(m_Iterator);
+ m_Iterator -= decrement;
+ m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
+ return *this;
+ }
+
+ PerAxisIterator& operator[](const unsigned int index) override
+ {
+ BOOST_ASSERT(m_Iterator);
+ m_Iterator = m_Start + index;
+ m_AxisIndex = static_cast<unsigned int>(*m_Iterator) % m_AxisFactor;
+ return *this;
+ }
+
+ protected:
+ T* m_Iterator;
+ T* m_Start;
+ unsigned int m_AxisIndex;
+ unsigned int m_AxisFactor;
+};
+
+class QSymm8PerAxisDecoder : public PerAxisIterator<const int8_t, Decoder<float>>
+{
+public:
+ QSymm8PerAxisDecoder(const int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
+ : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
+
+ float Get() const override
+ {
+ return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
+ }
+
+ // Get scale of the current value
+ float GetScale() const
+ {
+ return m_Scale[m_AxisIndex];
+ }
+
+private:
+ std::vector<float> m_Scale;
+};
+
+class QSymm8PerAxisEncoder : public PerAxisIterator<int8_t, Encoder<float>>
+{
+public:
+ QSymm8PerAxisEncoder(int8_t* data, const std::vector<float>& scale, unsigned int axisFactor)
+ : PerAxisIterator(data, axisFactor), m_Scale(scale) {}
+
+ void Set(float right)
+ {
+ *m_Iterator = armnn::Quantize<int8_t>(right, m_Scale[m_AxisIndex], 0);
+ }
+
+ float Get() const
+ {
+ return armnn::Dequantize(*m_Iterator, m_Scale[m_AxisIndex], 0);
+ }
+
+ // Get scale of the current value
+ float GetScale() const
+ {
+ return m_Scale[m_AxisIndex];
+ }
+
+private:
+ std::vector<float> m_Scale;
+};
+
+} //namespace armnn \ No newline at end of file