aboutsummaryrefslogtreecommitdiff
path: root/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2021-07-21 09:42:43 +0100
committermike.kelly <mike.kelly@arm.com>2021-07-26 08:56:53 +0000
commitd7ed6d4e53a877a25fcea754d76c8831451f18f1 (patch)
tree74edb3b7cdc991232bb8f8577ae2fd89dfc95b0a /tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
parent4adf0de1f2380c215b7d6f643afe04ef4366df1e (diff)
downloadarmnn-d7ed6d4e53a877a25fcea754d76c8831451f18f1.tar.gz
GitHub #557 wrong result in int8 modelexperimental/CustomAllocator3
* Added support for qasymms8 (int8) to ImageTensorGenerator * Added qasymmu8 as alias for qasymm8 in ImageTensorGenerator * Added support for qasymms8 (int8) to ExecuteNetwork * Added qasymmu8 as alias for qasymm8 in ExecuteNetwork * Set tflite to be the default model format in ImageTensorGenerator as it's the only supported model format. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: Ieda7b78e668ea390e3565cd65a41fe0a9c8a5b83
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp')
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp30
1 files changed, 28 insertions, 2 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
index 23b892ffb4..0906c1cf3f 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
@@ -40,6 +40,13 @@ auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream)
}
template<>
+auto ParseDataArray<armnn::DataType::QAsymmS8>(std::istream& stream)
+{
+ return ParseArrayImpl<int8_t>(stream,
+ [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
+}
+
+template<>
auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream)
{
return ParseArrayImpl<uint8_t>(stream,
@@ -54,7 +61,20 @@ auto ParseDataArray<armnn::DataType::QSymmS8>(std::istream& stream)
[](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
}
-
+template<>
+auto ParseDataArray<armnn::DataType::QAsymmS8>(std::istream& stream,
+ const float& quantizationScale,
+ const int32_t& quantizationOffset)
+{
+ return ParseArrayImpl<int8_t>(stream,
+ [&quantizationScale, &quantizationOffset](const std::string& s)
+ {
+ return armnn::numeric_cast<int8_t>(
+ armnn::Quantize<int8_t>(std::stof(s),
+ quantizationScale,
+ quantizationOffset));
+ });
+}
template<>
auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
@@ -232,12 +252,18 @@ void PopulateTensorWithData(TContainer& tensorData,
ParseDataArray<armnn::DataType::QSymmS8>(inputTensorFile) :
GenerateDummyTensorData<armnn::DataType::QSymmS8>(numElements);
}
- else if (dataTypeStr.compare("qasymm8") == 0)
+ else if (dataTypeStr.compare("qasymm8") == 0 || dataTypeStr.compare("qasymmu8") == 0)
{
tensorData = readFromFile ?
ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile) :
GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
}
+ else if (dataTypeStr.compare("qasymms8") == 0)
+ {
+ tensorData = readFromFile ?
+ ParseDataArray<armnn::DataType::QAsymmS8>(inputTensorFile) :
+ GenerateDummyTensorData<armnn::DataType::QAsymmS8>(numElements);
+ }
else
{
std::string errorMessage = "Unsupported tensor data type " + dataTypeStr;