diff options
Diffstat (limited to 'tests/InferenceTest.hpp')
-rw-r--r-- | tests/InferenceTest.hpp | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/tests/InferenceTest.hpp b/tests/InferenceTest.hpp index 3c22df9a5e..91a65ea494 100644 --- a/tests/InferenceTest.hpp +++ b/tests/InferenceTest.hpp @@ -100,7 +100,7 @@ template <typename TModel> class InferenceModelTestCase : public IInferenceTestCase { public: - using TContainer = std::vector<typename TModel::DataType>; + using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>; InferenceModelTestCase(TModel& model, unsigned int testCaseId, @@ -112,11 +112,11 @@ public: { // Initialize output vector const size_t numOutputs = outputSizes.size(); - m_Outputs.resize(numOutputs); + m_Outputs.reserve(numOutputs); for (size_t i = 0; i < numOutputs; i++) { - m_Outputs[i].resize(outputSizes[i]); + m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i])); } } @@ -147,6 +147,12 @@ struct ToFloat<float> // assuming that float models are not quantized return value; } + + static inline float Convert(int value, const InferenceModelInternal::QuantizationParams &) + { + // assuming that float models are not quantized + return static_cast<float>(value); + } }; template <> @@ -159,6 +165,22 @@ struct ToFloat<uint8_t> quantizationParams.first, quantizationParams.second); } + + static inline float Convert(int value, + const InferenceModelInternal::QuantizationParams & quantizationParams) + { + return armnn::Dequantize<uint8_t>(static_cast<uint8_t>(value), + quantizationParams.first, + quantizationParams.second); + } + + static inline float Convert(float value, + const InferenceModelInternal::QuantizationParams & quantizationParams) + { + return armnn::Dequantize<uint8_t>(static_cast<uint8_t>(value), + quantizationParams.first, + quantizationParams.second); + } }; template <typename TTestCaseDatabase, typename TModel> |