8 #include <boost/numeric/conversion/cast.hpp> 9 #include <boost/format.hpp> 10 #include <boost/program_options.hpp> 27 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
29 template <
typename TTestCaseDatabase,
typename TModel>
31 int& numInferencesRef,
32 int& numCorrectInferencesRef,
33 const std::vector<unsigned int>& validationPredictions,
34 std::vector<unsigned int>* validationPredictionsOut,
36 unsigned int testCaseId,
38 std::vector<typename TModel::DataType> modelInput)
40 model, testCaseId,
std::vector<
TContainer>{ modelInput }, { model.GetOutputSize() })
42 , m_QuantizationParams(model.GetQuantizationParams())
43 , m_NumInferencesRef(numInferencesRef)
44 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
45 , m_ValidationPredictions(validationPredictions)
46 , m_ValidationPredictionsOut(validationPredictionsOut)
50 struct ClassifierResultProcessor :
public boost::static_visitor<>
52 using ResultMap = std::map<float,int>;
54 ClassifierResultProcessor(
float scale,
int offset)
59 void operator()(
const std::vector<float>& values)
61 SortPredictions(values, [](
float value)
67 void operator()(
const std::vector<uint8_t>& values)
69 auto& scale = m_Scale;
70 auto& offset = m_Offset;
71 SortPredictions(values, [&scale, &offset](uint8_t value)
77 void operator()(
const std::vector<int>& values)
83 ResultMap& GetResultMap() {
return m_ResultMap; }
86 template<
typename Container,
typename Delegate>
87 void SortPredictions(
const Container& c, Delegate delegate)
90 for (
const auto& value : c)
92 int classification = index++;
96 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
98 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
101 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
106 ResultMap m_ResultMap;
112 template <
typename TTestCaseDatabase,
typename TModel>
118 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
119 boost::apply_visitor(resultProcessor, output);
121 ARMNN_LOG(
info) <<
"= Prediction values for test #" << testCaseId;
122 auto it = resultProcessor.GetResultMap().rbegin();
123 for (
int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
125 ARMNN_LOG(
info) <<
"Top(" << (i+1) <<
") prediction is " << it->second <<
126 " with value: " << (it->first);
130 unsigned int prediction = 0;
131 boost::apply_visitor([&](
auto&& value)
134 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
141 ARMNN_LOG(
error) <<
"Prediction for test case " << testCaseId <<
" (" << prediction <<
")" <<
142 " is incorrect (should be " << m_Label <<
")";
147 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
149 ARMNN_LOG(
error) <<
"Prediction for test case " << testCaseId <<
" (" << prediction <<
")" <<
150 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] <<
")";
155 if (m_ValidationPredictionsOut)
157 m_ValidationPredictionsOut->push_back(prediction);
161 m_NumInferencesRef++;
162 if (prediction == m_Label)
164 m_NumCorrectInferencesRef++;
170 template <
typename TDatabase,
typename InferenceModel>
171 template <
typename TConstructDatabaseCallable,
typename TConstructModelCallable>
173 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
174 : m_ConstructModel(constructModel)
175 , m_ConstructDatabase(constructDatabase)
177 , m_NumCorrectInferences(0)
181 template <
typename TDatabase,
typename InferenceModel>
183 boost::program_options::options_description& options)
185 namespace po = boost::program_options;
187 options.add_options()
188 (
"validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(
""),
189 "Reads expected predictions from the given file and confirms they match the actual predictions.")
190 (
"validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(
""),
191 "Predictions are saved to the given file for later use via --validation-file-in.")
192 (
"data-dir,d", po::value<std::string>(&m_DataDir)->required(),
193 "Path to directory containing test data");
198 template <
typename TDatabase,
typename InferenceModel>
209 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
215 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
224 template <
typename TDatabase,
typename InferenceModel>
225 std::unique_ptr<IInferenceTestCase>
228 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
229 if (testCaseData ==
nullptr)
234 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
236 m_NumCorrectInferences,
237 m_ValidationPredictions,
238 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
241 testCaseData->m_Label,
242 std::move(testCaseData->m_InputImage));
245 template <
typename TDatabase,
typename InferenceModel>
249 boost::numeric_cast<double>(m_NumInferences);
250 ARMNN_LOG(
info) << std::fixed << std::setprecision(3) <<
"Overall accuracy: " << accuracy;
253 if (!m_ValidationFileOut.empty())
255 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
256 if (validationFileOut.good())
258 for (
const unsigned int prediction : m_ValidationPredictionsOut)
260 validationFileOut << prediction << std::endl;
265 ARMNN_LOG(
error) <<
"Failed to open output validation file: " << m_ValidationFileOut;
273 template <
typename TDatabase,
typename InferenceModel>
277 if (!m_ValidationFileIn.empty())
279 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
280 if (validationFileIn.good())
282 while (!validationFileIn.eof())
285 validationFileIn >> i;
286 m_ValidationPredictions.emplace_back(i);
291 throw armnn::Exception(boost::str(boost::format(
"Failed to open input validation file: %1%")
292 % m_ValidationFileIn));
297 template<
typename TConstructTestCaseProv
ider>
300 const std::vector<unsigned int>& defaultTestCaseIds,
301 TConstructTestCaseProvider constructTestCaseProvider)
313 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
314 if (!testCaseProvider)
325 const bool success =
InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
326 return success ? 0 : 1;
346 template<
typename TDatabase,
348 typename TConstructDatabaseCallable>
351 const char* modelFilename,
353 const char* inputBindingName,
354 const char* outputBindingName,
355 const std::vector<unsigned int>& defaultTestCaseIds,
356 TConstructDatabaseCallable constructDatabase,
371 return make_unique<TestCaseProvider>(constructDatabase,
378 return std::unique_ptr<InferenceModel>();
382 modelParams.
m_ModelPath = modelOptions.m_ModelDir + modelFilename;
383 modelParams.m_InputBindings = { inputBindingName };
384 modelParams.m_OutputBindings = { outputBindingName };
386 if (inputTensorShape)
388 modelParams.m_InputShapes.push_back(*inputTensorShape);
391 modelParams.m_IsModelBinary = isModelBinary;
392 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
393 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
394 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
396 return std::make_unique<InferenceModel>(modelParams,
397 commonOptions.m_EnableProfiling,
398 commonOptions.m_DynamicBackendsPath);
bool ParseCommandLine(int argc, char **argv, IInferenceTestCaseProvider &testCaseProvider, InferenceTestOptions &outParams)
Parse the command line of an ArmNN (or referencetests) inference test program.
void ConfigureLogging(bool printToStandardOutput, bool printToDebugOutput, LogSeverity severity)
Configures the logging behaviour of the ARMNN library.
const std::vector< TContainer > & GetOutputs() const
virtual const char * what() const noexcept override
#define ARMNN_LOG(severity)
ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
Copyright (c) 2020 ARM Limited.
void IgnoreUnused(Ts &&...)
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
virtual bool OnInferenceTestFinished() override
#define ARMNN_ASSERT_MSG(COND, MSG)
virtual TestCaseResult ProcessResult(const InferenceTestOptions ¶ms) override
unsigned int GetTestCaseId() const
#define ARMNN_ASSERT(COND)
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
int ClassifierInferenceTestMain(int argc, char *argv[], const char *modelFilename, bool isModelBinary, const char *inputBindingName, const char *outputBindingName, const std::vector< unsigned int > &defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape *inputTensorShape=nullptr)
bool InferenceTest(const InferenceTestOptions ¶ms, const std::vector< unsigned int > &defaultTestCaseIds, IInferenceTestCaseProvider &testCaseProvider)
boost::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer
Base class for all ArmNN exceptions so that users can filter to just those.
virtual std::unique_ptr< IInferenceTestCase > GetTestCase(unsigned int testCaseId) override
virtual void AddCommandLineOptions(boost::program_options::options_description &options) override
static void AddCommandLineOptions(boost::program_options::options_description &desc, CommandLineOptions &options)
unsigned int m_IterationCount
bool ValidateDirectory(std::string &dir)
boost::variant< std::vector< float >, std::vector< int >, std::vector< unsigned char > > TContainer
int InferenceTestMain(int argc, char *argv[], const std::vector< unsigned int > &defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider)
The test completed without any errors.