From c30abd843e68dfbd186ca25e8a8ecaefcf95776f Mon Sep 17 00:00:00 2001 From: Matthew Bentham Date: Wed, 23 Nov 2022 12:11:32 +0000 Subject: Refactor: Remove m_ImportFlags from RefTensorHandle The import flags for a RefTensorHandle shouldn't be a data member, as RefTensorHandle can only import from MemorySource::Malloc. Instead, use m_ImportEnabled to determine what to return from GetImportFlags(). Simplifies the code in Import and CanBeImported. Signed-off-by: Matthew Bentham Change-Id: Ic629858920f7dd32f99ee27f150b81d8b67144cf --- src/backends/reference/RefTensorHandle.cpp | 83 +++++++++++----------- src/backends/reference/RefTensorHandle.hpp | 8 +-- src/backends/reference/RefTensorHandleFactory.cpp | 6 +- src/backends/reference/RefWorkloadFactory.cpp | 4 +- .../reference/test/RefTensorHandleTests.cpp | 12 ++-- 5 files changed, 56 insertions(+), 57 deletions(-) diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index e196b61ccd..eccdc26542 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -12,19 +12,16 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr(MemorySource::Undefined)), m_Imported(false), m_IsImportEnabled(false) { } -RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, - MemorySourceFlags importFlags) +RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo) : m_TensorInfo(tensorInfo), m_Pool(nullptr), m_UnmanagedMemory(nullptr), - m_ImportFlags(importFlags), m_Imported(false), m_IsImportEnabled(true) { @@ -115,43 +112,52 @@ void RefTensorHandle::CopyInFrom(const void* src) memcpy(dest, src, m_TensorInfo.GetNumBytes()); } +MemorySourceFlags RefTensorHandle::GetImportFlags() const +{ + if (m_IsImportEnabled) + { + return static_cast(MemorySource::Malloc); + } + else + { + return static_cast(MemorySource::Undefined); + } +} + bool RefTensorHandle::Import(void* memory, MemorySource source) { - if (m_ImportFlags & static_cast(source)) + if (m_IsImportEnabled && source == MemorySource::Malloc) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + // Check memory alignment + if(!CanBeImported(memory, source)) { - // Check memory alignment - if(!CanBeImported(memory, source)) + if (m_Imported) { - if (m_Imported) - { - m_Imported = false; - m_UnmanagedMemory = nullptr; - } - return false; + m_Imported = false; + m_UnmanagedMemory = nullptr; } + return false; + } - // m_UnmanagedMemory not yet allocated. - if (!m_Imported && !m_UnmanagedMemory) - { - m_UnmanagedMemory = memory; - m_Imported = true; - return true; - } + // m_UnmanagedMemory not yet allocated. + if (!m_Imported && !m_UnmanagedMemory) + { + m_UnmanagedMemory = memory; + m_Imported = true; + return true; + } - // m_UnmanagedMemory initially allocated with Allocate(). - if (!m_Imported && m_UnmanagedMemory) - { - return false; - } + // m_UnmanagedMemory initially allocated with Allocate(). + if (!m_Imported && m_UnmanagedMemory) + { + return false; + } - // m_UnmanagedMemory previously imported. - if (m_Imported) - { - m_UnmanagedMemory = memory; - return true; - } + // m_UnmanagedMemory previously imported. + if (m_Imported) + { + m_UnmanagedMemory = memory; + return true; } } @@ -160,17 +166,14 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) bool RefTensorHandle::CanBeImported(void *memory, MemorySource source) { - if (m_ImportFlags & static_cast(source)) + if (m_IsImportEnabled && source == MemorySource::Malloc) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); + if (reinterpret_cast(memory) % alignment) { - uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); - if (reinterpret_cast(memory) % alignment) - { - return false; - } - return true; + return false; } + return true; } return false; } diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp index a7eab034b2..d916b39ed9 100644 --- a/src/backends/reference/RefTensorHandle.hpp +++ b/src/backends/reference/RefTensorHandle.hpp @@ -17,7 +17,7 @@ class RefTensorHandle : public ITensorHandle public: RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr &memoryManager); - RefTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags); + RefTensorHandle(const TensorInfo& tensorInfo); ~RefTensorHandle(); @@ -51,10 +51,7 @@ public: return m_TensorInfo; } - virtual MemorySourceFlags GetImportFlags() const override - { - return m_ImportFlags; - } + virtual MemorySourceFlags GetImportFlags() const override; virtual bool Import(void* memory, MemorySource source) override; virtual bool CanBeImported(void* memory, MemorySource source) override; @@ -74,7 +71,6 @@ private: std::shared_ptr m_MemoryManager; RefMemoryManager::Pool* m_Pool; mutable void* m_UnmanagedMemory; - MemorySourceFlags m_ImportFlags; bool m_Imported; bool m_IsImportEnabled; }; diff --git a/src/backends/reference/RefTensorHandleFactory.cpp b/src/backends/reference/RefTensorHandleFactory.cpp index ade27dd733..da3b798d3d 100644 --- a/src/backends/reference/RefTensorHandleFactory.cpp +++ b/src/backends/reference/RefTensorHandleFactory.cpp @@ -48,7 +48,7 @@ std::unique_ptr RefTensorHandleFactory::CreateTensorHandle(const } else { - return std::make_unique(tensorInfo, m_ImportFlags); + return std::make_unique(tensorInfo); } } @@ -63,7 +63,7 @@ std::unique_ptr RefTensorHandleFactory::CreateTensorHandle(const } else { - return std::make_unique(tensorInfo, m_ImportFlags); + return std::make_unique(tensorInfo); } } @@ -87,4 +87,4 @@ MemorySourceFlags RefTensorHandleFactory::GetImportFlags() const return m_ImportFlags; } -} // namespace armnn \ No newline at end of file +} // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 69f75cae8a..bfe37d7bf5 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -119,7 +119,7 @@ std::unique_ptr RefWorkloadFactory::CreateTensorHandle(const Tens } else { - return std::make_unique(tensorInfo, static_cast(MemorySource::Malloc)); + return std::make_unique(tensorInfo); } } @@ -137,7 +137,7 @@ std::unique_ptr RefWorkloadFactory::CreateTensorHandle(const Tens } else { - return std::make_unique(tensorInfo, static_cast(MemorySource::Malloc)); + return std::make_unique(tensorInfo); } } diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp index 6f608e8541..b5fcc212a9 100644 --- a/src/backends/reference/test/RefTensorHandleTests.cpp +++ b/src/backends/reference/test/RefTensorHandleTests.cpp @@ -137,7 +137,7 @@ TEST_CASE("RefTensorHandleFactoryImport") TEST_CASE("RefTensorHandleImport") { TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32); - RefTensorHandle handle(info, static_cast(MemorySource::Malloc)); + RefTensorHandle handle(info); handle.Manage(); handle.Allocate(); @@ -224,7 +224,7 @@ TEST_CASE("TestManagedConstTensorHandle") TEST_CASE("CheckSourceType") { TensorInfo info({1}, DataType::Float32); - RefTensorHandle handle(info, static_cast(MemorySource::Malloc)); + RefTensorHandle handle(info); int* testPtr = new int(4); @@ -243,7 +243,7 @@ TEST_CASE("CheckSourceType") TEST_CASE("ReusePointer") { TensorInfo info({1}, DataType::Float32); - RefTensorHandle handle(info, static_cast(MemorySource::Malloc)); + RefTensorHandle handle(info); int* testPtr = new int(4); @@ -258,7 +258,7 @@ TEST_CASE("ReusePointer") TEST_CASE("MisalignedPointer") { TensorInfo info({2}, DataType::Float32); - RefTensorHandle handle(info, static_cast(MemorySource::Malloc)); + RefTensorHandle handle(info); // Allocate a 2 int array int* testPtr = new int[2]; @@ -274,7 +274,7 @@ TEST_CASE("MisalignedPointer") TEST_CASE("CheckCanBeImported") { TensorInfo info({1}, DataType::Float32); - RefTensorHandle handle(info, static_cast(MemorySource::Malloc)); + RefTensorHandle handle(info); int* testPtr = new int(4); @@ -291,7 +291,7 @@ TEST_CASE("CheckCanBeImported") TEST_CASE("MisalignedCanBeImported") { TensorInfo info({2}, DataType::Float32); - RefTensorHandle handle(info, static_cast(MemorySource::Malloc)); + RefTensorHandle handle(info); // Allocate a 2 int array int* testPtr = new int[2]; -- cgit v1.2.1