diff options
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 42 |
1 files changed, 27 insertions, 15 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index bb91c4d055..e4a8b96b52 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -14,7 +14,7 @@ #include "TensorHelpers.hpp" -#include <boost/test/unit_test.hpp> +#include <doctest/doctest.h> inline void ConfigureLoggingTest() { @@ -28,7 +28,7 @@ inline void ConfigureLoggingTest() // using FactoryType = armnn::ClWorkloadFactory; // using FactoryType = armnn::NeonWorkloadFactory; -/// Executes BOOST_TEST on CompareTensors() return value so that the predicate_result message is reported. +/// Executes CHECK_MESSAGE on CompareTensors() return value so that the predicate_result message is reported. /// If the test reports itself as not supported then the tensors are not compared. /// Additionally this checks that the supportedness reported by the test matches the name of the test. /// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name. @@ -40,8 +40,8 @@ template <typename T, std::size_t n> void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult) { bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos; - BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported, - "The test name does not match the supportedness it is reporting"); + CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported, + "The test name does not match the supportedness it is reporting"); if (testResult.m_Supported) { auto result = CompareTensors(testResult.m_ActualData, @@ -49,7 +49,7 @@ void CompareTestResultIfSupported(const std::string& testName, const LayerTestRe testResult.m_ActualShape, testResult.m_ExpectedShape, testResult.m_CompareBoolean); - BOOST_TEST(result.m_Result, result.m_Message.str()); + CHECK_MESSAGE(result.m_Result, result.m_Message.str()); } } @@ -59,15 +59,15 @@ void CompareTestResultIfSupported(const std::string& testName, const std::vector bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos; for (unsigned int i = 0; i < testResult.size(); ++i) { - BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported, - "The test name does not match the supportedness it is reporting"); + CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported, + "The test name does not match the supportedness it is reporting"); if (testResult[i].m_Supported) { auto result = CompareTensors(testResult[i].m_ActualData, testResult[i].m_ExpectedData, testResult[i].m_ActualShape, testResult[i].m_ExpectedShape); - BOOST_TEST(result.m_Result, result.m_Message.str()); + CHECK_MESSAGE(result.m_Result, result.m_Message.str()); } } } @@ -106,19 +106,31 @@ void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr test } #define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \ - BOOST_AUTO_TEST_CASE(TestName) \ + TEST_CASE(#TestName) \ { \ TestFunction(); \ } #define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \ - BOOST_AUTO_TEST_CASE(TestName) \ + TEST_CASE(#TestName) \ + { \ + RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ + } + +#define ARMNN_AUTO_TEST_FIXTURE(TestName, Fixture, TestFunction, ...) \ + TEST_CASE_FIXTURE(Fixture, #TestName) \ { \ RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \ - BOOST_AUTO_TEST_CASE(TestName) \ + TEST_CASE(#TestName) \ + { \ + RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ + } + +#define ARMNN_AUTO_TEST_FIXTURE_WITH_THF(TestName, Fixture, TestFunction, ...) \ + TEST_CASE_FIXTURE(Fixture, #TestName) \ { \ RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ } @@ -152,25 +164,25 @@ void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncP } #define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \ - BOOST_AUTO_TEST_CASE(TestName) \ + TEST_CASE(#TestName) \ { \ CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \ - BOOST_AUTO_TEST_CASE(TestName) \ + TEST_CASE(#TestName) \ { \ CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \ - BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \ + TEST_CASE_FIXTURE(Fixture, #TestName) \ { \ CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ } #define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \ - BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \ + TEST_CASE_FIXTURE(Fixture, #TestName) \ { \ CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \ } |