ArmNN
 22.05
InferenceTest.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "InferenceModel.hpp"
8 
9 #include <armnn/ArmNN.hpp>
10 #include <armnn/Logging.hpp>
11 #include <armnn/TypesUtils.hpp>
13 
15 
16 #include <cxxopts/cxxopts.hpp>
17 #include <fmt/format.h>
18 
19 
20 namespace armnn
21 {
22 
23 inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24 {
25  std::string token;
26  in >> token;
27  compute = armnn::ParseComputeDevice(token.c_str());
28  if (compute == armnn::Compute::Undefined)
29  {
30  in.setstate(std::ios_base::failbit);
31  throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
32  }
33  return in;
34 }
35 
36 inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
37 {
38  std::string token;
39  in >> token;
40  armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
41  if (compute == armnn::Compute::Undefined)
42  {
43  in.setstate(std::ios_base::failbit);
44  throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
45  }
46  backend = compute;
47  return in;
48 }
49 
50 namespace test
51 {
52 
54 {
55 public:
57 };
58 
60 {
61  unsigned int m_IterationCount;
62  std::string m_InferenceTimesFile;
64  std::string m_DynamicBackendsPath;
65 
67  : m_IterationCount(0)
68  , m_EnableProfiling(0)
69  , m_DynamicBackendsPath()
70  {}
71 };
72 
73 enum class TestCaseResult
74 {
75  /// The test completed without any errors.
76  Ok,
77  /// The test failed (e.g. the prediction didn't match the validation file).
78  /// This will eventually fail the whole program but the remaining test cases will still be run.
79  Failed,
80  /// The test failed with a fatal error. The remaining tests will not be run.
81  Abort
82 };
83 
85 {
86 public:
87  virtual ~IInferenceTestCase() {}
88 
89  virtual void Run() = 0;
90  virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
91 };
92 
94 {
95 public:
97 
98  virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
99  {
100  IgnoreUnused(options, required);
101  };
102  virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
103  {
104  IgnoreUnused(commonOptions);
105  return true;
106  };
107  virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
108  virtual bool OnInferenceTestFinished() { return true; };
109 };
110 
111 template <typename TModel>
113 {
114 public:
115 
116  InferenceModelTestCase(TModel& model,
117  unsigned int testCaseId,
118  const std::vector<armnnUtils::TContainer>& inputs,
119  const std::vector<unsigned int>& outputSizes)
120  : m_Model(model)
121  , m_TestCaseId(testCaseId)
122  , m_Inputs(std::move(inputs))
123  {
124  // Initialize output vector
125  const size_t numOutputs = outputSizes.size();
126  m_Outputs.reserve(numOutputs);
127 
128  for (size_t i = 0; i < numOutputs; i++)
129  {
130  m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
131  }
132  }
133 
134  virtual void Run() override
135  {
136  m_Model.Run(m_Inputs, m_Outputs);
137  }
138 
139 protected:
140  unsigned int GetTestCaseId() const { return m_TestCaseId; }
141  const std::vector<armnnUtils::TContainer>& GetOutputs() const { return m_Outputs; }
142 
143 private:
144  TModel& m_Model;
145  unsigned int m_TestCaseId;
146  std::vector<armnnUtils::TContainer> m_Inputs;
147  std::vector<armnnUtils::TContainer> m_Outputs;
148 };
149 
150 template <typename TTestCaseDatabase, typename TModel>
152 {
153 public:
154  ClassifierTestCase(int& numInferencesRef,
155  int& numCorrectInferencesRef,
156  const std::vector<unsigned int>& validationPredictions,
157  std::vector<unsigned int>* validationPredictionsOut,
158  TModel& model,
159  unsigned int testCaseId,
160  unsigned int label,
161  std::vector<typename TModel::DataType> modelInput);
162 
163  virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
164 
165 private:
166  unsigned int m_Label;
167  InferenceModelInternal::QuantizationParams m_QuantizationParams;
168 
169  /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
170  /// @{
171  int& m_NumInferencesRef;
172  int& m_NumCorrectInferencesRef;
173  const std::vector<unsigned int>& m_ValidationPredictions;
174  std::vector<unsigned int>* m_ValidationPredictionsOut;
175  /// @}
176 };
177 
178 template <typename TDatabase, typename InferenceModel>
180 {
181 public:
182  template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183  ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
184 
185  virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
186  virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
187  virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
188  virtual bool OnInferenceTestFinished() override;
189 
190 private:
191  void ReadPredictions();
192 
193  typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
194  std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
195  typename InferenceModel::CommandLineOptions)> m_ConstructModel;
196  std::unique_ptr<InferenceModel> m_Model;
197 
198  std::string m_DataDir;
199  std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
200  std::unique_ptr<TDatabase> m_Database;
201 
202  int m_NumInferences; // Referenced by test cases.
203  int m_NumCorrectInferences; // Referenced by test cases.
204 
205  std::string m_ValidationFileIn;
206  std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
207 
208  std::string m_ValidationFileOut;
209  std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
210 };
211 
212 bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
213  InferenceTestOptions& outParams);
214 
215 bool ValidateDirectory(std::string& dir);
216 
217 bool InferenceTest(const InferenceTestOptions& params,
218  const std::vector<unsigned int>& defaultTestCaseIds,
219  IInferenceTestCaseProvider& testCaseProvider);
220 
221 template<typename TConstructTestCaseProvider>
222 int InferenceTestMain(int argc,
223  char* argv[],
224  const std::vector<unsigned int>& defaultTestCaseIds,
225  TConstructTestCaseProvider constructTestCaseProvider);
226 
227 template<typename TDatabase,
228  typename TParser,
229  typename TConstructDatabaseCallable>
230 int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
231  const char* inputBindingName, const char* outputBindingName,
232  const std::vector<unsigned int>& defaultTestCaseIds,
233  TConstructDatabaseCallable constructDatabase,
234  const armnn::TensorShape* inputTensorShape = nullptr);
235 
236 } // namespace test
237 } // namespace armnn
238 
239 #include "InferenceTest.inl"
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
std::istream & operator>>(std::istream &in, armnn::Compute &compute)
InferenceModelTestCase(TModel &model, unsigned int testCaseId, const std::vector< armnnUtils::TContainer > &inputs, const std::vector< unsigned int > &outputSizes)
Exception(const std::string &message)
Definition: Exceptions.cpp:12
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
const std::vector< armnnUtils::TContainer > & GetOutputs() const
Compute
The Compute enum is now deprecated and it is now being replaced by BackendId.
Definition: BackendId.hpp:21
virtual void AddCommandLineOptions(cxxopts::Options &options, std::vector< std::string > &required)
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
constexpr armnn::Compute ParseComputeDevice(const char *str)
Deprecated function that will be removed together with the Compute enum.
Definition: TypesUtils.hpp:182
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
std::pair< float, int32_t > QuantizationParams
bool InferenceTest(const InferenceTestOptions &params, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
The test failed with a fatal error. The remaining tests will not be run.
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
bool ValidateDirectory(std::string &dir)
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
The test completed without any errors.