aboutsummaryrefslogtreecommitdiff
path: root/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp')
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp292
1 files changed, 292 insertions, 0 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
new file mode 100644
index 0000000000..3e7c87d653
--- /dev/null
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.cpp
@@ -0,0 +1,292 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "NetworkExecutionUtils.hpp"
+
+#include <Filesystem.hpp>
+#include <InferenceTest.hpp>
+#include <ResolveType.hpp>
+
+#if defined(ARMNN_SERIALIZER)
+#include "armnnDeserializer/IDeserializer.hpp"
+#endif
+#if defined(ARMNN_CAFFE_PARSER)
+#include "armnnCaffeParser/ICaffeParser.hpp"
+#endif
+#if defined(ARMNN_TF_PARSER)
+#include "armnnTfParser/ITfParser.hpp"
+#endif
+#if defined(ARMNN_TF_LITE_PARSER)
+#include "armnnTfLiteParser/ITfLiteParser.hpp"
+#endif
+#if defined(ARMNN_ONNX_PARSER)
+#include "armnnOnnxParser/IOnnxParser.hpp"
+#endif
+
+
+template<typename T, typename TParseElementFunc>
+std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
+{
+ std::vector<T> result;
+ // Processes line-by-line.
+ std::string line;
+ while (std::getline(stream, line))
+ {
+ std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
+ for (const std::string& token : tokens)
+ {
+ if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
+ {
+ try
+ {
+ result.push_back(parseElementFunc(token));
+ }
+ catch (const std::exception&)
+ {
+ ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
+ }
+ }
+ }
+ }
+
+ return result;
+}
+
+
+template<armnn::DataType NonQuantizedType>
+auto ParseDataArray(std::istream& stream);
+
+template<armnn::DataType QuantizedType>
+auto ParseDataArray(std::istream& stream,
+ const float& quantizationScale,
+ const int32_t& quantizationOffset);
+
+template<>
+auto ParseDataArray<armnn::DataType::Float32>(std::istream& stream)
+{
+ return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
+}
+
+template<>
+auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream)
+{
+ return ParseArrayImpl<int>(stream, [](const std::string& s) { return std::stoi(s); });
+}
+
+template<>
+auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream)
+{
+ return ParseArrayImpl<uint8_t>(stream,
+ [](const std::string& s) { return armnn::numeric_cast<uint8_t>(std::stoi(s)); });
+}
+
+template<>
+auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
+ const float& quantizationScale,
+ const int32_t& quantizationOffset)
+{
+ return ParseArrayImpl<uint8_t>(stream,
+ [&quantizationScale, &quantizationOffset](const std::string& s)
+ {
+ return armnn::numeric_cast<uint8_t>(
+ armnn::Quantize<uint8_t>(std::stof(s),
+ quantizationScale,
+ quantizationOffset));
+ });
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+std::vector<T> GenerateDummyTensorData(unsigned int numElements)
+{
+ return std::vector<T>(numElements, static_cast<T>(0));
+}
+
+
+std::vector<unsigned int> ParseArray(std::istream& stream)
+{
+ return ParseArrayImpl<unsigned int>(
+ stream,
+ [](const std::string& s) { return armnn::numeric_cast<unsigned int>(std::stoi(s)); });
+}
+
+std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter)
+{
+ std::stringstream stream(inputString);
+ return ParseArrayImpl<std::string>(stream, [](const std::string& s) {
+ return armnn::stringUtils::StringTrimCopy(s); }, delimiter);
+}
+
+
+TensorPrinter::TensorPrinter(const std::string& binding,
+ const armnn::TensorInfo& info,
+ const std::string& outputTensorFile,
+ bool dequantizeOutput)
+ : m_OutputBinding(binding)
+ , m_Scale(info.GetQuantizationScale())
+ , m_Offset(info.GetQuantizationOffset())
+ , m_OutputTensorFile(outputTensorFile)
+ , m_DequantizeOutput(dequantizeOutput) {}
+
+void TensorPrinter::operator()(const std::vector<float>& values)
+{
+ ForEachValue(values, [](float value)
+ {
+ printf("%f ", value);
+ });
+ WriteToFile(values);
+}
+
+void TensorPrinter::operator()(const std::vector<uint8_t>& values)
+{
+ if(m_DequantizeOutput)
+ {
+ auto& scale = m_Scale;
+ auto& offset = m_Offset;
+ std::vector<float> dequantizedValues;
+ ForEachValue(values, [&scale, &offset, &dequantizedValues](uint8_t value)
+ {
+ auto dequantizedValue = armnn::Dequantize(value, scale, offset);
+ printf("%f ", dequantizedValue);
+ dequantizedValues.push_back(dequantizedValue);
+ });
+ WriteToFile(dequantizedValues);
+ }
+ else
+ {
+ const std::vector<int> intValues(values.begin(), values.end());
+ operator()(intValues);
+ }
+}
+
+void TensorPrinter::operator()(const std::vector<int>& values)
+{
+ ForEachValue(values, [](int value)
+ {
+ printf("%d ", value);
+ });
+ WriteToFile(values);
+}
+
+template<typename Container, typename Delegate>
+void TensorPrinter::ForEachValue(const Container& c, Delegate delegate)
+{
+ std::cout << m_OutputBinding << ": ";
+ for (const auto& value : c)
+ {
+ delegate(value);
+ }
+ printf("\n");
+}
+
+template<typename T>
+void TensorPrinter::WriteToFile(const std::vector<T>& values)
+{
+ if (!m_OutputTensorFile.empty())
+ {
+ std::ofstream outputTensorFile;
+ outputTensorFile.open(m_OutputTensorFile, std::ofstream::out | std::ofstream::trunc);
+ if (outputTensorFile.is_open())
+ {
+ outputTensorFile << m_OutputBinding << ": ";
+ std::copy(values.begin(), values.end(), std::ostream_iterator<T>(outputTensorFile, " "));
+ }
+ else
+ {
+ ARMNN_LOG(info) << "Output Tensor File: " << m_OutputTensorFile << " could not be opened!";
+ }
+ outputTensorFile.close();
+ }
+}
+
+using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
+using QuantizationParams = std::pair<float, int32_t>;
+
+void PopulateTensorWithData(TContainer& tensorData,
+ unsigned int numElements,
+ const std::string& dataTypeStr,
+ const armnn::Optional<QuantizationParams>& qParams,
+ const armnn::Optional<std::string>& dataFile)
+{
+ const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
+ const bool quantizeData = qParams.has_value();
+
+ std::ifstream inputTensorFile;
+ if (readFromFile)
+ {
+ inputTensorFile = std::ifstream(dataFile.value());
+ }
+
+ if (dataTypeStr.compare("float") == 0)
+ {
+ if (quantizeData)
+ {
+ const float qScale = qParams.value().first;
+ const int qOffset = qParams.value().second;
+
+ tensorData = readFromFile ?
+ ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile, qScale, qOffset) :
+ GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
+ }
+ else
+ {
+ tensorData = readFromFile ?
+ ParseDataArray<armnn::DataType::Float32>(inputTensorFile) :
+ GenerateDummyTensorData<armnn::DataType::Float32>(numElements);
+ }
+ }
+ else if (dataTypeStr.compare("int") == 0)
+ {
+ tensorData = readFromFile ?
+ ParseDataArray<armnn::DataType::Signed32>(inputTensorFile) :
+ GenerateDummyTensorData<armnn::DataType::Signed32>(numElements);
+ }
+ else if (dataTypeStr.compare("qasymm8") == 0)
+ {
+ tensorData = readFromFile ?
+ ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile) :
+ GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
+ }
+ else
+ {
+ std::string errorMessage = "Unsupported tensor data type " + dataTypeStr;
+ ARMNN_LOG(fatal) << errorMessage;
+
+ inputTensorFile.close();
+ throw armnn::Exception(errorMessage);
+ }
+
+ inputTensorFile.close();
+}
+
+bool ValidatePath(const std::string& file, const bool expectFile)
+{
+ if (!fs::exists(file))
+ {
+ std::cerr << "Given file path '" << file << "' does not exist" << std::endl;
+ return false;
+ }
+ if (!fs::is_regular_file(file) && expectFile)
+ {
+ std::cerr << "Given file path '" << file << "' is not a regular file" << std::endl;
+ return false;
+ }
+ return true;
+}
+
+bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile)
+{
+ bool allPathsValid = true;
+ for (auto const& file : fileVec)
+ {
+ if(!ValidatePath(file, expectFile))
+ {
+ allPathsValid = false;
+ }
+ }
+ return allPathsValid;
+}
+
+
+