From 55a8ffda24fff5515803df10fb4863d46a1effdf Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Wed, 7 Apr 2021 20:10:49 +0100 Subject: IVGCVSW-5823 Refactor Async Network API * Moved IAsyncNetwork into IRuntime. * All LoadedNetworks can be executed Asynchronously. Signed-off-by: Mike Kelly Change-Id: Ibbc901ab9110dc2f881425b75489bccf9ad54169 --- include/armnn/ArmNN.hpp | 2 +- include/armnn/IAsyncNetwork.hpp | 64 ------------------------------------- include/armnn/IRuntime.hpp | 34 ++++++++++---------- include/armnn/IWorkingMemHandle.hpp | 5 +++ include/armnn/NetworkFwd.hpp | 7 ---- 5 files changed, 23 insertions(+), 89 deletions(-) delete mode 100644 include/armnn/IAsyncNetwork.hpp (limited to 'include') diff --git a/include/armnn/ArmNN.hpp b/include/armnn/ArmNN.hpp index ac4d33f737..e4d5ce1fa1 100644 --- a/include/armnn/ArmNN.hpp +++ b/include/armnn/ArmNN.hpp @@ -7,9 +7,9 @@ #include "BackendId.hpp" #include "Descriptors.hpp" #include "Exceptions.hpp" -#include "IAsyncNetwork.hpp" #include "INetwork.hpp" #include "IRuntime.hpp" +#include "IWorkingMemHandle.hpp" #include "LstmParams.hpp" #include "Optional.hpp" #include "QuantizedLstmParams.hpp" diff --git a/include/armnn/IAsyncNetwork.hpp b/include/armnn/IAsyncNetwork.hpp deleted file mode 100644 index c234ae55ac..0000000000 --- a/include/armnn/IAsyncNetwork.hpp +++ /dev/null @@ -1,64 +0,0 @@ -// -// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include - -#include "INetwork.hpp" -#include "IProfiler.hpp" -#include "IWorkingMemHandle.hpp" -#include "Tensor.hpp" -#include "Types.hpp" - -#include - -namespace armnn -{ -struct INetworkProperties; - -namespace profiling -{ -class ProfilingService; -} - -namespace experimental -{ -class AsyncNetworkImpl; - -class IAsyncNetwork -{ -public: - IAsyncNetwork(std::unique_ptr net, - const INetworkProperties& networkProperties, - profiling::ProfilingService& profilingService); - ~IAsyncNetwork(); - - TensorInfo GetInputTensorInfo(LayerBindingId layerId) const; - TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const; - - /// Thread safe execution of the network. Returns once execution is complete. - /// Will block until this and any other thread using the same workingMem object completes. - Status Execute(const InputTensors& inputTensors, - const OutputTensors& outputTensors, - IWorkingMemHandle& workingMemHandle); - - /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have - /// overlapped Execution by calling this function from different threads. - std::unique_ptr CreateWorkingMemHandle(); - - /// Get the profiler used for this network - std::shared_ptr GetProfiler() const; - - /// Register a debug callback function to be used with this network - void RegisterDebugCallback(const DebugCallbackFunction& func); - -private: - std::unique_ptr pAsyncNetworkImpl; -}; - -} // end experimental namespace - -} // end armnn namespace diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp index 9f7032914f..fc203e67e4 100644 --- a/include/armnn/IRuntime.hpp +++ b/include/armnn/IRuntime.hpp @@ -5,9 +5,9 @@ #pragma once #include "BackendOptions.hpp" -#include "IAsyncNetwork.hpp" #include "INetwork.hpp" #include "IProfiler.hpp" +#include "IWorkingMemHandle.hpp" #include "Tensor.hpp" #include "Types.hpp" #include "TypesUtils.hpp" @@ -28,12 +28,14 @@ using IRuntimePtr = std::unique_ptr; struct INetworkProperties { - INetworkProperties(bool importEnabled = false, bool exportEnabled = false) + INetworkProperties(bool importEnabled = false, bool exportEnabled = false, bool asyncEnabled = false) : m_ImportEnabled(importEnabled), - m_ExportEnabled(exportEnabled) {} + m_ExportEnabled(exportEnabled), + m_AsyncEnabled(asyncEnabled) {} const bool m_ImportEnabled; const bool m_ExportEnabled; + const bool m_AsyncEnabled; virtual ~INetworkProperties() {} }; @@ -145,20 +147,6 @@ public: std::string& errorMessage, const INetworkProperties& networkProperties); - /// This is an experimental function. - /// Creates an executable network. This network is thread safe allowing for multiple networks to be - /// loaded simultaneously via different threads. - /// Note that the network is never registered with the runtime so does not need to be 'Unloaded'. - /// @param [out] networkIdOut Unique identifier for the network is returned in this reference. - /// @param [in] network Complete network to load into the IRuntime. - /// @param [out] errorMessage Error message if there were any errors. - /// @param [out] networkProperties the INetworkProperties that govern how the network should operate. - /// @return The IAsyncNetwork - std::unique_ptr CreateAsyncNetwork(NetworkId& networkIdOut, - IOptimizedNetworkPtr network, - std::string& errorMessage, - const INetworkProperties& networkProperties); - TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; @@ -167,6 +155,14 @@ public: const InputTensors& inputTensors, const OutputTensors& outputTensors); + /// This is an experimental function. + /// Evaluates a network using input in inputTensors and outputs filled into outputTensors. + /// This function performs a thread safe execution of the network. Returns once execution is complete. + /// Will block until this and any other thread using the same workingMem object completes. + Status Execute(IWorkingMemHandle& workingMemHandle, + const InputTensors& inputTensors, + const OutputTensors& outputTensors); + /// Unloads a network from the IRuntime. /// At the moment this only removes the network from the m_Impl->m_Network. /// This might need more work in the future to be AndroidNN compliant. @@ -176,6 +172,10 @@ public: const IDeviceSpec& GetDeviceSpec() const; + /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have + /// overlapped Execution by calling this function from different threads. + std::unique_ptr CreateWorkingMemHandle(NetworkId networkId); + /// Gets the profiler corresponding to the given network id. /// @param networkId The id of the network for which to get the profile. /// @return A pointer to the requested profiler, or nullptr if not found. diff --git a/include/armnn/IWorkingMemHandle.hpp b/include/armnn/IWorkingMemHandle.hpp index 921b7e1f40..171fa3d81c 100644 --- a/include/armnn/IWorkingMemHandle.hpp +++ b/include/armnn/IWorkingMemHandle.hpp @@ -10,6 +10,8 @@ namespace armnn { +using NetworkId = int; + namespace experimental { @@ -20,6 +22,9 @@ class IWorkingMemHandle public: virtual ~IWorkingMemHandle() {}; + /// Returns the NetworkId of the Network that this IWorkingMemHandle works with. + virtual NetworkId GetNetworkId() = 0; + /// Allocate the backing memory required for execution. If this is not called, then allocation will be /// deferred to execution time. The mutex must be locked. virtual void Allocate() = 0; diff --git a/include/armnn/NetworkFwd.hpp b/include/armnn/NetworkFwd.hpp index 6c2970f28b..5db9ec4ebe 100644 --- a/include/armnn/NetworkFwd.hpp +++ b/include/armnn/NetworkFwd.hpp @@ -10,13 +10,6 @@ namespace armnn struct LstmInputParams; struct QuantizedLstmInputParams; -namespace experimental -{ - -class IAsyncNetwork; - -} // end experimental namespace - class INetwork; class IOptimizedNetwork; class Graph; -- cgit v1.2.1