diff options
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index 9b750b5b33..8d5c7055e7 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -12,7 +12,7 @@ inline void ConfigureLoggingTest() { - // Configure logging for both the ARMNN library and this test program + // Configures logging for both the ARMNN library and this test program. armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, armnn::LogSeverity::Fatal); } @@ -43,9 +43,27 @@ void CompareTestResultIfSupported(const std::string& testName, const LayerTestRe } } +template <typename T, std::size_t n> +void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult) +{ + bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos; + for (unsigned int i = 0; i < testResult.size(); ++i) + { + BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].supported, + "The test name does not match the supportedness it is reporting"); + if (testResult[i].supported) + { + BOOST_TEST(CompareTensors(testResult[i].output, testResult[i].outputExpected)); + } + } +} + template<typename FactoryType, typename TFuncPtr, typename... Args> void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) { + std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>(); + armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); + FactoryType workloadFactory; auto testResult = (*testFunction)(workloadFactory, args...); CompareTestResultIfSupported(testName, testResult); |