aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClBackend.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-05-07 17:52:36 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-05-08 20:15:32 +0100
commite5f0b2409c2e557a5a78e2f4659d203154289b23 (patch)
tree0e32680ed15ed5157c78d5deeabda2c0ceeeb4a3 /src/backends/cl/ClBackend.cpp
parentae12306486efc55293a40048618abe5e8b19151b (diff)
downloadarmnn-e5f0b2409c2e557a5a78e2f4659d203154289b23.tar.gz
IVGCVSW-5818 Enable import on GPU
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I4e4eb107aa2bfa09625840d738001f33152e6792
Diffstat (limited to 'src/backends/cl/ClBackend.cpp')
-rw-r--r--src/backends/cl/ClBackend.cpp41
1 files changed, 38 insertions, 3 deletions
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp
index f97cb4bba8..35770d9219 100644
--- a/src/backends/cl/ClBackend.cpp
+++ b/src/backends/cl/ClBackend.cpp
@@ -4,12 +4,13 @@
//
#include "ClBackend.hpp"
+#include "ClBackendContext.hpp"
#include "ClBackendId.hpp"
#include "ClBackendModelContext.hpp"
-#include "ClWorkloadFactory.hpp"
-#include "ClBackendContext.hpp"
+#include "ClImportTensorHandleFactory.hpp"
#include "ClLayerSupport.hpp"
#include "ClTensorHandleFactory.hpp"
+#include "ClWorkloadFactory.hpp"
#include <armnn/BackendRegistry.hpp>
#include <armnn/Descriptors.hpp>
@@ -71,6 +72,8 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
registry.RegisterMemoryManager(memoryManager);
registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
+ registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
+ static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
return std::make_unique<ClWorkloadFactory>(
PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
@@ -83,6 +86,24 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
registry.RegisterMemoryManager(memoryManager);
registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
+ registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
+ static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
+
+ return std::make_unique<ClWorkloadFactory>(
+ PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
+}
+
+IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
+ TensorHandleFactoryRegistry& registry,
+ const ModelOptions& modelOptions,
+ MemorySourceFlags inputFlags,
+ MemorySourceFlags outputFlags) const
+{
+ auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+
+ registry.RegisterMemoryManager(memoryManager);
+ registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager));
+ registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags));
return std::make_unique<ClWorkloadFactory>(
PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
@@ -90,7 +111,8 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
std::vector<ITensorHandleFactory::FactoryId> ClBackend::GetHandleFactoryPreferences() const
{
- return std::vector<ITensorHandleFactory::FactoryId> {ClTensorHandleFactory::GetIdStatic()};
+ return std::vector<ITensorHandleFactory::FactoryId> {ClTensorHandleFactory::GetIdStatic(),
+ ClImportTensorHandleFactory::GetIdStatic()};
}
void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry)
@@ -99,6 +121,19 @@ void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& regis
registry.RegisterMemoryManager(mgr);
registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr));
+ registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(
+ static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc)));
+}
+
+void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry,
+ MemorySourceFlags inputFlags,
+ MemorySourceFlags outputFlags)
+{
+ auto mgr = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
+
+ registry.RegisterMemoryManager(mgr);
+ registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr));
+ registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags));
}
IBackendInternal::IBackendContextPtr ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const