aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2018-11-12 18:10:43 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2018-11-13 14:41:52 +0000
commit56055193e82471a70b82e4eb11a8884c5904af75 (patch)
treebf66d0ba0088d963def8485c7e894b12d7a65b82 /src/armnn
parent95807cef855738ca481ace30f32ed9f245a098dd (diff)
downloadarmnn-56055193e82471a70b82e4eb11a8884c5904af75.tar.gz
IVGCVSW-2066: Add IMemoryManager and integrate into the backends framework
Change-Id: I93223c8678165cbc3d39f461c36bb8610dc81c05
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/LoadedNetwork.cpp8
-rw-r--r--src/armnn/test/UnitTests.hpp5
2 files changed, 9 insertions, 4 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 92433d11c6..24d119c260 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -13,6 +13,7 @@
#include <backendsCommon/CpuTensorHandle.hpp>
#include <backendsCommon/BackendRegistry.hpp>
+#include <backendsCommon/IMemoryManager.hpp>
#include <boost/polymorphic_cast.hpp>
#include <boost/assert.hpp>
@@ -90,8 +91,11 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net)
{
auto createBackend = BackendRegistryInstance().GetFactory(backend);
auto it = m_Backends.emplace(std::make_pair(backend, createBackend()));
- m_WorkloadFactories.emplace(std::make_pair(backend,
- it.first->second->CreateWorkloadFactory()));
+
+ auto memoryManager = it.first->second->CreateMemoryManager();
+ auto workloadFactory = it.first->second->CreateWorkloadFactory(std::move(memoryManager));
+
+ m_WorkloadFactories.emplace(std::make_pair(backend, std::move(workloadFactory)));
}
layer->CreateTensorHandles(m_OptimizedNetwork->GetGraph(), GetWorkloadFactory(*layer));
}
diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp
index 40765e8e4f..44b737cae4 100644
--- a/src/armnn/test/UnitTests.hpp
+++ b/src/armnn/test/UnitTests.hpp
@@ -8,6 +8,7 @@
#include <armnn/Utils.hpp>
#include <reference/RefWorkloadFactory.hpp>
#include <backendsCommon/test/LayerTests.hpp>
+#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
#include "TensorHelpers.hpp"
#include <boost/test/unit_test.hpp>
@@ -65,7 +66,7 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
- FactoryType workloadFactory;
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory();
auto testResult = (*testFunction)(workloadFactory, args...);
CompareTestResultIfSupported(testName, testResult);
}
@@ -79,7 +80,7 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
template<typename FactoryType, typename TFuncPtr, typename... Args>
void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
{
- FactoryType workloadFactory;
+ FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory();
armnn::RefWorkloadFactory refWorkloadFactory;
auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...);
CompareTestResultIfSupported(testName, testResult);