aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-02-14 15:28:02 +0000
committerAron Virginas-Tar <aron.virginas-tar@arm.com>2019-02-19 13:29:41 +0000
commit4fbae33571871ce584e421657e8ffba299e89d67 (patch)
treee916ce9a00295d9700038d75a86dbdc7d52a1a80
parent3166c3edeb64d834ba27031ddd39b5b1f940b604 (diff)
downloadarmnn-4fbae33571871ce584e421657e8ffba299e89d67.tar.gz
IVGCVSW-2662 Implement conversion tool to serialize ArmNN networks
Change-Id: I86fb93b8d739ab043e019643184a091e700fd7c4 Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
-rw-r--r--CMakeLists.txt20
-rw-r--r--src/armnnConverter/ArmnnConverter.cpp344
2 files changed, 364 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 429562362f..47c0ae039b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -633,3 +633,23 @@ if(BUILD_UNIT_TESTS)
addDllCopyCommands(UnitTests)
endif()
+
+if (BUILD_ARMNN_SERIALIZER AND BUILD_TF_PARSER)
+ set(ArmnnConverter_sources
+ src/armnnConverter/ArmnnConverter.cpp)
+
+ add_executable_ex(ArmnnConverter ${ArmnnConverter_sources})
+ target_include_directories(ArmnnConverter PRIVATE src/armnn)
+ target_include_directories(ArmnnConverter PRIVATE src/armnnUtils)
+
+ target_link_libraries(ArmnnConverter armnnTfParser)
+ target_link_libraries(ArmnnConverter armnnSerializer)
+
+ target_link_libraries(ArmnnConverter armnn)
+ target_link_libraries(ArmnnConverter ${CMAKE_THREAD_LIBS_INIT})
+ target_link_libraries(ArmnnConverter
+ ${Boost_SYSTEM_LIBRARY}
+ ${Boost_FILESYSTEM_LIBRARY}
+ ${Boost_PROGRAM_OPTIONS_LIBRARY})
+ addDllCopyCommands(ArmnnConverter)
+endif()
diff --git a/src/armnnConverter/ArmnnConverter.cpp b/src/armnnConverter/ArmnnConverter.cpp
new file mode 100644
index 0000000000..fbec1449a8
--- /dev/null
+++ b/src/armnnConverter/ArmnnConverter.cpp
@@ -0,0 +1,344 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include <armnn/ArmNN.hpp>
+
+#include <armnnSerializer/ISerializer.hpp>
+#include <armnnTfParser/ITfParser.hpp>
+
+#include <Logging.hpp>
+#include <HeapProfiling.hpp>
+
+#include <boost/format.hpp>
+#include <boost/algorithm/string/split.hpp>
+#include <boost/algorithm/string/classification.hpp>
+#include <boost/program_options.hpp>
+
+#include <iostream>
+#include <fstream>
+
+namespace
+{
+
+namespace po = boost::program_options;
+
+armnn::TensorShape ParseTensorShape(std::istream& stream)
+{
+ std::vector<unsigned int> result;
+ std::string line;
+
+ while (std::getline(stream, line))
+ {
+ std::vector<std::string> tokens;
+ try
+ {
+ // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
+ boost::split(tokens, line, boost::algorithm::is_any_of(","), boost::token_compress_on);
+ }
+ catch (const std::exception& e)
+ {
+ BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
+ continue;
+ }
+ for (const std::string& token : tokens)
+ {
+ if (!token.empty())
+ {
+ try
+ {
+ result.push_back(boost::numeric_cast<unsigned int>(std::stoi((token))));
+ }
+ catch (const std::exception&)
+ {
+ BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
+ }
+ }
+ }
+ }
+
+ return armnn::TensorShape(boost::numeric_cast<unsigned int>(result.size()), result.data());
+}
+
+bool CheckOption(const po::variables_map& vm,
+ const char* option)
+{
+ if (option == nullptr)
+ {
+ return false;
+ }
+
+ // Check whether 'option' is provided.
+ return vm.find(option) != vm.end();
+}
+
+void CheckOptionDependency(const po::variables_map& vm,
+ const char* option,
+ const char* required)
+{
+ if (option == nullptr || required == nullptr)
+ {
+ throw po::error("Invalid option to check dependency for");
+ }
+
+ // Check that if 'option' is provided, 'required' is also provided.
+ if (CheckOption(vm, option) && !vm[option].defaulted())
+ {
+ if (CheckOption(vm, required) == 0 || vm[required].defaulted())
+ {
+ throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
+ }
+ }
+}
+
+void CheckOptionDependencies(const po::variables_map& vm)
+{
+ CheckOptionDependency(vm, "model-path", "model-format");
+ CheckOptionDependency(vm, "model-path", "input-name");
+ CheckOptionDependency(vm, "model-path", "output-name");
+ CheckOptionDependency(vm, "input-tensor-shape", "model-path");
+}
+
+int ParseCommandLineArgs(int argc, const char* argv[],
+ std::string& modelFormat,
+ std::string& modelPath,
+ std::vector<std::string>& inputNames,
+ std::vector<std::string>& inputTensorShapeStrs,
+ std::vector<std::string>& outputNames,
+ std::string& outputPath, bool& isModelBinary)
+{
+ po::options_description desc("Options");
+
+ desc.add_options()
+ ("help", "Display usage information")
+ ("model-format,f", po::value(&modelFormat)->required(),"tensorflow-binary or tensorflow-text.")
+ ("model-path,m", po::value(&modelPath)->required(), "Path to model file")
+ ("input-name,i", po::value<std::vector<std::string>>()->multitoken(),
+ "Identifier of the input tensors in the network separated by whitespace")
+ ("input-tensor-shape,s", po::value<std::vector<std::string>>()->multitoken(),
+ "The shape of the input tensor in the network as a flat array of integers separated by comma"
+ "Multiple shapes are separated by whitespace"
+ "This parameter is optional, depending on the network.")
+ ("output-name,o", po::value<std::vector<std::string>>()->multitoken(),
+ "Identifier of the output tensor in the network.")
+ ("output-path,p", po::value(&outputPath)->required(), "Path to serialize the network to.");
+
+ po::variables_map vm;
+ try
+ {
+ po::store(po::parse_command_line(argc, argv, desc), vm);
+
+ if (CheckOption(vm, "help") || argc <= 1)
+ {
+ std::cout << "Convert a neural network model from provided file to ArmNN format " << std::endl;
+ std::cout << std::endl;
+ std::cout << desc << std::endl;
+ return EXIT_SUCCESS;
+ }
+
+ po::notify(vm);
+ }
+ catch (const po::error& e)
+ {
+ std::cerr << e.what() << std::endl << std::endl;
+ std::cerr << desc << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ try
+ {
+ CheckOptionDependencies(vm);
+ }
+ catch (const po::error& e)
+ {
+ std::cerr << e.what() << std::endl << std::endl;
+ std::cerr << desc << std::endl;
+ return EXIT_FAILURE;
+ }
+
+ if (modelFormat.find("bin") != std::string::npos)
+ {
+ isModelBinary = true;
+ }
+ else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
+ {
+ isModelBinary = false;
+ }
+ else
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
+ return EXIT_FAILURE;
+ }
+
+ inputNames = vm["input-name"].as<std::vector<std::string>>();
+ inputTensorShapeStrs = vm["input-tensor-shape"].as<std::vector<std::string>>();
+ outputNames = vm["output-name"].as<std::vector<std::string>>();
+
+ return EXIT_SUCCESS;
+}
+
+class ArmnnConverter
+{
+public:
+ ArmnnConverter(const std::string& modelPath,
+ const std::vector<std::string>& inputNames,
+ const std::vector<armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& outputNames,
+ const std::string& outputPath,
+ bool isModelBinary)
+ : m_NetworkPtr(armnn::INetworkPtr(nullptr, [](armnn::INetwork *){})),
+ m_ModelPath(modelPath),
+ m_InputNames(inputNames),
+ m_InputShapes(inputShapes),
+ m_OutputNames(outputNames),
+ m_OutputPath(outputPath),
+ m_IsModelBinary(isModelBinary) {}
+
+ bool Serialize()
+ {
+ if (m_NetworkPtr.get() == nullptr)
+ {
+ return false;
+ }
+
+ auto serializer(armnnSerializer::ISerializer::Create());
+
+ serializer->Serialize(*m_NetworkPtr);
+
+ std::ofstream file(m_OutputPath, std::ios::out | std::ios::binary);
+
+ bool retVal = serializer->SaveSerializedToStream(file);
+
+ return retVal;
+ }
+
+ template <typename IParser>
+ bool CreateNetwork ()
+ {
+ // Create a network from a file on disk
+ auto parser(IParser::Create());
+
+ std::map<std::string, armnn::TensorShape> inputShapes;
+ if (!m_InputShapes.empty())
+ {
+ const size_t numInputShapes = m_InputShapes.size();
+ const size_t numInputBindings = m_InputNames.size();
+ if (numInputShapes < numInputBindings)
+ {
+ throw armnn::Exception(boost::str(boost::format(
+ "Not every input has its tensor shape specified: expected=%1%, got=%2%")
+ % numInputBindings % numInputShapes));
+ }
+
+ for (size_t i = 0; i < numInputShapes; i++)
+ {
+ inputShapes[m_InputNames[i]] = m_InputShapes[i];
+ }
+ }
+
+ {
+ ARMNN_SCOPED_HEAP_PROFILING("Parsing");
+ m_NetworkPtr = (m_IsModelBinary ?
+ parser->CreateNetworkFromBinaryFile(m_ModelPath.c_str(), inputShapes, m_OutputNames) :
+ parser->CreateNetworkFromTextFile(m_ModelPath.c_str(), inputShapes, m_OutputNames));
+ }
+
+ return m_NetworkPtr.get() != nullptr;
+ }
+
+private:
+ armnn::INetworkPtr m_NetworkPtr;
+ std::string m_ModelPath;
+ std::vector<std::string> m_InputNames;
+ std::vector<armnn::TensorShape> m_InputShapes;
+ std::vector<std::string> m_OutputNames;
+ std::string m_OutputPath;
+ bool m_IsModelBinary;
+};
+
+} // anonymous namespace
+
+int main(int argc, const char* argv[])
+{
+
+#if !defined(ARMNN_TF_PARSER)
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
+ return EXIT_FAILURE;
+#endif
+
+#if !defined(ARMNN_SERIALIZER)
+ BOOST_LOG_TRIVIAL(fatal) << "Not built with Serializer support.";
+ return EXIT_FAILURE;
+#endif
+
+#ifdef NDEBUG
+ armnn::LogSeverity level = armnn::LogSeverity::Info;
+#else
+ armnn::LogSeverity level = armnn::LogSeverity::Debug;
+#endif
+
+ armnn::ConfigureLogging(true, true, level);
+ armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
+
+ std::string modelFormat;
+ std::string modelPath;
+
+ std::vector<std::string> inputNames;
+ std::vector<std::string> inputTensorShapeStrs;
+ std::vector<armnn::TensorShape> inputTensorShapes;
+
+ std::vector<std::string> outputNames;
+ std::string outputPath;
+
+ bool isModelBinary = true;
+
+ if (ParseCommandLineArgs(
+ argc, argv, modelFormat, modelPath, inputNames, inputTensorShapeStrs, outputNames, outputPath, isModelBinary)
+ != EXIT_SUCCESS)
+ {
+ return EXIT_FAILURE;
+ }
+
+ for (const std::string& shapeStr : inputTensorShapeStrs)
+ {
+ if (!shapeStr.empty())
+ {
+ std::stringstream ss(shapeStr);
+
+ try
+ {
+ armnn::TensorShape shape = ParseTensorShape(ss);
+ inputTensorShapes.push_back(shape);
+ }
+ catch (const armnn::InvalidArgumentException& e)
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
+ return EXIT_FAILURE;
+ }
+ }
+ }
+
+ ArmnnConverter converter(modelPath, inputNames, inputTensorShapes, outputNames, outputPath, isModelBinary);
+
+ if (modelFormat.find("tensorflow") != std::string::npos)
+ {
+ if (!converter.CreateNetwork<armnnTfParser::ITfParser>())
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Failed to load model from file";
+ return EXIT_FAILURE;
+ }
+ }
+ else
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat;
+ return EXIT_FAILURE;
+ }
+
+ if (!converter.Serialize())
+ {
+ BOOST_LOG_TRIVIAL(fatal) << "Failed to serialize model";
+ return EXIT_FAILURE;
+ }
+
+ return EXIT_SUCCESS;
+}