From 386ff1a721cdca3689b009ba31f2d3ac8bea2fae Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Mon, 29 Mar 2021 15:04:50 +0100 Subject: IVGCVSW-5790 Merge async prototype * Added thread safe execution mechanism for armnn * Removed duplicate function bool Compare(T a, T b, float tolerance) * Added StridedSliceAsyncEndToEndTest * Fixed memory leak Signed-off-by: Mike Kelly Change-Id: I2d367fc77ee7c01b8953138543e76af5e691211f --- include/armnn/ArmNN.hpp | 1 + include/armnn/IAsyncNetwork.hpp | 51 ++++++++++++++++++++++++++++++++++++ include/armnn/INetwork.hpp | 10 +++++++ include/armnn/IRuntime.hpp | 17 ++++++++++++ include/armnn/IWorkingMemHandle.hpp | 46 ++++++++++++++++++++++++++++++++ include/armnn/NetworkFwd.hpp | 12 ++++++++- include/armnn/backends/IWorkload.hpp | 15 +++++++++-- 7 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 include/armnn/IAsyncNetwork.hpp create mode 100644 include/armnn/IWorkingMemHandle.hpp (limited to 'include/armnn') diff --git a/include/armnn/ArmNN.hpp b/include/armnn/ArmNN.hpp index 4b945b91b3..ac4d33f737 100644 --- a/include/armnn/ArmNN.hpp +++ b/include/armnn/ArmNN.hpp @@ -7,6 +7,7 @@ #include "BackendId.hpp" #include "Descriptors.hpp" #include "Exceptions.hpp" +#include "IAsyncNetwork.hpp" #include "INetwork.hpp" #include "IRuntime.hpp" #include "LstmParams.hpp" diff --git a/include/armnn/IAsyncNetwork.hpp b/include/armnn/IAsyncNetwork.hpp new file mode 100644 index 0000000000..7ef83bbff1 --- /dev/null +++ b/include/armnn/IAsyncNetwork.hpp @@ -0,0 +1,51 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include "INetwork.hpp" +#include "IProfiler.hpp" +#include "IWorkingMemHandle.hpp" +#include "Tensor.hpp" +#include "Types.hpp" + +#include + +namespace armnn +{ + +namespace experimental +{ + +class IAsyncNetwork +{ +public: + virtual ~IAsyncNetwork() {}; + + virtual TensorInfo GetInputTensorInfo(LayerBindingId layerId) const = 0; + virtual TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const = 0; + + /// Thread safe execution of the network. Returns once execution is complete. + /// Will block until this and any other thread using the same workingMem object completes. + virtual Status Execute(const InputTensors& inputTensors, + const OutputTensors& outputTensors, + IWorkingMemHandle& workingMemHandle) = 0; + + /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have + /// overlapped Execution by calling this function from different threads. + virtual std::unique_ptr CreateWorkingMemHandle() = 0; + + /// Get the profiler used for this network + virtual std::shared_ptr GetProfiler() const = 0; + + /// Register a debug callback function to be used with this network + virtual void RegisterDebugCallback(const DebugCallbackFunction& func) = 0; +}; + +} // end experimental namespace + +} // end armnn namespace diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index bceb07405a..2db6d5de83 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -704,6 +704,12 @@ protected: std::unique_ptr pNetworkImpl; }; +namespace experimental +{ +class AsyncNetwork; +class WorkingMemHandle; +} + struct BackendSettings; struct OptimizationResult; class OptimizedNetworkImpl; @@ -723,6 +729,10 @@ public: protected: friend class LoadedNetwork; + + friend class experimental::AsyncNetwork; + friend class experimental::WorkingMemHandle; + friend Graph& GetGraphForTesting(IOptimizedNetwork* optNetPtr); friend ModelOptions& GetModelOptionsForTesting(IOptimizedNetwork* optNetPtr); friend IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp index 9122089b62..9f7032914f 100644 --- a/include/armnn/IRuntime.hpp +++ b/include/armnn/IRuntime.hpp @@ -5,6 +5,7 @@ #pragma once #include "BackendOptions.hpp" +#include "IAsyncNetwork.hpp" #include "INetwork.hpp" #include "IProfiler.hpp" #include "Tensor.hpp" @@ -37,6 +38,8 @@ struct INetworkProperties virtual ~INetworkProperties() {} }; +using namespace armnn::experimental; + class IRuntime { public: @@ -142,6 +145,20 @@ public: std::string& errorMessage, const INetworkProperties& networkProperties); + /// This is an experimental function. + /// Creates an executable network. This network is thread safe allowing for multiple networks to be + /// loaded simultaneously via different threads. + /// Note that the network is never registered with the runtime so does not need to be 'Unloaded'. + /// @param [out] networkIdOut Unique identifier for the network is returned in this reference. + /// @param [in] network Complete network to load into the IRuntime. + /// @param [out] errorMessage Error message if there were any errors. + /// @param [out] networkProperties the INetworkProperties that govern how the network should operate. + /// @return The IAsyncNetwork + std::unique_ptr CreateAsyncNetwork(NetworkId& networkIdOut, + IOptimizedNetworkPtr network, + std::string& errorMessage, + const INetworkProperties& networkProperties); + TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const; diff --git a/include/armnn/IWorkingMemHandle.hpp b/include/armnn/IWorkingMemHandle.hpp new file mode 100644 index 0000000000..921b7e1f40 --- /dev/null +++ b/include/armnn/IWorkingMemHandle.hpp @@ -0,0 +1,46 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +namespace armnn +{ + +namespace experimental +{ + +struct WorkingMemDescriptor; + +class IWorkingMemHandle +{ +public: + virtual ~IWorkingMemHandle() {}; + + /// Allocate the backing memory required for execution. If this is not called, then allocation will be + /// deferred to execution time. The mutex must be locked. + virtual void Allocate() = 0; + + /// Free the backing memory required for execution. The mutex must be locked. + virtual void Free() = 0; + + /// IsAllocated returns true if the backing memory is currently allocated. The mutex must be locked. + virtual bool IsAllocated() = 0; + + /// Get a mutex which can be used for synchronizing access to the WorkingMemHandle object. + virtual std::mutex& GetMutex() = 0; + + /// Get the WorkingMemDescriptor for a Layer. The mutex must be locked. + virtual WorkingMemDescriptor& GetWorkingMemDescriptor(LayerGuid id) = 0; + + /// Get the WorkingMemDescriptor at an index. The WorkingMemDescriptors are stored in the same order as + /// the Workloads in a topologically sorted graph. The mutex must be locked. + virtual WorkingMemDescriptor& GetWorkingMemDescriptorAt(unsigned int id) = 0; +}; + +} // end experimental namespace + +} // end armnn namespace diff --git a/include/armnn/NetworkFwd.hpp b/include/armnn/NetworkFwd.hpp index 619839eb64..6c2970f28b 100644 --- a/include/armnn/NetworkFwd.hpp +++ b/include/armnn/NetworkFwd.hpp @@ -6,8 +6,17 @@ namespace armnn { + struct LstmInputParams; struct QuantizedLstmInputParams; + +namespace experimental +{ + +class IAsyncNetwork; + +} // end experimental namespace + class INetwork; class IOptimizedNetwork; class Graph; @@ -15,4 +24,5 @@ class IInputSlot; class IOutputSlot; class IConnectableLayer; class IDataLayer; -} + +} // end armnn namespace diff --git a/include/armnn/backends/IWorkload.hpp b/include/armnn/backends/IWorkload.hpp index 0bd8d2db75..a4827ebcdf 100644 --- a/include/armnn/backends/IWorkload.hpp +++ b/include/armnn/backends/IWorkload.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd. All rights reserved. +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -9,6 +9,15 @@ namespace armnn { +namespace experimental +{ + +struct WorkingMemDescriptor; + +} // end experimental namespace + +using namespace armnn::experimental; + /// Workload interface to enqueue a layer computation. class IWorkload { public: @@ -18,9 +27,11 @@ public: virtual void Execute() const = 0; + virtual void ExecuteAsync(WorkingMemDescriptor& desc) = 0; + virtual profiling::ProfilingGuid GetGuid() const = 0; - virtual void RegisterDebugCallback(const DebugCallbackFunction & /*func*/) {} + virtual void RegisterDebugCallback(const DebugCallbackFunction& /*func*/) {} }; } //namespace armnn -- cgit v1.2.1