From 2b183fb359774cbac5d628579ec2b4a7b6b41def Mon Sep 17 00:00:00 2001 From: Derek Lamberti Date: Mon, 18 Feb 2019 16:36:57 +0000 Subject: IVGCVSW-2736 Deserialize using istream instead of filename Change-Id: I5656b23d9783e7f953e677001d16e41eedeb42b2 Signed-off-by: Derek Lamberti --- .../armnnDeserializeParser/IDeserializeParser.hpp | 7 ++--- src/armnnDeserializeParser/DeserializeParser.cpp | 35 +++++++--------------- src/armnnDeserializeParser/DeserializeParser.hpp | 19 +++++------- tests/InferenceModel.hpp | 16 ++++++++-- 4 files changed, 34 insertions(+), 43 deletions(-) diff --git a/include/armnnDeserializeParser/IDeserializeParser.hpp b/include/armnnDeserializeParser/IDeserializeParser.hpp index bb9726e427..ab64dc9e14 100644 --- a/include/armnnDeserializeParser/IDeserializeParser.hpp +++ b/include/armnnDeserializeParser/IDeserializeParser.hpp @@ -28,12 +28,11 @@ public: static IDeserializeParserPtr Create(); static void Destroy(IDeserializeParser* parser); - /// Create the network from a flatbuffers binary file on disk - virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) = 0; - - /// Create the network from a flatbuffers binary + /// Create an input network from binary file contents virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector& binaryContent) = 0; + /// Create an input network from a binary input stream + virtual armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) = 0; /// Retrieve binding info (layer id and tensor info) for the network input identified by /// the given layer name and layers id diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp index de9b1a98c7..9af5087cff 100644 --- a/src/armnnDeserializeParser/DeserializeParser.cpp +++ b/src/armnnDeserializeParser/DeserializeParser.cpp @@ -352,13 +352,6 @@ void IDeserializeParser::Destroy(IDeserializeParser* parser) delete parser; } -INetworkPtr DeserializeParser::CreateNetworkFromBinaryFile(const char* graphFile) -{ - ResetParser(); - m_Graph = LoadGraphFromFile(graphFile, m_FileContent); - return CreateNetworkFromGraph(); -} - INetworkPtr DeserializeParser::CreateNetworkFromBinary(const std::vector& binaryContent) { ResetParser(); @@ -366,25 +359,11 @@ INetworkPtr DeserializeParser::CreateNetworkFromBinary(const std::vector(file)), std::istreambuf_iterator()); - return LoadGraphFromBinary(reinterpret_cast(fileContent.c_str()), fileContent.size()); + ResetParser(); + m_Graph = LoadGraphFromBinary(binaryContent); + return CreateNetworkFromGraph(); } DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromBinary(const uint8_t* binaryContent, size_t len) @@ -406,6 +385,12 @@ DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromBinary(const uint8_t return GetSerializedGraph(binaryContent); } +DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromBinary(std::istream& binaryContent) +{ + std::string content((std::istreambuf_iterator(binaryContent)), std::istreambuf_iterator()); + return GetSerializedGraph(content.data()); +} + INetworkPtr DeserializeParser::CreateNetworkFromGraph() { m_Network = INetwork::Create(); diff --git a/src/armnnDeserializeParser/DeserializeParser.hpp b/src/armnnDeserializeParser/DeserializeParser.hpp index 666cbca33c..aee647c636 100644 --- a/src/armnnDeserializeParser/DeserializeParser.hpp +++ b/src/armnnDeserializeParser/DeserializeParser.hpp @@ -25,26 +25,25 @@ public: public: - /// Create the network from a flatbuffers binary file on disk - virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override; + /// Create an input network from binary file contents + armnn::INetworkPtr CreateNetworkFromBinary(const std::vector& binaryContent) override; - virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector& binaryContent) override; + /// Create an input network from a binary input stream + armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) override; /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name - virtual BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, - const std::string& name) const override; + BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const override; /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name - virtual BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, - const std::string& name) const override; + BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const override; DeserializeParser(); ~DeserializeParser() {} public: // testable helpers - static GraphPtr LoadGraphFromFile(const char* fileName, std::string& fileContent); static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len); + static GraphPtr LoadGraphFromBinary(std::istream& binaryContent); static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex); static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex); static LayerBaseRawPtrVector GetGraphInputs(const GraphPtr& graphPtr); @@ -91,10 +90,6 @@ private: std::vector m_ParserFunctions; std::string m_layerName; - /// This holds the data of the file that was read in from CreateNetworkFromBinaryFile - /// Needed for m_Graph to point to - std::string m_FileContent; - /// A mapping of an output slot to each of the input slots it should be connected to /// The outputSlot is from the layer that creates this tensor as one of its outputs /// The inputSlots are from the layers that use this tensor as one of their inputs diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index eb3b2ccd42..4819523595 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -183,8 +183,20 @@ public: { ARMNN_SCOPED_HEAP_PROFILING("Parsing"); - const std::string& modelPath = params.m_ModelPath; - network = parser->CreateNetworkFromBinaryFile(modelPath.c_str()); + + boost::system::error_code errorCode; + boost::filesystem::path pathToFile(params.m_ModelPath); + if (!boost::filesystem::exists(pathToFile, errorCode)) + { + throw armnn::FileNotFoundException(boost::str( + boost::format("Cannot find the file (%1%) errorCode: %2% %3%") % + params.m_ModelPath % + errorCode % + CHECK_LOCATION().AsString())); + } + std::ifstream file(params.m_ModelPath, std::ios::binary); + + network = parser->CreateNetworkFromBinary(file); } unsigned int subGraphId = boost::numeric_cast(params.m_SubgraphId); -- cgit v1.2.1