aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-08-26 16:19:15 +0100
committerFinn Williams <Finn.Williams@arm.com>2020-08-26 16:19:15 +0100
commit8702007eb96dd857776763f045052cc8815d9350 (patch)
tree1a3a7b665f74e0490981c3a8dc55e42b77176fb2 /src/armnn
parentab3bd4d48e75f75c1729a177b35aff61ed0fcd4e (diff)
downloadarmnn-8702007eb96dd857776763f045052cc8815d9350.tar.gz
IVGCVSW-5250 Remove CreateTensorHandle in the test for layers between E-F
* Added new test function to pass in the ITensorHandleFactory Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I9b2e9250200e092541e29796ec53cabd0b677acf
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/test/UnitTests.hpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
index 2d9c1583d2..058a932d03 100644
--- a/src/armnn/test/UnitTests.hpp
+++ b/src/armnn/test/UnitTests.hpp
@@ -74,6 +74,24 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
}
+
+template<typename FactoryType, typename TFuncPtr, typename... Args>
+void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args)
+{
+ std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
+ armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
+
+ auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
+
+ auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
+
+ auto testResult = (*testFunction)(workloadFactory, memoryManager, &tensorHandleFactory, args...);
+ CompareTestResultIfSupported(testName, testResult);
+
+ armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr);
+}
+
#define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \
BOOST_AUTO_TEST_CASE(TestName) \
{ \
@@ -86,6 +104,12 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}
+#define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
+ BOOST_AUTO_TEST_CASE(TestName) \
+ { \
+ RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
+ }
+
template<typename FactoryType, typename TFuncPtr, typename... Args>
void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
{