diff options
Diffstat (limited to 'tests/InferenceTest.hpp')
-rw-r--r-- | tests/InferenceTest.hpp | 44 |
1 files changed, 38 insertions, 6 deletions
diff --git a/tests/InferenceTest.hpp b/tests/InferenceTest.hpp index 5f53c06a88..181afe4d8f 100644 --- a/tests/InferenceTest.hpp +++ b/tests/InferenceTest.hpp @@ -6,11 +6,14 @@ #include "armnn/ArmNN.hpp" #include "armnn/TypesUtils.hpp" +#include "InferenceModel.hpp" + #include <Logging.hpp> #include <boost/log/core/core.hpp> #include <boost/program_options.hpp> + namespace armnn { @@ -40,9 +43,11 @@ struct InferenceTestOptions { unsigned int m_IterationCount; std::string m_InferenceTimesFile; + bool m_EnableProfiling; InferenceTestOptions() - : m_IterationCount(0) + : m_IterationCount(0), + m_EnableProfiling(0) {} }; @@ -108,6 +113,31 @@ private: std::vector<typename TModel::DataType> m_Output; }; +template <typename TDataType> +struct ToFloat { }; // nothing defined for the generic case + +template <> +struct ToFloat<float> +{ + static inline float Convert(float value, const InferenceModelInternal::QuantizationParams &) + { + // assuming that float models are not quantized + return value; + } +}; + +template <> +struct ToFloat<uint8_t> +{ + static inline float Convert(uint8_t value, + const InferenceModelInternal::QuantizationParams & quantizationParams) + { + return armnn::Dequantize<uint8_t>(value, + quantizationParams.first, + quantizationParams.second); + } +}; + template <typename TTestCaseDatabase, typename TModel> class ClassifierTestCase : public InferenceModelTestCase<TModel> { @@ -125,6 +155,8 @@ public: private: unsigned int m_Label; + InferenceModelInternal::QuantizationParams m_QuantizationParams; + /// These fields reference the corresponding member in the ClassifierTestCaseProvider. /// @{ int& m_NumInferencesRef; @@ -154,17 +186,17 @@ private: std::unique_ptr<InferenceModel> m_Model; std::string m_DataDir; - std::function<TDatabase(const char*)> m_ConstructDatabase; + std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase; std::unique_ptr<TDatabase> m_Database; - int m_NumInferences; // Referenced by test cases - int m_NumCorrectInferences; // Referenced by test cases + int m_NumInferences; // Referenced by test cases. + int m_NumCorrectInferences; // Referenced by test cases. std::string m_ValidationFileIn; - std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases + std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases. std::string m_ValidationFileOut; - std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases + std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases. }; bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider, |