diff options
Diffstat (limited to 'src/backends/reference/RefTensorHandle.cpp')
-rw-r--r-- | src/backends/reference/RefTensorHandle.cpp | 83 |
1 files changed, 43 insertions, 40 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index e196b61ccd..eccdc26542 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -12,19 +12,16 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<R m_MemoryManager(memoryManager), m_Pool(nullptr), m_UnmanagedMemory(nullptr), - m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)), m_Imported(false), m_IsImportEnabled(false) { } -RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, - MemorySourceFlags importFlags) +RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo) : m_TensorInfo(tensorInfo), m_Pool(nullptr), m_UnmanagedMemory(nullptr), - m_ImportFlags(importFlags), m_Imported(false), m_IsImportEnabled(true) { @@ -115,43 +112,52 @@ void RefTensorHandle::CopyInFrom(const void* src) memcpy(dest, src, m_TensorInfo.GetNumBytes()); } +MemorySourceFlags RefTensorHandle::GetImportFlags() const +{ + if (m_IsImportEnabled) + { + return static_cast<MemorySourceFlags>(MemorySource::Malloc); + } + else + { + return static_cast<MemorySourceFlags>(MemorySource::Undefined); + } +} + bool RefTensorHandle::Import(void* memory, MemorySource source) { - if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + if (m_IsImportEnabled && source == MemorySource::Malloc) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + // Check memory alignment + if(!CanBeImported(memory, source)) { - // Check memory alignment - if(!CanBeImported(memory, source)) + if (m_Imported) { - if (m_Imported) - { - m_Imported = false; - m_UnmanagedMemory = nullptr; - } - return false; + m_Imported = false; + m_UnmanagedMemory = nullptr; } + return false; + } - // m_UnmanagedMemory not yet allocated. - if (!m_Imported && !m_UnmanagedMemory) - { - m_UnmanagedMemory = memory; - m_Imported = true; - return true; - } + // m_UnmanagedMemory not yet allocated. + if (!m_Imported && !m_UnmanagedMemory) + { + m_UnmanagedMemory = memory; + m_Imported = true; + return true; + } - // m_UnmanagedMemory initially allocated with Allocate(). - if (!m_Imported && m_UnmanagedMemory) - { - return false; - } + // m_UnmanagedMemory initially allocated with Allocate(). + if (!m_Imported && m_UnmanagedMemory) + { + return false; + } - // m_UnmanagedMemory previously imported. - if (m_Imported) - { - m_UnmanagedMemory = memory; - return true; - } + // m_UnmanagedMemory previously imported. + if (m_Imported) + { + m_UnmanagedMemory = memory; + return true; } } @@ -160,17 +166,14 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) bool RefTensorHandle::CanBeImported(void *memory, MemorySource source) { - if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + if (m_IsImportEnabled && source == MemorySource::Malloc) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); + if (reinterpret_cast<uintptr_t>(memory) % alignment) { - uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); - if (reinterpret_cast<uintptr_t>(memory) % alignment) - { - return false; - } - return true; + return false; } + return true; } return false; } |