diff options
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 89 |
1 files changed, 39 insertions, 50 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index fbf8cfbb4c..b35dfd1107 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -451,8 +451,6 @@ private: Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors) { - ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "EnqueueWorkload"); - const Graph& graph = m_OptimizedNetwork->GetGraph(); // Walk graph to determine the order of execution. @@ -471,21 +469,27 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, } // For each input to the network, call EnqueueInput with the data passed by the user. - m_InputQueue.clear(); - m_InputQueue.reserve(graph.GetNumInputs()); - for (const BindableLayer* inputLayer : graph.GetInputLayers()) { - const TensorPin& pin = workloadData.GetInputTensorPin(inputLayer->GetBindingId()); - EnqueueInput(*inputLayer, pin.GetTensorHandle(), pin.GetTensorInfo()); + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareInputs"); + m_InputQueue.clear(); + m_InputQueue.reserve(graph.GetNumInputs()); + for (const BindableLayer* inputLayer : graph.GetInputLayers()) + { + const TensorPin& pin = workloadData.GetInputTensorPin(inputLayer->GetBindingId()); + EnqueueInput(*inputLayer, pin.GetTensorHandle(), pin.GetTensorInfo()); + } } // For each output to the network, call EnqueueOutput with the data passed by the user. - m_OutputQueue.clear(); - m_OutputQueue.reserve(graph.GetNumOutputs()); - for (const BindableLayer* outputLayer : graph.GetOutputLayers()) { - const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId()); - EnqueueOutput(*outputLayer, pin.GetTensorHandle(), pin.GetTensorInfo()); + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareOutputs"); + m_OutputQueue.clear(); + m_OutputQueue.reserve(graph.GetNumOutputs()); + for (const BindableLayer* outputLayer : graph.GetOutputLayers()) + { + const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId()); + EnqueueOutput(*outputLayer, pin.GetTensorHandle(), pin.GetTensorInfo()); + } } std::unique_ptr<TimelineUtilityMethods> timelineUtils = @@ -684,8 +688,13 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten } } -void LoadedNetwork::AllocateWorkingMemory() +void LoadedNetwork::AllocateWorkingMemory(std::lock_guard<std::mutex>& lock) { + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Working Memory Allocation"); + + // this unused parameter makes sure we can only call this function with a valid lock + IgnoreUnused(lock); + if (m_IsWorkingMemAllocated) { return; @@ -736,49 +745,29 @@ bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUti try { std::lock_guard<std::mutex> lockGuard(m_WorkingMemMutex); - AllocateWorkingMemory(); + AllocateWorkingMemory(lockGuard); ProfilingDynamicGuid workloadInferenceID(0); - for (auto& input : m_InputQueue) + auto ExecuteQueue = [&timelineUtils, &workloadInferenceID, &inferenceGuid](WorkloadQueue& queue) { - if(timelineUtils) + for (auto& workload : queue) { - workloadInferenceID = timelineUtils->RecordWorkloadInferenceAndStartOfLifeEvent(input->GetGuid(), - inferenceGuid); - } - input->Execute(); - if(timelineUtils) - { - timelineUtils->RecordEndOfLifeEvent(workloadInferenceID); + if(timelineUtils) + { + workloadInferenceID = timelineUtils->RecordWorkloadInferenceAndStartOfLifeEvent(workload->GetGuid(), + inferenceGuid); + } + workload->Execute(); + if(timelineUtils) + { + timelineUtils->RecordEndOfLifeEvent(workloadInferenceID); + } } - } + }; - for (auto& workload : m_WorkloadQueue) - { - if(timelineUtils) - { - workloadInferenceID = timelineUtils->RecordWorkloadInferenceAndStartOfLifeEvent(workload->GetGuid(), - inferenceGuid); - } - workload->Execute(); - if(timelineUtils) - { - timelineUtils->RecordEndOfLifeEvent(workloadInferenceID); - } - } - for (auto& output: m_OutputQueue) - { - if(timelineUtils) - { - workloadInferenceID = timelineUtils->RecordWorkloadInferenceAndStartOfLifeEvent(output->GetGuid(), - inferenceGuid); - } - output->Execute(); - if(timelineUtils) - { - timelineUtils->RecordEndOfLifeEvent(workloadInferenceID); - } - } + ExecuteQueue(m_InputQueue); + ExecuteQueue(m_WorkloadQueue); + ExecuteQueue(m_OutputQueue); } catch (const RuntimeException& error) { |