// // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "armnnTestUtils/TensorHelpers.hpp" #include #include #include #include #include #include #include #include #include inline void ConfigureLoggingTest() { // Configures logging for both the ARMNN library and this test program. armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); } // The following macros require the caller to have defined FactoryType, with one of the following using statements: // // using FactoryType = armnn::RefWorkloadFactory; // using FactoryType = armnn::ClWorkloadFactory; // using FactoryType = armnn::NeonWorkloadFactory; /// Executes CHECK_MESSAGE on CompareTensors() return value so that the predicate_result message is reported. /// If the test reports itself as not supported then the tensors are not compared. /// Additionally this checks that the supportedness reported by the test matches the name of the test. /// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name. /// This is useful because it clarifies that the feature being tested is not actually supported /// (a passed test with the name of a feature would imply that feature was supported). /// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED. /// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED. template void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult& testResult) { bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos; CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported, "The test name does not match the supportedness it is reporting"); if (testResult.m_Supported) { auto result = CompareTensors(testResult.m_ActualData, testResult.m_ExpectedData, testResult.m_ActualShape, testResult.m_ExpectedShape, testResult.m_CompareBoolean); CHECK_MESSAGE(result.m_Result, result.m_Message.str()); } } template void CompareTestResultIfSupported(const std::string& testName, const std::vector>& testResult) { bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos; for (unsigned int i = 0; i < testResult.size(); ++i) { CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported, "The test name does not match the supportedness it is reporting"); if (testResult[i].m_Supported) { auto result = CompareTensors(testResult[i].m_ActualData, testResult[i].m_ExpectedData, testResult[i].m_ActualShape, testResult[i].m_ExpectedShape); CHECK_MESSAGE(result.m_Result, result.m_Message.str()); } } } template void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) { std::unique_ptr profiler = std::make_unique(); armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); auto memoryManager = WorkloadFactoryHelper::GetMemoryManager(); FactoryType workloadFactory = WorkloadFactoryHelper::GetFactory(memoryManager); auto testResult = (*testFunction)(workloadFactory, memoryManager, args...); CompareTestResultIfSupported(testName, testResult); armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr); } template void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args) { std::unique_ptr profiler = std::make_unique(); armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); auto memoryManager = WorkloadFactoryHelper::GetMemoryManager(); FactoryType workloadFactory = WorkloadFactoryHelper::GetFactory(memoryManager); auto tensorHandleFactory = WorkloadFactoryHelper::GetTensorHandleFactory(memoryManager); auto testResult = (*testFunction)(workloadFactory, memoryManager, tensorHandleFactory, args...); CompareTestResultIfSupported(testName, testResult); armnn::ProfilerManager::GetInstance().RegisterProfiler(nullptr); } #define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \ TEST_CASE(#TestName) \ { \ TestFunction(); \ } #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \ TEST_CASE(#TestName) \ { \ RunTestFunction(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_AUTO_TEST_FIXTURE(TestName, Fixture, TestFunction, ...) \ TEST_CASE_FIXTURE(Fixture, #TestName) \ { \ RunTestFunction(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \ TEST_CASE(#TestName) \ { \ RunTestFunctionUsingTensorHandleFactory(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_AUTO_TEST_FIXTURE_WITH_THF(TestName, Fixture, TestFunction, ...) \ TEST_CASE_FIXTURE(Fixture, #TestName) \ { \ RunTestFunctionUsingTensorHandleFactory(#TestName, &TestFunction, ##__VA_ARGS__); \ } template void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args) { auto memoryManager = WorkloadFactoryHelper::GetMemoryManager(); FactoryType workloadFactory = WorkloadFactoryHelper::GetFactory(memoryManager); armnn::RefWorkloadFactory refWorkloadFactory; auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...); CompareTestResultIfSupported(testName, testResult); } template void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr testFunction, Args... args) { auto memoryManager = WorkloadFactoryHelper::GetMemoryManager(); FactoryType workloadFactory = WorkloadFactoryHelper::GetFactory(memoryManager); auto tensorHandleFactory = WorkloadFactoryHelper::GetTensorHandleFactory(memoryManager); armnn::RefWorkloadFactory refWorkloadFactory; auto refMemoryManager = WorkloadFactoryHelper::GetMemoryManager(); auto refTensorHandleFactory = RefWorkloadFactoryHelper::GetTensorHandleFactory(refMemoryManager); auto testResult = (*testFunction)( workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, args...); CompareTestResultIfSupported(testName, testResult); } #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \ TEST_CASE(#TestName) \ { \ CompareRefTestFunction(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \ TEST_CASE(#TestName) \ { \ CompareRefTestFunctionUsingTensorHandleFactory(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \ TEST_CASE_FIXTURE(Fixture, #TestName) \ { \ CompareRefTestFunction(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \ TEST_CASE_FIXTURE(Fixture, #TestName) \ { \ CompareRefTestFunctionUsingTensorHandleFactory(#TestName, &TestFunction, ##__VA_ARGS__); \ }