diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-08-27 12:57:20 +0100 |
---|---|---|
committer | finn.williams <finn.williams@arm.com> | 2020-08-27 18:26:00 +0000 |
commit | 56785c75037ed0cd377851616634b3129713394b (patch) | |
tree | c162b8b78ab18a34bf6e679a5c5b6219e849de94 /src/armnn/test | |
parent | 714fe5bb4cb157769c28b68e3506f970acce2412 (diff) | |
download | armnn-56785c75037ed0cd377851616634b3129713394b.tar.gz |
IVGCVSW-5257 'Remove CreateTensorHandle in the test for layers beginning with S'
* Re-factored SoftmaxTestImpl to use TensorHandleFactory to create TensorHandles
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I83559a89187bbed0d6f34ca589ea81c694bf5683
Diffstat (limited to 'src/armnn/test')
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index 058a932d03..c049f578fc 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -7,6 +7,8 @@ #include <armnn/Logging.hpp> #include <armnn/Utils.hpp> #include <reference/RefWorkloadFactory.hpp> +#include <reference/test/RefWorkloadFactoryHelper.hpp> + #include <backendsCommon/test/LayerTests.hpp> #include <backendsCommon/test/WorkloadFactoryHelper.hpp> #include "TensorHelpers.hpp" @@ -122,12 +124,34 @@ void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... CompareTestResultIfSupported(testName, testResult); } +template<typename FactoryType, typename TFuncPtr, typename... Args> +void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args) +{ + auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager(); + FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager); + + armnn::RefWorkloadFactory refWorkloadFactory; + auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager); + auto refTensorHandleFactory = + RefWorkloadFactoryHelper::GetTensorHandleFactory(memoryManager); + + auto testResult = (*testFunction)( + workloadFactory, memoryManager, refWorkloadFactory, &tensorHandleFactory, &refTensorHandleFactory, args...); + CompareTestResultIfSupported(testName, testResult); +} + #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \ BOOST_AUTO_TEST_CASE(TestName) \ { \ CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ } +#define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \ + BOOST_AUTO_TEST_CASE(TestName) \ + { \ + CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ + } + #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \ BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \ { \ |