aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-02-12 11:27:53 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-02-12 11:40:24 +0000
commit64e4ccb4546473e922b4ddd699ff6b77a5c2527d (patch)
tree3c5fbd6be6706d7450919030e7c91355d0b3507a /tests/InferenceModel.hpp
parent424951560f2948b49506f178352e788cbe680fd8 (diff)
downloadarmnn-64e4ccb4546473e922b4ddd699ff6b77a5c2527d.tar.gz
IVGCVSW-2663 Enable ExecuteNetwork to load ArmNN files
Change-Id: I1a61a1da2258bd07b39da6063d22a5bd22c1884d Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp50
1 files changed, 48 insertions, 2 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index eb5f708c81..eb3b2ccd42 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -5,15 +5,18 @@
#pragma once
#include <armnn/ArmNN.hpp>
+#if defined(ARMNN_SERIALIZER)
+#include "armnnDeserializeParser/IDeserializeParser.hpp"
+#endif
#if defined(ARMNN_TF_LITE_PARSER)
#include <armnnTfLiteParser/ITfLiteParser.hpp>
#endif
-
-#include <HeapProfiling.hpp>
#if defined(ARMNN_ONNX_PARSER)
#include <armnnOnnxParser/IOnnxParser.hpp>
#endif
+#include <HeapProfiling.hpp>
+
#include <backendsCommon/BackendRegistry.hpp>
#include <boost/algorithm/string/join.hpp>
@@ -160,6 +163,49 @@ public:
}
};
+#if defined(ARMNN_SERIALIZER)
+template <>
+struct CreateNetworkImpl<armnnDeserializeParser::IDeserializeParser>
+{
+public:
+ using IParser = armnnDeserializeParser::IDeserializeParser;
+ using Params = InferenceModelInternal::Params;
+ using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
+
+ static armnn::INetworkPtr Create(const Params& params,
+ std::vector<BindingPointInfo>& inputBindings,
+ std::vector<BindingPointInfo>& outputBindings)
+ {
+ auto parser(IParser::Create());
+ BOOST_ASSERT(parser);
+
+ armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
+
+ {
+ ARMNN_SCOPED_HEAP_PROFILING("Parsing");
+ const std::string& modelPath = params.m_ModelPath;
+ network = parser->CreateNetworkFromBinaryFile(modelPath.c_str());
+ }
+
+ unsigned int subGraphId = boost::numeric_cast<unsigned int>(params.m_SubgraphId);
+
+ for (const std::string& inputLayerName : params.m_InputBindings)
+ {
+ BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(subGraphId, inputLayerName);
+ inputBindings.push_back(inputBinding);
+ }
+
+ for (const std::string& outputLayerName : params.m_OutputBindings)
+ {
+ BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(subGraphId, outputLayerName);
+ outputBindings.push_back(outputBinding);
+ }
+
+ return network;
+ }
+};
+#endif
+
#if defined(ARMNN_TF_LITE_PARSER)
template <>
struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser>