11 #include <cxxopts/cxxopts.hpp> 12 #include <fmt/format.h> 29 using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
31 template <
typename TTestCaseDatabase,
typename TModel>
33 int& numInferencesRef,
34 int& numCorrectInferencesRef,
35 const std::vector<unsigned int>& validationPredictions,
36 std::vector<unsigned int>* validationPredictionsOut,
38 unsigned int testCaseId,
40 std::vector<typename TModel::DataType> modelInput)
42 model, testCaseId,
std::vector<
TContainer>{ modelInput }, { model.GetOutputSize() })
44 , m_QuantizationParams(model.GetQuantizationParams())
45 , m_NumInferencesRef(numInferencesRef)
46 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
47 , m_ValidationPredictions(validationPredictions)
48 , m_ValidationPredictionsOut(validationPredictionsOut)
52 struct ClassifierResultProcessor
54 using ResultMap = std::map<float,int>;
56 ClassifierResultProcessor(
float scale,
int offset)
61 void operator()(
const std::vector<float>& values)
63 SortPredictions(values, [](
float value)
69 void operator()(
const std::vector<uint8_t>& values)
71 auto& scale = m_Scale;
72 auto& offset = m_Offset;
73 SortPredictions(values, [&scale, &offset](uint8_t value)
79 void operator()(
const std::vector<int>& values)
85 ResultMap& GetResultMap() {
return m_ResultMap; }
88 template<
typename Container,
typename Delegate>
89 void SortPredictions(
const Container& c, Delegate delegate)
92 for (
const auto& value : c)
94 int classification = index++;
98 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
100 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
103 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
108 ResultMap m_ResultMap;
114 template <
typename TTestCaseDatabase,
typename TModel>
120 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
121 mapbox::util::apply_visitor(resultProcessor, output);
123 ARMNN_LOG(
info) <<
"= Prediction values for test #" << testCaseId;
124 auto it = resultProcessor.GetResultMap().rbegin();
125 for (
int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
127 ARMNN_LOG(
info) <<
"Top(" << (i+1) <<
") prediction is " << it->second <<
128 " with value: " << (it->first);
132 unsigned int prediction = 0;
133 mapbox::util::apply_visitor([&](
auto&& value)
136 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
143 ARMNN_LOG(
error) <<
"Prediction for test case " << testCaseId <<
" (" << prediction <<
")" <<
144 " is incorrect (should be " << m_Label <<
")";
149 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
151 ARMNN_LOG(
error) <<
"Prediction for test case " << testCaseId <<
" (" << prediction <<
")" <<
152 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] <<
")";
157 if (m_ValidationPredictionsOut)
159 m_ValidationPredictionsOut->push_back(prediction);
163 m_NumInferencesRef++;
164 if (prediction == m_Label)
166 m_NumCorrectInferencesRef++;
172 template <
typename TDatabase,
typename InferenceModel>
173 template <
typename TConstructDatabaseCallable,
typename TConstructModelCallable>
175 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
176 : m_ConstructModel(constructModel)
177 , m_ConstructDatabase(constructDatabase)
179 , m_NumCorrectInferences(0)
183 template <
typename TDatabase,
typename InferenceModel>
185 cxxopts::Options& options, std::vector<std::string>& required)
188 .allow_unrecognised_options()
190 (
"validation-file-in",
191 "Reads expected predictions from the given file and confirms they match the actual predictions.",
192 cxxopts::value<std::string>(m_ValidationFileIn)->default_value(
""))
193 (
"validation-file-out",
"Predictions are saved to the given file for later use via --validation-file-in.",
194 cxxopts::value<std::string>(m_ValidationFileOut)->default_value(
""))
195 (
"d,data-dir",
"Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
197 required.emplace_back(
"data-dir");
202 template <
typename TDatabase,
typename InferenceModel>
213 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
219 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
228 template <
typename TDatabase,
typename InferenceModel>
229 std::unique_ptr<IInferenceTestCase>
232 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
233 if (testCaseData ==
nullptr)
238 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
240 m_NumCorrectInferences,
241 m_ValidationPredictions,
242 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
245 testCaseData->m_Label,
246 std::move(testCaseData->m_InputImage));
249 template <
typename TDatabase,
typename InferenceModel>
253 armnn::numeric_cast<double>(m_NumInferences);
254 ARMNN_LOG(
info) << std::fixed << std::setprecision(3) <<
"Overall accuracy: " << accuracy;
257 if (!m_ValidationFileOut.empty())
259 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
260 if (validationFileOut.good())
262 for (
const unsigned int prediction : m_ValidationPredictionsOut)
264 validationFileOut << prediction << std::endl;
269 ARMNN_LOG(
error) <<
"Failed to open output validation file: " << m_ValidationFileOut;
277 template <
typename TDatabase,
typename InferenceModel>
281 if (!m_ValidationFileIn.empty())
283 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
284 if (validationFileIn.good())
286 while (!validationFileIn.eof())
289 validationFileIn >> i;
290 m_ValidationPredictions.emplace_back(i);
295 throw armnn::Exception(fmt::format(
"Failed to open input validation file: {}" 296 , m_ValidationFileIn));
301 template<
typename TConstructTestCaseProv
ider>
304 const std::vector<unsigned int>& defaultTestCaseIds,
305 TConstructTestCaseProvider constructTestCaseProvider)
317 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
318 if (!testCaseProvider)
329 const bool success =
InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
330 return success ? 0 : 1;
350 template<
typename TDatabase,
352 typename TConstructDatabaseCallable>
355 const char* modelFilename,
357 const char* inputBindingName,
358 const char* outputBindingName,
359 const std::vector<unsigned int>& defaultTestCaseIds,
360 TConstructDatabaseCallable constructDatabase,
375 return make_unique<TestCaseProvider>(constructDatabase,
382 return std::unique_ptr<InferenceModel>();
386 modelParams.
m_ModelPath = modelOptions.m_ModelDir + modelFilename;
387 modelParams.m_InputBindings = { inputBindingName };
388 modelParams.m_OutputBindings = { outputBindingName };
390 if (inputTensorShape)
392 modelParams.m_InputShapes.push_back(*inputTensorShape);
395 modelParams.m_IsModelBinary = isModelBinary;
396 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
397 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
398 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
400 return std::make_unique<InferenceModel>(modelParams,
401 commonOptions.m_EnableProfiling,
402 commonOptions.m_DynamicBackendsPath);
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
const std::vector< TContainer > & GetOutputs() const
static void AddCommandLineOptions(cxxopts::Options &options, CommandLineOptions &cLineOptions, std::vector< std::string > &required)
virtual const char * what() const noexcept override
#define ARMNN_LOG(severity)
ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
Copyright (c) 2021 ARM Limited and Contributors.
void IgnoreUnused(Ts &&...)
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
mapbox::util::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer
virtual bool OnInferenceTestFinished() override
#define ARMNN_ASSERT_MSG(COND, MSG)
virtual TestCaseResult ProcessResult(const InferenceTestOptions ¶ms) override
unsigned int GetTestCaseId() const
#define ARMNN_ASSERT(COND)
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
virtual void AddCommandLineOptions(cxxopts::Options &options, std::vector< std::string > &required) override
bool InferenceTest(const InferenceTestOptions ¶ms, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
Base class for all ArmNN exceptions so that users can filter to just those.
virtual std::unique_ptr< IInferenceTestCase > GetTestCase(unsigned int testCaseId) override
unsigned int m_IterationCount
bool ValidateDirectory(std::string &dir)
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
The test completed without any errors.
mapbox::util::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer