aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UnitTests.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
committertelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
commit4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch)
treec9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/test/UnitTests.hpp
downloadarmnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r--src/armnn/test/UnitTests.hpp79
1 files changed, 79 insertions, 0 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
new file mode 100644
index 0000000000..040048ad99
--- /dev/null
+++ b/src/armnn/test/UnitTests.hpp
@@ -0,0 +1,79 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "Logging.hpp"
+#include "armnn/Utils.hpp"
+#include "backends/RefWorkloadFactory.hpp"
+#include "backends/test/LayerTests.hpp"
+#include <boost/test/unit_test.hpp>
+
+inline void ConfigureLoggingTest()
+{
+ // Configure logging for both the ARMNN library and this test program
+ armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
+ armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, armnn::LogSeverity::Fatal);
+}
+
+// The following macros require the caller to have defined FactoryType, with one of the following using statements:
+//
+// using FactoryType = armnn::RefWorkloadFactory;
+// using FactoryType = armnn::ClWorkloadFactory;
+// using FactoryType = armnn::NeonWorkloadFactory;
+
+/// Executes BOOST_TEST on CompareTensors() return value so that the predicate_result message is reported.
+/// If the test reports itself as not supported then the tensors are not compared.
+/// Additionally this checks that the supportedness reported by the test matches the name of the test.
+/// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
+/// This is useful because it clarifies that the feature being tested is not actually supported
+/// (a passed test with the name of a feature would imply that feature was supported).
+/// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
+/// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
+template <typename T, std::size_t n>
+void CompareTestResultIfSupported(const std::string& testName, LayerTestResult<T, n> testResult)
+{
+ bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
+ BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.supported,
+ "The test name does not match the supportedness it is reporting");
+ if (testResult.supported)
+ {
+ BOOST_TEST(CompareTensors(testResult.output, testResult.outputExpected));
+ }
+}
+
+template<typename FactoryType, typename TFuncPtr, typename... Args>
+void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
+{
+ FactoryType workloadFactory;
+ auto testResult = (*testFunction)(workloadFactory, args...);
+ CompareTestResultIfSupported(testName, testResult);
+}
+
+#define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
+ BOOST_AUTO_TEST_CASE(TestName) \
+ { \
+ RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
+ }
+
+template<typename FactoryType, typename TFuncPtr, typename... Args>
+void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
+{
+ FactoryType workloadFactory;
+ armnn::RefWorkloadFactory refWorkloadFactory;
+ auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...);
+ CompareTestResultIfSupported(testName, testResult);
+}
+
+#define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
+ BOOST_AUTO_TEST_CASE(TestName) \
+ { \
+ CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
+ }
+
+#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
+ BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
+ { \
+ CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
+ }