diff options
Diffstat (limited to 'tests/ExecuteNetwork/ArmNNExecutor.hpp')
-rw-r--r-- | tests/ExecuteNetwork/ArmNNExecutor.hpp | 161 |
1 files changed, 161 insertions, 0 deletions
diff --git a/tests/ExecuteNetwork/ArmNNExecutor.hpp b/tests/ExecuteNetwork/ArmNNExecutor.hpp new file mode 100644 index 0000000000..c4adc9e120 --- /dev/null +++ b/tests/ExecuteNetwork/ArmNNExecutor.hpp @@ -0,0 +1,161 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "IExecutor.hpp" +#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp" +#include "ExecuteNetworkProgramOptions.hpp" +#include "armnn/utility/NumericCast.hpp" +#include "armnn/utility/Timer.hpp" + +#include <armnn/ArmNN.hpp> +#include <armnn/Threadpool.hpp> +#include <armnn/Logging.hpp> +#include <armnn/utility/Timer.hpp> +#include <armnn/BackendRegistry.hpp> +#include <armnn/utility/Assert.hpp> +#include <armnn/utility/NumericCast.hpp> + +#include <armnnUtils/Filesystem.hpp> +#include <HeapProfiling.hpp> + +#include <fmt/format.h> + +#if defined(ARMNN_SERIALIZER) +#include "armnnDeserializer/IDeserializer.hpp" +#endif +#if defined(ARMNN_TF_LITE_PARSER) +#include <armnnTfLiteParser/ITfLiteParser.hpp> +#endif +#if defined(ARMNN_ONNX_PARSER) +#include <armnnOnnxParser/IOnnxParser.hpp> +#endif + +class ArmNNExecutor : public IExecutor +{ +public: + ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions); + + std::vector<const void* > Execute() override; + void PrintNetworkInfo() override; + void CompareAndPrintResult(std::vector<const void*> otherOutput) override; + +private: + + struct IParser; + struct IOInfo; + struct IOStorage; + + using BindingPointInfo = armnn::BindingPointInfo; + + std::unique_ptr<IParser> CreateParser(); + + void ExecuteAsync(); + void ExecuteSync(); + void SetupInputsAndOutputs(); + + IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet); + + void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration); + + armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network); + + struct IOStorage + { + IOStorage(size_t size) + { + m_Mem = operator new(size); + } + ~IOStorage() + { + operator delete(m_Mem); + } + IOStorage(IOStorage&& rhs) + { + this->m_Mem = rhs.m_Mem; + rhs.m_Mem = nullptr; + } + + IOStorage(const IOStorage& rhs) = delete; + IOStorage& operator=(IOStorage& rhs) = delete; + IOStorage& operator=(IOStorage&& rhs) = delete; + + void* m_Mem; + }; + + struct IOInfo + { + std::vector<std::string> m_InputNames; + std::vector<std::string> m_OutputNames; + std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap; + std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap; + }; + + IOInfo m_IOInfo; + std::vector<IOStorage> m_InputStorage; + std::vector<IOStorage> m_OutputStorage; + std::vector<armnn::InputTensors> m_InputTensorsVec; + std::vector<armnn::OutputTensors> m_OutputTensorsVec; + std::vector<std::vector<unsigned int>> m_ImportedInputIds; + std::vector<std::vector<unsigned int>> m_ImportedOutputIds; + std::shared_ptr<armnn::IRuntime> m_Runtime; + armnn::NetworkId m_NetworkId; + ExecuteNetworkParams m_Params; + + struct IParser + { + virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0; + virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0; + virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0; + + virtual ~IParser(){}; + }; + +#if defined(ARMNN_SERIALIZER) + class ArmNNDeserializer : public IParser + { + public: + ArmNNDeserializer(); + + armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; + armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override; + armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override; + + private: + armnnDeserializer::IDeserializerPtr m_Parser; + }; +#endif + +#if defined(ARMNN_TF_LITE_PARSER) + class TfliteParser : public IParser + { + public: + TfliteParser(const ExecuteNetworkParams& params); + + armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; + armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override; + armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override; + + private: + armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}}; + }; +#endif + +#if defined(ARMNN_ONNX_PARSER) + class OnnxParser : public IParser + { + public: + OnnxParser(); + + armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override; + armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override; + armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override; + + private: + armnnOnnxParser::IOnnxParserPtr m_Parser; + }; +#endif +};
\ No newline at end of file |