diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2018-11-12 18:10:43 +0000 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2018-11-13 14:41:52 +0000 |
commit | 56055193e82471a70b82e4eb11a8884c5904af75 (patch) | |
tree | bf66d0ba0088d963def8485c7e894b12d7a65b82 /src/armnn | |
parent | 95807cef855738ca481ace30f32ed9f245a098dd (diff) | |
download | armnn-56055193e82471a70b82e4eb11a8884c5904af75.tar.gz |
IVGCVSW-2066: Add IMemoryManager and integrate into the backends framework
Change-Id: I93223c8678165cbc3d39f461c36bb8610dc81c05
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 8 | ||||
-rw-r--r-- | src/armnn/test/UnitTests.hpp | 5 |
2 files changed, 9 insertions, 4 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 92433d11c6..24d119c260 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -13,6 +13,7 @@ #include <backendsCommon/CpuTensorHandle.hpp> #include <backendsCommon/BackendRegistry.hpp> +#include <backendsCommon/IMemoryManager.hpp> #include <boost/polymorphic_cast.hpp> #include <boost/assert.hpp> @@ -90,8 +91,11 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<OptimizedNetwork> net) { auto createBackend = BackendRegistryInstance().GetFactory(backend); auto it = m_Backends.emplace(std::make_pair(backend, createBackend())); - m_WorkloadFactories.emplace(std::make_pair(backend, - it.first->second->CreateWorkloadFactory())); + + auto memoryManager = it.first->second->CreateMemoryManager(); + auto workloadFactory = it.first->second->CreateWorkloadFactory(std::move(memoryManager)); + + m_WorkloadFactories.emplace(std::make_pair(backend, std::move(workloadFactory))); } layer->CreateTensorHandles(m_OptimizedNetwork->GetGraph(), GetWorkloadFactory(*layer)); } diff --git a/src/armnn/test/UnitTests.hpp b/src/armnn/test/UnitTests.hpp index 40765e8e4f..44b737cae4 100644 --- a/src/armnn/test/UnitTests.hpp +++ b/src/armnn/test/UnitTests.hpp @@ -8,6 +8,7 @@ #include <armnn/Utils.hpp> #include <reference/RefWorkloadFactory.hpp> #include <backendsCommon/test/LayerTests.hpp> +#include <backendsCommon/test/WorkloadFactoryHelper.hpp> #include "TensorHelpers.hpp" #include <boost/test/unit_test.hpp> @@ -65,7 +66,7 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>(); armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); - FactoryType workloadFactory; + FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(); auto testResult = (*testFunction)(workloadFactory, args...); CompareTestResultIfSupported(testName, testResult); } @@ -79,7 +80,7 @@ void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args) template<typename FactoryType, typename TFuncPtr, typename... Args> void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args) { - FactoryType workloadFactory; + FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(); armnn::RefWorkloadFactory refWorkloadFactory; auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...); CompareTestResultIfSupported(testName, testResult); |