diff options
Diffstat (limited to 'tests/ExecuteNetwork/ExecuteNetwork.cpp')
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 36 |
1 files changed, 31 insertions, 5 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index bce83583cc..a9b5a3c3f4 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -155,7 +155,8 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, std::copy(tensorData.begin(), tensorData.end(), inputData); } - else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0) + else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0 || + params.m_InputTypes[inputIndex].compare("qasymmu8") == 0) { auto inputData = tfLiteInterpreter->typed_tensor<uint8_t>(input); @@ -175,6 +176,26 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, std::copy(tensorData.begin(), tensorData.end(), inputData); } + else if (params.m_InputTypes[inputIndex].compare("qasymms8") == 0) + { + auto inputData = tfLiteInterpreter->typed_tensor<int8_t>(input); + + if(inputData == NULL) + { + ARMNN_LOG(fatal) << "Input tensor is null, input type: " + "\"" << params.m_InputTypes[inputIndex] << "\" may be incorrect."; + return EXIT_FAILURE; + } + + std::vector<int8_t> tensorData; + PopulateTensorWithDataGeneric<int8_t>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); + } else { ARMNN_LOG(fatal) << "Unsupported input tensor data type \"" << params.m_InputTypes[inputIndex] << "\". "; @@ -245,7 +266,8 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, printf("%d ", tfLiteDelageOutputData[i]); } } - else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0) + else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0 || + params.m_OutputTypes[outputIndex].compare("qasymmu8") == 0) { auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<uint8_t>(tfLiteDelegateOutputId); if(tfLiteDelageOutputData == NULL) @@ -374,13 +396,17 @@ int MainImpl(const ExecuteNetworkParams& params, if (params.m_OutputTypes[i].compare("float") == 0) { outputDataContainers.push_back(std::vector<float>(model.GetOutputSize(i))); - } else if (params.m_OutputTypes[i].compare("int") == 0) + } + else if (params.m_OutputTypes[i].compare("int") == 0) { outputDataContainers.push_back(std::vector<int>(model.GetOutputSize(i))); - } else if (params.m_OutputTypes[i].compare("qasymm8") == 0) + } + else if (params.m_OutputTypes[i].compare("qasymm8") == 0 || + params.m_OutputTypes[i].compare("qasymmu8") == 0) { outputDataContainers.push_back(std::vector<uint8_t>(model.GetOutputSize(i))); - } else if (params.m_OutputTypes[i].compare("qsymms8") == 0) + } + else if (params.m_OutputTypes[i].compare("qasymms8") == 0) { outputDataContainers.push_back(std::vector<int8_t>(model.GetOutputSize(i))); } else |