aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/LoadedNetwork.cpp21
-rw-r--r--src/armnn/LoadedNetwork.hpp6
-rw-r--r--src/armnn/test/UnitTests.hpp13
3 files changed, 29 insertions, 11 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 24d119c260..3464fb0277 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -92,10 +92,11 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net)
auto createBackend = BackendRegistryInstance().GetFactory(backend);
auto it = m_Backends.emplace(std::make_pair(backend, createBackend()));
- auto memoryManager = it.first->second->CreateMemoryManager();
- auto workloadFactory = it.first->second->CreateWorkloadFactory(std::move(memoryManager));
+ IBackendInternal::IMemoryManagerSharedPtr memoryManager = it.first->second->CreateMemoryManager();
+ auto workloadFactory = it.first->second->CreateWorkloadFactory(memoryManager);
- m_WorkloadFactories.emplace(std::make_pair(backend, std::move(workloadFactory)));
+ m_WorkloadFactories.emplace(std::make_pair(backend,
+ std::make_pair(std::move(workloadFactory), memoryManager)));
}
layer->CreateTensorHandles(m_OptimizedNetwork->GetGraph(), GetWorkloadFactory(*layer));
}
@@ -182,7 +183,7 @@ const IWorkloadFactory& LoadedNetwork::GetWorkloadFactory(const Layer& layer) co
CHECK_LOCATION());
}
- workloadFactory = it->second.get();
+ workloadFactory = it->second.first.get();
BOOST_ASSERT_MSG(workloadFactory, "No workload factory");
@@ -416,7 +417,11 @@ void LoadedNetwork::AllocateWorkingMemory()
}
for (auto&& workloadFactory : m_WorkloadFactories)
{
- workloadFactory.second->Acquire();
+ IBackendInternal::IMemoryManagerSharedPtr memoryManager = workloadFactory.second.second;
+ if (memoryManager)
+ {
+ memoryManager->Acquire();
+ }
}
m_IsWorkingMemAllocated = true;
}
@@ -431,7 +436,11 @@ void LoadedNetwork::FreeWorkingMemory()
// Informs the memory managers to release memory in it's respective memory group
for (auto&& workloadFactory : m_WorkloadFactories)
{
- workloadFactory.second->Release();
+ IBackendInternal::IMemoryManagerSharedPtr memoryManager = workloadFactory.second.second;
+ if (memoryManager)
+ {
+ memoryManager->Release();
+ }
}
m_IsWorkingMemAllocated = false;
}
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index 65dd4ec25b..03a741fb75 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -62,7 +62,11 @@ private:
const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;
- using WorkloadFactoryMap = std::unordered_map<BackendId, IBackendInternal::IWorkloadFactoryPtr>;
+
+ using WorkloadFactoryWithMemoryManager =
+ std::pair<IBackendInternal::IWorkloadFactoryPtr, IBackendInternal::IMemoryManagerSharedPtr>;
+
+ using WorkloadFactoryMap = std::unordered_map<BackendId, WorkloadFactoryWithMemoryManager>;
BackendPtrMap m_Backends;
WorkloadFactoryMap m_WorkloadFactories;
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
index 44b737cae4..f489ca030c 100644
--- a/src/armnn/test/UnitTests.hpp
+++ b/src/armnn/test/UnitTests.hpp
@@ -66,8 +66,10 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
- FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory();
- auto testResult = (*testFunction)(workloadFactory, args...);
+ auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
+
+ auto testResult = (*testFunction)(workloadFactory, memoryManager, args...);
CompareTestResultIfSupported(testName, testResult);
}
@@ -80,9 +82,12 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
template<typename FactoryType, typename TFuncPtr, typename... Args>
void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
{
- FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory();
+ auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
+
armnn::RefWorkloadFactory refWorkloadFactory;
- auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...);
+
+ auto testResult = (*testFunction)(workloadFactory, memoryManager, refWorkloadFactory, args...);
CompareTestResultIfSupported(testName, testResult);
}