diff options
Diffstat (limited to 'src/backends/neon/NeonBackend.cpp')
-rw-r--r-- | src/backends/neon/NeonBackend.cpp | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp index d7be844c21..f86509cbe6 100644 --- a/src/backends/neon/NeonBackend.cpp +++ b/src/backends/neon/NeonBackend.cpp @@ -7,6 +7,7 @@ #include "NeonBackendId.hpp" #include "NeonWorkloadFactory.hpp" #include "NeonLayerSupport.hpp" +#include "NeonTensorHandleFactory.hpp" #include <aclCommon/BaseMemoryManager.hpp> @@ -58,6 +59,17 @@ IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory( boost::polymorphic_pointer_downcast<NeonMemoryManager>(memoryManager)); } +IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory( + class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const +{ + auto memoryManager = std::make_shared<NeonMemoryManager>(std::make_unique<arm_compute::Allocator>(), + BaseMemoryManager::MemoryAffinity::Offset); + + tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager); + return std::make_unique<NeonWorkloadFactory>( + boost::polymorphic_pointer_downcast<NeonMemoryManager>(memoryManager)); +} + IBackendInternal::IBackendContextPtr NeonBackend::CreateBackendContext(const IRuntime::CreationOptions&) const { return IBackendContextPtr{}; @@ -83,4 +95,18 @@ OptimizationViews NeonBackend::OptimizeSubgraphView(const SubgraphView& subgraph return optimizationViews; } +std::vector<ITensorHandleFactory::FactoryId> NeonBackend::GetHandleFactoryPreferences() const +{ + return std::vector<ITensorHandleFactory::FactoryId>() = {"Arm/Neon/TensorHandleFactory"}; +} + +void NeonBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry) +{ + auto memoryManager = std::make_shared<NeonMemoryManager>(std::make_unique<arm_compute::Allocator>(), + BaseMemoryManager::MemoryAffinity::Offset); + + registry.RegisterMemoryManager(memoryManager); + registry.RegisterFactory(std::make_unique<NeonTensorHandleFactory>(memoryManager, "Arm/Neon/TensorHandleFactory")); +} + } // namespace armnn |