diff options
Diffstat (limited to 'src/backends/reference/RefTensorHandle.cpp')
-rw-r--r-- | src/backends/reference/RefTensorHandle.cpp | 53 |
1 files changed, 32 insertions, 21 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index 7d86b110a7..b9e566eace 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -13,19 +13,20 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<R m_Pool(nullptr), m_UnmanagedMemory(nullptr), m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)), - m_Imported(false) + m_Imported(false), + m_IsImportEnabled(false) { } -RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager, +RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags) : m_TensorInfo(tensorInfo), - m_MemoryManager(memoryManager), m_Pool(nullptr), m_UnmanagedMemory(nullptr), m_ImportFlags(importFlags), - m_Imported(false) + m_Imported(false), + m_IsImportEnabled(true) { } @@ -44,31 +45,39 @@ RefTensorHandle::~RefTensorHandle() void RefTensorHandle::Manage() { - ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice"); - ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()"); + if (!m_IsImportEnabled) + { + ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice"); + ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()"); - m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes()); + m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes()); + } } void RefTensorHandle::Allocate() { - if (!m_UnmanagedMemory) + // If import is enabled, do not allocate the tensor + if (!m_IsImportEnabled) { - if (!m_Pool) + + if (!m_UnmanagedMemory) { - // unmanaged - m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes()); + if (!m_Pool) + { + // unmanaged + m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes()); + } + else + { + m_MemoryManager->Allocate(m_Pool); + } } else { - m_MemoryManager->Allocate(m_Pool); + throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle" + "that already has allocated memory."); } } - else - { - throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle" - "that already has allocated memory."); - } } const void* RefTensorHandle::Map(bool /*unused*/) const @@ -82,11 +91,14 @@ void* RefTensorHandle::GetPointer() const { return m_UnmanagedMemory; } - else + else if (m_Pool) { - ARMNN_ASSERT_MSG(m_Pool, "RefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle"); return m_MemoryManager->GetPointer(m_Pool); } + else + { + throw NullPointerException("RefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle"); + } } void RefTensorHandle::CopyOutTo(void* dest) const @@ -105,10 +117,9 @@ void RefTensorHandle::CopyInFrom(const void* src) bool RefTensorHandle::Import(void* memory, MemorySource source) { - if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) { - if (source == MemorySource::Malloc) + if (m_IsImportEnabled && source == MemorySource::Malloc) { // Check memory alignment constexpr uintptr_t alignment = sizeof(size_t); |