From 386ff1a721cdca3689b009ba31f2d3ac8bea2fae Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Mon, 29 Mar 2021 15:04:50 +0100 Subject: IVGCVSW-5790 Merge async prototype * Added thread safe execution mechanism for armnn * Removed duplicate function bool Compare(T a, T b, float tolerance) * Added StridedSliceAsyncEndToEndTest * Fixed memory leak Signed-off-by: Mike Kelly Change-Id: I2d367fc77ee7c01b8953138543e76af5e691211f --- src/armnn/AsyncNetwork.hpp | 106 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 src/armnn/AsyncNetwork.hpp (limited to 'src/armnn/AsyncNetwork.hpp') diff --git a/src/armnn/AsyncNetwork.hpp b/src/armnn/AsyncNetwork.hpp new file mode 100644 index 0000000000..9c525c5472 --- /dev/null +++ b/src/armnn/AsyncNetwork.hpp @@ -0,0 +1,106 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include +#include + +#include "LayerFwd.hpp" +#include "Network.hpp" +#include "Profiling.hpp" +#include "WorkingMemHandle.hpp" + +#include +#include +#include +#include +#include +#include + +#include + +namespace armnn +{ + +namespace experimental +{ + +class AsyncNetwork final : public IAsyncNetwork +{ +public: + using WorkloadQueue = std::vector>; + + AsyncNetwork(std::unique_ptr net, + const INetworkProperties &networkProperties, + profiling::ProfilingService &profilingService); + + ~AsyncNetwork() { FreeWorkingMemory(); } + + TensorInfo GetInputTensorInfo(LayerBindingId layerId) const override; + TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const override; + + /// Thread safe execution of the network. Returns once execution is complete. + /// Will block until this and any other thread using the same workingMem object completes. + virtual Status Execute(const InputTensors& inputTensors, + const OutputTensors& outputTensors, + IWorkingMemHandle& workingMemHandle) override; + + /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have + /// overlapped Execution by calling this function from different threads. + std::unique_ptr CreateWorkingMemHandle() override; + + /// Get the profiler used for this network + std::shared_ptr GetProfiler() const override; + + /// Register a debug callback function to be used with this network + void RegisterDebugCallback(const DebugCallbackFunction& func) override; + +private: + void FreeWorkingMemory(); + + void CollectInputTensorHandles(std::unordered_map >& tensorHandles, + std::vector& inputs, + const armnn::Layer* layer, + const TensorHandleFactoryRegistry& registry, + const bool isMemoryManaged = false); + + void CreateOutputTensorHandles(std::unordered_map >& tensorHandles, + std::vector& outputs, + const armnn::Layer* layer, + const TensorHandleFactoryRegistry& registry, + const bool isMemoryManaged = false); + + void EnqueueInput(const BindableLayer& layer, const ConstTensor& inputTensor, WorkingMemHandle& handle); + + void EnqueueOutput(const BindableLayer& layer, const Tensor& outputTensor, WorkingMemHandle& handle); + + using BackendPtrMap = std::unordered_map; + + using WorkloadFactoryWithMemoryManager = + std::pair; + + using WorkloadFactoryMap = std::unordered_map; + + const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const; + + BackendPtrMap m_Backends; + WorkloadFactoryMap m_WorkloadFactories; + + std::unique_ptr m_OptimizedNetwork; + INetworkProperties m_NetworkProperties; + WorkloadQueue m_WorkloadQueue; + std::shared_ptr m_Profiler; + + TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry; + + /// Profiling Service Instance + profiling::ProfilingService& m_ProfilingService; +}; + +} // end experimental namespace + +} // end armnn namespace -- cgit v1.2.1