aboutsummaryrefslogtreecommitdiff
path: root/tests/ExecuteNetwork/ExecuteNetwork.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ExecuteNetwork/ExecuteNetwork.cpp')
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp98
1 files changed, 55 insertions, 43 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index ba7ce29cd7..be341b670a 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -88,57 +88,50 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
if (params.m_InputTypes[inputIndex].compare("float") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<float>(input);
- TContainer tensorData;
- PopulateTensorWithData(tensorData,
- params.m_InputTensorShapes[inputIndex]->GetNumElements(),
- params.m_InputTypes[inputIndex],
- armnn::EmptyOptional(),
- dataFile);
-
- mapbox::util::apply_visitor([&](auto&& value)
- {
- for (unsigned int i = 0; i < inputSize; ++i)
- {
- inputData[i] = value.data()[i];
- }
- },
- tensorData);
+ std::vector<float> tensorData;
+ PopulateTensorWithDataGeneric<float>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return std::stof(s); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
+ }
+ else if (params.m_InputTypes[inputIndex].compare("int8") == 0)
+ {
+ auto inputData = tfLiteInterpreter->typed_tensor<int8_t>(input);
+ std::vector<int8_t> tensorData;
+ PopulateTensorWithDataGeneric<int8_t>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else if (params.m_InputTypes[inputIndex].compare("int") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<int32_t>(input);
- TContainer tensorData;
- PopulateTensorWithData(tensorData,
- params.m_InputTensorShapes[inputIndex]->GetNumElements(),
- params.m_InputTypes[inputIndex],
- armnn::EmptyOptional(),
- dataFile);
- mapbox::util::apply_visitor([&](auto&& value)
- {
- for (unsigned int i = 0; i < inputSize; ++i)
- {
- inputData[i] = value.data()[i];
- }
- },
- tensorData);
+ std::vector<int32_t> tensorData;
+ PopulateTensorWithDataGeneric<int32_t>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return std::stoi(s); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<uint8_t>(input);
- TContainer tensorData;
- PopulateTensorWithData(tensorData,
- params.m_InputTensorShapes[inputIndex]->GetNumElements(),
- params.m_InputTypes[inputIndex],
- armnn::EmptyOptional(),
- dataFile);
- mapbox::util::apply_visitor([&](auto&& value)
- {
- for (unsigned int i = 0; i < inputSize; ++i)
- {
- inputData[i] = value.data()[i];
- }
- },
- tensorData);
+ std::vector<uint8_t> tensorData;
+ PopulateTensorWithDataGeneric<uint8_t>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return armnn::numeric_cast<uint8_t>(std::stoi(s)); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else
{
@@ -203,6 +196,25 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
}
}
}
+ else if (params.m_OutputTypes[outputIndex].compare("int8") == 0)
+ {
+ auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<int8_t>(tfLiteDelegateOutputId);
+ if(tfLiteDelageOutputData == NULL)
+ {
+ ARMNN_LOG(fatal) << "Output tensor is null, output type: "
+ "\"" << params.m_OutputTypes[outputIndex] << "\" may be incorrect.";
+ return EXIT_FAILURE;
+ }
+
+ for (int i = 0; i < outputSize; ++i)
+ {
+ std::cout << signed(tfLiteDelageOutputData[i]) << ", ";
+ if (i % 60 == 0)
+ {
+ std::cout << std::endl;
+ }
+ }
+ }
else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0)
{
auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<uint8_t>(tfLiteDelegateOutputId);