aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UnitTests.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r--src/armnn/test/UnitTests.hpp42
1 files changed, 27 insertions, 15 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
index bb91c4d055..e4a8b96b52 100644
--- a/src/armnn/test/UnitTests.hpp
+++ b/src/armnn/test/UnitTests.hpp
@@ -14,7 +14,7 @@
#include "TensorHelpers.hpp"
-#include <boost/test/unit_test.hpp>
+#include <doctest/doctest.h>
inline void ConfigureLoggingTest()
{
@@ -28,7 +28,7 @@ inline void ConfigureLoggingTest()
// using FactoryType = armnn::ClWorkloadFactory;
// using FactoryType = armnn::NeonWorkloadFactory;
-/// Executes BOOST_TEST on CompareTensors() return value so that the predicate_result message is reported.
+/// Executes CHECK_MESSAGE on CompareTensors() return value so that the predicate_result message is reported.
/// If the test reports itself as not supported then the tensors are not compared.
/// Additionally this checks that the supportedness reported by the test matches the name of the test.
/// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
@@ -40,8 +40,8 @@ template <typename T, std::size_t n>
void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
{
bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
- BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported,
- "The test name does not match the supportedness it is reporting");
+ CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.m_Supported,
+ "The test name does not match the supportedness it is reporting");
if (testResult.m_Supported)
{
auto result = CompareTensors(testResult.m_ActualData,
@@ -49,7 +49,7 @@ void CompareTestResultIfSupported(const std::string& testName, const LayerTestRe
testResult.m_ActualShape,
testResult.m_ExpectedShape,
testResult.m_CompareBoolean);
- BOOST_TEST(result.m_Result, result.m_Message.str());
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
}
}
@@ -59,15 +59,15 @@ void CompareTestResultIfSupported(const std::string& testName, const std::vector
bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
for (unsigned int i = 0; i < testResult.size(); ++i)
{
- BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported,
- "The test name does not match the supportedness it is reporting");
+ CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].m_Supported,
+ "The test name does not match the supportedness it is reporting");
if (testResult[i].m_Supported)
{
auto result = CompareTensors(testResult[i].m_ActualData,
testResult[i].m_ExpectedData,
testResult[i].m_ActualShape,
testResult[i].m_ExpectedShape);
- BOOST_TEST(result.m_Result, result.m_Message.str());
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
}
}
}
@@ -106,19 +106,31 @@ void RunTestFunctionUsingTensorHandleFactory(const char* testName, TFuncPtr test
}
#define ARMNN_SIMPLE_TEST_CASE(TestName, TestFunction) \
- BOOST_AUTO_TEST_CASE(TestName) \
+ TEST_CASE(#TestName) \
{ \
TestFunction(); \
}
#define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
- BOOST_AUTO_TEST_CASE(TestName) \
+ TEST_CASE(#TestName) \
+ { \
+ RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
+ }
+
+#define ARMNN_AUTO_TEST_FIXTURE(TestName, Fixture, TestFunction, ...) \
+ TEST_CASE_FIXTURE(Fixture, #TestName) \
{ \
RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}
#define ARMNN_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
- BOOST_AUTO_TEST_CASE(TestName) \
+ TEST_CASE(#TestName) \
+ { \
+ RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
+ }
+
+#define ARMNN_AUTO_TEST_FIXTURE_WITH_THF(TestName, Fixture, TestFunction, ...) \
+ TEST_CASE_FIXTURE(Fixture, #TestName) \
{ \
RunTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}
@@ -152,25 +164,25 @@ void CompareRefTestFunctionUsingTensorHandleFactory(const char* testName, TFuncP
}
#define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
- BOOST_AUTO_TEST_CASE(TestName) \
+ TEST_CASE(#TestName) \
{ \
CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}
#define ARMNN_COMPARE_REF_AUTO_TEST_CASE_WITH_THF(TestName, TestFunction, ...) \
- BOOST_AUTO_TEST_CASE(TestName) \
+ TEST_CASE(#TestName) \
{ \
CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}
#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
- BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
+ TEST_CASE_FIXTURE(Fixture, #TestName) \
{ \
CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}
#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE_WITH_THF(TestName, Fixture, TestFunction, ...) \
- BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
+ TEST_CASE_FIXTURE(Fixture, #TestName) \
{ \
CompareRefTestFunctionUsingTensorHandleFactory<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
}