aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UnitTests.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r--src/armnn/test/UnitTests.hpp20
1 files changed, 19 insertions, 1 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
index 9b750b5b33..8d5c7055e7 100644
--- a/src/armnn/test/UnitTests.hpp
+++ b/src/armnn/test/UnitTests.hpp
@@ -12,7 +12,7 @@
inline void ConfigureLoggingTest()
{
- // Configure logging for both the ARMNN library and this test program
+ // Configures logging for both the ARMNN library and this test program.
armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, armnn::LogSeverity::Fatal);
}
@@ -43,9 +43,27 @@ void CompareTestResultIfSupported(const std::string& testName, const LayerTestRe
}
}
+template <typename T, std::size_t n>
+void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult)
+{
+ bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
+ for (unsigned int i = 0; i < testResult.size(); ++i)
+ {
+ BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].supported,
+ "The test name does not match the supportedness it is reporting");
+ if (testResult[i].supported)
+ {
+ BOOST_TEST(CompareTensors(testResult[i].output, testResult[i].outputExpected));
+ }
+ }
+}
+
template<typename FactoryType, typename TFuncPtr, typename... Args>
void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
{
+ std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
+ armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
+
FactoryType workloadFactory;
auto testResult = (*testFunction)(workloadFactory, args...);
CompareTestResultIfSupported(testName, testResult);