diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
commit | c577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch) | |
tree | bd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/test/UnitTests.hpp | |
parent | 4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff) | |
download | armnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz |
Release 18.08
Diffstat (limited to 'src/armnn/test/UnitTests.hpp')
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index 9b750b5b33..8d5c7055e7 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -12,7 +12,7 @@ inline void ConfigureLoggingTest() { - // Configure logging for both the ARMNN library and this test program + // Configures logging for both the ARMNN library and this test program. armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, armnn::LogSeverity::Fatal); } @@ -43,9 +43,27 @@ void CompareTestResultIfSupported(const std::string& testName, const LayerTestRe } } +template <typename T, std::size_t n> +void CompareTestResultIfSupported(const std::string& testName, const std::vector<LayerTestResult<T, n>>& testResult) +{ + bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos; + for (unsigned int i = 0; i < testResult.size(); ++i) + { + BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult[i].supported, + "The test name does not match the supportedness it is reporting"); + if (testResult[i].supported) + { + BOOST_TEST(CompareTensors(testResult[i].output, testResult[i].outputExpected)); + } + } +} + template<typename FactoryType, typename TFuncPtr, typename... Args> void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) { + std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>(); + armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); + FactoryType workloadFactory; auto testResult = (*testFunction)(workloadFactory, args...); CompareTestResultIfSupported(testName, testResult); |