ArmNN
 21.02
InferenceTest.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "InferenceModel.hpp"
8 
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
13 
14 #include <cxxopts/cxxopts.hpp>
15 #include <fmt/format.h>
16 
17 
18 namespace armnn
19 {
20 
21 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
22 {
23  std::string token;
24  in >> token;
25  compute = armnn::ParseComputeDevice(token.c_str());
26  if (compute == armnn::Compute::Undefined)
27  {
28  in.setstate(std::ios_base::failbit);
29  throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
30  }
31  return in;
32 }
33 
34 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
35 {
36  std::string token;
37  in >> token;
38  armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
39  if (compute == armnn::Compute::Undefined)
40  {
41  in.setstate(std::ios_base::failbit);
42  throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
43  }
44  backend = compute;
45  return in;
46 }
47 
48 namespace test
49 {
50 
52 {
53 public:
55 };
56 
58 {
59  unsigned int m_IterationCount;
60  std::string m_InferenceTimesFile;
62  std::string m_DynamicBackendsPath;
63 
65  : m_IterationCount(0)
66  , m_EnableProfiling(0)
67  , m_DynamicBackendsPath()
68  {}
69 };
70 
71 enum class TestCaseResult
72 {
73  /// The test completed without any errors.
74  Ok,
75  /// The test failed (e.g. the prediction didn't match the validation file).
76  /// This will eventually fail the whole program but the remaining test cases will still be run.
77  Failed,
78  /// The test failed with a fatal error. The remaining tests will not be run.
79  Abort
80 };
81 
83 {
84 public:
85  virtual ~IInferenceTestCase() {}
86 
87  virtual void Run() = 0;
88  virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
89 };
90 
92 {
93 public:
95 
96  virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
97  {
98  IgnoreUnused(options, required);
99  };
100  virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
101  {
102  IgnoreUnused(commonOptions);
103  return true;
104  };
105  virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
106  virtual bool OnInferenceTestFinished() { return true; };
107 };
108 
109 template <typename TModel>
111 {
112 public:
113  using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
114 
115  InferenceModelTestCase(TModel& model,
116  unsigned int testCaseId,
117  const std::vector<TContainer>& inputs,
118  const std::vector<unsigned int>& outputSizes)
119  : m_Model(model)
120  , m_TestCaseId(testCaseId)
121  , m_Inputs(std::move(inputs))
122  {
123  // Initialize output vector
124  const size_t numOutputs = outputSizes.size();
125  m_Outputs.reserve(numOutputs);
126 
127  for (size_t i = 0; i < numOutputs; i++)
128  {
129  m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
130  }
131  }
132 
133  virtual void Run() override
134  {
135  m_Model.Run(m_Inputs, m_Outputs);
136  }
137 
138 protected:
139  unsigned int GetTestCaseId() const { return m_TestCaseId; }
140  const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
141 
142 private:
143  TModel& m_Model;
144  unsigned int m_TestCaseId;
145  std::vector<TContainer> m_Inputs;
146  std::vector<TContainer> m_Outputs;
147 };
148 
149 template <typename TTestCaseDatabase, typename TModel>
151 {
152 public:
153  ClassifierTestCase(int& numInferencesRef,
154  int& numCorrectInferencesRef,
155  const std::vector<unsigned int>& validationPredictions,
156  std::vector<unsigned int>* validationPredictionsOut,
157  TModel& model,
158  unsigned int testCaseId,
159  unsigned int label,
160  std::vector<typename TModel::DataType> modelInput);
161 
162  virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
163 
164 private:
165  unsigned int m_Label;
166  InferenceModelInternal::QuantizationParams m_QuantizationParams;
167 
168  /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
169  /// @{
170  int& m_NumInferencesRef;
171  int& m_NumCorrectInferencesRef;
172  const std::vector<unsigned int>& m_ValidationPredictions;
173  std::vector<unsigned int>* m_ValidationPredictionsOut;
174  /// @}
175 };
176 
177 template <typename TDatabase, typename InferenceModel>
179 {
180 public:
181  template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
182  ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
183 
184  virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
185  virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
186  virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
187  virtual bool OnInferenceTestFinished() override;
188 
189 private:
190  void ReadPredictions();
191 
192  typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
193  std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
194  typename InferenceModel::CommandLineOptions)> m_ConstructModel;
195  std::unique_ptr<InferenceModel> m_Model;
196 
197  std::string m_DataDir;
198  std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
199  std::unique_ptr<TDatabase> m_Database;
200 
201  int m_NumInferences; // Referenced by test cases.
202  int m_NumCorrectInferences; // Referenced by test cases.
203 
204  std::string m_ValidationFileIn;
205  std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
206 
207  std::string m_ValidationFileOut;
208  std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
209 };
210 
211 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
212  InferenceTestOptions& outParams);
213 
214 bool ValidateDirectory(std::string& dir);
215 
216 bool InferenceTest(const InferenceTestOptions& params,
217  const std::vector<unsigned int>& defaultTestCaseIds,
218  IInferenceTestCaseProvider& testCaseProvider);
219 
220 template<typename TConstructTestCaseProvider>
221 int InferenceTestMain(int argc,
222  char* argv[],
223  const std::vector<unsigned int>& defaultTestCaseIds,
224  TConstructTestCaseProvider constructTestCaseProvider);
225 
226 template<typename TDatabase,
227  typename TParser,
228  typename TConstructDatabaseCallable>
229 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
230  const char* inputBindingName, const char* outputBindingName,
231  const std::vector<unsigned int>& defaultTestCaseIds,
232  TConstructDatabaseCallable constructDatabase,
233  const armnn::TensorShape* inputTensorShape = nullptr);
234 
235 } // namespace test
236 } // namespace armnn
237 
238 #include "InferenceTest.inl"
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
std::istream & operator>>(std::istream &in, armnn::Compute &compute)
const std::vector< TContainer > & GetOutputs() const
Exception(const std::string &message)
Definition: Exceptions.cpp:12
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
Compute
The Compute enum is now deprecated and it is now being replaced by BackendId.
Definition: BackendId.hpp:21
virtual void AddCommandLineOptions(cxxopts::Options &options, std::vector< std::string > &required)
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
constexpr armnn::Compute ParseComputeDevice(const char *str)
Deprecated function that will be removed together with the Compute enum.
Definition: TypesUtils.hpp:160
InferenceModelTestCase(TModel &model, unsigned int testCaseId, const std::vector< TContainer > &inputs, const std::vector< unsigned int > &outputSizes)
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
std::pair< float, int32_t > QuantizationParams
bool InferenceTest(const InferenceTestOptions &params, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
The test failed with a fatal error. The remaining tests will not be run.
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
bool ValidateDirectory(std::string &dir)
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
The test completed without any errors.
mapbox::util::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer