ArmNN
 20.05
InferenceTest.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "InferenceModel.hpp"
8 
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
13 
14 #include <boost/program_options.hpp>
15 
16 
17 namespace armnn
18 {
19 
20 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
21 {
22  std::string token;
23  in >> token;
24  compute = armnn::ParseComputeDevice(token.c_str());
25  if (compute == armnn::Compute::Undefined)
26  {
27  in.setstate(std::ios_base::failbit);
28  throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
29  }
30  return in;
31 }
32 
33 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
34 {
35  std::string token;
36  in >> token;
37  armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
38  if (compute == armnn::Compute::Undefined)
39  {
40  in.setstate(std::ios_base::failbit);
41  throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
42  }
43  backend = compute;
44  return in;
45 }
46 
47 namespace test
48 {
49 
51 {
52 public:
54 };
55 
57 {
58  unsigned int m_IterationCount;
59  std::string m_InferenceTimesFile;
61  std::string m_DynamicBackendsPath;
62 
64  : m_IterationCount(0)
65  , m_EnableProfiling(0)
66  , m_DynamicBackendsPath()
67  {}
68 };
69 
70 enum class TestCaseResult
71 {
72  /// The test completed without any errors.
73  Ok,
74  /// The test failed (e.g. the prediction didn't match the validation file).
75  /// This will eventually fail the whole program but the remaining test cases will still be run.
76  Failed,
77  /// The test failed with a fatal error. The remaining tests will not be run.
78  Abort
79 };
80 
82 {
83 public:
84  virtual ~IInferenceTestCase() {}
85 
86  virtual void Run() = 0;
87  virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
88 };
89 
91 {
92 public:
94 
95  virtual void AddCommandLineOptions(boost::program_options::options_description& options)
96  {
97  IgnoreUnused(options);
98  };
99  virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
100  {
101  IgnoreUnused(commonOptions);
102  return true;
103  };
104  virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
105  virtual bool OnInferenceTestFinished() { return true; };
106 };
107 
108 template <typename TModel>
110 {
111 public:
112  using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
113 
114  InferenceModelTestCase(TModel& model,
115  unsigned int testCaseId,
116  const std::vector<TContainer>& inputs,
117  const std::vector<unsigned int>& outputSizes)
118  : m_Model(model)
119  , m_TestCaseId(testCaseId)
120  , m_Inputs(std::move(inputs))
121  {
122  // Initialize output vector
123  const size_t numOutputs = outputSizes.size();
124  m_Outputs.reserve(numOutputs);
125 
126  for (size_t i = 0; i < numOutputs; i++)
127  {
128  m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
129  }
130  }
131 
132  virtual void Run() override
133  {
134  m_Model.Run(m_Inputs, m_Outputs);
135  }
136 
137 protected:
138  unsigned int GetTestCaseId() const { return m_TestCaseId; }
139  const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
140 
141 private:
142  TModel& m_Model;
143  unsigned int m_TestCaseId;
144  std::vector<TContainer> m_Inputs;
145  std::vector<TContainer> m_Outputs;
146 };
147 
148 template <typename TTestCaseDatabase, typename TModel>
150 {
151 public:
152  ClassifierTestCase(int& numInferencesRef,
153  int& numCorrectInferencesRef,
154  const std::vector<unsigned int>& validationPredictions,
155  std::vector<unsigned int>* validationPredictionsOut,
156  TModel& model,
157  unsigned int testCaseId,
158  unsigned int label,
159  std::vector<typename TModel::DataType> modelInput);
160 
161  virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
162 
163 private:
164  unsigned int m_Label;
165  InferenceModelInternal::QuantizationParams m_QuantizationParams;
166 
167  /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
168  /// @{
169  int& m_NumInferencesRef;
170  int& m_NumCorrectInferencesRef;
171  const std::vector<unsigned int>& m_ValidationPredictions;
172  std::vector<unsigned int>* m_ValidationPredictionsOut;
173  /// @}
174 };
175 
176 template <typename TDatabase, typename InferenceModel>
178 {
179 public:
180  template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
181  ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
182 
183  virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
184  virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
185  virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
186  virtual bool OnInferenceTestFinished() override;
187 
188 private:
189  void ReadPredictions();
190 
191  typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
192  std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
193  typename InferenceModel::CommandLineOptions)> m_ConstructModel;
194  std::unique_ptr<InferenceModel> m_Model;
195 
196  std::string m_DataDir;
197  std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
198  std::unique_ptr<TDatabase> m_Database;
199 
200  int m_NumInferences; // Referenced by test cases.
201  int m_NumCorrectInferences; // Referenced by test cases.
202 
203  std::string m_ValidationFileIn;
204  std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
205 
206  std::string m_ValidationFileOut;
207  std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
208 };
209 
210 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
211  InferenceTestOptions& outParams);
212 
213 bool ValidateDirectory(std::string& dir);
214 
215 bool InferenceTest(const InferenceTestOptions& params,
216  const std::vector<unsigned int>& defaultTestCaseIds,
217  IInferenceTestCaseProvider& testCaseProvider);
218 
219 template<typename TConstructTestCaseProvider>
220 int InferenceTestMain(int argc,
221  char* argv[],
222  const std::vector<unsigned int>& defaultTestCaseIds,
223  TConstructTestCaseProvider constructTestCaseProvider);
224 
225 template<typename TDatabase,
226  typename TParser,
227  typename TConstructDatabaseCallable>
228 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
229  const char* inputBindingName, const char* outputBindingName,
230  const std::vector<unsigned int>& defaultTestCaseIds,
231  TConstructDatabaseCallable constructDatabase,
232  const armnn::TensorShape* inputTensorShape = nullptr);
233 
234 } // namespace test
235 } // namespace armnn
236 
237 #include "InferenceTest.inl"
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
std::istream & operator>>(std::istream &in, armnn::Compute &compute)
const std::vector< TContainer > & GetOutputs() const
Exception(const std::string &message)
Definition: Exceptions.cpp:12
virtual void AddCommandLineOptions(boost::program_options::options_description &options)
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
Compute
The Compute enum is now deprecated and it is now being replaced by BackendId.
Definition: BackendId.hpp:21
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
constexpr armnn::Compute ParseComputeDevice(const char *str)
Deprecated function that will be removed together with the Compute enum.
Definition: TypesUtils.hpp:148
InferenceModelTestCase(TModel &model, unsigned int testCaseId, const std::vector< TContainer > &inputs, const std::vector< unsigned int > &outputSizes)
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
std::pair< float, int32_t > QuantizationParams
bool InferenceTest(const InferenceTestOptions &params, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
The test failed with a fatal error. The remaining tests will not be run.
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
bool ValidateDirectory(std::string &dir)
boost::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer
armnn::Runtime::CreationOptions::ExternalProfilingOptions options
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
The test completed without any errors.