diff options
Diffstat (limited to 'src/backends/reference/RefBackend.cpp')
-rw-r--r-- | src/backends/reference/RefBackend.cpp | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/src/backends/reference/RefBackend.cpp b/src/backends/reference/RefBackend.cpp index ad52434ef8..c38d6b6710 100644 --- a/src/backends/reference/RefBackend.cpp +++ b/src/backends/reference/RefBackend.cpp @@ -38,7 +38,12 @@ IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory( auto memoryManager = std::make_shared<RefMemoryManager>(); tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager); - tensorHandleFactoryRegistry.RegisterFactory(std::make_unique<RefTensorHandleFactory>(memoryManager)); + + std::unique_ptr<RefTensorHandleFactory> factory = std::make_unique<RefTensorHandleFactory>(memoryManager); + // Register copy and import factory pair + tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId()); + // Register the factory + tensorHandleFactoryRegistry.RegisterFactory(std::move(factory)); return std::make_unique<RefWorkloadFactory>(PolymorphicPointerDowncast<RefMemoryManager>(memoryManager)); } @@ -84,7 +89,13 @@ void RefBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry auto memoryManager = std::make_shared<RefMemoryManager>(); registry.RegisterMemoryManager(memoryManager); - registry.RegisterFactory(std::make_unique<RefTensorHandleFactory>(memoryManager)); + + std::unique_ptr<RefTensorHandleFactory> factory = std::make_unique<RefTensorHandleFactory>(memoryManager); + + // Register copy and import factory pair + registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId()); + // Register the factory + registry.RegisterFactory(std::move(factory)); } std::unique_ptr<ICustomAllocator> RefBackend::GetDefaultAllocator() const |