aboutsummaryrefslogtreecommitdiff
path: root/tests/ExecuteNetwork/ArmNNExecutor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ExecuteNetwork/ArmNNExecutor.hpp')
-rw-r--r--tests/ExecuteNetwork/ArmNNExecutor.hpp161
1 files changed, 161 insertions, 0 deletions
diff --git a/tests/ExecuteNetwork/ArmNNExecutor.hpp b/tests/ExecuteNetwork/ArmNNExecutor.hpp
new file mode 100644
index 0000000000..c4adc9e120
--- /dev/null
+++ b/tests/ExecuteNetwork/ArmNNExecutor.hpp
@@ -0,0 +1,161 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "IExecutor.hpp"
+#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
+#include "ExecuteNetworkProgramOptions.hpp"
+#include "armnn/utility/NumericCast.hpp"
+#include "armnn/utility/Timer.hpp"
+
+#include <armnn/ArmNN.hpp>
+#include <armnn/Threadpool.hpp>
+#include <armnn/Logging.hpp>
+#include <armnn/utility/Timer.hpp>
+#include <armnn/BackendRegistry.hpp>
+#include <armnn/utility/Assert.hpp>
+#include <armnn/utility/NumericCast.hpp>
+
+#include <armnnUtils/Filesystem.hpp>
+#include <HeapProfiling.hpp>
+
+#include <fmt/format.h>
+
+#if defined(ARMNN_SERIALIZER)
+#include "armnnDeserializer/IDeserializer.hpp"
+#endif
+#if defined(ARMNN_TF_LITE_PARSER)
+#include <armnnTfLiteParser/ITfLiteParser.hpp>
+#endif
+#if defined(ARMNN_ONNX_PARSER)
+#include <armnnOnnxParser/IOnnxParser.hpp>
+#endif
+
+class ArmNNExecutor : public IExecutor
+{
+public:
+ ArmNNExecutor(const ExecuteNetworkParams& params, armnn::IRuntime::CreationOptions runtimeOptions);
+
+ std::vector<const void* > Execute() override;
+ void PrintNetworkInfo() override;
+ void CompareAndPrintResult(std::vector<const void*> otherOutput) override;
+
+private:
+
+ struct IParser;
+ struct IOInfo;
+ struct IOStorage;
+
+ using BindingPointInfo = armnn::BindingPointInfo;
+
+ std::unique_ptr<IParser> CreateParser();
+
+ void ExecuteAsync();
+ void ExecuteSync();
+ void SetupInputsAndOutputs();
+
+ IOInfo GetIOInfo(armnn::IOptimizedNetwork* optNet);
+
+ void PrintOutputTensors(const armnn::OutputTensors* outputTensors, unsigned int iteration);
+
+ armnn::IOptimizedNetworkPtr OptimizeNetwork(armnn::INetwork* network);
+
+ struct IOStorage
+ {
+ IOStorage(size_t size)
+ {
+ m_Mem = operator new(size);
+ }
+ ~IOStorage()
+ {
+ operator delete(m_Mem);
+ }
+ IOStorage(IOStorage&& rhs)
+ {
+ this->m_Mem = rhs.m_Mem;
+ rhs.m_Mem = nullptr;
+ }
+
+ IOStorage(const IOStorage& rhs) = delete;
+ IOStorage& operator=(IOStorage& rhs) = delete;
+ IOStorage& operator=(IOStorage&& rhs) = delete;
+
+ void* m_Mem;
+ };
+
+ struct IOInfo
+ {
+ std::vector<std::string> m_InputNames;
+ std::vector<std::string> m_OutputNames;
+ std::map<std::string, armnn::BindingPointInfo> m_InputInfoMap;
+ std::map<std::string, armnn::BindingPointInfo> m_OutputInfoMap;
+ };
+
+ IOInfo m_IOInfo;
+ std::vector<IOStorage> m_InputStorage;
+ std::vector<IOStorage> m_OutputStorage;
+ std::vector<armnn::InputTensors> m_InputTensorsVec;
+ std::vector<armnn::OutputTensors> m_OutputTensorsVec;
+ std::vector<std::vector<unsigned int>> m_ImportedInputIds;
+ std::vector<std::vector<unsigned int>> m_ImportedOutputIds;
+ std::shared_ptr<armnn::IRuntime> m_Runtime;
+ armnn::NetworkId m_NetworkId;
+ ExecuteNetworkParams m_Params;
+
+ struct IParser
+ {
+ virtual armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) = 0;
+ virtual armnn::BindingPointInfo GetInputBindingPointInfo(size_t id, const std::string& inputName) = 0;
+ virtual armnn::BindingPointInfo GetOutputBindingPointInfo(size_t id, const std::string& outputName) = 0;
+
+ virtual ~IParser(){};
+ };
+
+#if defined(ARMNN_SERIALIZER)
+ class ArmNNDeserializer : public IParser
+ {
+ public:
+ ArmNNDeserializer();
+
+ armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
+ armnn::BindingPointInfo GetInputBindingPointInfo(size_t, const std::string& inputName) override;
+ armnn::BindingPointInfo GetOutputBindingPointInfo(size_t, const std::string& outputName) override;
+
+ private:
+ armnnDeserializer::IDeserializerPtr m_Parser;
+ };
+#endif
+
+#if defined(ARMNN_TF_LITE_PARSER)
+ class TfliteParser : public IParser
+ {
+ public:
+ TfliteParser(const ExecuteNetworkParams& params);
+
+ armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
+ armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
+ armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
+
+ private:
+ armnnTfLiteParser::ITfLiteParserPtr m_Parser{nullptr, [](armnnTfLiteParser::ITfLiteParser*){}};
+ };
+#endif
+
+#if defined(ARMNN_ONNX_PARSER)
+ class OnnxParser : public IParser
+ {
+ public:
+ OnnxParser();
+
+ armnn::INetworkPtr CreateNetwork(const ExecuteNetworkParams& params) override;
+ armnn::BindingPointInfo GetInputBindingPointInfo(size_t subgraphId, const std::string& inputName) override;
+ armnn::BindingPointInfo GetOutputBindingPointInfo(size_t subgraphId, const std::string& outputName) override;
+
+ private:
+ armnnOnnxParser::IOnnxParserPtr m_Parser;
+ };
+#endif
+}; \ No newline at end of file