// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include "InferenceTest.hpp" #include #include #include "../src/armnn/Profiling.hpp" #include #include #include #include #include using namespace std; using namespace std::chrono; using namespace armnn::test; namespace armnn { namespace test { /// Parse the command line of an ArmNN (or referencetests) inference test program. /// \return false if any error occurred during options processing, otherwise true bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider, InferenceTestOptions& outParams) { cxxopts::Options options("InferenceTest", "Inference iteration parameters"); try { // Adds generic options needed for all inference tests. options .allow_unrecognised_options() .add_options() ("h,help", "Display help messages") ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.", cxxopts::value(outParams.m_IterationCount)->default_value("0")) ("inference-times-file", "If non-empty, each individual inference time will be recorded and output to this file", cxxopts::value(outParams.m_InferenceTimesFile)->default_value("")) ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.", cxxopts::value(outParams.m_EnableProfiling)->default_value("0")); std::vector required; //to be passed as reference to derived inference tests // Adds options specific to the ITestCaseProvider. testCaseProvider.AddCommandLineOptions(options, required); auto result = options.parse(argc, argv); if (result.count("help")) { std::cout << options.help() << std::endl; return false; } CheckRequiredOptions(result, required); } catch (const cxxopts::OptionException& e) { std::cerr << e.what() << std::endl << options.help() << std::endl; return false; } catch (const std::exception& e) { // Coverity points out that default_value(...) can throw a bad_lexical_cast, // and that desc.add_options() can throw boost::io::too_few_args. // They really won't in any of these cases. ARMNN_ASSERT_MSG(false, "Caught unexpected exception"); std::cerr << "Fatal internal error: " << e.what() << std::endl; return false; } if (!testCaseProvider.ProcessCommandLineOptions(outParams)) { return false; } return true; } bool ValidateDirectory(std::string& dir) { if (dir.empty()) { std::cerr << "No directory specified" << std::endl; return false; } if (dir[dir.length() - 1] != '/') { dir += "/"; } if (!fs::exists(dir)) { std::cerr << "Given directory " << dir << " does not exist" << std::endl; return false; } if (!fs::is_directory(dir)) { std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl; return false; } return true; } bool InferenceTest(const InferenceTestOptions& params, const std::vector& defaultTestCaseIds, IInferenceTestCaseProvider& testCaseProvider) { #if !defined (NDEBUG) if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn. { ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate."; } #endif double totalTime = 0; unsigned int nbProcessed = 0; bool success = true; // Opens the file to write inference times too, if needed. ofstream inferenceTimesFile; const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty(); if (recordInferenceTimes) { inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out); if (!inferenceTimesFile.good()) { ARMNN_LOG(error) << "Failed to open inference times file for writing: " << params.m_InferenceTimesFile; return false; } } // Create a profiler and register it for the current thread. std::unique_ptr profiler = std::make_unique(); ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); // Enable profiling if requested. profiler->EnableProfiling(params.m_EnableProfiling); // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer std::unique_ptr warmupTestCase = testCaseProvider.GetTestCase(0); if (warmupTestCase == nullptr) { ARMNN_LOG(error) << "Failed to load test case"; return false; } try { warmupTestCase->Run(); } catch (const TestFrameworkException& testError) { ARMNN_LOG(error) << testError.what(); return false; } const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount : static_cast(defaultTestCaseIds.size()); for (; nbProcessed < nbTotalToProcess; nbProcessed++) { const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed]; std::unique_ptr testCase = testCaseProvider.GetTestCase(testCaseId); if (testCase == nullptr) { ARMNN_LOG(error) << "Failed to load test case"; return false; } time_point predictStart; time_point predictEnd; TestCaseResult result = TestCaseResult::Ok; try { predictStart = high_resolution_clock::now(); testCase->Run(); predictEnd = high_resolution_clock::now(); // duration will convert the time difference into seconds as a double by default. double timeTakenS = duration(predictEnd - predictStart).count(); totalTime += timeTakenS; // Outputss inference times, if needed. if (recordInferenceTimes) { inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl; } result = testCase->ProcessResult(params); } catch (const TestFrameworkException& testError) { ARMNN_LOG(error) << testError.what(); result = TestCaseResult::Abort; } switch (result) { case TestCaseResult::Ok: break; case TestCaseResult::Abort: return false; case TestCaseResult::Failed: // This test failed so we will fail the entire program eventually, but keep going for now. success = false; break; default: ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult"); return false; } } const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f; ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds"; ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Average time per test case: " << averageTimePerTestCaseMs << " ms"; // if profiling is enabled print out the results if (profiler && profiler->IsProfilingEnabled()) { profiler->Print(std::cout); } if (!success) { ARMNN_LOG(error) << "One or more test cases failed"; return false; } return testCaseProvider.OnInferenceTestFinished(); } } // namespace test } // namespace armnn