aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UnitTests.hpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-08-27 12:57:20 +0100
committerfinn.williams <finn.williams@arm.com>2020-08-27 18:26:00 +0000
commit56785c75037ed0cd377851616634b3129713394b (patch)
treec162b8b78ab18a34bf6e679a5c5b6219e849de94 /src/armnn/test/UnitTests.hpp
parent714fe5bb4cb157769c28b68e3506f970acce2412 (diff)
downloadarmnn-56785c75037ed0cd377851616634b3129713394b.tar.gz
IVGCVSW-5257 'Remove CreateTensorHandle in the test for layers beginning with S'
* Re-factored SoftmaxTestImpl to use TensorHandleFactory to create TensorHandles Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I83559a89187bbed0d6f34ca589ea81c694bf5683
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r--src/armnn/test/UnitTests.hpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
index 058a932d03..c049f578fc 100644
--- a/src/armnn/test/UnitTests.hpp
+++ b/src/armnn/test/UnitTests.hpp
@@ -7,6 +7,8 @@
#include <armnn/Logging.hpp>
#include <armnn/Utils.hpp>
#include <reference/RefWorkloadFactory.hpp>
+#include <reference/test/RefWorkloadFactoryHelper.hpp>
+
#include <backendsCommon/test/LayerTests.hpp>
#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
#include "TensorHelpers.hpp"
@@ -122,12 +124,34 @@ void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args...
CompareTestResultIfSupported(testName, testResult);
}
+template<typename FactoryType, typename TFuncPtr, typename... Args>
+void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
+{
+ auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
+
+ armnn::RefWorkloadFactory refWorkloadFactory;
+ auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
+ auto refTensorHandleFactory =
+ RefWorkloadFactoryHelper::GetTensorHandleFactory(memoryManager);
+
+ auto testResult = (*testFunction)(
+ workloadFactory, memoryManager, refWorkloadFactory, &tensorHandleFactory, &refTensorHandleFactory, args...);
+ CompareTestResultIfSupported(testName, testResult);
+}
+
#define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
BOOST_AUTO_TEST_CASE(TestName) \
{ \
CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}
+#define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
+ BOOST_AUTO_TEST_CASE(TestName) \
+ { \
+ CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
+ }
+
#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
{ \