// // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "InferenceModel.hpp" #include #include #include #include #include #include #include namespace armnn { inline std::istream& operator>>(std::istream& in, armnn::Compute& compute) { std::string token; in >> token; compute = armnn::ParseComputeDevice(token.c_str()); if (compute == armnn::Compute::Undefined) { in.setstate(std::ios_base::failbit); throw cxxopts::exceptions::exception(fmt::format("Unrecognised compute device: {}", token)); } return in; } inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend) { std::string token; in >> token; armnn::Compute compute = armnn::ParseComputeDevice(token.c_str()); if (compute == armnn::Compute::Undefined) { in.setstate(std::ios_base::failbit); throw cxxopts::exceptions::exception(fmt::format("Unrecognised compute device: {}", token)); } backend = compute; return in; } namespace test { class TestFrameworkException : public Exception { public: using Exception::Exception; }; struct InferenceTestOptions { unsigned int m_IterationCount; std::string m_InferenceTimesFile; bool m_EnableProfiling; std::string m_DynamicBackendsPath; InferenceTestOptions() : m_IterationCount(0) , m_EnableProfiling(0) , m_DynamicBackendsPath() {} }; enum class TestCaseResult { /// The test completed without any errors. Ok, /// The test failed (e.g. the prediction didn't match the validation file). /// This will eventually fail the whole program but the remaining test cases will still be run. Failed, /// The test failed with a fatal error. The remaining tests will not be run. Abort }; class IInferenceTestCase { public: virtual ~IInferenceTestCase() {} virtual void Run() = 0; virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0; }; class IInferenceTestCaseProvider { public: virtual ~IInferenceTestCaseProvider() {} virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector& required) { IgnoreUnused(options, required); }; virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) { IgnoreUnused(commonOptions); return true; }; virtual std::unique_ptr GetTestCase(unsigned int testCaseId) = 0; virtual bool OnInferenceTestFinished() { return true; }; }; template class InferenceModelTestCase : public IInferenceTestCase { public: InferenceModelTestCase(TModel& model, unsigned int testCaseId, const std::vector& inputs, const std::vector& outputSizes) : m_Model(model) , m_TestCaseId(testCaseId) , m_Inputs(std::move(inputs)) { // Initialize output vector const size_t numOutputs = outputSizes.size(); m_Outputs.reserve(numOutputs); for (size_t i = 0; i < numOutputs; i++) { m_Outputs.push_back(std::vector(outputSizes[i])); } } virtual void Run() override { m_Model.Run(m_Inputs, m_Outputs); } protected: unsigned int GetTestCaseId() const { return m_TestCaseId; } const std::vector& GetOutputs() const { return m_Outputs; } private: TModel& m_Model; unsigned int m_TestCaseId; std::vector m_Inputs; std::vector m_Outputs; }; template class ClassifierTestCase : public InferenceModelTestCase { public: ClassifierTestCase(int& numInferencesRef, int& numCorrectInferencesRef, const std::vector& validationPredictions, std::vector* validationPredictionsOut, TModel& model, unsigned int testCaseId, unsigned int label, std::vector modelInput); virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override; private: unsigned int m_Label; InferenceModelInternal::QuantizationParams m_QuantizationParams; /// These fields reference the corresponding member in the ClassifierTestCaseProvider. /// @{ int& m_NumInferencesRef; int& m_NumCorrectInferencesRef; const std::vector& m_ValidationPredictions; std::vector* m_ValidationPredictionsOut; /// @} }; template class ClassifierTestCaseProvider : public IInferenceTestCaseProvider { public: template ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel); virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector& required) override; virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override; virtual std::unique_ptr GetTestCase(unsigned int testCaseId) override; virtual bool OnInferenceTestFinished() override; private: void ReadPredictions(); typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions; std::function(const InferenceTestOptions& commonOptions, typename InferenceModel::CommandLineOptions)> m_ConstructModel; std::unique_ptr m_Model; std::string m_DataDir; std::function m_ConstructDatabase; std::unique_ptr m_Database; int m_NumInferences; // Referenced by test cases. int m_NumCorrectInferences; // Referenced by test cases. std::string m_ValidationFileIn; std::vector m_ValidationPredictions; // Referenced by test cases. std::string m_ValidationFileOut; std::vector m_ValidationPredictionsOut; // Referenced by test cases. }; bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider, InferenceTestOptions& outParams); bool ValidateDirectory(std::string& dir); bool InferenceTest(const InferenceTestOptions& params, const std::vector& defaultTestCaseIds, IInferenceTestCaseProvider& testCaseProvider); template int InferenceTestMain(int argc, char* argv[], const std::vector& defaultTestCaseIds, TConstructTestCaseProvider constructTestCaseProvider); template int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary, const char* inputBindingName, const char* outputBindingName, const std::vector& defaultTestCaseIds, TConstructDatabaseCallable constructDatabase, const armnn::TensorShape* inputTensorShape = nullptr); } // namespace test } // namespace armnn #include "InferenceTest.inl"