aboutsummaryrefslogtreecommitdiff
path: root/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp')
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp52
1 files changed, 51 insertions, 1 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
index 9d9e616e98..742f968a7a 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
@@ -7,10 +7,13 @@
#include <armnn/IRuntime.hpp>
#include <armnn/Types.hpp>
+#include <armnn/Logging.hpp>
+#include <armnn/utility/StringUtils.hpp>
#include <mapbox/variant.hpp>
#include <iostream>
+#include <fstream>
std::vector<unsigned int> ParseArray(std::istream& stream);
@@ -68,4 +71,51 @@ bool ValidatePath(const std::string& file, const bool expectFile);
* @param expectFile bool - If true, checks for a regular file.
* @return bool - True if all given strings are valid paths., false otherwise.
* */
-bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile); \ No newline at end of file
+bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
+
+template<typename T, typename TParseElementFunc>
+std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
+{
+ std::vector<T> result;
+ // Processes line-by-line.
+ std::string line;
+ while (std::getline(stream, line))
+ {
+ std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
+ for (const std::string& token : tokens)
+ {
+ if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
+ {
+ try
+ {
+ result.push_back(parseElementFunc(token));
+ }
+ catch (const std::exception&)
+ {
+ ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
+ }
+ }
+ }
+ }
+
+ return result;
+}
+
+template <typename T, typename TParseElementFunc>
+void PopulateTensorWithDataGeneric(std::vector<T>& tensorData,
+ unsigned int numElements,
+ const armnn::Optional<std::string>& dataFile,
+ TParseElementFunc parseFunction)
+{
+ const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
+
+ std::ifstream inputTensorFile;
+ if (readFromFile)
+ {
+ inputTensorFile = std::ifstream(dataFile.value());
+ }
+
+ tensorData = readFromFile ?
+ ParseArrayImpl<T>(inputTensorFile, parseFunction) :
+ std::vector<T>(numElements, static_cast<T>(0));
+}