From cbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Mon, 22 Jun 2020 15:58:32 +0100 Subject: IVGCVSW-5007 Implement an Int32 reference Elementwise workload Signed-off-by: Finn Williams Change-Id: I6592169b74ac4294bc09647879aec0718c641f91 --- src/backends/reference/workloads/BaseIterator.hpp | 35 ++++++++++++++++++++++ src/backends/reference/workloads/Decoders.hpp | 18 +++++++++++ .../reference/workloads/ElementwiseFunction.cpp | 7 +++++ src/backends/reference/workloads/Encoders.hpp | 18 +++++++++++ .../reference/workloads/RefElementwiseWorkload.cpp | 24 +++++++++++++++ .../reference/workloads/RefElementwiseWorkload.hpp | 18 +++++++---- 6 files changed, 114 insertions(+), 6 deletions(-) (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp index be20644ab7..1f4f2da717 100644 --- a/src/backends/reference/workloads/BaseIterator.hpp +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -274,6 +274,21 @@ public: } }; +class Int32ToInt32tDecoder : public TypedIterator> +{ +public: + Int32ToInt32tDecoder(const int32_t* data) + : TypedIterator(data){} + + Int32ToInt32tDecoder() + : Int32ToInt32tDecoder(nullptr) {} + + int32_t Get() const override + { + return *m_Iterator; + } +}; + class BooleanDecoder : public TypedIterator> { public: @@ -470,6 +485,26 @@ public: } }; +class Int32ToInt32tEncoder : public TypedIterator> +{ +public: + Int32ToInt32tEncoder(int32_t* data) + : TypedIterator(data){} + + Int32ToInt32tEncoder() + : Int32ToInt32tEncoder(nullptr) {} + + void Set(int32_t right) override + { + *m_Iterator = right; + } + + int32_t Get() const override + { + return *m_Iterator; + } +}; + class BooleanEncoder : public TypedIterator> { public: diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp index deb3b1f4b2..08e0140fad 100644 --- a/src/backends/reference/workloads/Decoders.hpp +++ b/src/backends/reference/workloads/Decoders.hpp @@ -149,4 +149,22 @@ inline std::unique_ptr> MakeDecoder(const TensorInfo& info, const return nullptr; } +template<> +inline std::unique_ptr> MakeDecoder(const TensorInfo& info, const void* data) +{ + switch(info.GetDataType()) + { + case DataType::Signed32: + { + return std::make_unique(static_cast(data)); + } + default: + { + ARMNN_ASSERT_MSG(false, "Unsupported Data Type!"); + break; + } + } + return nullptr; +} + } //namespace armnn diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index 5687cf5861..afae188bd6 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -46,6 +46,13 @@ template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; + // Comparison template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp index c0524a7719..a2d565ec4a 100644 --- a/src/backends/reference/workloads/Encoders.hpp +++ b/src/backends/reference/workloads/Encoders.hpp @@ -114,4 +114,22 @@ inline std::unique_ptr> MakeEncoder(const TensorInfo& info, void* return nullptr; } +template<> +inline std::unique_ptr> MakeEncoder(const TensorInfo& info, void* data) +{ + switch(info.GetDataType()) + { + case DataType::Signed32: + { + return std::make_unique(static_cast(data)); + } + default: + { + ARMNN_ASSERT_MSG(false, "Unsupported Data Type!"); + break; + } + } + return nullptr; +} + } //namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 18bf0a7ad9..60acbd6252 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -67,22 +67,46 @@ template class armnn::RefElementwiseWorkload, armnn::AdditionQueueDescriptor, armnn::StringMapping::RefAdditionWorkload_Execute>; +template class armnn::RefElementwiseWorkload, + armnn::AdditionQueueDescriptor, + armnn::StringMapping::RefAdditionWorkload_Execute>; + template class armnn::RefElementwiseWorkload, armnn::SubtractionQueueDescriptor, armnn::StringMapping::RefSubtractionWorkload_Execute>; +template class armnn::RefElementwiseWorkload, + armnn::SubtractionQueueDescriptor, + armnn::StringMapping::RefSubtractionWorkload_Execute>; + template class armnn::RefElementwiseWorkload, armnn::MultiplicationQueueDescriptor, armnn::StringMapping::RefMultiplicationWorkload_Execute>; +template class armnn::RefElementwiseWorkload, + armnn::MultiplicationQueueDescriptor, + armnn::StringMapping::RefMultiplicationWorkload_Execute>; + template class armnn::RefElementwiseWorkload, armnn::DivisionQueueDescriptor, armnn::StringMapping::RefDivisionWorkload_Execute>; +template class armnn::RefElementwiseWorkload, + armnn::DivisionQueueDescriptor, + armnn::StringMapping::RefDivisionWorkload_Execute>; + template class armnn::RefElementwiseWorkload, armnn::MaximumQueueDescriptor, armnn::StringMapping::RefMaximumWorkload_Execute>; +template class armnn::RefElementwiseWorkload, + armnn::MaximumQueueDescriptor, + armnn::StringMapping::RefMaximumWorkload_Execute>; + template class armnn::RefElementwiseWorkload, armnn::MinimumQueueDescriptor, armnn::StringMapping::RefMinimumWorkload_Execute>; + +template class armnn::RefElementwiseWorkload, + armnn::MinimumQueueDescriptor, + armnn::StringMapping::RefMinimumWorkload_Execute>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 264ddce2de..03683b1a06 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -35,33 +35,39 @@ private: std::unique_ptr> m_Output; }; +template using RefAdditionWorkload = - RefElementwiseWorkload, + RefElementwiseWorkload, AdditionQueueDescriptor, StringMapping::RefAdditionWorkload_Execute>; +template using RefSubtractionWorkload = - RefElementwiseWorkload, + RefElementwiseWorkload, SubtractionQueueDescriptor, StringMapping::RefSubtractionWorkload_Execute>; +template using RefMultiplicationWorkload = - RefElementwiseWorkload, + RefElementwiseWorkload, MultiplicationQueueDescriptor, StringMapping::RefMultiplicationWorkload_Execute>; +template using RefDivisionWorkload = - RefElementwiseWorkload, + RefElementwiseWorkload, DivisionQueueDescriptor, StringMapping::RefDivisionWorkload_Execute>; +template using RefMaximumWorkload = - RefElementwiseWorkload, + RefElementwiseWorkload, MaximumQueueDescriptor, StringMapping::RefMaximumWorkload_Execute>; +template using RefMinimumWorkload = - RefElementwiseWorkload, + RefElementwiseWorkload, MinimumQueueDescriptor, StringMapping::RefMinimumWorkload_Execute>; -- cgit v1.2.1