12 #include <boost/core/ignore_unused.hpp> 13 #include <boost/program_options.hpp> 26 in.setstate(std::ios_base::failbit);
27 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
39 in.setstate(std::ios_base::failbit);
40 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
64 , m_EnableProfiling(0)
65 , m_DynamicBackendsPath()
85 virtual void Run() = 0;
96 boost::ignore_unused(options);
100 boost::ignore_unused(commonOptions);
103 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(
unsigned int testCaseId) = 0;
107 template <
typename TModel>
111 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
114 unsigned int testCaseId,
115 const std::vector<TContainer>& inputs,
116 const std::vector<unsigned int>& outputSizes)
118 , m_TestCaseId(testCaseId)
119 , m_Inputs(
std::move(inputs))
122 const size_t numOutputs = outputSizes.size();
123 m_Outputs.reserve(numOutputs);
125 for (
size_t i = 0; i < numOutputs; i++)
127 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
131 virtual void Run()
override 133 m_Model.Run(m_Inputs, m_Outputs);
138 const std::vector<TContainer>&
GetOutputs()
const {
return m_Outputs; }
142 unsigned int m_TestCaseId;
143 std::vector<TContainer> m_Inputs;
144 std::vector<TContainer> m_Outputs;
147 template <
typename TTestCaseDatabase,
typename TModel>
152 int& numCorrectInferencesRef,
153 const std::vector<unsigned int>& validationPredictions,
154 std::vector<unsigned int>* validationPredictionsOut,
156 unsigned int testCaseId,
158 std::vector<typename TModel::DataType> modelInput);
163 unsigned int m_Label;
168 int& m_NumInferencesRef;
169 int& m_NumCorrectInferencesRef;
170 const std::vector<unsigned int>& m_ValidationPredictions;
171 std::vector<unsigned int>* m_ValidationPredictionsOut;
175 template <
typename TDatabase,
typename InferenceModel>
179 template <
typename TConstructDatabaseCallable,
typename TConstructModelCallable>
182 virtual void AddCommandLineOptions(boost::program_options::options_description&
options)
override;
184 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(
unsigned int testCaseId)
override;
185 virtual bool OnInferenceTestFinished()
override;
188 void ReadPredictions();
193 std::unique_ptr<InferenceModel> m_Model;
195 std::string m_DataDir;
196 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
197 std::unique_ptr<TDatabase> m_Database;
200 int m_NumCorrectInferences;
202 std::string m_ValidationFileIn;
203 std::vector<unsigned int> m_ValidationPredictions;
205 std::string m_ValidationFileOut;
206 std::vector<unsigned int> m_ValidationPredictionsOut;
215 const std::vector<unsigned int>& defaultTestCaseIds,
218 template<
typename TConstructTestCaseProv
ider>
221 const std::vector<unsigned int>& defaultTestCaseIds,
222 TConstructTestCaseProvider constructTestCaseProvider);
224 template<
typename TDatabase,
226 typename TConstructDatabaseCallable>
228 const char* inputBindingName,
const char* outputBindingName,
229 const std::vector<unsigned int>& defaultTestCaseIds,
230 TConstructDatabaseCallable constructDatabase,
InferenceModelTestCase(TModel &model, unsigned int testCaseId, const std::vector< TContainer > &inputs, const std::vector< unsigned int > &outputSizes)
unsigned int m_IterationCount
virtual ~IInferenceTestCaseProvider()
std::string m_InferenceTimesFile
virtual void AddCommandLineOptions(boost::program_options::options_description &options)
virtual bool OnInferenceTestFinished()
bool InferenceTest(const InferenceTestOptions ¶ms, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
std::pair< float, int32_t > QuantizationParams
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
Exception(const std::string &message)
boost::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer
std::string m_DynamicBackendsPath
The test completed without any errors.
Base class for all ArmNN exceptions so that users can filter to just those.
std::istream & operator>>(std::istream &in, armnn::Compute &compute)
const std::vector< TContainer > & GetOutputs() const
unsigned int GetTestCaseId() const
The test failed with a fatal error. The remaining tests will not be run.
constexpr armnn::Compute ParseComputeDevice(const char *str)
virtual ~IInferenceTestCase()
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
armnn::Runtime::CreationOptions::ExternalProfilingOptions options
virtual void Run() override
bool ValidateDirectory(std::string &dir)