// // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "NetworkExecutionUtils.hpp" #include #include #include #if defined(ARMNN_SERIALIZER) #include "armnnDeserializer/IDeserializer.hpp" #endif #if defined(ARMNN_TF_LITE_PARSER) #include "armnnTfLiteParser/ITfLiteParser.hpp" #endif #if defined(ARMNN_ONNX_PARSER) #include "armnnOnnxParser/IOnnxParser.hpp" #endif template auto ParseDataArray(std::istream& stream); template auto ParseDataArray(std::istream& stream, const float& quantizationScale, const int32_t& quantizationOffset); template<> auto ParseDataArray(std::istream& stream) { return ParseArrayImpl(stream, [](const std::string& s) { return std::stof(s); }); } template<> auto ParseDataArray(std::istream& stream) { return ParseArrayImpl(stream, [](const std::string& s) { return std::stoi(s); }); } template<> auto ParseDataArray(std::istream& stream) { return ParseArrayImpl(stream, [](const std::string& s) { return armnn::numeric_cast(std::stoi(s)); }); } template<> auto ParseDataArray(std::istream& stream) { return ParseArrayImpl(stream, [](const std::string& s) { return armnn::numeric_cast(std::stoi(s)); }); } template<> auto ParseDataArray(std::istream& stream, const float& quantizationScale, const int32_t& quantizationOffset) { return ParseArrayImpl(stream, [&quantizationScale, &quantizationOffset](const std::string& s) { return armnn::numeric_cast( armnn::Quantize(std::stof(s), quantizationScale, quantizationOffset)); }); } template> std::vector GenerateDummyTensorData(unsigned int numElements) { return std::vector(numElements, static_cast(0)); } std::vector ParseArray(std::istream& stream) { return ParseArrayImpl( stream, [](const std::string& s) { return armnn::numeric_cast(std::stoi(s)); }); } std::vector ParseStringList(const std::string& inputString, const char* delimiter) { std::stringstream stream(inputString); return ParseArrayImpl(stream, [](const std::string& s) { return armnn::stringUtils::StringTrimCopy(s); }, delimiter); } TensorPrinter::TensorPrinter(const std::string& binding, const armnn::TensorInfo& info, const std::string& outputTensorFile, bool dequantizeOutput) : m_OutputBinding(binding) , m_Scale(info.GetQuantizationScale()) , m_Offset(info.GetQuantizationOffset()) , m_OutputTensorFile(outputTensorFile) , m_DequantizeOutput(dequantizeOutput) {} void TensorPrinter::operator()(const std::vector& values) { ForEachValue(values, [](float value) { printf("%f ", value); }); WriteToFile(values); } void TensorPrinter::operator()(const std::vector& values) { if(m_DequantizeOutput) { auto& scale = m_Scale; auto& offset = m_Offset; std::vector dequantizedValues; ForEachValue(values, [&scale, &offset, &dequantizedValues](uint8_t value) { auto dequantizedValue = armnn::Dequantize(value, scale, offset); printf("%f ", dequantizedValue); dequantizedValues.push_back(dequantizedValue); }); WriteToFile(dequantizedValues); } else { const std::vector intValues(values.begin(), values.end()); operator()(intValues); } } void TensorPrinter::operator()(const std::vector& values) { ForEachValue(values, [](int8_t value) { printf("%d ", value); }); WriteToFile(values); } void TensorPrinter::operator()(const std::vector& values) { ForEachValue(values, [](int value) { printf("%d ", value); }); WriteToFile(values); } template void TensorPrinter::ForEachValue(const Container& c, Delegate delegate) { std::cout << m_OutputBinding << ": "; for (const auto& value : c) { delegate(value); } printf("\n"); } template void TensorPrinter::WriteToFile(const std::vector& values) { if (!m_OutputTensorFile.empty()) { std::ofstream outputTensorFile; outputTensorFile.open(m_OutputTensorFile, std::ofstream::out | std::ofstream::trunc); if (outputTensorFile.is_open()) { outputTensorFile << m_OutputBinding << ": "; std::copy(values.begin(), values.end(), std::ostream_iterator(outputTensorFile, " ")); } else { ARMNN_LOG(info) << "Output Tensor File: " << m_OutputTensorFile << " could not be opened!"; } outputTensorFile.close(); } } using TContainer = mapbox::util::variant, std::vector, std::vector, std::vector>; using QuantizationParams = std::pair; void PopulateTensorWithData(TContainer& tensorData, unsigned int numElements, const std::string& dataTypeStr, const armnn::Optional& qParams, const armnn::Optional& dataFile) { const bool readFromFile = dataFile.has_value() && !dataFile.value().empty(); const bool quantizeData = qParams.has_value(); std::ifstream inputTensorFile; if (readFromFile) { inputTensorFile = std::ifstream(dataFile.value()); } if (dataTypeStr.compare("float") == 0) { if (quantizeData) { const float qScale = qParams.value().first; const int qOffset = qParams.value().second; tensorData = readFromFile ? ParseDataArray(inputTensorFile, qScale, qOffset) : GenerateDummyTensorData(numElements); } else { tensorData = readFromFile ? ParseDataArray(inputTensorFile) : GenerateDummyTensorData(numElements); } } else if (dataTypeStr.compare("int") == 0) { tensorData = readFromFile ? ParseDataArray(inputTensorFile) : GenerateDummyTensorData(numElements); } else if (dataTypeStr.compare("qsymms8") == 0) { tensorData = readFromFile ? ParseDataArray(inputTensorFile) : GenerateDummyTensorData(numElements); } else if (dataTypeStr.compare("qasymm8") == 0) { tensorData = readFromFile ? ParseDataArray(inputTensorFile) : GenerateDummyTensorData(numElements); } else { std::string errorMessage = "Unsupported tensor data type " + dataTypeStr; ARMNN_LOG(fatal) << errorMessage; inputTensorFile.close(); throw armnn::Exception(errorMessage); } inputTensorFile.close(); } bool ValidatePath(const std::string& file, const bool expectFile) { if (!fs::exists(file)) { std::cerr << "Given file path '" << file << "' does not exist" << std::endl; return false; } if (!fs::is_regular_file(file) && expectFile) { std::cerr << "Given file path '" << file << "' is not a regular file" << std::endl; return false; } return true; } bool ValidatePaths(const std::vector& fileVec, const bool expectFile) { bool allPathsValid = true; for (auto const& file : fileVec) { if(!ValidatePath(file, expectFile)) { allPathsValid = false; } } return allPathsValid; }