diff options
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp')
-rw-r--r-- | tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp index 23b892ffb4..0906c1cf3f 100644 --- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp +++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp @@ -40,6 +40,13 @@ auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream) } template<> +auto ParseDataArray<armnn::DataType::QAsymmS8>(std::istream& stream) +{ + return ParseArrayImpl<int8_t>(stream, + [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); +} + +template<> auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream) { return ParseArrayImpl<uint8_t>(stream, @@ -54,7 +61,20 @@ auto ParseDataArray<armnn::DataType::QSymmS8>(std::istream& stream) [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); } - +template<> +auto ParseDataArray<armnn::DataType::QAsymmS8>(std::istream& stream, + const float& quantizationScale, + const int32_t& quantizationOffset) +{ + return ParseArrayImpl<int8_t>(stream, + [&quantizationScale, &quantizationOffset](const std::string& s) + { + return armnn::numeric_cast<int8_t>( + armnn::Quantize<int8_t>(std::stof(s), + quantizationScale, + quantizationOffset)); + }); +} template<> auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream, @@ -232,12 +252,18 @@ void PopulateTensorWithData(TContainer& tensorData, ParseDataArray<armnn::DataType::QSymmS8>(inputTensorFile) : GenerateDummyTensorData<armnn::DataType::QSymmS8>(numElements); } - else if (dataTypeStr.compare("qasymm8") == 0) + else if (dataTypeStr.compare("qasymm8") == 0 || dataTypeStr.compare("qasymmu8") == 0) { tensorData = readFromFile ? ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile) : GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements); } + else if (dataTypeStr.compare("qasymms8") == 0) + { + tensorData = readFromFile ? + ParseDataArray<armnn::DataType::QAsymmS8>(inputTensorFile) : + GenerateDummyTensorData<armnn::DataType::QAsymmS8>(numElements); + } else { std::string errorMessage = "Unsupported tensor data type " + dataTypeStr; |