aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp')
-rw-r--r--src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp57
1 files changed, 20 insertions, 37 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp
index b0b613c137..a3d2b2796f 100644
--- a/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp
@@ -4,18 +4,10 @@
//
#include "EqualTestImpl.hpp"
-#include "ElementwiseTestImpl.hpp"
-#include <Half.hpp>
+#include "ComparisonTestImpl.hpp"
-template<>
-std::unique_ptr<armnn::IWorkload> CreateWorkload<armnn::EqualQueueDescriptor>(
- const armnn::IWorkloadFactory& workloadFactory,
- const armnn::WorkloadInfo& info,
- const armnn::EqualQueueDescriptor& descriptor)
-{
- return workloadFactory.CreateEqual(descriptor, info);
-}
+#include <Half.hpp>
LayerTestResult<uint8_t, 4> EqualSimpleTest(armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
@@ -39,9 +31,10 @@ LayerTestResult<uint8_t, 4> EqualSimpleTest(armnn::IWorkloadFactory& workloadFac
std::vector<uint8_t> output({ 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1 });
- return ElementwiseTestHelper<4, armnn::EqualQueueDescriptor, armnn::DataType::Float32, armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape,
input0,
shape,
@@ -62,9 +55,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1ElementTest(
std::vector<uint8_t> output({ 1, 0, 0, 0, 0, 0, 0, 0});
- return ElementwiseTestHelper<4, armnn::EqualQueueDescriptor, armnn::DataType::Float32, armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -88,9 +82,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1DVectorTest(
std::vector<uint8_t> output({ 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4, armnn::EqualQueueDescriptor, armnn::DataType::Float32, armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -117,12 +112,10 @@ LayerTestResult<uint8_t, 4> EqualFloat16Test(
std::vector<uint8_t> output({ 0, 0, 0, 0, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape,
input0,
shape,
@@ -148,12 +141,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1ElementFloat16Test(
std::vector<uint8_t> output({ 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -179,12 +170,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1DVectorFloat16Test(
std::vector<uint8_t> output({ 1, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -209,12 +198,10 @@ LayerTestResult<uint8_t, 4> EqualUint8Test(
std::vector<uint8_t> output({ 0, 0, 0, 0, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape,
input0,
shape,
@@ -238,12 +225,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1ElementUint8Test(
std::vector<uint8_t> output({ 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -267,12 +252,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1DVectorUint8Test(
std::vector<uint8_t> output({ 1, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,