diff options
author | Finn Williams <finn.williams@arm.com> | 2022-06-20 13:48:20 +0100 |
---|---|---|
committer | Nikhil Raj <nikhil.raj@arm.com> | 2022-07-08 15:19:18 +0100 |
commit | 452c58080e9f8f577de87e0c07d0097aac97f3b8 (patch) | |
tree | 0a3bd2cc754cde1b3133a914597d607c52ce75ff /tests/ExecuteNetwork/ArmNNExecutor.hpp | |
parent | c7b6de86431e26766b60a69bcfcde985af61a028 (diff) | |
download | armnn-452c58080e9f8f577de87e0c07d0097aac97f3b8.tar.gz |
IVGCVSW-6650 Refactor ExecuteNetwork
* Remove InferenceModel
* Add automatic IO type, shape and name configuration
* Depreciate various redundant options
* Add internal output comparison
Signed-off-by: Finn Williams <finn.williams@arm.com>
Change-Id: I2eca248bc91e1655a99ed94990efb8059f541fa9
Diffstat (limited to 'tests/ExecuteNetwork/ArmNNExecutor.hpp')
-rw-r--r-- | tests/ExecuteNetwork/ArmNNExecutor.hpp | 161 |
1 files changed, 161 insertions, 0 deletions
diff --git a/tests/ExecuteNetwork/ArmNNExecutor.hpp b/tests/ExecuteNetwork/ArmNNExecutor.hpp new file mode 100644 index 0000000000..aec7a20a06 --- /dev/null +++ b/tests/ExecuteNetwork/ArmNNExecutor.hpp @@ -0,0 +1,161 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "IExecutor.hpp" +#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp" +#include "ExecuteNetworkProgramOptions.hpp" +#include "armnn/utility/NumericCast.hpp" +#include "armnn/utility/Timer.hpp" + +#include <armnn/ArmNN.hpp> +#include <armnn/Threadpool.hpp> +#include <armnn/Logging.hpp> +#include <armnn/utility/Timer.hpp> +#include <armnn/BackendRegistry.hpp> +#include <armnn/utility/Assert.hpp> +#include <armnn/utility/NumericCast.hpp> + +#include <armnnUtils/Filesystem.hpp> +#include <HeapProfiling.hpp> + +#include <fmt/format.h> + +#if defined(ARMNN_SERIALIZER) +#include "armnnDeserializer/IDeserializer.hpp" +#endif +#if defined(ARMNN_TF_LITE_PARSER) +#include <armnnTfLiteParser/ITfLiteParser.hpp> +#endif +#if defined(ARMNN_ONNX_PARSER) +#include <armnnOnnxParser/IOnnxParser.hpp> +#endif + +class ArmNNExecutor : public IExecutor +{ +public: + ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions); + + std::vector<const void *> Execute() override; + void PrintNetworkInfo() override; + void CompareAndPrintResult(std::vector<const void*> otherOutput) override; + +private: + + struct IParser; + struct IOInfo; + struct IOStorage; + + using BindingPointInfo = armnn::BindingPointInfo; + + std::unique_ptr<IParser> CreateParser(); + + void ExecuteAsync(); + void ExecuteSync(); + void SetupInputsAndOutputs(); + + IOInfo GetIOInfo(armnn::INetwork* network); + + void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration); + + armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network); + + struct IOStorage + { + IOStorage(size_t size) + { + m_Mem = operator new(size); + } + ~IOStorage() + { + operator delete(m_Mem); + } + IOStorage(IOStorage &&rhs) + { + this->m_Mem = rhs.m_Mem; + rhs.m_Mem = nullptr; + } + + IOStorage(const IOStorage &rhs) = delete; + IOStorage &operator=(IOStorage &rhs) = delete; + IOStorage &operator=(IOStorage &&rhs) = delete; + + void *m_Mem; + }; + + struct IOInfo + { + std::vector<std::string> m_InputNames; + std::vector<std::string> m_OutputNames; + std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap; + std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap; + }; + + IOInfo m_IOInfo; + std::vector<IOStorage> m_InputStorage; + std::vector<IOStorage> m_OutputStorage; + std::vector<armnn::InputTensors> m_InputTensorsVec; + std::vector<armnn::OutputTensors> m_OutputTensorsVec; + std::vector<std::vector<unsigned int>> m_ImportedInputIds; + std::vector<std::vector<unsigned int>> m_ImportedOutputIds; + std::shared_ptr<armnn::IRuntime> m_Runtime; + armnn::NetworkId m_NetworkId; + ExecuteNetworkParams m_Params; + + struct IParser + { + virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0; + virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string &inputName) = 0; + virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string &outputName) = 0; + + virtual ~IParser(){}; + }; + +#if defined(ARMNN_SERIALIZER) + class ArmNNDeserializer : public IParser + { + public: + ArmNNDeserializer(); + + armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams ¶ms) override; + armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string &inputName) override; + armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string &outputName) override; + + private: + armnnDeserializer::IDeserializerPtr m_Parser; + }; +#endif + +#if defined(ARMNN_TF_LITE_PARSER) + class TfliteParser : public IParser + { + public: + TfliteParser(const ExecuteNetworkParams& params); + + armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams ¶ms) override; + armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string &inputName) override; + armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string &outputName) override; + + private: + armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}}; + }; +#endif + +#if defined(ARMNN_ONNX_PARSER) + class OnnxParser : public IParser + { + public: + OnnxParser(); + + armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams ¶ms) override; + armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string &inputName) override; + armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string &outputName) override; + + private: + armnnOnnxParser::IOnnxParserPtr m_Parser; + }; +#endif +};
\ No newline at end of file |