diff options
Diffstat (limited to 'src/backends/cl/ClBackend.cpp')
-rw-r--r-- | src/backends/cl/ClBackend.cpp | 62 |
1 files changed, 53 insertions, 9 deletions
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp index f1e52c1998..b85232e75c 100644 --- a/src/backends/cl/ClBackend.cpp +++ b/src/backends/cl/ClBackend.cpp @@ -49,6 +49,10 @@ const BackendId& ClBackend::GetIdStatic() IBackendInternal::IMemoryManagerUniquePtr ClBackend::CreateMemoryManager() const { + if (m_UsingCustomAllocator) + { + return std::make_unique<ClMemoryManager>(m_CustomAllocator); + } return std::make_unique<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); } @@ -69,7 +73,15 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( TensorHandleFactoryRegistry& registry) const { - auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + std::shared_ptr<ClMemoryManager> memoryManager; + if (m_UsingCustomAllocator) + { + memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator); + } + else + { + memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + } registry.RegisterMemoryManager(memoryManager); registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager)); @@ -83,7 +95,15 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( TensorHandleFactoryRegistry& registry, const ModelOptions& modelOptions) const { - auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + std::shared_ptr<ClMemoryManager> memoryManager; + if (m_UsingCustomAllocator) + { + memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator); + } + else + { + memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + } registry.RegisterMemoryManager(memoryManager); registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager)); @@ -100,7 +120,15 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( MemorySourceFlags inputFlags, MemorySourceFlags outputFlags) const { - auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + std::shared_ptr<ClMemoryManager> memoryManager; + if (m_UsingCustomAllocator) + { + memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator); + } + else + { + memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + } registry.RegisterMemoryManager(memoryManager); registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager)); @@ -118,10 +146,18 @@ std::vector<ITensorHandleFactory::FactoryId> ClBackend::GetHandleFactoryPreferen void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) { - auto mgr = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + std::shared_ptr<ClMemoryManager> memoryManager; + if (m_UsingCustomAllocator) + { + memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator); + } + else + { + memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + } - registry.RegisterMemoryManager(mgr); - registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr)); + registry.RegisterMemoryManager(memoryManager); + registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager)); registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>( static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc))); } @@ -130,10 +166,18 @@ void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& regis MemorySourceFlags inputFlags, MemorySourceFlags outputFlags) { - auto mgr = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + std::shared_ptr<ClMemoryManager> memoryManager; + if (m_UsingCustomAllocator) + { + memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator); + } + else + { + memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + } - registry.RegisterMemoryManager(mgr); - registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr)); + registry.RegisterMemoryManager(memoryManager); + registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager)); registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags)); } |