From 5edc8816118fcddb2681379db04c978041ce8b46 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Tue, 5 Nov 2019 18:00:21 +0000 Subject: IVGCVSW-3837 Add support for per-axis quantization to reference Convolution2d workload Signed-off-by: Aron Virginas-Tar Change-Id: I0ac08ba4864d48e6f64c4ac645dad8ea850be112 --- src/backends/reference/workloads/BaseIterator.hpp | 13 ++++++++++++- src/backends/reference/workloads/ConvImpl.cpp | 6 +++--- 2 files changed, 15 insertions(+), 4 deletions(-) (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index 50475312a5..95a31fbdd6 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -11,6 +11,7 @@ #include #include +#include namespace armnn { @@ -22,6 +23,8 @@ public: virtual ~BaseIterator() {} + virtual BaseIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) = 0; + virtual BaseIterator& operator++() = 0; virtual BaseIterator& operator+=(const unsigned int increment) = 0; @@ -101,6 +104,14 @@ public: return *this; } + TypedIterator& SetIndex(unsigned int index, unsigned int axisIndex = 0) override + { + boost::ignore_unused(axisIndex); + BOOST_ASSERT(m_Iterator); + m_Iterator = m_Start + index; + return *this; + } + protected: T* m_Iterator; T* m_Start; @@ -350,7 +361,7 @@ public: {} // This should be called to set index for per-axis Encoder/Decoder - PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) + PerAxisIterator& SetIndex(unsigned int index, unsigned int axisIndex) override { BOOST_ASSERT(m_Iterator); m_Iterator = m_Start + index; diff --git a/src/backends/reference/workloads/ConvImpl.cpp b/src/backends/reference/workloads/ConvImpl.cpp index 92e3b2d7dd..0c13e3ba0d 100644 --- a/src/backends/reference/workloads/ConvImpl.cpp +++ b/src/backends/reference/workloads/ConvImpl.cpp @@ -165,7 +165,7 @@ void Convolve(const TensorShape& rInputShape, } } - rFilterDecoder[filterIndex]; + rFilterDecoder.SetIndex(filterIndex, cOutput); float filterValue = rFilterDecoder.Get(); unsigned int yInput = yOutput * yStride + yFilter * yDilation; @@ -211,7 +211,7 @@ void Convolve(const TensorShape& rInputShape, if (biasEnabled) { - (*pBiasDecoder)[cOutput]; + (*pBiasDecoder).SetIndex(cOutput, cOutput); sum += pBiasDecoder->Get(); } @@ -225,4 +225,4 @@ void Convolve(const TensorShape& rInputShape, } } -} //namespace armnn +} // namespace armnn -- cgit v1.2.1