aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp50
1 files changed, 48 insertions, 2 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index eb5f708c81..eb3b2ccd42 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -5,15 +5,18 @@
#pragma once
#include <armnn/ArmNN.hpp>
+#if defined(ARMNN_SERIALIZER)
+#include "armnnDeserializeParser/IDeserializeParser.hpp"
+#endif
#if defined(ARMNN_TF_LITE_PARSER)
#include <armnnTfLiteParser/ITfLiteParser.hpp>
#endif
-
-#include <HeapProfiling.hpp>
#if defined(ARMNN_ONNX_PARSER)
#include <armnnOnnxParser/IOnnxParser.hpp>
#endif
+#include <HeapProfiling.hpp>
+
#include <backendsCommon/BackendRegistry.hpp>
#include <boost/algorithm/string/join.hpp>
@@ -160,6 +163,49 @@ public:
}
};
+#if defined(ARMNN_SERIALIZER)
+template <>
+struct CreateNetworkImpl<armnnDeserializeParser::IDeserializeParser>
+{
+public:
+ using IParser = armnnDeserializeParser::IDeserializeParser;
+ using Params = InferenceModelInternal::Params;
+ using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
+
+ static armnn::INetworkPtr Create(const Params& params,
+ std::vector<BindingPointInfo>& inputBindings,
+ std::vector<BindingPointInfo>& outputBindings)
+ {
+ auto parser(IParser::Create());
+ BOOST_ASSERT(parser);
+
+ armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
+
+ {
+ ARMNN_SCOPED_HEAP_PROFILING("Parsing");
+ const std::string& modelPath = params.m_ModelPath;
+ network = parser->CreateNetworkFromBinaryFile(modelPath.c_str());
+ }
+
+ unsigned int subGraphId = boost::numeric_cast<unsigned int>(params.m_SubgraphId);
+
+ for (const std::string& inputLayerName : params.m_InputBindings)
+ {
+ BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(subGraphId, inputLayerName);
+ inputBindings.push_back(inputBinding);
+ }
+
+ for (const std::string& outputLayerName : params.m_OutputBindings)
+ {
+ BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(subGraphId, outputLayerName);
+ outputBindings.push_back(outputBinding);
+ }
+
+ return network;
+ }
+};
+#endif
+
#if defined(ARMNN_TF_LITE_PARSER)
template <>
struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser>