diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
commit | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch) | |
tree | c9a70aeb2887006160c1b3d265c27efadb7bdbae /tests/ExecuteNetwork/ExecuteNetwork.cpp | |
download | armnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz |
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'tests/ExecuteNetwork/ExecuteNetwork.cpp')
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp new file mode 100644 index 0000000000..5e9e6532cf --- /dev/null +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -0,0 +1,244 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#include "armnn/ArmNN.hpp" +#if defined(ARMNN_CAFFE_PARSER) +#include "armnnCaffeParser/ICaffeParser.hpp" +#endif +#include "Logging.hpp" +#include "../InferenceTest.hpp" + +#include <boost/program_options.hpp> +#include <boost/algorithm/string/split.hpp> +#include <boost/algorithm/string/classification.hpp> + +#include <iostream> +#include <fstream> + +namespace +{ + +template<typename T, typename TParseElementFunc> +std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc) +{ + std::vector<T> result; + // Process line-by-line + std::string line; + while (std::getline(stream, line)) + { + std::vector<std::string> tokens; + boost::split(tokens, line, boost::algorithm::is_any_of("\t ,;:"), boost::token_compress_on); + for (const std::string& token : tokens) + { + if (!token.empty()) // See https://stackoverflow.com/questions/10437406/ + { + try + { + result.push_back(parseElementFunc(token)); + } + catch (const std::exception&) + { + BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored."; + } + } + } + } + + return result; +} + +} + +template<typename T> +std::vector<T> ParseArray(std::istream& stream); + +template<> +std::vector<float> ParseArray(std::istream& stream) +{ + return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); }); +} + +template<> +std::vector<unsigned int> ParseArray(std::istream& stream) +{ + return ParseArrayImpl<unsigned int>(stream, + [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); }); +} + +void PrintArray(const std::vector<float>& v) +{ + for (size_t i = 0; i < v.size(); i++) + { + printf("%f ", v[i]); + } + printf("\n"); +} + +template<typename TParser, typename TDataType> +int MainImpl(const char* modelPath, bool isModelBinary, armnn::Compute computeDevice, + const char* inputName, const armnn::TensorShape* inputTensorShape, const char* inputTensorDataFilePath, + const char* outputName) +{ + // Load input tensor + std::vector<TDataType> input; + { + std::ifstream inputTensorFile(inputTensorDataFilePath); + if (!inputTensorFile.good()) + { + BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath; + return 1; + } + input = ParseArray<TDataType>(inputTensorFile); + } + + try + { + // Create an InferenceModel, which will parse the model and load it into an IRuntime + typename InferenceModel<TParser, TDataType>::Params params; + params.m_ModelPath = modelPath; + params.m_IsModelBinary = isModelBinary; + params.m_ComputeDevice = computeDevice; + params.m_InputBinding = inputName; + params.m_InputTensorShape = inputTensorShape; + params.m_OutputBinding = outputName; + InferenceModel<TParser, TDataType> model(params); + + // Execute the model + std::vector<TDataType> output(model.GetOutputSize()); + model.Run(input, output); + + // Print the output tensor + PrintArray(output); + } + catch (armnn::Exception const& e) + { + BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what(); + return 1; + } + + return 0; +} + +int main(int argc, char* argv[]) +{ + // Configure logging for both the ARMNN library and this test program +#ifdef NDEBUG + armnn::LogSeverity level = armnn::LogSeverity::Info; +#else + armnn::LogSeverity level = armnn::LogSeverity::Debug; +#endif + armnn::ConfigureLogging(true, true, level); + armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level); + + // Configure boost::program_options for command-line parsing + namespace po = boost::program_options; + + std::string modelFormat; + std::string modelPath; + std::string inputName; + std::string inputTensorShapeStr; + std::string inputTensorDataFilePath; + std::string outputName; + armnn::Compute computeDevice; + + po::options_description desc("Options"); + try + { + desc.add_options() + ("help", "Display usage information") + ("model-format,f", po::value(&modelFormat)->required(), + "caffe-binary, caffe-text, tensorflow-binary or tensorflow-text.") + ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt") + ("compute,c", po::value<armnn::Compute>(&computeDevice)->required(), + "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc") + ("input-name,i", po::value(&inputName)->required(), "Identifier of the input tensor in the network.") + ("input-tensor-shape,s", po::value(&inputTensorShapeStr), + "The shape of the input tensor in the network as a flat array of integers separated by whitespace. " + "This parameter is optional, depending on the network.") + ("input-tensor-data,d", po::value(&inputTensorDataFilePath)->required(), + "Path to a file containing the input data as a flat array separated by whitespace.") + ("output-name,o", po::value(&outputName)->required(), "Identifier of the output tensor in the network."); + } + catch (const std::exception& e) + { + // Coverity points out that default_value(...) can throw a bad_lexical_cast, + // and that desc.add_options() can throw boost::io::too_few_args. + // They really won't in any of these cases. + BOOST_ASSERT_MSG(false, "Caught unexpected exception"); + BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what(); + return 1; + } + + // Parse the command-line + po::variables_map vm; + try + { + po::store(po::parse_command_line(argc, argv, desc), vm); + + if (vm.count("help") || argc <= 1) + { + std::cout << "Executes a neural network model using the provided input tensor. " << std::endl; + std::cout << "Prints the resulting output tensor." << std::endl; + std::cout << std::endl; + std::cout << desc << std::endl; + return 1; + } + + po::notify(vm); + } + catch (po::error& e) + { + std::cerr << e.what() << std::endl << std::endl; + std::cerr << desc << std::endl; + return 1; + } + + // Parse model binary flag from the model-format string we got from the command-line + bool isModelBinary; + if (modelFormat.find("bin") != std::string::npos) + { + isModelBinary = true; + } + else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos) + { + isModelBinary = false; + } + else + { + BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'"; + return 1; + } + + // Parse input tensor shape from the string we got from the command-line. + std::unique_ptr<armnn::TensorShape> inputTensorShape; + if (!inputTensorShapeStr.empty()) + { + std::stringstream ss(inputTensorShapeStr); + std::vector<unsigned int> dims = ParseArray<unsigned int>(ss); + inputTensorShape = std::make_unique<armnn::TensorShape>(dims.size(), dims.data()); + } + + // Forward to implementation based on the parser type + if (modelFormat.find("caffe") != std::string::npos) + { +#if defined(ARMNN_CAFFE_PARSER) + return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice, + inputName.c_str(), inputTensorShape.get(), inputTensorDataFilePath.c_str(), outputName.c_str()); +#else + BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support."; + return 1; +#endif + } + else if (modelFormat.find("tensorflow") != std::string::npos) + { + BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support."; + return 1; + } + else + { + BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << + "'. Please include 'caffe' or 'tensorflow'"; + return 1; + } +} |