diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2018-11-14 18:35:18 +0000 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2018-11-15 10:38:19 +0000 |
commit | 5caf907efc31e774f8afde54b17a5596477772f6 (patch) | |
tree | 9fcdfe44ccf7c96e5088a2cef06b7d74dfd3221c /src/armnn | |
parent | dd9d8ca997cb6c63677249350247e9f44525104c (diff) | |
download | armnn-5caf907efc31e774f8afde54b17a5596477772f6.tar.gz |
IVGCVSW-2136: Remove memory management methods from workload factories
Change-Id: Idc0f94590566ac362f7e1d1999361d025cc2f67a
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 21 | ||||
-rw-r--r-- | src/armnn/LoadedNetwork.hpp | 6 | ||||
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 13 |
3 files changed, 29 insertions, 11 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 24d119c260..3464fb0277 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -92,10 +92,11 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net) auto createBackend = BackendRegistryInstance().GetFactory(backend); auto it = m_Backends.emplace(std::make_pair(backend, createBackend())); - auto memoryManager = it.first->second->CreateMemoryManager(); - auto workloadFactory = it.first->second->CreateWorkloadFactory(std::move(memoryManager)); + IBackendInternal::IMemoryManagerSharedPtr memoryManager = it.first->second->CreateMemoryManager(); + auto workloadFactory = it.first->second->CreateWorkloadFactory(memoryManager); - m_WorkloadFactories.emplace(std::make_pair(backend, std::move(workloadFactory))); + m_WorkloadFactories.emplace(std::make_pair(backend, + std::make_pair(std::move(workloadFactory), memoryManager))); } layer->CreateTensorHandles(m_OptimizedNetwork->GetGraph(), GetWorkloadFactory(*layer)); } @@ -182,7 +183,7 @@ const IWorkloadFactory& LoadedNetwork::GetWorkloadFactory(const Layer& layer) co CHECK_LOCATION()); } - workloadFactory = it->second.get(); + workloadFactory = it->second.first.get(); BOOST_ASSERT_MSG(workloadFactory, "No workload factory"); @@ -416,7 +417,11 @@ void LoadedNetwork::AllocateWorkingMemory() } for (auto&& workloadFactory : m_WorkloadFactories) { - workloadFactory.second->Acquire(); + IBackendInternal::IMemoryManagerSharedPtr memoryManager = workloadFactory.second.second; + if (memoryManager) + { + memoryManager->Acquire(); + } } m_IsWorkingMemAllocated = true; } @@ -431,7 +436,11 @@ void LoadedNetwork::FreeWorkingMemory() // Informs the memory managers to release memory in it's respective memory group for (auto&& workloadFactory : m_WorkloadFactories) { - workloadFactory.second->Release(); + IBackendInternal::IMemoryManagerSharedPtr memoryManager = workloadFactory.second.second; + if (memoryManager) + { + memoryManager->Release(); + } } m_IsWorkingMemAllocated = false; } diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index 65dd4ec25b..03a741fb75 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -62,7 +62,11 @@ private: const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const; using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>; - using WorkloadFactoryMap = std::unordered_map<BackendId, IBackendInternal::IWorkloadFactoryPtr>; + + using WorkloadFactoryWithMemoryManager = + std::pair<IBackendInternal::IWorkloadFactoryPtr, IBackendInternal::IMemoryManagerSharedPtr>; + + using WorkloadFactoryMap = std::unordered_map<BackendId, WorkloadFactoryWithMemoryManager>; BackendPtrMap m_Backends; WorkloadFactoryMap m_WorkloadFactories; diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index 44b737cae4..f489ca030c 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -66,8 +66,10 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>(); armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); - FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(); - auto testResult = (*testFunction)(workloadFactory, args...); + auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager(); + FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager); + + auto testResult = (*testFunction)(workloadFactory, memoryManager, args...); CompareTestResultIfSupported(testName, testResult); } @@ -80,9 +82,12 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) template<typename FactoryType, typename TFuncPtr, typename... Args> void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args) { - FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(); + auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager(); + FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager); + armnn::RefWorkloadFactory refWorkloadFactory; - auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...); + + auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...); CompareTestResultIfSupported(testName, testResult); } |