ArmNN  NotReleased
InferenceTest.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/ArmNN.hpp>
8 #include <armnn/Logging.hpp>
9 #include <armnn/TypesUtils.hpp>
10 #include "InferenceModel.hpp"
11 
12 #include <boost/core/ignore_unused.hpp>
13 #include <boost/program_options.hpp>
14 
15 
16 namespace armnn
17 {
18 
19 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
20 {
21  std::string token;
22  in >> token;
23  compute = armnn::ParseComputeDevice(token.c_str());
24  if (compute == armnn::Compute::Undefined)
25  {
26  in.setstate(std::ios_base::failbit);
27  throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
28  }
29  return in;
30 }
31 
32 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
33 {
34  std::string token;
35  in >> token;
36  armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
37  if (compute == armnn::Compute::Undefined)
38  {
39  in.setstate(std::ios_base::failbit);
40  throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
41  }
42  backend = compute;
43  return in;
44 }
45 
46 namespace test
47 {
48 
50 {
51 public:
53 };
54 
56 {
57  unsigned int m_IterationCount;
58  std::string m_InferenceTimesFile;
60  std::string m_DynamicBackendsPath;
61 
63  : m_IterationCount(0)
64  , m_EnableProfiling(0)
65  , m_DynamicBackendsPath()
66  {}
67 };
68 
69 enum class TestCaseResult
70 {
72  Ok,
75  Failed,
77  Abort
78 };
79 
81 {
82 public:
83  virtual ~IInferenceTestCase() {}
84 
85  virtual void Run() = 0;
86  virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
87 };
88 
90 {
91 public:
93 
94  virtual void AddCommandLineOptions(boost::program_options::options_description& options)
95  {
96  boost::ignore_unused(options);
97  };
98  virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
99  {
100  boost::ignore_unused(commonOptions);
101  return true;
102  };
103  virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
104  virtual bool OnInferenceTestFinished() { return true; };
105 };
106 
107 template <typename TModel>
109 {
110 public:
111  using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
112 
113  InferenceModelTestCase(TModel& model,
114  unsigned int testCaseId,
115  const std::vector<TContainer>& inputs,
116  const std::vector<unsigned int>& outputSizes)
117  : m_Model(model)
118  , m_TestCaseId(testCaseId)
119  , m_Inputs(std::move(inputs))
120  {
121  // Initialize output vector
122  const size_t numOutputs = outputSizes.size();
123  m_Outputs.reserve(numOutputs);
124 
125  for (size_t i = 0; i < numOutputs; i++)
126  {
127  m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
128  }
129  }
130 
131  virtual void Run() override
132  {
133  m_Model.Run(m_Inputs, m_Outputs);
134  }
135 
136 protected:
137  unsigned int GetTestCaseId() const { return m_TestCaseId; }
138  const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
139 
140 private:
141  TModel& m_Model;
142  unsigned int m_TestCaseId;
143  std::vector<TContainer> m_Inputs;
144  std::vector<TContainer> m_Outputs;
145 };
146 
147 template <typename TTestCaseDatabase, typename TModel>
149 {
150 public:
151  ClassifierTestCase(int& numInferencesRef,
152  int& numCorrectInferencesRef,
153  const std::vector<unsigned int>& validationPredictions,
154  std::vector<unsigned int>* validationPredictionsOut,
155  TModel& model,
156  unsigned int testCaseId,
157  unsigned int label,
158  std::vector<typename TModel::DataType> modelInput);
159 
160  virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
161 
162 private:
163  unsigned int m_Label;
164  InferenceModelInternal::QuantizationParams m_QuantizationParams;
165 
168  int& m_NumInferencesRef;
169  int& m_NumCorrectInferencesRef;
170  const std::vector<unsigned int>& m_ValidationPredictions;
171  std::vector<unsigned int>* m_ValidationPredictionsOut;
173 };
174 
175 template <typename TDatabase, typename InferenceModel>
177 {
178 public:
179  template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
180  ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
181 
182  virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
183  virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
184  virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
185  virtual bool OnInferenceTestFinished() override;
186 
187 private:
188  void ReadPredictions();
189 
190  typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
191  std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
192  typename InferenceModel::CommandLineOptions)> m_ConstructModel;
193  std::unique_ptr<InferenceModel> m_Model;
194 
195  std::string m_DataDir;
196  std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
197  std::unique_ptr<TDatabase> m_Database;
198 
199  int m_NumInferences; // Referenced by test cases.
200  int m_NumCorrectInferences; // Referenced by test cases.
201 
202  std::string m_ValidationFileIn;
203  std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
204 
205  std::string m_ValidationFileOut;
206  std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
207 };
208 
209 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
210  InferenceTestOptions& outParams);
211 
212 bool ValidateDirectory(std::string& dir);
213 
214 bool InferenceTest(const InferenceTestOptions& params,
215  const std::vector<unsigned int>& defaultTestCaseIds,
216  IInferenceTestCaseProvider& testCaseProvider);
217 
218 template<typename TConstructTestCaseProvider>
219 int InferenceTestMain(int argc,
220  char* argv[],
221  const std::vector<unsigned int>& defaultTestCaseIds,
222  TConstructTestCaseProvider constructTestCaseProvider);
223 
224 template<typename TDatabase,
225  typename TParser,
226  typename TConstructDatabaseCallable>
227 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
228  const char* inputBindingName, const char* outputBindingName,
229  const std::vector<unsigned int>& defaultTestCaseIds,
230  TConstructDatabaseCallable constructDatabase,
231  const armnn::TensorShape* inputTensorShape = nullptr);
232 
233 } // namespace test
234 } // namespace armnn
235 
236 #include "InferenceTest.inl"
InferenceModelTestCase(TModel &model, unsigned int testCaseId, const std::vector< TContainer > &inputs, const std::vector< unsigned int > &outputSizes)
virtual void AddCommandLineOptions(boost::program_options::options_description &options)
bool InferenceTest(const InferenceTestOptions &params, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
std::pair< float, int32_t > QuantizationParams
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
Exception(const std::string &message)
Definition: Exceptions.cpp:12
boost::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer
The test completed without any errors.
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
std::istream & operator>>(std::istream &in, armnn::Compute &compute)
const std::vector< TContainer > & GetOutputs() const
The test failed with a fatal error. The remaining tests will not be run.
constexpr armnn::Compute ParseComputeDevice(const char *str)
Definition: TypesUtils.hpp:145
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
armnn::Runtime::CreationOptions::ExternalProfilingOptions options
bool ValidateDirectory(std::string &dir)