aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2022-11-25 13:55:24 +0000
committerFrancis Murtagh <francis.murtagh@arm.com>2023-01-03 14:45:41 +0000
commit2ae3224c559dcd3033be1bfd41be08113048dc50 (patch)
treee130d940fb6d080807dc29bf0c9602e1bd30af4a
parenta09e0336d3e89931e8e0d43010197155a45d2ec7 (diff)
downloadarmnn-2ae3224c559dcd3033be1bfd41be08113048dc50.tar.gz
GitHub #709 Provide a CreateNetworkFromBinary method for the ONNX parser
* Added CreateNetworkFromBinary to the ONNX parser Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I5ca72ee49c7b098f9fb4aaf55a8bc077230cb30e
-rw-r--r--include/armnnOnnxParser/IOnnxParser.hpp9
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp51
-rw-r--r--src/armnnOnnxParser/OnnxParser.hpp10
3 files changed, 67 insertions, 3 deletions
diff --git a/include/armnnOnnxParser/IOnnxParser.hpp b/include/armnnOnnxParser/IOnnxParser.hpp
index ba7fc83f93..89c22c03de 100644
--- a/include/armnnOnnxParser/IOnnxParser.hpp
+++ b/include/armnnOnnxParser/IOnnxParser.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -27,6 +27,13 @@ public:
static IOnnxParserPtr Create();
static void Destroy(IOnnxParser* parser);
+ /// Create the network from a protobuf binary vector
+ armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent);
+
+ /// Create the network from a protobuf binary vector, with inputShapes specified
+ armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
+ const std::map<std::string, armnn::TensorShape>& inputShapes);
+
/// Create the network from a protobuf binary file on disk
armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 63fb60382c..552d4e4163 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "OnnxParser.hpp"
@@ -50,6 +50,17 @@ armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFil
return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile);
}
+armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
+{
+ return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent);
+}
+
+armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
+ const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent, inputShapes);
+}
+
armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile)
{
return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile);
@@ -731,6 +742,44 @@ INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile,
return CreateNetworkFromModel(*modelProto);
}
+INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
+{
+ ResetParser();
+ ModelPtr modelProto = LoadModelFromBinary(binaryContent);
+ return CreateNetworkFromModel(*modelProto);
+}
+
+INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
+ const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ ResetParser();
+ m_InputShapes = inputShapes;
+ ModelPtr modelProto = LoadModelFromBinary(binaryContent);
+ return CreateNetworkFromModel(*modelProto);
+}
+
+ModelPtr OnnxParserImpl::LoadModelFromBinary(const std::vector<uint8_t>& binaryContent)
+{
+ if (binaryContent.size() == 0)
+ {
+ throw ParseException(fmt::format("Missing binary content", CHECK_LOCATION().AsString()));
+ }
+ // Parse the file into a message
+ ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
+
+ google::protobuf::io::CodedInputStream codedStream(binaryContent.data(), static_cast<int>(binaryContent.size()));
+ codedStream.SetTotalBytesLimit(INT_MAX);
+ bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
+
+ if (!success)
+ {
+ std::stringstream error;
+ error << "Failed to parse graph";
+ throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
+ }
+ return modelProto;
+}
+
ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile)
{
FILE* fd = fopen(graphFile, "rb");
diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp
index bb94472c6d..c9f321a5b5 100644
--- a/src/armnnOnnxParser/OnnxParser.hpp
+++ b/src/armnnOnnxParser/OnnxParser.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -38,6 +38,13 @@ public:
armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile,
const std::map<std::string, armnn::TensorShape>& inputShapes);
+ /// Create the network from a protobuf binary
+ armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent);
+
+ /// Create the network from a protobuf binary, with inputShapes specified
+ armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
+ const std::map<std::string, armnn::TensorShape>& inputShapes);
+
/// Create the network from a protobuf text file on disk
armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile);
@@ -64,6 +71,7 @@ public:
OnnxParserImpl();
~OnnxParserImpl() = default;
+ static ModelPtr LoadModelFromBinary(const std::vector<uint8_t>& binaryContent);
static ModelPtr LoadModelFromBinaryFile(const char * fileName);
static ModelPtr LoadModelFromTextFile(const char * fileName);
static ModelPtr LoadModelFromString(const std::string& inputString);