ArmNN
 22.08
armnn::test Namespace Reference

Classes

class  ClassifierTestCase
 
class  ClassifierTestCaseProvider
 
class  IInferenceTestCase
 
class  IInferenceTestCaseProvider
 
class  InferenceModelTestCase
 
struct  InferenceTestOptions
 
class  TestFrameworkException
 

Enumerations

enum  TestCaseResult { Ok, Failed, Abort }
 

Functions

bool ParseCommandLine (int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
 Parse the command line of an ArmNN (or referencetests) inference test program. More...
 
bool ValidateDirectory (std::string &dir)
 
bool InferenceTest (const InferenceTestOptions &params, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
 
template<typename TConstructTestCaseProvider >
int InferenceTestMain (int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
 
template<typename TDatabase , typename TParser , typename TConstructDatabaseCallable >
int ClassifierInferenceTestMain (int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
 

Enumeration Type Documentation

◆ TestCaseResult

enum TestCaseResult
strong
Enumerator
Ok 

The test completed without any errors.

Failed 

The test failed (e.g.

the prediction didn't match the validation file). This will eventually fail the whole program but the remaining test cases will still be run.

Abort 

The test failed with a fatal error. The remaining tests will not be run.

Definition at line 73 of file InferenceTest.hpp.

74 {
75  /// The test completed without any errors.
76  Ok,
77  /// The test failed (e.g. the prediction didn't match the validation file).
78  /// This will eventually fail the whole program but the remaining test cases will still be run.
79  Failed,
80  /// The test failed with a fatal error. The remaining tests will not be run.
81  Abort
82 };
The test failed with a fatal error. The remaining tests will not be run.
The test completed without any errors.

Function Documentation

◆ ClassifierInferenceTestMain()

int ClassifierInferenceTestMain ( int  argc,
char *  argv[],
const char *  modelFilename,
bool  isModelBinary,
const char *  inputBindingName,
const char *  outputBindingName,
const std::vector< unsigned int > &  defaultTestCaseIds,
TConstructDatabaseCallable  constructDatabase,
const armnn::TensorShape inputTensorShape = nullptr 
)

Definition at line 362 of file InferenceTest.inl.

References ARMNN_ASSERT, InferenceTestMain(), Params::m_ModelPath, and ValidateDirectory().

Referenced by main().

372 {
373  ARMNN_ASSERT(modelFilename);
374  ARMNN_ASSERT(inputBindingName);
375  ARMNN_ASSERT(outputBindingName);
376 
377  return InferenceTestMain(argc, argv, defaultTestCaseIds,
378  [=]
379  ()
380  {
383 
384  return make_unique<TestCaseProvider>(constructDatabase,
385  [&]
386  (const InferenceTestOptions &commonOptions,
387  typename InferenceModel::CommandLineOptions modelOptions)
388  {
389  if (!ValidateDirectory(modelOptions.m_ModelDir))
390  {
391  return std::unique_ptr<InferenceModel>();
392  }
393 
394  typename InferenceModel::Params modelParams;
395  modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
396  modelParams.m_InputBindings = { inputBindingName };
397  modelParams.m_OutputBindings = { outputBindingName };
398 
399  if (inputTensorShape)
400  {
401  modelParams.m_InputShapes.push_back(*inputTensorShape);
402  }
403 
404  modelParams.m_IsModelBinary = isModelBinary;
405  modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
406  modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
407  modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
408 
409  return std::make_unique<InferenceModel>(modelParams,
410  commonOptions.m_EnableProfiling,
411  commonOptions.m_DynamicBackendsPath);
412  });
413  });
414 }
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
bool ValidateDirectory(std::string &dir)
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)

◆ InferenceTest()

bool InferenceTest ( const InferenceTestOptions params,
const std::vector< unsigned int > &  defaultTestCaseIds,
IInferenceTestCaseProvider testCaseProvider 
)

Definition at line 112 of file InferenceTest.cpp.

References ARMNN_ASSERT_MSG, ARMNN_LOG, armnn::error, IInferenceTestCaseProvider::GetTestCase(), armnn::info, InferenceTestOptions::m_EnableProfiling, InferenceTestOptions::m_InferenceTimesFile, InferenceTestOptions::m_IterationCount, IInferenceTestCaseProvider::OnInferenceTestFinished(), armnn::warning, and Exception::what().

Referenced by InferenceTestMain().

115 {
116 #if !defined (NDEBUG)
117  if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
118  {
119  ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
120  }
121 #endif
122 
123  double totalTime = 0;
124  unsigned int nbProcessed = 0;
125  bool success = true;
126 
127  // Opens the file to write inference times too, if needed.
128  ofstream inferenceTimesFile;
129  const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
130  if (recordInferenceTimes)
131  {
132  inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
133  if (!inferenceTimesFile.good())
134  {
135  ARMNN_LOG(error) << "Failed to open inference times file for writing: "
136  << params.m_InferenceTimesFile;
137  return false;
138  }
139  }
140 
141  // Create a profiler and register it for the current thread.
142  std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>();
143  ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
144 
145  // Enable profiling if requested.
146  profiler->EnableProfiling(params.m_EnableProfiling);
147 
148  // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
149  std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
150  if (warmupTestCase == nullptr)
151  {
152  ARMNN_LOG(error) << "Failed to load test case";
153  return false;
154  }
155 
156  try
157  {
158  warmupTestCase->Run();
159  }
160  catch (const TestFrameworkException& testError)
161  {
162  ARMNN_LOG(error) << testError.what();
163  return false;
164  }
165 
166  const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
167  : static_cast<unsigned int>(defaultTestCaseIds.size());
168 
169  for (; nbProcessed < nbTotalToProcess; nbProcessed++)
170  {
171  const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
172  std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
173 
174  if (testCase == nullptr)
175  {
176  ARMNN_LOG(error) << "Failed to load test case";
177  return false;
178  }
179 
180  time_point<high_resolution_clock> predictStart;
181  time_point<high_resolution_clock> predictEnd;
182 
183  TestCaseResult result = TestCaseResult::Ok;
184 
185  try
186  {
187  predictStart = high_resolution_clock::now();
188 
189  testCase->Run();
190 
191  predictEnd = high_resolution_clock::now();
192 
193  // duration<double> will convert the time difference into seconds as a double by default.
194  double timeTakenS = duration<double>(predictEnd - predictStart).count();
195  totalTime += timeTakenS;
196 
197  // Outputss inference times, if needed.
198  if (recordInferenceTimes)
199  {
200  inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
201  }
202 
203  result = testCase->ProcessResult(params);
204 
205  }
206  catch (const TestFrameworkException& testError)
207  {
208  ARMNN_LOG(error) << testError.what();
209  result = TestCaseResult::Abort;
210  }
211 
212  switch (result)
213  {
214  case TestCaseResult::Ok:
215  break;
216  case TestCaseResult::Abort:
217  return false;
218  case TestCaseResult::Failed:
219  // This test failed so we will fail the entire program eventually, but keep going for now.
220  success = false;
221  break;
222  default:
223  ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
224  return false;
225  }
226  }
227 
228  const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
229 
230  ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
231  "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
232  ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
233  "Average time per test case: " << averageTimePerTestCaseMs << " ms";
234 
235  // if profiling is enabled print out the results
236  if (profiler && profiler->IsProfilingEnabled())
237  {
238  profiler->Print(std::cout);
239  }
240 
241  if (!success)
242  {
243  ARMNN_LOG(error) << "One or more test cases failed";
244  return false;
245  }
246 
247  return testCaseProvider.OnInferenceTestFinished();
248 }
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
virtual std::unique_ptr< IInferenceTestCase > GetTestCase(unsigned int testCaseId)=0

◆ InferenceTestMain()

int InferenceTestMain ( int  argc,
char *  argv[],
const std::vector< unsigned int > &  defaultTestCaseIds,
TConstructTestCaseProvider  constructTestCaseProvider 
)

Definition at line 311 of file InferenceTest.inl.

References ARMNN_LOG, armnn::ConfigureLogging(), armnn::Debug, armnn::fatal, InferenceTest(), armnn::Info, ParseCommandLine(), and Exception::what().

Referenced by ClassifierInferenceTestMain(), and main().

315 {
316  // Configures logging for both the ARMNN library and this test program.
317 #ifdef NDEBUG
319 #else
321 #endif
322  armnn::ConfigureLogging(true, true, level);
323 
324  try
325  {
326  std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
327  if (!testCaseProvider)
328  {
329  return 1;
330  }
331 
332  InferenceTestOptions inferenceTestOptions;
333  if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
334  {
335  return 1;
336  }
337 
338  const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
339  return success ? 0 : 1;
340  }
341  catch (armnn::Exception const& e)
342  {
343  ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
344  return 1;
345  }
346 }
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
Definition: Utils.cpp:18
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
bool InferenceTest(const InferenceTestOptions &params, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
LogSeverity
Definition: Utils.hpp:14

◆ ParseCommandLine()

bool ParseCommandLine ( int  argc,
char **  argv,
IInferenceTestCaseProvider testCaseProvider,
InferenceTestOptions outParams 
)

Parse the command line of an ArmNN (or referencetests) inference test program.

Returns
false if any error occurred during options processing, otherwise true

Definition at line 28 of file InferenceTest.cpp.

References IInferenceTestCaseProvider::AddCommandLineOptions(), ARMNN_ASSERT_MSG, CheckRequiredOptions(), InferenceTestOptions::m_EnableProfiling, InferenceTestOptions::m_InferenceTimesFile, InferenceTestOptions::m_IterationCount, and IInferenceTestCaseProvider::ProcessCommandLineOptions().

Referenced by InferenceTestMain().

30 {
31  cxxopts::Options options("InferenceTest", "Inference iteration parameters");
32 
33  try
34  {
35  // Adds generic options needed for all inference tests.
36  options
37  .allow_unrecognised_options()
38  .add_options()
39  ("h,help", "Display help messages")
40  ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
41  cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
42  ("inference-times-file",
43  "If non-empty, each individual inference time will be recorded and output to this file",
44  cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
45  ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
46  cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
47 
48  std::vector<std::string> required; //to be passed as reference to derived inference tests
49 
50  // Adds options specific to the ITestCaseProvider.
51  testCaseProvider.AddCommandLineOptions(options, required);
52 
53  auto result = options.parse(argc, argv);
54 
55  if (result.count("help"))
56  {
57  std::cout << options.help() << std::endl;
58  return false;
59  }
60 
61  CheckRequiredOptions(result, required);
62 
63  }
64  catch (const cxxopts::OptionException& e)
65  {
66  std::cerr << e.what() << std::endl << options.help() << std::endl;
67  return false;
68  }
69  catch (const std::exception& e)
70  {
71  ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
72  std::cerr << "Fatal internal error: " << e.what() << std::endl;
73  return false;
74  }
75 
76  if (!testCaseProvider.ProcessCommandLineOptions(outParams))
77  {
78  return false;
79  }
80 
81  return true;
82 }
bool CheckRequiredOptions(const cxxopts::ParseResult &result, const std::vector< std::string > &required)
Ensure all mandatory command-line parameters have been passed to cxxopts.
virtual void AddCommandLineOptions(cxxopts::Options &options, std::vector< std::string > &required)
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)

◆ ValidateDirectory()

bool ValidateDirectory ( std::string &  dir)

Definition at line 84 of file InferenceTest.cpp.

Referenced by ClassifierInferenceTestMain(), main(), ClassifierTestCaseProvider< TDatabase, InferenceModel >::ProcessCommandLineOptions(), and YoloTestCaseProvider< Model >::ProcessCommandLineOptions().

85 {
86  if (dir.empty())
87  {
88  std::cerr << "No directory specified" << std::endl;
89  return false;
90  }
91 
92  if (dir[dir.length() - 1] != '/')
93  {
94  dir += "/";
95  }
96 
97  if (!fs::exists(dir))
98  {
99  std::cerr << "Given directory " << dir << " does not exist" << std::endl;
100  return false;
101  }
102 
103  if (!fs::is_directory(dir))
104  {
105  std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
106  return false;
107  }
108 
109  return true;
110 }