// // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "IExecutor.hpp" #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp" #include "ExecuteNetworkProgramOptions.hpp" #include "armnn/utility/NumericCast.hpp" #include "armnn/utility/Timer.hpp" #include #include #include #include #include #include #include #include #include #include #if defined(ARMNN_SERIALIZER) #include "armnnDeserializer/IDeserializer.hpp" #endif #if defined(ARMNN_TF_LITE_PARSER) #include #endif #if defined(ARMNN_ONNX_PARSER) #include #endif class ArmNNExecutor : public IExecutor { public: ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions); ~ArmNNExecutor(); ArmNNExecutor(const ArmNNExecutor&) = delete; // No copy constructor. ArmNNExecutor & operator=(const ArmNNExecutor&) = delete; // No Copy operator. std::vector Execute() override; void PrintNetworkInfo() override; void CompareAndPrintResult(std::vector otherOutput) override; private: ArmNNExecutor(ArmNNExecutor&&); // No move constructor. ArmNNExecutor& operator=(ArmNNExecutor&&); // No move operator. /** * Returns a pointer to the armnn::IRuntime* this will be shared by all ArmNNExecutors. */ armnn::IRuntime* GetRuntime(const armnn::IRuntime::CreationOptions& options) { static armnn::IRuntimePtr instance = armnn::IRuntime::Create(options); // Instantiated on first use. return instance.get(); } struct IParser; struct IOInfo; struct IOStorage; using BindingPointInfo = armnn::BindingPointInfo; std::unique_ptr CreateParser(); void ExecuteAsync(); void ExecuteSync(); void SetupInputsAndOutputs(); IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet); void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration); armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network); struct IOStorage { IOStorage(size_t size) { m_Mem = operator new(size); } ~IOStorage() { operator delete(m_Mem); } IOStorage(IOStorage&& rhs) { this->m_Mem = rhs.m_Mem; rhs.m_Mem = nullptr; } IOStorage(const IOStorage& rhs) = delete; IOStorage& operator=(IOStorage& rhs) = delete; IOStorage& operator=(IOStorage&& rhs) = delete; void* m_Mem; }; struct IOInfo { std::vector m_InputNames; std::vector m_OutputNames; std::map m_InputInfoMap; std::map m_OutputInfoMap; }; IOInfo m_IOInfo; std::vector m_InputStorage; std::vector m_OutputStorage; std::vector m_InputTensorsVec; std::vector m_OutputTensorsVec; std::vector> m_ImportedInputIds; std::vector> m_ImportedOutputIds; armnn::IRuntime* m_Runtime; armnn::NetworkId m_NetworkId; ExecuteNetworkParams m_Params; struct IParser { virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0; virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0; virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0; virtual ~IParser(){}; }; #if defined(ARMNN_SERIALIZER) class ArmNNDeserializer : public IParser { public: ArmNNDeserializer(); armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override; armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override; private: armnnDeserializer::IDeserializerPtr m_Parser; }; #endif #if defined(ARMNN_TF_LITE_PARSER) class TfliteParser : public IParser { public: TfliteParser(const ExecuteNetworkParams& params); armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override; armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override; private: armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}}; }; #endif #if defined(ARMNN_ONNX_PARSER) class OnnxParser : public IParser { public: OnnxParser(); armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override; armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override; private: armnnOnnxParser::IOnnxParserPtr m_Parser; }; #endif };