diff options
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 98 | ||||
-rw-r--r-- | tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp | 30 | ||||
-rw-r--r-- | tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp | 52 |
3 files changed, 106 insertions, 74 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index ba7ce29cd7..be341b670a 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -88,57 +88,50 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, if (params.m_InputTypes[inputIndex].compare("float") == 0) { auto inputData = tfLiteInterpreter->typed_tensor<float>(input); - TContainer tensorData; - PopulateTensorWithData(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), - params.m_InputTypes[inputIndex], - armnn::EmptyOptional(), - dataFile); - - mapbox::util::apply_visitor([&](auto&& value) - { - for (unsigned int i = 0; i < inputSize; ++i) - { - inputData[i] = value.data()[i]; - } - }, - tensorData); + std::vector<float> tensorData; + PopulateTensorWithDataGeneric<float>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return std::stof(s); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); + } + else if (params.m_InputTypes[inputIndex].compare("int8") == 0) + { + auto inputData = tfLiteInterpreter->typed_tensor<int8_t>(input); + std::vector<int8_t> tensorData; + PopulateTensorWithDataGeneric<int8_t>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return armnn::numeric_cast<int8_t>(std::stoi(s)); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); } else if (params.m_InputTypes[inputIndex].compare("int") == 0) { auto inputData = tfLiteInterpreter->typed_tensor<int32_t>(input); - TContainer tensorData; - PopulateTensorWithData(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), - params.m_InputTypes[inputIndex], - armnn::EmptyOptional(), - dataFile); - mapbox::util::apply_visitor([&](auto&& value) - { - for (unsigned int i = 0; i < inputSize; ++i) - { - inputData[i] = value.data()[i]; - } - }, - tensorData); + std::vector<int32_t> tensorData; + PopulateTensorWithDataGeneric<int32_t>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return std::stoi(s); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); } else if (params.m_InputTypes[inputIndex].compare("qasymm8") == 0) { auto inputData = tfLiteInterpreter->typed_tensor<uint8_t>(input); - TContainer tensorData; - PopulateTensorWithData(tensorData, - params.m_InputTensorShapes[inputIndex]->GetNumElements(), - params.m_InputTypes[inputIndex], - armnn::EmptyOptional(), - dataFile); - mapbox::util::apply_visitor([&](auto&& value) - { - for (unsigned int i = 0; i < inputSize; ++i) - { - inputData[i] = value.data()[i]; - } - }, - tensorData); + std::vector<uint8_t> tensorData; + PopulateTensorWithDataGeneric<uint8_t>(tensorData, + params.m_InputTensorShapes[inputIndex]->GetNumElements(), + dataFile, + [](const std::string& s) + { return armnn::numeric_cast<uint8_t>(std::stoi(s)); }); + + std::copy(tensorData.begin(), tensorData.end(), inputData); } else { @@ -203,6 +196,25 @@ int TfLiteDelegateMainImpl(const ExecuteNetworkParams& params, } } } + else if (params.m_OutputTypes[outputIndex].compare("int8") == 0) + { + auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<int8_t>(tfLiteDelegateOutputId); + if(tfLiteDelageOutputData == NULL) + { + ARMNN_LOG(fatal) << "Output tensor is null, output type: " + "\"" << params.m_OutputTypes[outputIndex] << "\" may be incorrect."; + return EXIT_FAILURE; + } + + for (int i = 0; i < outputSize; ++i) + { + std::cout << signed(tfLiteDelageOutputData[i]) << ", "; + if (i % 60 == 0) + { + std::cout << std::endl; + } + } + } else if (params.m_OutputTypes[outputIndex].compare("qasymm8") == 0) { auto tfLiteDelageOutputData = tfLiteInterpreter->typed_tensor<uint8_t>(tfLiteDelegateOutputId); diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp index 3e7c87d653..2afd941636 100644 --- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp +++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp @@ -25,36 +25,6 @@ #include "armnnOnnxParser/IOnnxParser.hpp" #endif - -template<typename T, typename TParseElementFunc> -std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:") -{ - std::vector<T> result; - // Processes line-by-line. - std::string line; - while (std::getline(stream, line)) - { - std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars); - for (const std::string& token : tokens) - { - if (!token.empty()) // See https://stackoverflow.com/questions/10437406/ - { - try - { - result.push_back(parseElementFunc(token)); - } - catch (const std::exception&) - { - ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored."; - } - } - } - } - - return result; -} - - template<armnn::DataType NonQuantizedType> auto ParseDataArray(std::istream& stream); diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp index 9d9e616e98..742f968a7a 100644 --- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp +++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp @@ -7,10 +7,13 @@ #include <armnn/IRuntime.hpp> #include <armnn/Types.hpp> +#include <armnn/Logging.hpp> +#include <armnn/utility/StringUtils.hpp> #include <mapbox/variant.hpp> #include <iostream> +#include <fstream> std::vector<unsigned int> ParseArray(std::istream& stream); @@ -68,4 +71,51 @@ bool ValidatePath(const std::string& file, const bool expectFile); * @param expectFile bool - If true, checks for a regular file. * @return bool - True if all given strings are valid paths., false otherwise. * */ -bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
\ No newline at end of file +bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile); + +template<typename T, typename TParseElementFunc> +std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:") +{ + std::vector<T> result; + // Processes line-by-line. + std::string line; + while (std::getline(stream, line)) + { + std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars); + for (const std::string& token : tokens) + { + if (!token.empty()) // See https://stackoverflow.com/questions/10437406/ + { + try + { + result.push_back(parseElementFunc(token)); + } + catch (const std::exception&) + { + ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored."; + } + } + } + } + + return result; +} + +template <typename T, typename TParseElementFunc> +void PopulateTensorWithDataGeneric(std::vector<T>& tensorData, + unsigned int numElements, + const armnn::Optional<std::string>& dataFile, + TParseElementFunc parseFunction) +{ + const bool readFromFile = dataFile.has_value() && !dataFile.value().empty(); + + std::ifstream inputTensorFile; + if (readFromFile) + { + inputTensorFile = std::ifstream(dataFile.value()); + } + + tensorData = readFromFile ? + ParseArrayImpl<T>(inputTensorFile, parseFunction) : + std::vector<T>(numElements, static_cast<T>(0)); +} |