diff options
author | Finn Williams <Finn.Williams@arm.com> | 2020-08-26 16:19:15 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2020-08-26 16:19:15 +0100 |
commit | 8702007eb96dd857776763f045052cc8815d9350 (patch) | |
tree | 1a3a7b665f74e0490981c3a8dc55e42b77176fb2 /src/armnn/test/UnitTests.hpp | |
parent | ab3bd4d48e75f75c1729a177b35aff61ed0fcd4e (diff) | |
download | armnn-8702007eb96dd857776763f045052cc8815d9350.tar.gz |
IVGCVSW-5250 Remove CreateTensorHandle in the test for layers between E-F
* Added new test function to pass in the ITensorHandleFactory
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I9b2e9250200e092541e29796ec53cabd0b677acf
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index 2d9c1583d2..058a932d03 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -74,6 +74,24 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr); } + +template<typename FactoryType, typename TFuncPtr, typename... Args> +void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args) +{ + std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>(); + armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); + + auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager(); + FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager); + + auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager); + + auto testResult = (*testFunction)(workloadFactory, memoryManager, &tensorHandleFactory, args...); + CompareTestResultIfSupported(testName, testResult); + + armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr); +} + #define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \ BOOST_AUTO_TEST_CASE(TestName) \ { \ @@ -86,6 +104,12 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ } +#define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \ + BOOST_AUTO_TEST_CASE(TestName) \ + { \ + RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ + } + template<typename FactoryType, typename TFuncPtr, typename... Args> void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args) { |