From 2ae3224c559dcd3033be1bfd41be08113048dc50 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Fri, 25 Nov 2022 13:55:24 +0000 Subject: GitHub #709 Provide a CreateNetworkFromBinary method for the ONNX parser * Added CreateNetworkFromBinary to the ONNX parser Signed-off-by: Mike Kelly Change-Id: I5ca72ee49c7b098f9fb4aaf55a8bc077230cb30e --- src/armnnOnnxParser/OnnxParser.cpp | 51 +++++++++++++++++++++++++++++++++++++- src/armnnOnnxParser/OnnxParser.hpp | 10 +++++++- 2 files changed, 59 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 63fb60382c..552d4e4163 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "OnnxParser.hpp" @@ -50,6 +50,17 @@ armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFil return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile); } +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector& binaryContent) +{ + return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent); +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector& binaryContent, + const std::map& inputShapes) +{ + return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent, inputShapes); +} + armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile) { return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile); @@ -731,6 +742,44 @@ INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile, return CreateNetworkFromModel(*modelProto); } +INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector& binaryContent) +{ + ResetParser(); + ModelPtr modelProto = LoadModelFromBinary(binaryContent); + return CreateNetworkFromModel(*modelProto); +} + +INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector& binaryContent, + const std::map& inputShapes) +{ + ResetParser(); + m_InputShapes = inputShapes; + ModelPtr modelProto = LoadModelFromBinary(binaryContent); + return CreateNetworkFromModel(*modelProto); +} + +ModelPtr OnnxParserImpl::LoadModelFromBinary(const std::vector& binaryContent) +{ + if (binaryContent.size() == 0) + { + throw ParseException(fmt::format("Missing binary content", CHECK_LOCATION().AsString())); + } + // Parse the file into a message + ModelPtr modelProto = std::make_unique(); + + google::protobuf::io::CodedInputStream codedStream(binaryContent.data(), static_cast(binaryContent.size())); + codedStream.SetTotalBytesLimit(INT_MAX); + bool success = modelProto.get()->ParseFromCodedStream(&codedStream); + + if (!success) + { + std::stringstream error; + error << "Failed to parse graph"; + throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString())); + } + return modelProto; +} + ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile) { FILE* fd = fopen(graphFile, "rb"); diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp index bb94472c6d..c9f321a5b5 100644 --- a/src/armnnOnnxParser/OnnxParser.hpp +++ b/src/armnnOnnxParser/OnnxParser.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -38,6 +38,13 @@ public: armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile, const std::map& inputShapes); + /// Create the network from a protobuf binary + armnn::INetworkPtr CreateNetworkFromBinary(const std::vector& binaryContent); + + /// Create the network from a protobuf binary, with inputShapes specified + armnn::INetworkPtr CreateNetworkFromBinary(const std::vector& binaryContent, + const std::map& inputShapes); + /// Create the network from a protobuf text file on disk armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile); @@ -64,6 +71,7 @@ public: OnnxParserImpl(); ~OnnxParserImpl() = default; + static ModelPtr LoadModelFromBinary(const std::vector& binaryContent); static ModelPtr LoadModelFromBinaryFile(const char * fileName); static ModelPtr LoadModelFromTextFile(const char * fileName); static ModelPtr LoadModelFromString(const std::string& inputString); -- cgit v1.2.1