diff options
Diffstat (limited to 'ArmnnPreparedModel_1_3.cpp')
-rw-r--r-- | ArmnnPreparedModel_1_3.cpp | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/ArmnnPreparedModel_1_3.cpp b/ArmnnPreparedModel_1_3.cpp index dcac2813..16ea113c 100644 --- a/ArmnnPreparedModel_1_3.cpp +++ b/ArmnnPreparedModel_1_3.cpp @@ -145,6 +145,9 @@ RequestThread_1_3<ArmnnPreparedModel_1_3, HalVersion, CallbackContext_1_3> ArmnnPreparedModel_1_3<HalVersion>::m_RequestThread; template<typename HalVersion> +std::unique_ptr<armnn::Threadpool> ArmnnPreparedModel_1_3<HalVersion>::m_Threadpool(nullptr); + +template<typename HalVersion> template<typename TensorBindingCollection> void ArmnnPreparedModel_1_3<HalVersion>::DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings) @@ -183,7 +186,7 @@ ArmnnPreparedModel_1_3<HalVersion>::ArmnnPreparedModel_1_3(armnn::NetworkId netw // Enable profiling if required. m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled); - if (asyncModelExecutionEnabled) + if (m_AsyncModelExecutionEnabled) { std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles; for (unsigned int i=0; i < numberOfThreads; ++i) @@ -191,8 +194,16 @@ ArmnnPreparedModel_1_3<HalVersion>::ArmnnPreparedModel_1_3(armnn::NetworkId netw memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(networkId)); } + if (!m_Threadpool) + { + m_Threadpool = std::make_unique<armnn::Threadpool>(numberOfThreads, runtime, memHandles); + } + else + { + m_Threadpool->LoadMemHandles(memHandles); + } + m_WorkingMemHandle = memHandles.back(); - m_Threadpool = std::make_unique<armnn::Threadpool>(numberOfThreads, runtime, memHandles); } } @@ -205,6 +216,12 @@ ArmnnPreparedModel_1_3<HalVersion>::~ArmnnPreparedModel_1_3() // Unload the network associated with this model. m_Runtime->UnloadNetwork(m_NetworkId); + // Unload the network memhandles from the threadpool + if (m_AsyncModelExecutionEnabled) + { + m_Threadpool->UnloadMemHandles(m_NetworkId); + } + // Dump the profiling info to a file if required. DumpJsonProfilingIfRequired(m_GpuProfilingEnabled, m_RequestInputsAndOutputsDumpDir, m_NetworkId, profiler.get()); } |