aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefTensorHandle.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefTensorHandle.cpp')
-rw-r--r--src/backends/reference/RefTensorHandle.cpp55
1 files changed, 54 insertions, 1 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index fe9310f423..59ccec6bac 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -11,7 +11,21 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<R
m_TensorInfo(tensorInfo),
m_MemoryManager(memoryManager),
m_Pool(nullptr),
- m_UnmanagedMemory(nullptr)
+ m_UnmanagedMemory(nullptr),
+ m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
+ m_Imported(false)
+{
+
+}
+
+RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager,
+ MemorySourceFlags importFlags)
+ : m_TensorInfo(tensorInfo),
+ m_MemoryManager(memoryManager),
+ m_Pool(nullptr),
+ m_UnmanagedMemory(nullptr),
+ m_ImportFlags(importFlags),
+ m_Imported(false)
{
}
@@ -86,4 +100,43 @@ void RefTensorHandle::CopyInFrom(const void* src)
memcpy(dest, src, m_TensorInfo.GetNumBytes());
}
+bool RefTensorHandle::Import(void* memory, MemorySource source)
+{
+
+ if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ {
+ if (source == MemorySource::Malloc)
+ {
+ // Checks the 16 byte memory alignment.
+ if (reinterpret_cast<uint64_t>(memory) % 16)
+ {
+ return false;
+ }
+
+ // m_UnmanagedMemory not yet allocated.
+ if (!m_Imported && !m_UnmanagedMemory)
+ {
+ m_UnmanagedMemory = memory;
+ m_Imported = true;
+ return true;
+ }
+
+ // m_UnmanagedMemory initially allocated with Allocate().
+ if (!m_Imported && m_UnmanagedMemory)
+ {
+ return false;
+ }
+
+ // m_UnmanagedMemory previously imported.
+ if (m_Imported)
+ {
+ m_UnmanagedMemory = memory;
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
}