diff options
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r-- | tests/InferenceModel.hpp | 50 |
1 files changed, 48 insertions, 2 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index eb5f708c81..eb3b2ccd42 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -5,15 +5,18 @@ #pragma once #include <armnn/ArmNN.hpp> +#if defined(ARMNN_SERIALIZER) +#include "armnnDeserializeParser/IDeserializeParser.hpp" +#endif #if defined(ARMNN_TF_LITE_PARSER) #include <armnnTfLiteParser/ITfLiteParser.hpp> #endif - -#include <HeapProfiling.hpp> #if defined(ARMNN_ONNX_PARSER) #include <armnnOnnxParser/IOnnxParser.hpp> #endif +#include <HeapProfiling.hpp> + #include <backendsCommon/BackendRegistry.hpp> #include <boost/algorithm/string/join.hpp> @@ -160,6 +163,49 @@ public: } }; +#if defined(ARMNN_SERIALIZER) +template <> +struct CreateNetworkImpl<armnnDeserializeParser::IDeserializeParser> +{ +public: + using IParser = armnnDeserializeParser::IDeserializeParser; + using Params = InferenceModelInternal::Params; + using BindingPointInfo = InferenceModelInternal::BindingPointInfo; + + static armnn::INetworkPtr Create(const Params& params, + std::vector<BindingPointInfo>& inputBindings, + std::vector<BindingPointInfo>& outputBindings) + { + auto parser(IParser::Create()); + BOOST_ASSERT(parser); + + armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}}; + + { + ARMNN_SCOPED_HEAP_PROFILING("Parsing"); + const std::string& modelPath = params.m_ModelPath; + network = parser->CreateNetworkFromBinaryFile(modelPath.c_str()); + } + + unsigned int subGraphId = boost::numeric_cast<unsigned int>(params.m_SubgraphId); + + for (const std::string& inputLayerName : params.m_InputBindings) + { + BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(subGraphId, inputLayerName); + inputBindings.push_back(inputBinding); + } + + for (const std::string& outputLayerName : params.m_OutputBindings) + { + BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(subGraphId, outputLayerName); + outputBindings.push_back(outputBinding); + } + + return network; + } +}; +#endif + #if defined(ARMNN_TF_LITE_PARSER) template <> struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser> |