ArmNN
 22.05
InferenceTest.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "InferenceTest.hpp"
6 
9 
10 #include "../src/armnn/Profiling.hpp"
11 #include <cxxopts/cxxopts.hpp>
12 
13 #include <fstream>
14 #include <iostream>
15 #include <iomanip>
16 #include <array>
17 
18 using namespace std;
19 using namespace std::chrono;
20 using namespace armnn::test;
21 
22 namespace armnn
23 {
24 namespace test
25 {
26 /// Parse the command line of an ArmNN (or referencetests) inference test program.
27 /// \return false if any error occurred during options processing, otherwise true
28 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
29  InferenceTestOptions& outParams)
30 {
31  cxxopts::Options options("InferenceTest", "Inference iteration parameters");
32 
33  try
34  {
35  // Adds generic options needed for all inference tests.
36  options
37  .allow_unrecognised_options()
38  .add_options()
39  ("h,help", "Display help messages")
40  ("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
41  cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
42  ("inference-times-file",
43  "If non-empty, each individual inference time will be recorded and output to this file",
44  cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
45  ("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
46  cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
47 
48  std::vector<std::string> required; //to be passed as reference to derived inference tests
49 
50  // Adds options specific to the ITestCaseProvider.
51  testCaseProvider.AddCommandLineOptions(options, required);
52 
53  auto result = options.parse(argc, argv);
54 
55  if (result.count("help"))
56  {
57  std::cout << options.help() << std::endl;
58  return false;
59  }
60 
61  CheckRequiredOptions(result, required);
62 
63  }
64  catch (const cxxopts::OptionException& e)
65  {
66  std::cerr << e.what() << std::endl << options.help() << std::endl;
67  return false;
68  }
69  catch (const std::exception& e)
70  {
71  ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
72  std::cerr << "Fatal internal error: " << e.what() << std::endl;
73  return false;
74  }
75 
76  if (!testCaseProvider.ProcessCommandLineOptions(outParams))
77  {
78  return false;
79  }
80 
81  return true;
82 }
83 
84 bool ValidateDirectory(std::string& dir)
85 {
86  if (dir.empty())
87  {
88  std::cerr << "No directory specified" << std::endl;
89  return false;
90  }
91 
92  if (dir[dir.length() - 1] != '/')
93  {
94  dir += "/";
95  }
96 
97  if (!fs::exists(dir))
98  {
99  std::cerr << "Given directory " << dir << " does not exist" << std::endl;
100  return false;
101  }
102 
103  if (!fs::is_directory(dir))
104  {
105  std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
106  return false;
107  }
108 
109  return true;
110 }
111 
113  const std::vector<unsigned int>& defaultTestCaseIds,
114  IInferenceTestCaseProvider& testCaseProvider)
115 {
116 #if !defined (NDEBUG)
117  if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
118  {
119  ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
120  }
121 #endif
122 
123  double totalTime = 0;
124  unsigned int nbProcessed = 0;
125  bool success = true;
126 
127  // Opens the file to write inference times too, if needed.
128  ofstream inferenceTimesFile;
129  const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
130  if (recordInferenceTimes)
131  {
132  inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
133  if (!inferenceTimesFile.good())
134  {
135  ARMNN_LOG(error) << "Failed to open inference times file for writing: "
136  << params.m_InferenceTimesFile;
137  return false;
138  }
139  }
140 
141  // Create a profiler and register it for the current thread.
142  std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>();
143  ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
144 
145  // Enable profiling if requested.
146  profiler->EnableProfiling(params.m_EnableProfiling);
147 
148  // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
149  std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
150  if (warmupTestCase == nullptr)
151  {
152  ARMNN_LOG(error) << "Failed to load test case";
153  return false;
154  }
155 
156  try
157  {
158  warmupTestCase->Run();
159  }
160  catch (const TestFrameworkException& testError)
161  {
162  ARMNN_LOG(error) << testError.what();
163  return false;
164  }
165 
166  const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
167  : static_cast<unsigned int>(defaultTestCaseIds.size());
168 
169  for (; nbProcessed < nbTotalToProcess; nbProcessed++)
170  {
171  const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
172  std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
173 
174  if (testCase == nullptr)
175  {
176  ARMNN_LOG(error) << "Failed to load test case";
177  return false;
178  }
179 
180  time_point<high_resolution_clock> predictStart;
181  time_point<high_resolution_clock> predictEnd;
182 
183  TestCaseResult result = TestCaseResult::Ok;
184 
185  try
186  {
187  predictStart = high_resolution_clock::now();
188 
189  testCase->Run();
190 
191  predictEnd = high_resolution_clock::now();
192 
193  // duration<double> will convert the time difference into seconds as a double by default.
194  double timeTakenS = duration<double>(predictEnd - predictStart).count();
195  totalTime += timeTakenS;
196 
197  // Outputss inference times, if needed.
198  if (recordInferenceTimes)
199  {
200  inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
201  }
202 
203  result = testCase->ProcessResult(params);
204 
205  }
206  catch (const TestFrameworkException& testError)
207  {
208  ARMNN_LOG(error) << testError.what();
209  result = TestCaseResult::Abort;
210  }
211 
212  switch (result)
213  {
214  case TestCaseResult::Ok:
215  break;
216  case TestCaseResult::Abort:
217  return false;
218  case TestCaseResult::Failed:
219  // This test failed so we will fail the entire program eventually, but keep going for now.
220  success = false;
221  break;
222  default:
223  ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
224  return false;
225  }
226  }
227 
228  const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
229 
230  ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
231  "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
232  ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
233  "Average time per test case: " << averageTimePerTestCaseMs << " ms";
234 
235  // if profiling is enabled print out the results
236  if (profiler && profiler->IsProfilingEnabled())
237  {
238  profiler->Print(std::cout);
239  }
240 
241  if (!success)
242  {
243  ARMNN_LOG(error) << "One or more test cases failed";
244  return false;
245  }
246 
247  return testCaseProvider.OnInferenceTestFinished();
248 }
249 
250 } // namespace test
251 
252 } // namespace armnn
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
bool CheckRequiredOptions(const cxxopts::ParseResult &result, const std::vector< std::string > &required)
Ensure all mandatory command-line parameters have been passed to cxxopts.
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
Copyright (c) 2021 ARM Limited and Contributors.
virtual void AddCommandLineOptions(cxxopts::Options &options, std::vector< std::string > &required)
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
virtual std::unique_ptr< IInferenceTestCase > GetTestCase(unsigned int testCaseId)=0
bool InferenceTest(const InferenceTestOptions &params, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
bool ValidateDirectory(std::string &dir)