diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-05-07 17:52:36 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-05-08 20:15:32 +0100 |
commit | e5f0b2409c2e557a5a78e2f4659d203154289b23 (patch) | |
tree | 0e32680ed15ed5157c78d5deeabda2c0ceeeb4a3 /src/backends/cl/ClBackend.cpp | |
parent | ae12306486efc55293a40048618abe5e8b19151b (diff) | |
download | armnn-e5f0b2409c2e557a5a78e2f4659d203154289b23.tar.gz |
IVGCVSW-5818 Enable import on GPU
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I4e4eb107aa2bfa09625840d738001f33152e6792
Diffstat (limited to 'src/backends/cl/ClBackend.cpp')
-rw-r--r-- | src/backends/cl/ClBackend.cpp | 41 |
1 files changed, 38 insertions, 3 deletions
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp index f97cb4bba8..35770d9219 100644 --- a/src/backends/cl/ClBackend.cpp +++ b/src/backends/cl/ClBackend.cpp @@ -4,12 +4,13 @@ // #include "ClBackend.hpp" +#include "ClBackendContext.hpp" #include "ClBackendId.hpp" #include "ClBackendModelContext.hpp" -#include "ClWorkloadFactory.hpp" -#include "ClBackendContext.hpp" +#include "ClImportTensorHandleFactory.hpp" #include "ClLayerSupport.hpp" #include "ClTensorHandleFactory.hpp" +#include "ClWorkloadFactory.hpp" #include <armnn/BackendRegistry.hpp> #include <armnn/Descriptors.hpp> @@ -71,6 +72,8 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( registry.RegisterMemoryManager(memoryManager); registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager)); + registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>( + static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc))); return std::make_unique<ClWorkloadFactory>( PolymorphicPointerDowncast<ClMemoryManager>(memoryManager)); @@ -83,6 +86,24 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( registry.RegisterMemoryManager(memoryManager); registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager)); + registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>( + static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc))); + + return std::make_unique<ClWorkloadFactory>( + PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions)); +} + +IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( + TensorHandleFactoryRegistry& registry, + const ModelOptions& modelOptions, + MemorySourceFlags inputFlags, + MemorySourceFlags outputFlags) const +{ + auto memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + + registry.RegisterMemoryManager(memoryManager); + registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(memoryManager)); + registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags)); return std::make_unique<ClWorkloadFactory>( PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions)); @@ -90,7 +111,8 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( std::vector<ITensorHandleFactory::FactoryId> ClBackend::GetHandleFactoryPreferences() const { - return std::vector<ITensorHandleFactory::FactoryId> {ClTensorHandleFactory::GetIdStatic()}; + return std::vector<ITensorHandleFactory::FactoryId> {ClTensorHandleFactory::GetIdStatic(), + ClImportTensorHandleFactory::GetIdStatic()}; } void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) @@ -99,6 +121,19 @@ void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& regis registry.RegisterMemoryManager(mgr); registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr)); + registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>( + static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc))); +} + +void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry, + MemorySourceFlags inputFlags, + MemorySourceFlags outputFlags) +{ + auto mgr = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>()); + + registry.RegisterMemoryManager(mgr); + registry.RegisterFactory(std::make_unique<ClTensorHandleFactory>(mgr)); + registry.RegisterFactory(std::make_unique<ClImportTensorHandleFactory>(inputFlags, outputFlags)); } IBackendInternal::IBackendContextPtr ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const |