aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClBackend.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/ClBackend.cpp')
-rw-r--r--src/backends/cl/ClBackend.cpp62
1 files changed, 53 insertions, 9 deletions
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp
index f1e52c1998..b85232e75c 100644
--- a/src/backends/cl/ClBackend.cpp
+++ b/src/backends/cl/ClBackend.cpp
@@ -49,6 +49,10 @@ const BackendId& ClBackend::GetIdStatic()
IBackendInternal::IMemoryManagerUniquePtr ClBackend::CreateMemoryManager() const
{
+ if (m_UsingCustomAllocator)
+ {
+ return std::make_unique<ClMemoryManager>(m_CustomAllocator);
+ }
return std::make_unique<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
}
@@ -69,7 +73,15 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
TensorHandleFactoryRegistry& registry) const
{
- auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ std::shared_ptr<ClMemoryManager> memoryManager;
+ if (m_UsingCustomAllocator)
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
+ }
+ else
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ }
registry.RegisterMemoryManager(memoryManager);
registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
@@ -83,7 +95,15 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
TensorHandleFactoryRegistry& registry, const ModelOptions& modelOptions) const
{
- auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ std::shared_ptr<ClMemoryManager> memoryManager;
+ if (m_UsingCustomAllocator)
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
+ }
+ else
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ }
registry.RegisterMemoryManager(memoryManager);
registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
@@ -100,7 +120,15 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
MemorySourceFlags inputFlags,
MemorySourceFlags outputFlags) const
{
- auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ std::shared_ptr<ClMemoryManager> memoryManager;
+ if (m_UsingCustomAllocator)
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
+ }
+ else
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ }
registry.RegisterMemoryManager(memoryManager);
registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
@@ -118,10 +146,18 @@ std::vector<ITensorHandleFactory::FactoryId> ClBackend::GetHandleFactoryPreferen
void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry)
{
- auto mgr = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ std::shared_ptr<ClMemoryManager> memoryManager;
+ if (m_UsingCustomAllocator)
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
+ }
+ else
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ }
- registry.RegisterMemoryManager(mgr);
- registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr));
+ registry.RegisterMemoryManager(memoryManager);
+ registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
}
@@ -130,10 +166,18 @@ void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& regis
MemorySourceFlags inputFlags,
MemorySourceFlags outputFlags)
{
- auto mgr = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ std::shared_ptr<ClMemoryManager> memoryManager;
+ if (m_UsingCustomAllocator)
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
+ }
+ else
+ {
+ memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+ }
- registry.RegisterMemoryManager(mgr);
- registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr));
+ registry.RegisterMemoryManager(memoryManager);
+ registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags));
}