From c394a6d17008f876c73e94883f0c59aeedfe73f0 Mon Sep 17 00:00:00 2001 From: Matthew Bentham Date: Mon, 24 Jun 2019 12:51:25 +0100 Subject: IVGCVSW-3307 Don't assume TensorInfo::Map() can be called before Execute() Change-Id: I445c69d2e99d8c93622e739af61f721e61b0f90f Signed-off-by: Matthew Bentham --- src/backends/reference/workloads/BaseIterator.hpp | 42 ++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) (limited to 'src/backends/reference/workloads/BaseIterator.hpp') diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index 26b0179e71..5583fe79ad 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -8,6 +8,8 @@ #include #include +#include + namespace armnn { @@ -35,6 +37,8 @@ public: virtual ~Decoder() {} + virtual void Reset(void*) = 0; + virtual IType Get() const = 0; }; @@ -46,6 +50,8 @@ public: virtual ~Encoder() {} + virtual void Reset(void*) = 0; + virtual void Set(IType right) = 0; virtual IType Get() const = 0; @@ -55,30 +61,40 @@ template class TypedIterator : public Base { public: - TypedIterator(T* data) + TypedIterator(T* data = nullptr) : m_Iterator(data), m_Start(data) {} + void Reset(void* data) override + { + m_Iterator = reinterpret_cast(data); + m_Start = m_Iterator; + } + TypedIterator& operator++() override { + BOOST_ASSERT(m_Iterator); ++m_Iterator; return *this; } TypedIterator& operator+=(const unsigned int increment) override { + BOOST_ASSERT(m_Iterator); m_Iterator += increment; return *this; } TypedIterator& operator-=(const unsigned int increment) override { + BOOST_ASSERT(m_Iterator); m_Iterator -= increment; return *this; } TypedIterator& operator[](const unsigned int index) override { + BOOST_ASSERT(m_Iterator); m_Iterator = m_Start + index; return *this; } @@ -94,6 +110,9 @@ public: QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset) : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + QASymm8Decoder(const float scale, const int32_t offset) + : QASymm8Decoder(nullptr, scale, offset) {} + float Get() const override { return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); @@ -110,6 +129,9 @@ public: QSymm16Decoder(const int16_t* data, const float scale, const int32_t offset) : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + QSymm16Decoder(const float scale, const int32_t offset) + : QSymm16Decoder(nullptr, scale, offset) {} + float Get() const override { return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); @@ -126,6 +148,9 @@ public: FloatDecoder(const float* data) : TypedIterator(data) {} + FloatDecoder() + : FloatDecoder(nullptr) {} + float Get() const override { return *m_Iterator; @@ -138,6 +163,9 @@ public: ScaledInt32Decoder(const int32_t* data, const float scale) : TypedIterator(data), m_Scale(scale) {} + ScaledInt32Decoder(const float scale) + : ScaledInt32Decoder(nullptr, scale) {} + float Get() const override { return static_cast(*m_Iterator) * m_Scale; @@ -153,6 +181,9 @@ public: QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset) : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + QASymm8Encoder(const float scale, const int32_t offset) + : QASymm8Encoder(nullptr, scale, offset) {} + void Set(float right) override { *m_Iterator = armnn::Quantize(right, m_Scale, m_Offset); @@ -174,6 +205,9 @@ public: QSymm16Encoder(int16_t* data, const float scale, const int32_t offset) : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + QSymm16Encoder(const float scale, const int32_t offset) + : QSymm16Encoder(nullptr, scale, offset) {} + void Set(float right) override { *m_Iterator = armnn::Quantize(right, m_Scale, m_Offset); @@ -195,6 +229,9 @@ public: FloatEncoder(float* data) : TypedIterator(data) {} + FloatEncoder() + : FloatEncoder(nullptr) {} + void Set(float right) override { *m_Iterator = right; @@ -212,6 +249,9 @@ public: BooleanEncoder(uint8_t* data) : TypedIterator(data) {} + BooleanEncoder() + : BooleanEncoder(nullptr) {} + void Set(bool right) override { *m_Iterator = right; -- cgit v1.2.1