aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UnitTests.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r--src/armnn/test/UnitTests.hpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
index 44b737cae4..f489ca030c 100644
--- a/src/armnn/test/UnitTests.hpp
+++ b/src/armnn/test/UnitTests.hpp
@@ -66,8 +66,10 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
- FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory();
- auto testResult = (*testFunction)(workloadFactory, args...);
+ auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
+
+ auto testResult = (*testFunction)(workloadFactory, memoryManager, args...);
CompareTestResultIfSupported(testName, testResult);
}
@@ -80,9 +82,12 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
template<typename FactoryType, typename TFuncPtr, typename... Args>
void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
{
- FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory();
+ auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
+
armnn::RefWorkloadFactory refWorkloadFactory;
- auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...);
+
+ auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...);
CompareTestResultIfSupported(testName, testResult);
}