diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-11-25 13:55:24 +0000 |
---|---|---|
committer | Francis Murtagh <francis.murtagh@arm.com> | 2023-01-03 14:45:41 +0000 |
commit | 2ae3224c559dcd3033be1bfd41be08113048dc50 (patch) | |
tree | e130d940fb6d080807dc29bf0c9602e1bd30af4a /src/armnnOnnxParser/OnnxParser.cpp | |
parent | a09e0336d3e89931e8e0d43010197155a45d2ec7 (diff) | |
download | armnn-2ae3224c559dcd3033be1bfd41be08113048dc50.tar.gz |
GitHub #709 Provide a CreateNetworkFromBinary method for the ONNX parser
* Added CreateNetworkFromBinary to the ONNX parser
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I5ca72ee49c7b098f9fb4aaf55a8bc077230cb30e
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 51 |
1 files changed, 50 insertions, 1 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 63fb60382c..552d4e4163 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "OnnxParser.hpp" @@ -50,6 +50,17 @@ armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFil return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile); } +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) +{ + return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent); +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, + const std::map<std::string, armnn::TensorShape>& inputShapes) +{ + return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent, inputShapes); +} + armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile) { return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile); @@ -731,6 +742,44 @@ INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile, return CreateNetworkFromModel(*modelProto); } +INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) +{ + ResetParser(); + ModelPtr modelProto = LoadModelFromBinary(binaryContent); + return CreateNetworkFromModel(*modelProto); +} + +INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, + const std::map<std::string, armnn::TensorShape>& inputShapes) +{ + ResetParser(); + m_InputShapes = inputShapes; + ModelPtr modelProto = LoadModelFromBinary(binaryContent); + return CreateNetworkFromModel(*modelProto); +} + +ModelPtr OnnxParserImpl::LoadModelFromBinary(const std::vector<uint8_t>& binaryContent) +{ + if (binaryContent.size() == 0) + { + throw ParseException(fmt::format("Missing binary content", CHECK_LOCATION().AsString())); + } + // Parse the file into a message + ModelPtr modelProto = std::make_unique<onnx::ModelProto>(); + + google::protobuf::io::CodedInputStream codedStream(binaryContent.data(), static_cast<int>(binaryContent.size())); + codedStream.SetTotalBytesLimit(INT_MAX); + bool success = modelProto.get()->ParseFromCodedStream(&codedStream); + + if (!success) + { + std::stringstream error; + error << "Failed to parse graph"; + throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString())); + } + return modelProto; +} + ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile) { FILE* fd = fopen(graphFile, "rb"); |