diff options
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 51 |
1 files changed, 50 insertions, 1 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 63fb60382c..552d4e4163 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "OnnxParser.hpp" @@ -50,6 +50,17 @@ armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFil return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile); } +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) +{ + return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent); +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, + const std::map<std::string, armnn::TensorShape>& inputShapes) +{ + return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent, inputShapes); +} + armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile) { return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile); @@ -731,6 +742,44 @@ INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile, return CreateNetworkFromModel(*modelProto); } +INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) +{ + ResetParser(); + ModelPtr modelProto = LoadModelFromBinary(binaryContent); + return CreateNetworkFromModel(*modelProto); +} + +INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, + const std::map<std::string, armnn::TensorShape>& inputShapes) +{ + ResetParser(); + m_InputShapes = inputShapes; + ModelPtr modelProto = LoadModelFromBinary(binaryContent); + return CreateNetworkFromModel(*modelProto); +} + +ModelPtr OnnxParserImpl::LoadModelFromBinary(const std::vector<uint8_t>& binaryContent) +{ + if (binaryContent.size() == 0) + { + throw ParseException(fmt::format("Missing binary content", CHECK_LOCATION().AsString())); + } + // Parse the file into a message + ModelPtr modelProto = std::make_unique<onnx::ModelProto>(); + + google::protobuf::io::CodedInputStream codedStream(binaryContent.data(), static_cast<int>(binaryContent.size())); + codedStream.SetTotalBytesLimit(INT_MAX); + bool success = modelProto.get()->ParseFromCodedStream(&codedStream); + + if (!success) + { + std::stringstream error; + error << "Failed to parse graph"; + throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString())); + } + return modelProto; +} + ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile) { FILE* fd = fopen(graphFile, "rb"); |