diff options
author | Finn Williams <Finn.Williams@arm.com> | 2020-11-20 13:57:53 +0000 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2020-11-20 17:41:33 +0000 |
commit | 56870183198842be1706562d8386f4e5f534e9b6 (patch) | |
tree | ce50c3c0398d4804c9a505edfa062d7034fe395d /tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp | |
parent | 55518ca7faaf6c2b0cd567afe9fb39d529a10150 (diff) | |
download | armnn-56870183198842be1706562d8386f4e5f534e9b6.tar.gz |
IVGCVSW-5559 Add int8_t to tflite delegate on ExecuteNetwork
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: I56afc73d48848bc40842692831c05316484757a4
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp')
-rw-r--r-- | tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp | 52 |
1 files changed, 51 insertions, 1 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp index 9d9e616e98..742f968a7a 100644 --- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp +++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp @@ -7,10 +7,13 @@ #include <armnn/IRuntime.hpp> #include <armnn/Types.hpp> +#include <armnn/Logging.hpp> +#include <armnn/utility/StringUtils.hpp> #include <mapbox/variant.hpp> #include <iostream> +#include <fstream> std::vector<unsigned int> ParseArray(std::istream& stream); @@ -68,4 +71,51 @@ bool ValidatePath(const std::string& file, const bool expectFile); * @param expectFile bool - If true, checks for a regular file. * @return bool - True if all given strings are valid paths., false otherwise. * */ -bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
\ No newline at end of file +bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile); + +template<typename T, typename TParseElementFunc> +std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:") +{ + std::vector<T> result; + // Processes line-by-line. + std::string line; + while (std::getline(stream, line)) + { + std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars); + for (const std::string& token : tokens) + { + if (!token.empty()) // See https://stackoverflow.com/questions/10437406/ + { + try + { + result.push_back(parseElementFunc(token)); + } + catch (const std::exception&) + { + ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored."; + } + } + } + } + + return result; +} + +template <typename T, typename TParseElementFunc> +void PopulateTensorWithDataGeneric(std::vector<T>& tensorData, + unsigned int numElements, + const armnn::Optional<std::string>& dataFile, + TParseElementFunc parseFunction) +{ + const bool readFromFile = dataFile.has_value() && !dataFile.value().empty(); + + std::ifstream inputTensorFile; + if (readFromFile) + { + inputTensorFile = std::ifstream(dataFile.value()); + } + + tensorData = readFromFile ? + ParseArrayImpl<T>(inputTensorFile, parseFunction) : + std::vector<T>(numElements, static_cast<T>(0)); +} |