ArmNN
 20.08
InferenceTest.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "InferenceTest.hpp"
6 
8 #include <Filesystem.hpp>
9 
10 #include "../src/armnn/Profiling.hpp"
11 #include <boost/numeric/conversion/cast.hpp>
12 #include <boost/format.hpp>
13 #include <boost/program_options.hpp>
14 
15 #include <fstream>
16 #include <iostream>
17 #include <iomanip>
18 #include <array>
19 
20 using namespace std;
21 using namespace std::chrono;
22 using namespace armnn::test;
23 
24 namespace armnn
25 {
26 namespace test
27 {
28 /// Parse the command line of an ArmNN (or referencetests) inference test program.
29 /// \return false if any error occurred during options processing, otherwise true
30 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
31  InferenceTestOptions& outParams)
32 {
33  namespace po = boost::program_options;
34 
35  po::options_description desc("Options");
36 
37  try
38  {
39  // Adds generic options needed for all inference tests.
40  desc.add_options()
41  ("help", "Display help messages")
42  ("iterations,i", po::value<unsigned int>(&outParams.m_IterationCount)->default_value(0),
43  "Sets the number number of inferences to perform. If unset, a default number will be ran.")
44  ("inference-times-file", po::value<std::string>(&outParams.m_InferenceTimesFile)->default_value(""),
45  "If non-empty, each individual inference time will be recorded and output to this file")
46  ("event-based-profiling,e", po::value<bool>(&outParams.m_EnableProfiling)->default_value(0),
47  "Enables built in profiler. If unset, defaults to off.");
48 
49  // Adds options specific to the ITestCaseProvider.
50  testCaseProvider.AddCommandLineOptions(desc);
51  }
52  catch (const std::exception& e)
53  {
54  // Coverity points out that default_value(...) can throw a bad_lexical_cast,
55  // and that desc.add_options() can throw boost::io::too_few_args.
56  // They really won't in any of these cases.
57  ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
58  std::cerr << "Fatal internal error: " << e.what() << std::endl;
59  return false;
60  }
61 
62  po::variables_map vm;
63 
64  try
65  {
66  po::store(po::parse_command_line(argc, argv, desc), vm);
67 
68  if (vm.count("help"))
69  {
70  std::cout << desc << std::endl;
71  return false;
72  }
73 
74  po::notify(vm);
75  }
76  catch (po::error& e)
77  {
78  std::cerr << e.what() << std::endl << std::endl;
79  std::cerr << desc << std::endl;
80  return false;
81  }
82 
83  if (!testCaseProvider.ProcessCommandLineOptions(outParams))
84  {
85  return false;
86  }
87 
88  return true;
89 }
90 
91 bool ValidateDirectory(std::string& dir)
92 {
93  if (dir.empty())
94  {
95  std::cerr << "No directory specified" << std::endl;
96  return false;
97  }
98 
99  if (dir[dir.length() - 1] != '/')
100  {
101  dir += "/";
102  }
103 
104  if (!fs::exists(dir))
105  {
106  std::cerr << "Given directory " << dir << " does not exist" << std::endl;
107  return false;
108  }
109 
110  if (!fs::is_directory(dir))
111  {
112  std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
113  return false;
114  }
115 
116  return true;
117 }
118 
120  const std::vector<unsigned int>& defaultTestCaseIds,
121  IInferenceTestCaseProvider& testCaseProvider)
122 {
123 #if !defined (NDEBUG)
124  if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
125  {
126  ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
127  }
128 #endif
129 
130  double totalTime = 0;
131  unsigned int nbProcessed = 0;
132  bool success = true;
133 
134  // Opens the file to write inference times too, if needed.
135  ofstream inferenceTimesFile;
136  const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
137  if (recordInferenceTimes)
138  {
139  inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
140  if (!inferenceTimesFile.good())
141  {
142  ARMNN_LOG(error) << "Failed to open inference times file for writing: "
143  << params.m_InferenceTimesFile;
144  return false;
145  }
146  }
147 
148  // Create a profiler and register it for the current thread.
149  std::unique_ptr<Profiler> profiler = std::make_unique<Profiler>();
150  ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
151 
152  // Enable profiling if requested.
153  profiler->EnableProfiling(params.m_EnableProfiling);
154 
155  // Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
156  std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
157  if (warmupTestCase == nullptr)
158  {
159  ARMNN_LOG(error) << "Failed to load test case";
160  return false;
161  }
162 
163  try
164  {
165  warmupTestCase->Run();
166  }
167  catch (const TestFrameworkException& testError)
168  {
169  ARMNN_LOG(error) << testError.what();
170  return false;
171  }
172 
173  const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
174  : static_cast<unsigned int>(defaultTestCaseIds.size());
175 
176  for (; nbProcessed < nbTotalToProcess; nbProcessed++)
177  {
178  const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
179  std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
180 
181  if (testCase == nullptr)
182  {
183  ARMNN_LOG(error) << "Failed to load test case";
184  return false;
185  }
186 
187  time_point<high_resolution_clock> predictStart;
188  time_point<high_resolution_clock> predictEnd;
189 
190  TestCaseResult result = TestCaseResult::Ok;
191 
192  try
193  {
194  predictStart = high_resolution_clock::now();
195 
196  testCase->Run();
197 
198  predictEnd = high_resolution_clock::now();
199 
200  // duration<double> will convert the time difference into seconds as a double by default.
201  double timeTakenS = duration<double>(predictEnd - predictStart).count();
202  totalTime += timeTakenS;
203 
204  // Outputss inference times, if needed.
205  if (recordInferenceTimes)
206  {
207  inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
208  }
209 
210  result = testCase->ProcessResult(params);
211 
212  }
213  catch (const TestFrameworkException& testError)
214  {
215  ARMNN_LOG(error) << testError.what();
216  result = TestCaseResult::Abort;
217  }
218 
219  switch (result)
220  {
221  case TestCaseResult::Ok:
222  break;
223  case TestCaseResult::Abort:
224  return false;
225  case TestCaseResult::Failed:
226  // This test failed so we will fail the entire program eventually, but keep going for now.
227  success = false;
228  break;
229  default:
230  ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
231  return false;
232  }
233  }
234 
235  const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
236 
237  ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
238  "Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
239  ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
240  "Average time per test case: " << averageTimePerTestCaseMs << " ms";
241 
242  // if profiling is enabled print out the results
243  if (profiler && profiler->IsProfilingEnabled())
244  {
245  profiler->Print(std::cout);
246  }
247 
248  if (!success)
249  {
250  ARMNN_LOG(error) << "One or more test cases failed";
251  return false;
252  }
253 
254  return testCaseProvider.OnInferenceTestFinished();
255 }
256 
257 } // namespace test
258 
259 } // namespace armnn
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
#define ARMNN_LOG(severity)
Definition: Logging.hpp:163
virtual void AddCommandLineOptions(boost::program_options::options_description &options)
Copyright (c) 2020 ARM Limited.
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
virtual std::unique_ptr< IInferenceTestCase > GetTestCase(unsigned int testCaseId)=0
bool InferenceTest(const InferenceTestOptions &params, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
bool ValidateDirectory(std::string &dir)