aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceTest.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /tests/InferenceTest.cpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'tests/InferenceTest.cpp')
-rw-r--r--tests/InferenceTest.cpp23
1 files changed, 16 insertions, 7 deletions
diff --git a/tests/InferenceTest.cpp b/tests/InferenceTest.cpp
index 161481f2cd..477ae4e84e 100644
--- a/tests/InferenceTest.cpp
+++ b/tests/InferenceTest.cpp
@@ -4,6 +4,7 @@
//
#include "InferenceTest.hpp"
+#include "../src/armnn/Profiling.hpp"
#include <boost/algorithm/string.hpp>
#include <boost/numeric/conversion/cast.hpp>
#include <boost/log/trivial.hpp>
@@ -26,7 +27,6 @@ namespace armnn
{
namespace test
{
-
/// Parse the command line of an ArmNN (or referencetests) inference test program.
/// \return false if any error occurred during options processing, otherwise true
bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
@@ -40,15 +40,17 @@ bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCas
try
{
- // Add generic options needed for all inference tests
+ // Adds generic options needed for all inference tests.
desc.add_options()
("help", "Display help messages")
("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
"Sets the number number of inferences to perform. If unset, a default number will be ran.")
("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
- "If non-empty, each individual inference time will be recorded and output to this file");
+ "If non-empty, each individual inference time will be recorded and output to this file")
+ ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
+ "Enables built in profiler. If unset, defaults to off.");
- // Add options specific to the ITestCaseProvider
+ // Adds options specific to the ITestCaseProvider.
testCaseProvider.AddCommandLineOptions(desc);
}
catch (const std::exception& e)
@@ -111,7 +113,7 @@ bool InferenceTest(const InferenceTestOptions& params,
IInferenceTestCaseProvider& testCaseProvider)
{
#if !defined (NDEBUG)
- if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn
+ if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
{
BOOST_LOG_TRIVIAL(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
}
@@ -121,7 +123,7 @@ bool InferenceTest(const InferenceTestOptions& params,
unsigned int nbProcessed = 0;
bool success = true;
- // Open the file to write inference times to, if needed
+ // Opens the file to write inference times too, if needed.
ofstream inferenceTimesFile;
const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
if (recordInferenceTimes)
@@ -135,6 +137,13 @@ bool InferenceTest(const InferenceTestOptions& params,
}
}
+ // Create a profiler and register it for the current thread.
+ std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
+ ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
+
+ // Enable profiling if requested.
+ profiler->EnableProfiling(params.m_EnableProfiling);
+
// Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
if (warmupTestCase == nullptr)
@@ -184,7 +193,7 @@ bool InferenceTest(const InferenceTestOptions& params,
double timeTakenS = duration<double>(predictEnd - predictStart).count();
totalTime += timeTakenS;
- // Output inference times if needed
+ // Outputss inference times, if needed.
if (recordInferenceTimes)
{
inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;