14 #include <cxxopts/cxxopts.hpp> 15 #include <fmt/format.h> 32 template <
typename TTestCaseDatabase,
typename TModel>
34 int& numInferencesRef,
35 int& numCorrectInferencesRef,
36 const std::vector<unsigned int>& validationPredictions,
37 std::vector<unsigned int>* validationPredictionsOut,
39 unsigned int testCaseId,
41 std::vector<typename TModel::DataType> modelInput)
45 , m_QuantizationParams(model.GetQuantizationParams())
46 , m_NumInferencesRef(numInferencesRef)
47 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
48 , m_ValidationPredictions(validationPredictions)
49 , m_ValidationPredictionsOut(validationPredictionsOut)
53 struct ClassifierResultProcessor
55 using ResultMap = std::map<float,int>;
57 ClassifierResultProcessor(
float scale,
int offset)
62 void operator()(
const std::vector<float>& values)
64 SortPredictions(values, [](
float value)
70 void operator()(
const std::vector<int8_t>& values)
72 SortPredictions(values, [](int8_t value)
78 void operator()(
const std::vector<uint8_t>& values)
80 auto& scale = m_Scale;
81 auto& offset = m_Offset;
82 SortPredictions(values, [&scale, &offset](uint8_t value)
88 void operator()(
const std::vector<int>& values)
94 ResultMap& GetResultMap() {
return m_ResultMap; }
97 template<
typename Container,
typename Delegate>
98 void SortPredictions(
const Container& c, Delegate delegate)
101 for (
const auto& value : c)
103 int classification = index++;
107 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
109 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
112 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
117 ResultMap m_ResultMap;
123 template <
typename TTestCaseDatabase,
typename TModel>
129 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
130 mapbox::util::apply_visitor(resultProcessor, output);
132 ARMNN_LOG(
info) <<
"= Prediction values for test #" << testCaseId;
133 auto it = resultProcessor.GetResultMap().rbegin();
134 for (
int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
136 ARMNN_LOG(
info) <<
"Top(" << (i+1) <<
") prediction is " << it->second <<
137 " with value: " << (it->first);
141 unsigned int prediction = 0;
142 mapbox::util::apply_visitor([&](
auto&& value)
145 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
152 ARMNN_LOG(
error) <<
"Prediction for test case " << testCaseId <<
" (" << prediction <<
")" <<
153 " is incorrect (should be " << m_Label <<
")";
158 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
160 ARMNN_LOG(
error) <<
"Prediction for test case " << testCaseId <<
" (" << prediction <<
")" <<
161 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] <<
")";
166 if (m_ValidationPredictionsOut)
168 m_ValidationPredictionsOut->push_back(prediction);
172 m_NumInferencesRef++;
173 if (prediction == m_Label)
175 m_NumCorrectInferencesRef++;
181 template <
typename TDatabase,
typename InferenceModel>
182 template <
typename TConstructDatabaseCallable,
typename TConstructModelCallable>
184 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
185 : m_ConstructModel(constructModel)
186 , m_ConstructDatabase(constructDatabase)
188 , m_NumCorrectInferences(0)
192 template <
typename TDatabase,
typename InferenceModel>
194 cxxopts::Options& options, std::vector<std::string>& required)
197 .allow_unrecognised_options()
199 (
"validation-file-in",
200 "Reads expected predictions from the given file and confirms they match the actual predictions.",
201 cxxopts::value<std::string>(m_ValidationFileIn)->default_value(
""))
202 (
"validation-file-out",
"Predictions are saved to the given file for later use via --validation-file-in.",
203 cxxopts::value<std::string>(m_ValidationFileOut)->default_value(
""))
204 (
"d,data-dir",
"Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
206 required.emplace_back(
"data-dir");
211 template <
typename TDatabase,
typename InferenceModel>
222 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
228 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
237 template <
typename TDatabase,
typename InferenceModel>
238 std::unique_ptr<IInferenceTestCase>
241 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
242 if (testCaseData ==
nullptr)
247 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
249 m_NumCorrectInferences,
250 m_ValidationPredictions,
251 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
254 testCaseData->m_Label,
255 std::move(testCaseData->m_InputImage));
258 template <
typename TDatabase,
typename InferenceModel>
262 armnn::numeric_cast<double>(m_NumInferences);
263 ARMNN_LOG(
info) << std::fixed << std::setprecision(3) <<
"Overall accuracy: " << accuracy;
266 if (!m_ValidationFileOut.empty())
268 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
269 if (validationFileOut.good())
271 for (
const unsigned int prediction : m_ValidationPredictionsOut)
273 validationFileOut << prediction << std::endl;
278 ARMNN_LOG(
error) <<
"Failed to open output validation file: " << m_ValidationFileOut;
286 template <
typename TDatabase,
typename InferenceModel>
290 if (!m_ValidationFileIn.empty())
292 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
293 if (validationFileIn.good())
295 while (!validationFileIn.eof())
298 validationFileIn >> i;
299 m_ValidationPredictions.emplace_back(i);
304 throw armnn::Exception(fmt::format(
"Failed to open input validation file: {}" 305 , m_ValidationFileIn));
310 template<
typename TConstructTestCaseProv
ider>
313 const std::vector<unsigned int>& defaultTestCaseIds,
314 TConstructTestCaseProvider constructTestCaseProvider)
326 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
327 if (!testCaseProvider)
338 const bool success =
InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
339 return success ? 0 : 1;
359 template<
typename TDatabase,
361 typename TConstructDatabaseCallable>
364 const char* modelFilename,
366 const char* inputBindingName,
367 const char* outputBindingName,
368 const std::vector<unsigned int>& defaultTestCaseIds,
369 TConstructDatabaseCallable constructDatabase,
384 return make_unique<TestCaseProvider>(constructDatabase,
391 return std::unique_ptr<InferenceModel>();
395 modelParams.
m_ModelPath = modelOptions.m_ModelDir + modelFilename;
396 modelParams.m_InputBindings = { inputBindingName };
397 modelParams.m_OutputBindings = { outputBindingName };
399 if (inputTensorShape)
401 modelParams.m_InputShapes.push_back(*inputTensorShape);
404 modelParams.m_IsModelBinary = isModelBinary;
405 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
406 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
407 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
409 return std::make_unique<InferenceModel>(modelParams,
410 commonOptions.m_EnableProfiling,
411 commonOptions.m_DynamicBackendsPath);
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
static void AddCommandLineOptions(cxxopts::Options &options, CommandLineOptions &cLineOptions, std::vector< std::string > &required)
virtual const char * what() const noexcept override
#define ARMNN_LOG(severity)
ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
const std::vector< armnnUtils::TContainer > & GetOutputs() const
virtual bool OnInferenceTestFinished() override
#define ARMNN_ASSERT_MSG(COND, MSG)
virtual TestCaseResult ProcessResult(const InferenceTestOptions ¶ms) override
unsigned int GetTestCaseId() const
#define ARMNN_ASSERT(COND)
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
virtual void AddCommandLineOptions(cxxopts::Options &options, std::vector< std::string > &required) override
bool InferenceTest(const InferenceTestOptions ¶ms, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
Base class for all ArmNN exceptions so that users can filter to just those.
virtual std::unique_ptr< IInferenceTestCase > GetTestCase(unsigned int testCaseId) override
mapbox::util::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char >, std::vector< int8_t > > TContainer
unsigned int m_IterationCount
bool ValidateDirectory(std::string &dir)
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
The test completed without any errors.