aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceTest.hpp
diff options
context:
space:
mode:
authorFerran Balaguer <ferran.balaguer@arm.com>2019-02-08 17:09:55 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-02-11 08:48:53 +0000
commitc602f29d57f34b6bf0805d379b2174667d8bf52f (patch)
tree50cdc475ec8732575c0cf2c56d4ced770215c4a2 /tests/InferenceTest.hpp
parent9c5d33a26ebc4be391ae4da9de584be2e453c78f (diff)
downloadarmnn-c602f29d57f34b6bf0805d379b2174667d8bf52f.tar.gz
IVGCVSW-2529 DeepSpeech v1 test
Change-Id: Ieb99ac1aa347cee4b28b831753855c4614220648
Diffstat (limited to 'tests/InferenceTest.hpp')
-rw-r--r--tests/InferenceTest.hpp28
1 files changed, 25 insertions, 3 deletions
diff --git a/tests/InferenceTest.hpp b/tests/InferenceTest.hpp
index 3c22df9a5e..91a65ea494 100644
--- a/tests/InferenceTest.hpp
+++ b/tests/InferenceTest.hpp
@@ -100,7 +100,7 @@ template <typename TModel>
class InferenceModelTestCase : public IInferenceTestCase
{
public:
- using TContainer = std::vector<typename TModel::DataType>;
+ using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
InferenceModelTestCase(TModel& model,
unsigned int testCaseId,
@@ -112,11 +112,11 @@ public:
{
// Initialize output vector
const size_t numOutputs = outputSizes.size();
- m_Outputs.resize(numOutputs);
+ m_Outputs.reserve(numOutputs);
for (size_t i = 0; i < numOutputs; i++)
{
- m_Outputs[i].resize(outputSizes[i]);
+ m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
}
}
@@ -147,6 +147,12 @@ struct ToFloat<float>
// assuming that float models are not quantized
return value;
}
+
+ static inline float Convert(int value, const InferenceModelInternal::QuantizationParams &)
+ {
+ // assuming that float models are not quantized
+ return static_cast<float>(value);
+ }
};
template <>
@@ -159,6 +165,22 @@ struct ToFloat<uint8_t>
quantizationParams.first,
quantizationParams.second);
}
+
+ static inline float Convert(int value,
+ const InferenceModelInternal::QuantizationParams & quantizationParams)
+ {
+ return armnn::Dequantize<uint8_t>(static_cast<uint8_t>(value),
+ quantizationParams.first,
+ quantizationParams.second);
+ }
+
+ static inline float Convert(float value,
+ const InferenceModelInternal::QuantizationParams & quantizationParams)
+ {
+ return armnn::Dequantize<uint8_t>(static_cast<uint8_t>(value),
+ quantizationParams.first,
+ quantizationParams.second);
+ }
};
template <typename TTestCaseDatabase, typename TModel>