diff options
author | Mike Kelly <mike.kelly@arm.com> | 2021-07-21 09:42:43 +0100 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2021-07-26 08:56:53 +0000 |
commit | d7ed6d4e53a877a25fcea754d76c8831451f18f1 (patch) | |
tree | 74edb3b7cdc991232bb8f8577ae2fd89dfc95b0a /tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp | |
parent | 4adf0de1f2380c215b7d6f643afe04ef4366df1e (diff) | |
download | armnn-experimental/CustomAllocator3.tar.gz |
GitHub #557 wrong result in int8 modelexperimental/CustomAllocator3
* Added support for qasymms8 (int8) to ImageTensorGenerator
* Added qasymmu8 as alias for qasymm8 in ImageTensorGenerator
* Added support for qasymms8 (int8) to ExecuteNetwork
* Added qasymmu8 as alias for qasymm8 in ExecuteNetwork
* Set tflite to be the default model format in ImageTensorGenerator as
it's the only supported model format.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Ieda7b78e668ea390e3565cd65a41fe0a9c8a5b83
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp')
-rw-r--r-- | tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp index 23b892ffb4..0906c1cf3f 100644 --- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp +++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp @@ -40,6 +40,13 @@ auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream) } template<> +auto ParseDataArray<armnn::DataType::QAsymmS8>(std::istream& stream) +{ + return ParseArrayImpl<int8_t>(stream, + [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); +} + +template<> auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream) { return ParseArrayImpl<uint8_t>(stream, @@ -54,7 +61,20 @@ auto ParseDataArray<armnn::DataType::QSymmS8>(std::istream& stream) [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); } - +template<> +auto ParseDataArray<armnn::DataType::QAsymmS8>(std::istream& stream, + const float& quantizationScale, + const int32_t& quantizationOffset) +{ + return ParseArrayImpl<int8_t>(stream, + [&quantizationScale, &quantizationOffset](const std::string& s) + { + return armnn::numeric_cast<int8_t>( + armnn::Quantize<int8_t>(std::stof(s), + quantizationScale, + quantizationOffset)); + }); +} template<> auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream, @@ -232,12 +252,18 @@ void PopulateTensorWithData(TContainer& tensorData, ParseDataArray<armnn::DataType::QSymmS8>(inputTensorFile) : GenerateDummyTensorData<armnn::DataType::QSymmS8>(numElements); } - else if (dataTypeStr.compare("qasymm8") == 0) + else if (dataTypeStr.compare("qasymm8") == 0 || dataTypeStr.compare("qasymmu8") == 0) { tensorData = readFromFile ? ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile) : GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements); } + else if (dataTypeStr.compare("qasymms8") == 0) + { + tensorData = readFromFile ? + ParseDataArray<armnn::DataType::QAsymmS8>(inputTensorFile) : + GenerateDummyTensorData<armnn::DataType::QAsymmS8>(numElements); + } else { std::string errorMessage = "Unsupported tensor data type " + dataTypeStr; |