aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UnitTests.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r--src/armnn/test/UnitTests.hpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
index 058a932d03..c049f578fc 100644
--- a/src/armnn/test/UnitTests.hpp
+++ b/src/armnn/test/UnitTests.hpp
@@ -7,6 +7,8 @@
#include <armnn/Logging.hpp>
#include <armnn/Utils.hpp>
#include <reference/RefWorkloadFactory.hpp>
+#include <reference/test/RefWorkloadFactoryHelper.hpp>
+
#include <backendsCommon/test/LayerTests.hpp>
#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
#include "TensorHelpers.hpp"
@@ -122,12 +124,34 @@ void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args...
CompareTestResultIfSupported(testName, testResult);
}
+template<typename FactoryType, typename TFuncPtr, typename... Args>
+void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
+{
+ auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
+
+ armnn::RefWorkloadFactory refWorkloadFactory;
+ auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
+ auto refTensorHandleFactory =
+ RefWorkloadFactoryHelper::GetTensorHandleFactory(memoryManager);
+
+ auto testResult = (*testFunction)(
+ workloadFactory, memoryManager, refWorkloadFactory, &tensorHandleFactory, &refTensorHandleFactory, args...);
+ CompareTestResultIfSupported(testName, testResult);
+}
+
#define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
BOOST_AUTO_TEST_CASE(TestName) \
{ \
CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}
+#define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
+ BOOST_AUTO_TEST_CASE(TestName) \
+ { \
+ CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
+ }
+
#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
{ \