diff options
Diffstat (limited to 'src/backends/neon/NeonBackend.cpp')
-rw-r--r-- | src/backends/neon/NeonBackend.cpp | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/src/backends/neon/NeonBackend.cpp b/src/backends/neon/NeonBackend.cpp index 54af14e30b..66547ad4df 100644 --- a/src/backends/neon/NeonBackend.cpp +++ b/src/backends/neon/NeonBackend.cpp @@ -74,7 +74,13 @@ IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory( BaseMemoryManager::MemoryAffinity::Offset); tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager); - tensorHandleFactoryRegistry.RegisterFactory(std::make_unique<NeonTensorHandleFactory>(memoryManager)); + + auto factory = std::make_unique<NeonTensorHandleFactory>(memoryManager); + // Register copy and import factory pair + tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId()); + // Register the factory + tensorHandleFactoryRegistry.RegisterFactory(std::move(factory)); + return std::make_unique<NeonWorkloadFactory>( PolymorphicPointerDowncast<NeonMemoryManager>(memoryManager)); @@ -87,7 +93,12 @@ IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory( BaseMemoryManager::MemoryAffinity::Offset); tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager); - tensorHandleFactoryRegistry.RegisterFactory(std::make_unique<NeonTensorHandleFactory>(memoryManager)); + + auto factory = std::make_unique<NeonTensorHandleFactory>(memoryManager); + // Register copy and import factory pair + tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId()); + // Register the factory + tensorHandleFactoryRegistry.RegisterFactory(std::move(factory)); return std::make_unique<NeonWorkloadFactory>( PolymorphicPointerDowncast<NeonMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions)); @@ -424,7 +435,12 @@ void NeonBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistr BaseMemoryManager::MemoryAffinity::Offset); registry.RegisterMemoryManager(memoryManager); - registry.RegisterFactory(std::make_unique<NeonTensorHandleFactory>(memoryManager)); + + auto factory = std::make_unique<NeonTensorHandleFactory>(memoryManager); + // Register copy and import factory pair + registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId()); + // Register the factory + registry.RegisterFactory(std::move(factory)); } std::unique_ptr<ICustomAllocator> NeonBackend::GetDefaultAllocator() const |