diff options
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp new file mode 100644 index 0000000000..040048ad99 --- /dev/null +++ b/src/armnn/test/UnitTests.hpp @@ -0,0 +1,79 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "Logging.hpp" +#include "armnn/Utils.hpp" +#include "backends/RefWorkloadFactory.hpp" +#include "backends/test/LayerTests.hpp" +#include <boost/test/unit_test.hpp> + +inline void ConfigureLoggingTest() +{ + // Configure logging for both the ARMNN library and this test program + armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); + armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, armnn::LogSeverity::Fatal); +} + +// The following macros require the caller to have defined FactoryType, with one of the following using statements: +// +// using FactoryType = armnn::RefWorkloadFactory; +// using FactoryType = armnn::ClWorkloadFactory; +// using FactoryType = armnn::NeonWorkloadFactory; + +/// Executes BOOST_TEST on CompareTensors() return value so that the predicate_result message is reported. +/// If the test reports itself as not supported then the tensors are not compared. +/// Additionally this checks that the supportedness reported by the test matches the name of the test. +/// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name. +/// This is useful because it clarifies that the feature being tested is not actually supported +/// (a passed test with the name of a feature would imply that feature was supported). +/// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED. +/// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED. +template <typename T, std::size_t n> +void CompareTestResultIfSupported(const std::string& testName, LayerTestResult<T, n> testResult) +{ + bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos; + BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.supported, + "The test name does not match the supportedness it is reporting"); + if (testResult.supported) + { + BOOST_TEST(CompareTensors(testResult.output, testResult.outputExpected)); + } +} + +template<typename FactoryType, typename TFuncPtr, typename... Args> +void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) +{ + FactoryType workloadFactory; + auto testResult = (*testFunction)(workloadFactory, args...); + CompareTestResultIfSupported(testName, testResult); +} + +#define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \ + BOOST_AUTO_TEST_CASE(TestName) \ + { \ + RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ + } + +template<typename FactoryType, typename TFuncPtr, typename... Args> +void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args) +{ + FactoryType workloadFactory; + armnn::RefWorkloadFactory refWorkloadFactory; + auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...); + CompareTestResultIfSupported(testName, testResult); +} + +#define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \ + BOOST_AUTO_TEST_CASE(TestName) \ + { \ + CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ + } + +#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \ + BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \ + { \ + CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ + } |