aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefTensorHandle.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefTensorHandle.cpp')
-rw-r--r--src/backends/reference/RefTensorHandle.cpp83
1 files changed, 43 insertions, 40 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index e196b61ccd..eccdc26542 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -12,19 +12,16 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<R
m_MemoryManager(memoryManager),
m_Pool(nullptr),
m_UnmanagedMemory(nullptr),
- m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
m_Imported(false),
m_IsImportEnabled(false)
{
}
-RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo,
- MemorySourceFlags importFlags)
+RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo)
: m_TensorInfo(tensorInfo),
m_Pool(nullptr),
m_UnmanagedMemory(nullptr),
- m_ImportFlags(importFlags),
m_Imported(false),
m_IsImportEnabled(true)
{
@@ -115,43 +112,52 @@ void RefTensorHandle::CopyInFrom(const void* src)
memcpy(dest, src, m_TensorInfo.GetNumBytes());
}
+MemorySourceFlags RefTensorHandle::GetImportFlags() const
+{
+ if (m_IsImportEnabled)
+ {
+ return static_cast<MemorySourceFlags>(MemorySource::Malloc);
+ }
+ else
+ {
+ return static_cast<MemorySourceFlags>(MemorySource::Undefined);
+ }
+}
+
bool RefTensorHandle::Import(void* memory, MemorySource source)
{
- if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ if (m_IsImportEnabled && source == MemorySource::Malloc)
{
- if (m_IsImportEnabled && source == MemorySource::Malloc)
+ // Check memory alignment
+ if(!CanBeImported(memory, source))
{
- // Check memory alignment
- if(!CanBeImported(memory, source))
+ if (m_Imported)
{
- if (m_Imported)
- {
- m_Imported = false;
- m_UnmanagedMemory = nullptr;
- }
- return false;
+ m_Imported = false;
+ m_UnmanagedMemory = nullptr;
}
+ return false;
+ }
- // m_UnmanagedMemory not yet allocated.
- if (!m_Imported && !m_UnmanagedMemory)
- {
- m_UnmanagedMemory = memory;
- m_Imported = true;
- return true;
- }
+ // m_UnmanagedMemory not yet allocated.
+ if (!m_Imported && !m_UnmanagedMemory)
+ {
+ m_UnmanagedMemory = memory;
+ m_Imported = true;
+ return true;
+ }
- // m_UnmanagedMemory initially allocated with Allocate().
- if (!m_Imported && m_UnmanagedMemory)
- {
- return false;
- }
+ // m_UnmanagedMemory initially allocated with Allocate().
+ if (!m_Imported && m_UnmanagedMemory)
+ {
+ return false;
+ }
- // m_UnmanagedMemory previously imported.
- if (m_Imported)
- {
- m_UnmanagedMemory = memory;
- return true;
- }
+ // m_UnmanagedMemory previously imported.
+ if (m_Imported)
+ {
+ m_UnmanagedMemory = memory;
+ return true;
}
}
@@ -160,17 +166,14 @@ bool RefTensorHandle::Import(void* memory, MemorySource source)
bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
{
- if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ if (m_IsImportEnabled && source == MemorySource::Malloc)
{
- if (m_IsImportEnabled && source == MemorySource::Malloc)
+ uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
+ if (reinterpret_cast<uintptr_t>(memory) % alignment)
{
- uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
- if (reinterpret_cast<uintptr_t>(memory) % alignment)
- {
- return false;
- }
- return true;
+ return false;
}
+ return true;
}
return false;
}