diff options
-rw-r--r-- | include/armnnOnnxParser/IOnnxParser.hpp | 9 | ||||
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 51 | ||||
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.hpp | 10 |
3 files changed, 67 insertions, 3 deletions
diff --git a/include/armnnOnnxParser/IOnnxParser.hpp b/include/armnnOnnxParser/IOnnxParser.hpp index ba7fc83f93..89c22c03de 100644 --- a/include/armnnOnnxParser/IOnnxParser.hpp +++ b/include/armnnOnnxParser/IOnnxParser.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017,2022 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -27,6 +27,13 @@ public: static IOnnxParserPtr Create(); static void Destroy(IOnnxParser* parser); + /// Create the network from a protobuf binary vector + armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent); + + /// Create the network from a protobuf binary vector, with inputShapes specified + armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, + const std::map<std::string, armnn::TensorShape>& inputShapes); + /// Create the network from a protobuf binary file on disk armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile); diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 63fb60382c..552d4e4163 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "OnnxParser.hpp" @@ -50,6 +50,17 @@ armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFil return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile); } +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) +{ + return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent); +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, + const std::map<std::string, armnn::TensorShape>& inputShapes) +{ + return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent, inputShapes); +} + armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile) { return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile); @@ -731,6 +742,44 @@ INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile, return CreateNetworkFromModel(*modelProto); } +INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) +{ + ResetParser(); + ModelPtr modelProto = LoadModelFromBinary(binaryContent); + return CreateNetworkFromModel(*modelProto); +} + +INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, + const std::map<std::string, armnn::TensorShape>& inputShapes) +{ + ResetParser(); + m_InputShapes = inputShapes; + ModelPtr modelProto = LoadModelFromBinary(binaryContent); + return CreateNetworkFromModel(*modelProto); +} + +ModelPtr OnnxParserImpl::LoadModelFromBinary(const std::vector<uint8_t>& binaryContent) +{ + if (binaryContent.size() == 0) + { + throw ParseException(fmt::format("Missing binary content", CHECK_LOCATION().AsString())); + } + // Parse the file into a message + ModelPtr modelProto = std::make_unique<onnx::ModelProto>(); + + google::protobuf::io::CodedInputStream codedStream(binaryContent.data(), static_cast<int>(binaryContent.size())); + codedStream.SetTotalBytesLimit(INT_MAX); + bool success = modelProto.get()->ParseFromCodedStream(&codedStream); + + if (!success) + { + std::stringstream error; + error << "Failed to parse graph"; + throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString())); + } + return modelProto; +} + ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile) { FILE* fd = fopen(graphFile, "rb"); diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp index bb94472c6d..c9f321a5b5 100644 --- a/src/armnnOnnxParser/OnnxParser.hpp +++ b/src/armnnOnnxParser/OnnxParser.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -38,6 +38,13 @@ public: armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile, const std::map<std::string, armnn::TensorShape>& inputShapes); + /// Create the network from a protobuf binary + armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent); + + /// Create the network from a protobuf binary, with inputShapes specified + armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, + const std::map<std::string, armnn::TensorShape>& inputShapes); + /// Create the network from a protobuf text file on disk armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile); @@ -64,6 +71,7 @@ public: OnnxParserImpl(); ~OnnxParserImpl() = default; + static ModelPtr LoadModelFromBinary(const std::vector<uint8_t>& binaryContent); static ModelPtr LoadModelFromBinaryFile(const char * fileName); static ModelPtr LoadModelFromTextFile(const char * fileName); static ModelPtr LoadModelFromString(const std::string& inputString); |