aboutsummaryrefslogtreecommitdiff
path: root/tests/ExecuteNetwork/ExecuteNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ExecuteNetwork/ExecuteNetwork.cpp')
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp36
1 files changed, 31 insertions, 5 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index bce83583cc..a9b5a3c3f4 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -155,7 +155,8 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
std::copy(tensorData.begin(), tensorData.end(), inputData);
}
- else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0)
+ else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0 ||
+ params.m_InputTypes[inputIndex].compare("qasymmu8") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<uint8_t>(input);
@@ -175,6 +176,26 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
std::copy(tensorData.begin(), tensorData.end(), inputData);
}
+ else if (params.m_InputTypes[inputIndex].compare("qasymms8") == 0)
+ {
+ auto inputData = tfLiteInterpreter->typed_tensor<int8_t>(input);
+
+ if(inputData == NULL)
+ {
+ ARMNN_LOG(fatal) << "Input tensor is null, input type: "
+ "\"" << params.m_InputTypes[inputIndex] << "\" may be incorrect.";
+ return EXIT_FAILURE;
+ }
+
+ std::vector<int8_t> tensorData;
+ PopulateTensorWithDataGeneric<int8_t>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
+ }
else
{
ARMNN_LOG(fatal) << "Unsupported input tensor data type \"" << params.m_InputTypes[inputIndex] << "\". ";
@@ -245,7 +266,8 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
printf("%d ", tfLiteDelageOutputData[i]);
}
}
- else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0)
+ else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0 ||
+ params.m_OutputTypes[outputIndex].compare("qasymmu8") == 0)
{
auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<uint8_t>(tfLiteDelegateOutputId);
if(tfLiteDelageOutputData == NULL)
@@ -374,13 +396,17 @@ int MainImpl(const ExecuteNetworkParams& params,
if (params.m_OutputTypes[i].compare("float") == 0)
{
outputDataContainers.push_back(std::vector<float>(model.GetOutputSize(i)));
- } else if (params.m_OutputTypes[i].compare("int") == 0)
+ }
+ else if (params.m_OutputTypes[i].compare("int") == 0)
{
outputDataContainers.push_back(std::vector<int>(model.GetOutputSize(i)));
- } else if (params.m_OutputTypes[i].compare("qasymm8") == 0)
+ }
+ else if (params.m_OutputTypes[i].compare("qasymm8") == 0 ||
+ params.m_OutputTypes[i].compare("qasymmu8") == 0)
{
outputDataContainers.push_back(std::vector<uint8_t>(model.GetOutputSize(i)));
- } else if (params.m_OutputTypes[i].compare("qsymms8") == 0)
+ }
+ else if (params.m_OutputTypes[i].compare("qasymms8") == 0)
{
outputDataContainers.push_back(std::vector<int8_t>(model.GetOutputSize(i)));
} else