16 #include <cxxopts/cxxopts.hpp> 17 #include <fmt/format.h> 30 in.setstate(std::ios_base::failbit);
31 throw cxxopts::OptionException(fmt::format(
"Unrecognised compute device: {}", token));
43 in.setstate(std::ios_base::failbit);
44 throw cxxopts::OptionException(fmt::format(
"Unrecognised compute device: {}", token));
68 , m_EnableProfiling(0)
69 , m_DynamicBackendsPath()
89 virtual void Run() = 0;
107 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(
unsigned int testCaseId) = 0;
111 template <
typename TModel>
117 unsigned int testCaseId,
118 const std::vector<armnnUtils::TContainer>& inputs,
119 const std::vector<unsigned int>& outputSizes)
121 , m_TestCaseId(testCaseId)
122 , m_Inputs(
std::move(inputs))
125 const size_t numOutputs = outputSizes.size();
126 m_Outputs.reserve(numOutputs);
128 for (
size_t i = 0; i < numOutputs; i++)
130 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
134 virtual void Run()
override 136 m_Model.Run(m_Inputs, m_Outputs);
141 const std::vector<armnnUtils::TContainer>&
GetOutputs()
const {
return m_Outputs; }
145 unsigned int m_TestCaseId;
146 std::vector<armnnUtils::TContainer> m_Inputs;
147 std::vector<armnnUtils::TContainer> m_Outputs;
150 template <
typename TTestCaseDatabase,
typename TModel>
155 int& numCorrectInferencesRef,
156 const std::vector<unsigned int>& validationPredictions,
157 std::vector<unsigned int>* validationPredictionsOut,
159 unsigned int testCaseId,
161 std::vector<typename TModel::DataType> modelInput);
166 unsigned int m_Label;
171 int& m_NumInferencesRef;
172 int& m_NumCorrectInferencesRef;
173 const std::vector<unsigned int>& m_ValidationPredictions;
174 std::vector<unsigned int>* m_ValidationPredictionsOut;
178 template <
typename TDatabase,
typename InferenceModel>
182 template <
typename TConstructDatabaseCallable,
typename TConstructModelCallable>
185 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
override;
187 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(
unsigned int testCaseId)
override;
188 virtual bool OnInferenceTestFinished()
override;
191 void ReadPredictions();
196 std::unique_ptr<InferenceModel> m_Model;
198 std::string m_DataDir;
199 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
200 std::unique_ptr<TDatabase> m_Database;
203 int m_NumCorrectInferences;
205 std::string m_ValidationFileIn;
206 std::vector<unsigned int> m_ValidationPredictions;
208 std::string m_ValidationFileOut;
209 std::vector<unsigned int> m_ValidationPredictionsOut;
218 const std::vector<unsigned int>& defaultTestCaseIds,
221 template<
typename TConstructTestCaseProv
ider>
224 const std::vector<unsigned int>& defaultTestCaseIds,
225 TConstructTestCaseProvider constructTestCaseProvider);
227 template<
typename TDatabase,
229 typename TConstructDatabaseCallable>
231 const char* inputBindingName,
const char* outputBindingName,
232 const std::vector<unsigned int>& defaultTestCaseIds,
233 TConstructDatabaseCallable constructDatabase,
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
virtual ~IInferenceTestCaseProvider()
std::istream & operator>>(std::istream &in, armnn::Compute &compute)
std::string m_InferenceTimesFile
virtual void Run() override
virtual bool OnInferenceTestFinished()
InferenceModelTestCase(TModel &model, unsigned int testCaseId, const std::vector< armnnUtils::TContainer > &inputs, const std::vector< unsigned int > &outputSizes)
Exception(const std::string &message)
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
const std::vector< armnnUtils::TContainer > & GetOutputs() const
Compute
The Compute enum is now deprecated and it is now being replaced by BackendId.
virtual void AddCommandLineOptions(cxxopts::Options &options, std::vector< std::string > &required)
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
constexpr armnn::Compute ParseComputeDevice(const char *str)
Deprecated function that will be removed together with the Compute enum.
unsigned int GetTestCaseId() const
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
std::pair< float, int32_t > QuantizationParams
virtual ~IInferenceTestCase()
bool InferenceTest(const InferenceTestOptions ¶ms, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
The test failed with a fatal error. The remaining tests will not be run.
Base class for all ArmNN exceptions so that users can filter to just those.
unsigned int m_IterationCount
bool ValidateDirectory(std::string &dir)
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
std::string m_DynamicBackendsPath
The test completed without any errors.