From 6b5f674aad30a3438c295c25b5d115007e80b757 Mon Sep 17 00:00:00 2001 From: Matthew Bentham Date: Wed, 23 Nov 2022 18:17:48 +0000 Subject: Change the semantics of RefTensorHandle::Import to 'overlay' existing memory This makes it possible to call Import on an Allocated() or memory-managed Tensor, which is needed for the current implementation of OptimizerOptions::m_ExportEnabled to work (as the last layer before the OutputLayer needs to be able to Import the user's OutputTensor, but this is done after other memory allocation). Signed-off-by: Matthew Bentham Change-Id: I1a885c2da7b1f0f3964ae53b8135b5e96a66614f --- src/backends/reference/RefTensorHandle.cpp | 95 +++++++--------------- src/backends/reference/RefTensorHandle.hpp | 3 +- .../reference/test/RefTensorHandleTests.cpp | 23 ++++-- 3 files changed, 44 insertions(+), 77 deletions(-) diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index eccdc26542..dbfa374945 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -12,8 +12,7 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptrManage(m_TensorInfo.GetNumBytes()); } } void RefTensorHandle::Allocate() { - // If import is enabled, do not allocate the tensor - if (!m_IsImportEnabled) + if (!m_UnmanagedMemory) { - - if (!m_UnmanagedMemory) + if (!m_Pool) { - if (!m_Pool) - { - // unmanaged - m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes()); - } - else - { - m_MemoryManager->Allocate(m_Pool); - } + // unmanaged + m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes()); } else { - throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle" - "that already has allocated memory."); + m_MemoryManager->Allocate(m_Pool); } } + else + { + throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle" + "that already has allocated memory."); + } } const void* RefTensorHandle::Map(bool /*unused*/) const @@ -84,7 +70,11 @@ const void* RefTensorHandle::Map(bool /*unused*/) const void* RefTensorHandle::GetPointer() const { - if (m_UnmanagedMemory) + if (m_ImportedMemory) + { + return m_ImportedMemory; + } + else if (m_UnmanagedMemory) { return m_UnmanagedMemory; } @@ -114,51 +104,22 @@ void RefTensorHandle::CopyInFrom(const void* src) MemorySourceFlags RefTensorHandle::GetImportFlags() const { - if (m_IsImportEnabled) - { - return static_cast(MemorySource::Malloc); - } - else - { - return static_cast(MemorySource::Undefined); - } + return static_cast(MemorySource::Malloc); } bool RefTensorHandle::Import(void* memory, MemorySource source) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + if (source == MemorySource::Malloc) { // Check memory alignment if(!CanBeImported(memory, source)) { - if (m_Imported) - { - m_Imported = false; - m_UnmanagedMemory = nullptr; - } + m_ImportedMemory = nullptr; return false; } - // m_UnmanagedMemory not yet allocated. - if (!m_Imported && !m_UnmanagedMemory) - { - m_UnmanagedMemory = memory; - m_Imported = true; - return true; - } - - // m_UnmanagedMemory initially allocated with Allocate(). - if (!m_Imported && m_UnmanagedMemory) - { - return false; - } - - // m_UnmanagedMemory previously imported. - if (m_Imported) - { - m_UnmanagedMemory = memory; - return true; - } + m_ImportedMemory = memory; + return true; } return false; @@ -166,7 +127,7 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) bool RefTensorHandle::CanBeImported(void *memory, MemorySource source) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + if (source == MemorySource::Malloc) { uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); if (reinterpret_cast(memory) % alignment) diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp index d916b39ed9..b4dedd5e77 100644 --- a/src/backends/reference/RefTensorHandle.hpp +++ b/src/backends/reference/RefTensorHandle.hpp @@ -71,8 +71,7 @@ private: std::shared_ptr m_MemoryManager; RefMemoryManager::Pool* m_Pool; mutable void* m_UnmanagedMemory; - bool m_Imported; - bool m_IsImportEnabled; + void* m_ImportedMemory; }; } diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp index b5fcc212a9..883df6fe4d 100644 --- a/src/backends/reference/test/RefTensorHandleTests.cpp +++ b/src/backends/reference/test/RefTensorHandleTests.cpp @@ -99,8 +99,14 @@ TEST_CASE("RefTensorHandleFactoryMemoryManaged") memoryManager->Release(); float testPtr[2] = { 2.5f, 5.5f }; - // Cannot import as import is disabled - CHECK(!handle->Import(static_cast(testPtr), MemorySource::Malloc)); + // Check import overlays contents + CHECK(handle->Import(static_cast(testPtr), MemorySource::Malloc)); + { + float* buffer = reinterpret_cast(handle->Map()); + CHECK(buffer != nullptr); // Yields a valid pointer + CHECK(buffer[0] == 2.5f); // Memory is writable and readable + CHECK(buffer[1] == 5.5f); // Memory is writable and readable + } } TEST_CASE("RefTensorHandleFactoryImport") @@ -115,11 +121,12 @@ TEST_CASE("RefTensorHandleFactoryImport") handle->Allocate(); memoryManager->Acquire(); - // No buffer allocated when import is enabled - CHECK_THROWS_AS(handle->Map(), armnn::NullPointerException); + // Check storage has been allocated + void* unmanagedStorage = handle->Map(); + CHECK(unmanagedStorage != nullptr); + // Check importing overlays the storage float testPtr[2] = { 2.5f, 5.5f }; - // Correctly import CHECK(handle->Import(static_cast(testPtr), MemorySource::Malloc)); float* buffer = reinterpret_cast(handle->Map()); CHECK(buffer != nullptr); // Yields a valid pointer after import @@ -142,11 +149,11 @@ TEST_CASE("RefTensorHandleImport") handle.Manage(); handle.Allocate(); - // No buffer allocated when import is enabled - CHECK_THROWS_AS(handle.Map(), armnn::NullPointerException); + // Check unmanaged memory allocated + CHECK(handle.Map()); float testPtr[2] = { 2.5f, 5.5f }; - // Correctly import + // Check imoport overlays the unamaged memory CHECK(handle.Import(static_cast(testPtr), MemorySource::Malloc)); float* buffer = reinterpret_cast(handle.Map()); CHECK(buffer != nullptr); // Yields a valid pointer after import -- cgit v1.2.1