diff options
Diffstat (limited to 'tests/InferenceTest.cpp')
-rw-r--r-- | tests/InferenceTest.cpp | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/tests/InferenceTest.cpp b/tests/InferenceTest.cpp index 161481f2cd..477ae4e84e 100644 --- a/tests/InferenceTest.cpp +++ b/tests/InferenceTest.cpp @@ -4,6 +4,7 @@ // #include "InferenceTest.hpp" +#include "../src/armnn/Profiling.hpp" #include <boost/algorithm/string.hpp> #include <boost/numeric/conversion/cast.hpp> #include <boost/log/trivial.hpp> @@ -26,7 +27,6 @@ namespace armnn { namespace test { - /// Parse the command line of an ArmNN (or referencetests) inference test program. /// \return false if any error occurred during options processing, otherwise true bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider, @@ -40,15 +40,17 @@ bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCas try { - // Add generic options needed for all inference tests + // Adds generic options needed for all inference tests. desc.add_options() ("help", "Display help messages") ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0), "Sets the number number of inferences to perform. If unset, a default number will be ran.") ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""), - "If non-empty, each individual inference time will be recorded and output to this file"); + "If non-empty, each individual inference time will be recorded and output to this file") + ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0), + "Enables built in profiler. If unset, defaults to off."); - // Add options specific to the ITestCaseProvider + // Adds options specific to the ITestCaseProvider. testCaseProvider.AddCommandLineOptions(desc); } catch (const std::exception& e) @@ -111,7 +113,7 @@ bool InferenceTest(const InferenceTestOptions& params, IInferenceTestCaseProvider& testCaseProvider) { #if !defined (NDEBUG) - if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn + if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn. { BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate."; } @@ -121,7 +123,7 @@ bool InferenceTest(const InferenceTestOptions& params, unsigned int nbProcessed = 0; bool success = true; - // Open the file to write inference times to, if needed + // Opens the file to write inference times too, if needed. ofstream inferenceTimesFile; const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty(); if (recordInferenceTimes) @@ -135,6 +137,13 @@ bool InferenceTest(const InferenceTestOptions& params, } } + // Create a profiler and register it for the current thread. + std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>(); + ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); + + // Enable profiling if requested. + profiler->EnableProfiling(params.m_EnableProfiling); + // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0); if (warmupTestCase == nullptr) @@ -184,7 +193,7 @@ bool InferenceTest(const InferenceTestOptions& params, double timeTakenS = duration<double>(predictEnd - predictStart).count(); totalTime += timeTakenS; - // Output inference times if needed + // Outputss inference times, if needed. if (recordInferenceTimes) { inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl; |