diff options
Diffstat (limited to 'src/backends/cl/ClBackend.cpp')
-rw-r--r-- | src/backends/cl/ClBackend.cpp | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/src/backends/cl/ClBackend.cpp b/src/backends/cl/ClBackend.cpp index 8abb16ccca..0fc5da78d1 100644 --- a/src/backends/cl/ClBackend.cpp +++ b/src/backends/cl/ClBackend.cpp @@ -133,6 +133,15 @@ IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory( MemorySourceFlags inputFlags, MemorySourceFlags outputFlags) const { + // To allow force import if inputFlags/outputFlags are Undefined, set it as Malloc + if (inputFlags == static_cast<MemorySourceFlags>(MemorySource::Undefined)) + { + inputFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc); + } + if (outputFlags == static_cast<MemorySourceFlags>(MemorySource::Undefined)) + { + outputFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc); + } std::shared_ptr<ClMemoryManager> memoryManager; if (m_UsingCustomAllocator) { @@ -193,6 +202,15 @@ void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& regis MemorySourceFlags inputFlags, MemorySourceFlags outputFlags) { + // To allow force import if inputFlags/outputFlags are Undefined, set it as Malloc + if (inputFlags == static_cast<MemorySourceFlags>(MemorySource::Undefined)) + { + inputFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc); + } + if (outputFlags == static_cast<MemorySourceFlags>(MemorySource::Undefined)) + { + outputFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc); + } std::shared_ptr<ClMemoryManager> memoryManager; if (m_UsingCustomAllocator) { |