aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-11-20 13:57:53 +0000
committerFrancis Murtagh <francis.murtagh@arm.com>2020-11-20 17:42:30 +0000
commit4f55a25217f679205bd39587a26f2a2d1866cb67 (patch)
tree296a1769d3f83203b990ad8e6377afad6948cc32
parent66da7510362d00c6d5b6e8c1fe7f10145efe764b (diff)
downloadarmnn-4f55a25217f679205bd39587a26f2a2d1866cb67.tar.gz
IVGCVSW-5559 Add int8_t to tflite delegate on ExecuteNetwork
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: I56afc73d48848bc40842692831c05316484757a4
-rw-r--r--tests/ExecuteNetwork/ExecuteNetwork.cpp98
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp30
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp52
3 files changed, 106 insertions, 74 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp
index ba7ce29cd7..be341b670a 100644
--- a/tests/ExecuteNetwork/ExecuteNetwork.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp
@@ -88,57 +88,50 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
if (params.m_InputTypes[inputIndex].compare("float") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<float>(input);
- TContainer tensorData;
- PopulateTensorWithData(tensorData,
- params.m_InputTensorShapes[inputIndex]->GetNumElements(),
- params.m_InputTypes[inputIndex],
- armnn::EmptyOptional(),
- dataFile);
-
- mapbox::util::apply_visitor([&](auto&& value)
- {
- for (unsigned int i = 0; i < inputSize; ++i)
- {
- inputData[i] = value.data()[i];
- }
- },
- tensorData);
+ std::vector<float> tensorData;
+ PopulateTensorWithDataGeneric<float>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return std::stof(s); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
+ }
+ else if (params.m_InputTypes[inputIndex].compare("int8") == 0)
+ {
+ auto inputData = tfLiteInterpreter->typed_tensor<int8_t>(input);
+ std::vector<int8_t> tensorData;
+ PopulateTensorWithDataGeneric<int8_t>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else if (params.m_InputTypes[inputIndex].compare("int") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<int32_t>(input);
- TContainer tensorData;
- PopulateTensorWithData(tensorData,
- params.m_InputTensorShapes[inputIndex]->GetNumElements(),
- params.m_InputTypes[inputIndex],
- armnn::EmptyOptional(),
- dataFile);
- mapbox::util::apply_visitor([&](auto&& value)
- {
- for (unsigned int i = 0; i < inputSize; ++i)
- {
- inputData[i] = value.data()[i];
- }
- },
- tensorData);
+ std::vector<int32_t> tensorData;
+ PopulateTensorWithDataGeneric<int32_t>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return std::stoi(s); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0)
{
auto inputData = tfLiteInterpreter->typed_tensor<uint8_t>(input);
- TContainer tensorData;
- PopulateTensorWithData(tensorData,
- params.m_InputTensorShapes[inputIndex]->GetNumElements(),
- params.m_InputTypes[inputIndex],
- armnn::EmptyOptional(),
- dataFile);
- mapbox::util::apply_visitor([&](auto&& value)
- {
- for (unsigned int i = 0; i < inputSize; ++i)
- {
- inputData[i] = value.data()[i];
- }
- },
- tensorData);
+ std::vector<uint8_t> tensorData;
+ PopulateTensorWithDataGeneric<uint8_t>(tensorData,
+ params.m_InputTensorShapes[inputIndex]->GetNumElements(),
+ dataFile,
+ [](const std::string& s)
+ { return armnn::numeric_cast<uint8_t>(std::stoi(s)); });
+
+ std::copy(tensorData.begin(), tensorData.end(), inputData);
}
else
{
@@ -203,6 +196,25 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params,
}
}
}
+ else if (params.m_OutputTypes[outputIndex].compare("int8") == 0)
+ {
+ auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<int8_t>(tfLiteDelegateOutputId);
+ if(tfLiteDelageOutputData == NULL)
+ {
+ ARMNN_LOG(fatal) << "Output tensor is null, output type: "
+ "\"" << params.m_OutputTypes[outputIndex] << "\" may be incorrect.";
+ return EXIT_FAILURE;
+ }
+
+ for (int i = 0; i < outputSize; ++i)
+ {
+ std::cout << signed(tfLiteDelageOutputData[i]) << ", ";
+ if (i % 60 == 0)
+ {
+ std::cout << std::endl;
+ }
+ }
+ }
else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0)
{
auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<uint8_t>(tfLiteDelegateOutputId);
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
index 3e7c87d653..2afd941636 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
@@ -25,36 +25,6 @@
#include "armnnOnnxParser/IOnnxParser.hpp"
#endif
-
-template<typename T, typename TParseElementFunc>
-std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
-{
- std::vector<T> result;
- // Processes line-by-line.
- std::string line;
- while (std::getline(stream, line))
- {
- std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
- for (const std::string& token : tokens)
- {
- if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
- {
- try
- {
- result.push_back(parseElementFunc(token));
- }
- catch (const std::exception&)
- {
- ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
- }
- }
- }
- }
-
- return result;
-}
-
-
template<armnn::DataType NonQuantizedType>
auto ParseDataArray(std::istream& stream);
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
index 9d9e616e98..742f968a7a 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
@@ -7,10 +7,13 @@
#include <armnn/IRuntime.hpp>
#include <armnn/Types.hpp>
+#include <armnn/Logging.hpp>
+#include <armnn/utility/StringUtils.hpp>
#include <mapbox/variant.hpp>
#include <iostream>
+#include <fstream>
std::vector<unsigned int> ParseArray(std::istream& stream);
@@ -68,4 +71,51 @@ bool ValidatePath(const std::string& file, const bool expectFile);
* @param expectFile bool - If true, checks for a regular file.
* @return bool - True if all given strings are valid paths., false otherwise.
* */
-bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile); \ No newline at end of file
+bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
+
+template<typename T, typename TParseElementFunc>
+std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
+{
+ std::vector<T> result;
+ // Processes line-by-line.
+ std::string line;
+ while (std::getline(stream, line))
+ {
+ std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
+ for (const std::string& token : tokens)
+ {
+ if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
+ {
+ try
+ {
+ result.push_back(parseElementFunc(token));
+ }
+ catch (const std::exception&)
+ {
+ ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
+ }
+ }
+ }
+ }
+
+ return result;
+}
+
+template <typename T, typename TParseElementFunc>
+void PopulateTensorWithDataGeneric(std::vector<T>& tensorData,
+ unsigned int numElements,
+ const armnn::Optional<std::string>& dataFile,
+ TParseElementFunc parseFunction)
+{
+ const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
+
+ std::ifstream inputTensorFile;
+ if (readFromFile)
+ {
+ inputTensorFile = std::ifstream(dataFile.value());
+ }
+
+ tensorData = readFromFile ?
+ ParseArrayImpl<T>(inputTensorFile, parseFunction) :
+ std::vector<T>(numElements, static_cast<T>(0));
+}