diff options
Diffstat (limited to 'src/backends/reference/RefTensorHandle.cpp')
-rw-r--r-- | src/backends/reference/RefTensorHandle.cpp | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index 5229e9d62b..0be9708cff 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -122,8 +122,7 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) if (m_IsImportEnabled && source == MemorySource::Malloc) { // Check memory alignment - uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); - if (reinterpret_cast<uintptr_t>(memory) % alignment) + if(!CanBeImported(memory, source)) { if (m_Imported) { @@ -160,4 +159,24 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) return false; } +bool RefTensorHandle::CanBeImported(void *memory, MemorySource source) +{ + if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + { + if (m_IsImportEnabled && source == MemorySource::Malloc) + { + uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); + if (reinterpret_cast<uintptr_t>(memory) % alignment) + { + return false; + } + + return true; + + } + + } + return false; +} + } |