aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-06-22 15:58:32 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-06-30 13:00:52 +0000
commitcbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09 (patch)
treeeb6e5393726be21213e72f26a676b7c3809fc995 /src/backends/reference/workloads
parent532a29d12d72f54549d8b71edd485c17af65698a (diff)
downloadarmnn-cbd2c230b7ce5f26e2ccccf36b7ad450f6e1ad09.tar.gz
IVGCVSW-5007 Implement an Int32 reference Elementwise workload
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I6592169b74ac4294bc09647879aec0718c641f91
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp35
-rw-r--r--src/backends/reference/workloads/Decoders.hpp18
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp7
-rw-r--r--src/backends/reference/workloads/Encoders.hpp18
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp24
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp18
6 files changed, 114 insertions, 6 deletions
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index be20644ab7..1f4f2da717 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -274,6 +274,21 @@ public:
}
};
+class Int32ToInt32tDecoder : public TypedIterator<const int32_t, Decoder<int32_t>>
+{
+public:
+ Int32ToInt32tDecoder(const int32_t* data)
+ : TypedIterator(data){}
+
+ Int32ToInt32tDecoder()
+ : Int32ToInt32tDecoder(nullptr) {}
+
+ int32_t Get() const override
+ {
+ return *m_Iterator;
+ }
+};
+
class BooleanDecoder : public TypedIterator<const uint8_t, Decoder<float>>
{
public:
@@ -470,6 +485,26 @@ public:
}
};
+class Int32ToInt32tEncoder : public TypedIterator<int32_t, Encoder<int32_t>>
+{
+public:
+ Int32ToInt32tEncoder(int32_t* data)
+ : TypedIterator(data){}
+
+ Int32ToInt32tEncoder()
+ : Int32ToInt32tEncoder(nullptr) {}
+
+ void Set(int32_t right) override
+ {
+ *m_Iterator = right;
+ }
+
+ int32_t Get() const override
+ {
+ return *m_Iterator;
+ }
+};
+
class BooleanEncoder : public TypedIterator<uint8_t, Encoder<bool>>
{
public:
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index deb3b1f4b2..08e0140fad 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -149,4 +149,22 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
return nullptr;
}
+template<>
+inline std::unique_ptr<Decoder<int32_t>> MakeDecoder(const TensorInfo& info, const void* data)
+{
+ switch(info.GetDataType())
+ {
+ case DataType::Signed32:
+ {
+ return std::make_unique<Int32ToInt32tDecoder>(static_cast<const int32_t*>(data));
+ }
+ default:
+ {
+ ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
+ break;
+ }
+ }
+ return nullptr;
+}
+
} //namespace armnn
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 5687cf5861..afae188bd6 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -46,6 +46,13 @@ template struct armnn::ElementwiseBinaryFunction<std::divides<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::maximum<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::minimum<float>>;
+template struct armnn::ElementwiseBinaryFunction<std::plus<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::minus<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::multiplies<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<std::divides<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::maximum<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::minimum<int32_t>>;
+
// Comparison
template struct armnn::ElementwiseBinaryFunction<std::equal_to<float>>;
template struct armnn::ElementwiseBinaryFunction<std::greater<float>>;
diff --git a/src/backends/reference/workloads/Encoders.hpp b/src/backends/reference/workloads/Encoders.hpp
index c0524a7719..a2d565ec4a 100644
--- a/src/backends/reference/workloads/Encoders.hpp
+++ b/src/backends/reference/workloads/Encoders.hpp
@@ -114,4 +114,22 @@ inline std::unique_ptr<Encoder<bool>> MakeEncoder(const TensorInfo& info, void*
return nullptr;
}
+template<>
+inline std::unique_ptr<Encoder<int32_t>> MakeEncoder(const TensorInfo& info, void* data)
+{
+ switch(info.GetDataType())
+ {
+ case DataType::Signed32:
+ {
+ return std::make_unique<Int32ToInt32tEncoder>(static_cast<int32_t*>(data));
+ }
+ default:
+ {
+ ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
+ break;
+ }
+ }
+ return nullptr;
+}
+
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 18bf0a7ad9..60acbd6252 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -67,22 +67,46 @@ template class armnn::RefElementwiseWorkload<std::plus<float>,
armnn::AdditionQueueDescriptor,
armnn::StringMapping::RefAdditionWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<std::plus<int32_t>,
+ armnn::AdditionQueueDescriptor,
+ armnn::StringMapping::RefAdditionWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<std::minus<float>,
armnn::SubtractionQueueDescriptor,
armnn::StringMapping::RefSubtractionWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<std::minus<int32_t>,
+ armnn::SubtractionQueueDescriptor,
+ armnn::StringMapping::RefSubtractionWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<std::multiplies<float>,
armnn::MultiplicationQueueDescriptor,
armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<std::multiplies<int32_t>,
+ armnn::MultiplicationQueueDescriptor,
+ armnn::StringMapping::RefMultiplicationWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<std::divides<float>,
armnn::DivisionQueueDescriptor,
armnn::StringMapping::RefDivisionWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<std::divides<int32_t>,
+ armnn::DivisionQueueDescriptor,
+ armnn::StringMapping::RefDivisionWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
armnn::MaximumQueueDescriptor,
armnn::StringMapping::RefMaximumWorkload_Execute>;
+template class armnn::RefElementwiseWorkload<armnn::maximum<int32_t>,
+ armnn::MaximumQueueDescriptor,
+ armnn::StringMapping::RefMaximumWorkload_Execute>;
+
template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
armnn::MinimumQueueDescriptor,
armnn::StringMapping::RefMinimumWorkload_Execute>;
+
+template class armnn::RefElementwiseWorkload<armnn::minimum<int32_t>,
+ armnn::MinimumQueueDescriptor,
+ armnn::StringMapping::RefMinimumWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 264ddce2de..03683b1a06 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -35,33 +35,39 @@ private:
std::unique_ptr<Encoder<OutType>> m_Output;
};
+template <typename DataType = float>
using RefAdditionWorkload =
- RefElementwiseWorkload<std::plus<float>,
+ RefElementwiseWorkload<std::plus<DataType>,
AdditionQueueDescriptor,
StringMapping::RefAdditionWorkload_Execute>;
+template <typename DataType = float>
using RefSubtractionWorkload =
- RefElementwiseWorkload<std::minus<float>,
+ RefElementwiseWorkload<std::minus<DataType>,
SubtractionQueueDescriptor,
StringMapping::RefSubtractionWorkload_Execute>;
+template <typename DataType = float>
using RefMultiplicationWorkload =
- RefElementwiseWorkload<std::multiplies<float>,
+ RefElementwiseWorkload<std::multiplies<DataType>,
MultiplicationQueueDescriptor,
StringMapping::RefMultiplicationWorkload_Execute>;
+template <typename DataType = float>
using RefDivisionWorkload =
- RefElementwiseWorkload<std::divides<float>,
+ RefElementwiseWorkload<std::divides<DataType>,
DivisionQueueDescriptor,
StringMapping::RefDivisionWorkload_Execute>;
+template <typename DataType = float>
using RefMaximumWorkload =
- RefElementwiseWorkload<armnn::maximum<float>,
+ RefElementwiseWorkload<armnn::maximum<DataType>,
MaximumQueueDescriptor,
StringMapping::RefMaximumWorkload_Execute>;
+template <typename DataType = float>
using RefMinimumWorkload =
- RefElementwiseWorkload<armnn::minimum<float>,
+ RefElementwiseWorkload<armnn::minimum<DataType>,
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;