diff options
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index 44b737cae4..f489ca030c 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -66,8 +66,10 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>(); armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); - FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(); - auto testResult = (*testFunction)(workloadFactory, args...); + auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager(); + FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager); + + auto testResult = (*testFunction)(workloadFactory, memoryManager, args...); CompareTestResultIfSupported(testName, testResult); } @@ -80,9 +82,12 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) template<typename FactoryType, typename TFuncPtr, typename... Args> void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args) { - FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(); + auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager(); + FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager); + armnn::RefWorkloadFactory refWorkloadFactory; - auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...); + + auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...); CompareTestResultIfSupported(testName, testResult); } |