diff options
-rw-r--r-- | src/backends/reference/RefTensorHandle.cpp | 95 | ||||
-rw-r--r-- | src/backends/reference/RefTensorHandle.hpp | 3 | ||||
-rw-r--r-- | src/backends/reference/test/RefTensorHandleTests.cpp | 23 |
3 files changed, 44 insertions, 77 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index eccdc26542..dbfa374945 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -12,8 +12,7 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<R m_MemoryManager(memoryManager), m_Pool(nullptr), m_UnmanagedMemory(nullptr), - m_Imported(false), - m_IsImportEnabled(false) + m_ImportedMemory(nullptr) { } @@ -22,59 +21,46 @@ RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo) : m_TensorInfo(tensorInfo), m_Pool(nullptr), m_UnmanagedMemory(nullptr), - m_Imported(false), - m_IsImportEnabled(true) + m_ImportedMemory(nullptr) { } RefTensorHandle::~RefTensorHandle() { - if (!m_Pool) - { - // unmanaged - if (!m_Imported) - { - ::operator delete(m_UnmanagedMemory); - } - } + ::operator delete(m_UnmanagedMemory); } void RefTensorHandle::Manage() { - if (!m_IsImportEnabled) - { - ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice"); - ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()"); + ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice"); + ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()"); + if (m_MemoryManager) + { m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes()); } } void RefTensorHandle::Allocate() { - // If import is enabled, do not allocate the tensor - if (!m_IsImportEnabled) + if (!m_UnmanagedMemory) { - - if (!m_UnmanagedMemory) + if (!m_Pool) { - if (!m_Pool) - { - // unmanaged - m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes()); - } - else - { - m_MemoryManager->Allocate(m_Pool); - } + // unmanaged + m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes()); } else { - throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle" - "that already has allocated memory."); + m_MemoryManager->Allocate(m_Pool); } } + else + { + throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle" + "that already has allocated memory."); + } } const void* RefTensorHandle::Map(bool /*unused*/) const @@ -84,7 +70,11 @@ const void* RefTensorHandle::Map(bool /*unused*/) const void* RefTensorHandle::GetPointer() const { - if (m_UnmanagedMemory) + if (m_ImportedMemory) + { + return m_ImportedMemory; + } + else if (m_UnmanagedMemory) { return m_UnmanagedMemory; } @@ -114,51 +104,22 @@ void RefTensorHandle::CopyInFrom(const void* src) MemorySourceFlags RefTensorHandle::GetImportFlags() const { - if (m_IsImportEnabled) - { - return static_cast<MemorySourceFlags>(MemorySource::Malloc); - } - else - { - return static_cast<MemorySourceFlags>(MemorySource::Undefined); - } + return static_cast<MemorySourceFlags>(MemorySource::Malloc); } bool RefTensorHandle::Import(void* memory, MemorySource source) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + if (source == MemorySource::Malloc) { // Check memory alignment if(!CanBeImported(memory, source)) { - if (m_Imported) - { - m_Imported = false; - m_UnmanagedMemory = nullptr; - } + m_ImportedMemory = nullptr; return false; } - // m_UnmanagedMemory not yet allocated. - if (!m_Imported && !m_UnmanagedMemory) - { - m_UnmanagedMemory = memory; - m_Imported = true; - return true; - } - - // m_UnmanagedMemory initially allocated with Allocate(). - if (!m_Imported && m_UnmanagedMemory) - { - return false; - } - - // m_UnmanagedMemory previously imported. - if (m_Imported) - { - m_UnmanagedMemory = memory; - return true; - } + m_ImportedMemory = memory; + return true; } return false; @@ -166,7 +127,7 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) bool RefTensorHandle::CanBeImported(void *memory, MemorySource source) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + if (source == MemorySource::Malloc) { uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); if (reinterpret_cast<uintptr_t>(memory) % alignment) diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp index d916b39ed9..b4dedd5e77 100644 --- a/src/backends/reference/RefTensorHandle.hpp +++ b/src/backends/reference/RefTensorHandle.hpp @@ -71,8 +71,7 @@ private: std::shared_ptr<RefMemoryManager> m_MemoryManager; RefMemoryManager::Pool* m_Pool; mutable void* m_UnmanagedMemory; - bool m_Imported; - bool m_IsImportEnabled; + void* m_ImportedMemory; }; } diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp index b5fcc212a9..883df6fe4d 100644 --- a/src/backends/reference/test/RefTensorHandleTests.cpp +++ b/src/backends/reference/test/RefTensorHandleTests.cpp @@ -99,8 +99,14 @@ TEST_CASE("RefTensorHandleFactoryMemoryManaged") memoryManager->Release(); float testPtr[2] = { 2.5f, 5.5f }; - // Cannot import as import is disabled - CHECK(!handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc)); + // Check import overlays contents + CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc)); + { + float* buffer = reinterpret_cast<float*>(handle->Map()); + CHECK(buffer != nullptr); // Yields a valid pointer + CHECK(buffer[0] == 2.5f); // Memory is writable and readable + CHECK(buffer[1] == 5.5f); // Memory is writable and readable + } } TEST_CASE("RefTensorHandleFactoryImport") @@ -115,11 +121,12 @@ TEST_CASE("RefTensorHandleFactoryImport") handle->Allocate(); memoryManager->Acquire(); - // No buffer allocated when import is enabled - CHECK_THROWS_AS(handle->Map(), armnn::NullPointerException); + // Check storage has been allocated + void* unmanagedStorage = handle->Map(); + CHECK(unmanagedStorage != nullptr); + // Check importing overlays the storage float testPtr[2] = { 2.5f, 5.5f }; - // Correctly import CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc)); float* buffer = reinterpret_cast<float*>(handle->Map()); CHECK(buffer != nullptr); // Yields a valid pointer after import @@ -142,11 +149,11 @@ TEST_CASE("RefTensorHandleImport") handle.Manage(); handle.Allocate(); - // No buffer allocated when import is enabled - CHECK_THROWS_AS(handle.Map(), armnn::NullPointerException); + // Check unmanaged memory allocated + CHECK(handle.Map()); float testPtr[2] = { 2.5f, 5.5f }; - // Correctly import + // Check imoport overlays the unamaged memory CHECK(handle.Import(static_cast<void*>(testPtr), MemorySource::Malloc)); float* buffer = reinterpret_cast<float*>(handle.Map()); CHECK(buffer != nullptr); // Yields a valid pointer after import |