aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefTensorHandle.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefTensorHandle.cpp')
-rw-r--r--src/backends/reference/RefTensorHandle.cpp23
1 files changed, 21 insertions, 2 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index 5229e9d62b..0be9708cff 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -122,8 +122,7 @@ bool RefTensorHandle::Import(void* memory, MemorySource source)
if (m_IsImportEnabled && source == MemorySource::Malloc)
{
// Check memory alignment
- uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
- if (reinterpret_cast<uintptr_t>(memory) % alignment)
+ if(!CanBeImported(memory, source))
{
if (m_Imported)
{
@@ -160,4 +159,24 @@ bool RefTensorHandle::Import(void* memory, MemorySource source)
return false;
}
+bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
+{
+ if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ {
+ if (m_IsImportEnabled && source == MemorySource::Malloc)
+ {
+ uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
+ if (reinterpret_cast<uintptr_t>(memory) % alignment)
+ {
+ return false;
+ }
+
+ return true;
+
+ }
+
+ }
+ return false;
+}
+
}