aboutsummaryrefslogtreecommitdiff
path: root/tests/ExecuteNetwork
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ExecuteNetwork')
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp36
-rw-r--r--tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp6
2 files changed, 34 insertions, 8 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index bce83583cc..a9b5a3c3f4 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -155,7 +155,8 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
std::copy(tensorData.begin(), tensorData.end(), inputData);
}
- else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0)
+ else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0 ||
+ params.m_InputTypes[inputIndex].compare("qasymmu8") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<uint8_t>(input);
@@ -175,6 +176,26 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
std::copy(tensorData.begin(), tensorData.end(), inputData);
}
+ else if (params.m_InputTypes[inputIndex].compare("qasymms8") == 0)
+ {
+ auto inputData = tfLiteInterpreter->typed_tensor<int8_t>(input);
+
+ if(inputData == NULL)
+ {
+ ARMNN_LOG(fatal) << "Input tensor is null, input type: "
+ "\"" << params.m_InputTypes[inputIndex] << "\" may be incorrect.";
+ return EXIT_FAILURE;
+ }
+
+ std::vector<int8_t> tensorData;
+ PopulateTensorWithDataGeneric<int8_t>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
+ }
else
{
ARMNN_LOG(fatal) << "Unsupported input tensor data type \"" << params.m_InputTypes[inputIndex] << "\". ";
@@ -245,7 +266,8 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
printf("%d ", tfLiteDelageOutputData[i]);
}
}
- else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0)
+ else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0 ||
+ params.m_OutputTypes[outputIndex].compare("qasymmu8") == 0)
{
auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<uint8_t>(tfLiteDelegateOutputId);
if(tfLiteDelageOutputData == NULL)
@@ -374,13 +396,17 @@ int MainImpl(const ExecuteNetworkParams& params,
if (params.m_OutputTypes[i].compare("float") == 0)
{
outputDataContainers.push_back(std::vector<float>(model.GetOutputSize(i)));
- } else if (params.m_OutputTypes[i].compare("int") == 0)
+ }
+ else if (params.m_OutputTypes[i].compare("int") == 0)
{
outputDataContainers.push_back(std::vector<int>(model.GetOutputSize(i)));
- } else if (params.m_OutputTypes[i].compare("qasymm8") == 0)
+ }
+ else if (params.m_OutputTypes[i].compare("qasymm8") == 0 ||
+ params.m_OutputTypes[i].compare("qasymmu8") == 0)
{
outputDataContainers.push_back(std::vector<uint8_t>(model.GetOutputSize(i)));
- } else if (params.m_OutputTypes[i].compare("qsymms8") == 0)
+ }
+ else if (params.m_OutputTypes[i].compare("qasymms8") == 0)
{
outputDataContainers.push_back(std::vector<int8_t>(model.GetOutputSize(i)));
} else
diff --git a/tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp b/tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp
index 25ddecf3ba..b12547f51c 100644
--- a/tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetworkProgramOptions.cpp
@@ -232,7 +232,7 @@ ProgramOptions::ProgramOptions() : m_CxxOptions{"ExecuteNetwork",
cxxopts::value<bool>(m_ExNetParams.m_ParseUnsupported)->default_value("false")->implicit_value("true"))
("q,quantize-input",
- "If this option is enabled, all float inputs will be quantized to qasymm8. "
+ "If this option is enabled, all float inputs will be quantized as appropriate for the model's inputs. "
"If unset, default to not quantized. Accepted values (true or false)",
cxxopts::value<bool>(m_ExNetParams.m_QuantizeInput)->default_value("false")->implicit_value("true"))
@@ -264,13 +264,13 @@ ProgramOptions::ProgramOptions() : m_CxxOptions{"ExecuteNetwork",
("y,input-type",
"The type of the input tensors in the network separated by comma. "
"If unset, defaults to \"float\" for all defined inputs. "
- "Accepted values (float, int or qasymm8).",
+ "Accepted values (float, int, qasymms8 or qasymmu8).",
cxxopts::value<std::string>())
("z,output-type",
"The type of the output tensors in the network separated by comma. "
"If unset, defaults to \"float\" for all defined outputs. "
- "Accepted values (float, int or qasymm8).",
+ "Accepted values (float, int, qasymms8 or qasymmu8).",
cxxopts::value<std::string>())
("T,tflite-executor",