aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2019-02-18 16:36:57 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-02-19 17:07:12 +0000
commit2b183fb359774cbac5d628579ec2b4a7b6b41def (patch)
tree764f4ee3293bc419ee204b1685d17bb59df410ac
parent263829c2163d79a28f98f24f9dd1e52e1c3cbbef (diff)
downloadarmnn-2b183fb359774cbac5d628579ec2b4a7b6b41def.tar.gz
IVGCVSW-2736 Deserialize using istream instead of filename
Change-Id: I5656b23d9783e7f953e677001d16e41eedeb42b2 Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
-rw-r--r--include/armnnDeserializeParser/IDeserializeParser.hpp7
-rw-r--r--src/armnnDeserializeParser/DeserializeParser.cpp35
-rw-r--r--src/armnnDeserializeParser/DeserializeParser.hpp19
-rw-r--r--tests/InferenceModel.hpp16
4 files changed, 34 insertions, 43 deletions
diff --git a/include/armnnDeserializeParser/IDeserializeParser.hpp b/include/armnnDeserializeParser/IDeserializeParser.hpp
index bb9726e427..ab64dc9e14 100644
--- a/include/armnnDeserializeParser/IDeserializeParser.hpp
+++ b/include/armnnDeserializeParser/IDeserializeParser.hpp
@@ -28,12 +28,11 @@ public:
static IDeserializeParserPtr Create();
static void Destroy(IDeserializeParser* parser);
- /// Create the network from a flatbuffers binary file on disk
- virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) = 0;
-
- /// Create the network from a flatbuffers binary
+ /// Create an input network from binary file contents
virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) = 0;
+ /// Create an input network from a binary input stream
+ virtual armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) = 0;
/// Retrieve binding info (layer id and tensor info) for the network input identified by
/// the given layer name and layers id
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp
index de9b1a98c7..9af5087cff 100644
--- a/src/armnnDeserializeParser/DeserializeParser.cpp
+++ b/src/armnnDeserializeParser/DeserializeParser.cpp
@@ -352,13 +352,6 @@ void IDeserializeParser::Destroy(IDeserializeParser* parser)
delete parser;
}
-INetworkPtr DeserializeParser::CreateNetworkFromBinaryFile(const char* graphFile)
-{
- ResetParser();
- m_Graph = LoadGraphFromFile(graphFile, m_FileContent);
- return CreateNetworkFromGraph();
-}
-
INetworkPtr DeserializeParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
{
ResetParser();
@@ -366,25 +359,11 @@ INetworkPtr DeserializeParser::CreateNetworkFromBinary(const std::vector<uint8_t
return CreateNetworkFromGraph();
}
-DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromFile(const char* fileName, std::string& fileContent)
+armnn::INetworkPtr DeserializeParser::CreateNetworkFromBinary(std::istream& binaryContent)
{
- if (fileName == nullptr)
- {
- throw InvalidArgumentException(boost::str(boost::format("Invalid (null) file name %1%") %
- CHECK_LOCATION().AsString()));
- }
- boost::system::error_code errorCode;
- boost::filesystem::path pathToFile(fileName);
- if (!boost::filesystem::exists(pathToFile, errorCode))
- {
- throw FileNotFoundException(boost::str(boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
- fileName %
- errorCode %
- CHECK_LOCATION().AsString()));
- }
- std::ifstream file(fileName, std::ios::binary);
- fileContent = std::string((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
- return LoadGraphFromBinary(reinterpret_cast<const uint8_t*>(fileContent.c_str()), fileContent.size());
+ ResetParser();
+ m_Graph = LoadGraphFromBinary(binaryContent);
+ return CreateNetworkFromGraph();
}
DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromBinary(const uint8_t* binaryContent, size_t len)
@@ -406,6 +385,12 @@ DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromBinary(const uint8_t
return GetSerializedGraph(binaryContent);
}
+DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromBinary(std::istream& binaryContent)
+{
+ std::string content((std::istreambuf_iterator<char>(binaryContent)), std::istreambuf_iterator<char>());
+ return GetSerializedGraph(content.data());
+}
+
INetworkPtr DeserializeParser::CreateNetworkFromGraph()
{
m_Network = INetwork::Create();
diff --git a/src/armnnDeserializeParser/DeserializeParser.hpp b/src/armnnDeserializeParser/DeserializeParser.hpp
index 666cbca33c..aee647c636 100644
--- a/src/armnnDeserializeParser/DeserializeParser.hpp
+++ b/src/armnnDeserializeParser/DeserializeParser.hpp
@@ -25,26 +25,25 @@ public:
public:
- /// Create the network from a flatbuffers binary file on disk
- virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
+ /// Create an input network from binary file contents
+ armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override;
- virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override;
+ /// Create an input network from a binary input stream
+ armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) override;
/// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
- virtual BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId,
- const std::string& name) const override;
+ BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const override;
/// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
- virtual BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId,
- const std::string& name) const override;
+ BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const override;
DeserializeParser();
~DeserializeParser() {}
public:
// testable helpers
- static GraphPtr LoadGraphFromFile(const char* fileName, std::string& fileContent);
static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
+ static GraphPtr LoadGraphFromBinary(std::istream& binaryContent);
static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex);
static LayerBaseRawPtrVector GetGraphInputs(const GraphPtr& graphPtr);
@@ -91,10 +90,6 @@ private:
std::vector<LayerParsingFunction> m_ParserFunctions;
std::string m_layerName;
- /// This holds the data of the file that was read in from CreateNetworkFromBinaryFile
- /// Needed for m_Graph to point to
- std::string m_FileContent;
-
/// A mapping of an output slot to each of the input slots it should be connected to
/// The outputSlot is from the layer that creates this tensor as one of its outputs
/// The inputSlots are from the layers that use this tensor as one of their inputs
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index eb3b2ccd42..4819523595 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -183,8 +183,20 @@ public:
{
ARMNN_SCOPED_HEAP_PROFILING("Parsing");
- const std::string& modelPath = params.m_ModelPath;
- network = parser->CreateNetworkFromBinaryFile(modelPath.c_str());
+
+ boost::system::error_code errorCode;
+ boost::filesystem::path pathToFile(params.m_ModelPath);
+ if (!boost::filesystem::exists(pathToFile, errorCode))
+ {
+ throw armnn::FileNotFoundException(boost::str(
+ boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
+ params.m_ModelPath %
+ errorCode %
+ CHECK_LOCATION().AsString()));
+ }
+ std::ifstream file(params.m_ModelPath, std::ios::binary);
+
+ network = parser->CreateNetworkFromBinary(file);
}
unsigned int subGraphId = boost::numeric_cast<unsigned int>(params.m_SubgraphId);