aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceTest.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /tests/InferenceTest.hpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'tests/InferenceTest.hpp')
-rw-r--r--tests/InferenceTest.hpp44
1 files changed, 38 insertions, 6 deletions
diff --git a/tests/InferenceTest.hpp b/tests/InferenceTest.hpp
index 5f53c06a88..181afe4d8f 100644
--- a/tests/InferenceTest.hpp
+++ b/tests/InferenceTest.hpp
@@ -6,11 +6,14 @@
#include "armnn/ArmNN.hpp"
#include "armnn/TypesUtils.hpp"
+#include "InferenceModel.hpp"
+
#include <Logging.hpp>
#include <boost/log/core/core.hpp>
#include <boost/program_options.hpp>
+
namespace armnn
{
@@ -40,9 +43,11 @@ struct InferenceTestOptions
{
unsigned int m_IterationCount;
std::string m_InferenceTimesFile;
+ bool m_EnableProfiling;
InferenceTestOptions()
- : m_IterationCount(0)
+ : m_IterationCount(0),
+ m_EnableProfiling(0)
{}
};
@@ -108,6 +113,31 @@ private:
std::vector<typename TModel::DataType> m_Output;
};
+template <typename TDataType>
+struct ToFloat { }; // nothing defined for the generic case
+
+template <>
+struct ToFloat<float>
+{
+ static inline float Convert(float value, const InferenceModelInternal::QuantizationParams &)
+ {
+ // assuming that float models are not quantized
+ return value;
+ }
+};
+
+template <>
+struct ToFloat<uint8_t>
+{
+ static inline float Convert(uint8_t value,
+ const InferenceModelInternal::QuantizationParams & quantizationParams)
+ {
+ return armnn::Dequantize<uint8_t>(value,
+ quantizationParams.first,
+ quantizationParams.second);
+ }
+};
+
template <typename TTestCaseDatabase, typename TModel>
class ClassifierTestCase : public InferenceModelTestCase<TModel>
{
@@ -125,6 +155,8 @@ public:
private:
unsigned int m_Label;
+ InferenceModelInternal::QuantizationParams m_QuantizationParams;
+
/// These fields reference the corresponding member in the ClassifierTestCaseProvider.
/// @{
int& m_NumInferencesRef;
@@ -154,17 +186,17 @@ private:
std::unique_ptr<InferenceModel> m_Model;
std::string m_DataDir;
- std::function<TDatabase(const char*)> m_ConstructDatabase;
+ std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
std::unique_ptr<TDatabase> m_Database;
- int m_NumInferences; // Referenced by test cases
- int m_NumCorrectInferences; // Referenced by test cases
+ int m_NumInferences; // Referenced by test cases.
+ int m_NumCorrectInferences; // Referenced by test cases.
std::string m_ValidationFileIn;
- std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases
+ std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
std::string m_ValidationFileOut;
- std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases
+ std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
};
bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,