diff options
author | Finn Williams <Finn.Williams@arm.com> | 2020-11-20 13:57:53 +0000 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2020-11-20 17:41:33 +0000 |
commit | 56870183198842be1706562d8386f4e5f534e9b6 (patch) | |
tree | ce50c3c0398d4804c9a505edfa062d7034fe395d /tests/ExecuteNetwork | |
parent | 55518ca7faaf6c2b0cd567afe9fb39d529a10150 (diff) | |
download | armnn-56870183198842be1706562d8386f4e5f534e9b6.tar.gz |
IVGCVSW-5559 Add int8_t to tflite delegate on ExecuteNetwork
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: I56afc73d48848bc40842692831c05316484757a4
Diffstat (limited to 'tests/ExecuteNetwork')
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 98 |
1 files changed, 55 insertions, 43 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index ba7ce29cd7..be341b670a 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -88,57 +88,50 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, if (params.m_InputTypes[inputIndex].compare("float") == 0) { auto inputData = tfLiteInterpreter->typed_tensor<float>(input); - TContainer tensorData; - PopulateTensorWithData(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), - params.m_InputTypes[inputIndex], - armnn::EmptyOptional(), - dataFile); - - mapbox::util::apply_visitor([&](auto&& value) - { - for (unsigned int i = 0; i < inputSize; ++i) - { - inputData[i] = value.data()[i]; - } - }, - tensorData); + std::vector<float> tensorData; + PopulateTensorWithDataGeneric<float>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return std::stof(s); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); + } + else if (params.m_InputTypes[inputIndex].compare("int8") == 0) + { + auto inputData = tfLiteInterpreter->typed_tensor<int8_t>(input); + std::vector<int8_t> tensorData; + PopulateTensorWithDataGeneric<int8_t>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); } else if (params.m_InputTypes[inputIndex].compare("int") == 0) { auto inputData = tfLiteInterpreter->typed_tensor<int32_t>(input); - TContainer tensorData; - PopulateTensorWithData(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), - params.m_InputTypes[inputIndex], - armnn::EmptyOptional(), - dataFile); - mapbox::util::apply_visitor([&](auto&& value) - { - for (unsigned int i = 0; i < inputSize; ++i) - { - inputData[i] = value.data()[i]; - } - }, - tensorData); + std::vector<int32_t> tensorData; + PopulateTensorWithDataGeneric<int32_t>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return std::stoi(s); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); } else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0) { auto inputData = tfLiteInterpreter->typed_tensor<uint8_t>(input); - TContainer tensorData; - PopulateTensorWithData(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), - params.m_InputTypes[inputIndex], - armnn::EmptyOptional(), - dataFile); - mapbox::util::apply_visitor([&](auto&& value) - { - for (unsigned int i = 0; i < inputSize; ++i) - { - inputData[i] = value.data()[i]; - } - }, - tensorData); + std::vector<uint8_t> tensorData; + PopulateTensorWithDataGeneric<uint8_t>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return armnn::numeric_cast<uint8_t>(std::stoi(s)); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); } else { @@ -203,6 +196,25 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, } } } + else if (params.m_OutputTypes[outputIndex].compare("int8") == 0) + { + auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<int8_t>(tfLiteDelegateOutputId); + if(tfLiteDelageOutputData == NULL) + { + ARMNN_LOG(fatal) << "Output tensor is null, output type: " + "\"" << params.m_OutputTypes[outputIndex] << "\" may be incorrect."; + return EXIT_FAILURE; + } + + for (int i = 0; i < outputSize; ++i) + { + std::cout << signed(tfLiteDelageOutputData[i]) << ", "; + if (i % 60 == 0) + { + std::cout << std::endl; + } + } + } else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0) { auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<uint8_t>(tfLiteDelegateOutputId); |