diff options
author | Matthew Bentham <matthew.bentham@arm.com> | 2022-11-23 12:11:32 +0000 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2022-12-14 12:53:00 +0000 |
commit | c30abd843e68dfbd186ca25e8a8ecaefcf95776f (patch) | |
tree | e2f17d59df2dbef75cc056dd355560df22e93d78 /src/backends/reference/RefTensorHandle.cpp | |
parent | 6d2647df4ce2e15bff8548e74993aa4b12ea8f34 (diff) | |
download | armnn-c30abd843e68dfbd186ca25e8a8ecaefcf95776f.tar.gz |
Refactor: Remove m_ImportFlags from RefTensorHandle
The import flags for a RefTensorHandle shouldn't be a data member,
as RefTensorHandle can only import from MemorySource::Malloc. Instead,
use m_ImportEnabled to determine what to return from GetImportFlags().
Simplifies the code in Import and CanBeImported.
Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Change-Id: Ic629858920f7dd32f99ee27f150b81d8b67144cf
Diffstat (limited to 'src/backends/reference/RefTensorHandle.cpp')
-rw-r--r-- | src/backends/reference/RefTensorHandle.cpp | 83 |
1 files changed, 43 insertions, 40 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index e196b61ccd..eccdc26542 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -12,19 +12,16 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<R m_MemoryManager(memoryManager), m_Pool(nullptr), m_UnmanagedMemory(nullptr), - m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)), m_Imported(false), m_IsImportEnabled(false) { } -RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo, - MemorySourceFlags importFlags) +RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo) : m_TensorInfo(tensorInfo), m_Pool(nullptr), m_UnmanagedMemory(nullptr), - m_ImportFlags(importFlags), m_Imported(false), m_IsImportEnabled(true) { @@ -115,43 +112,52 @@ void RefTensorHandle::CopyInFrom(const void* src) memcpy(dest, src, m_TensorInfo.GetNumBytes()); } +MemorySourceFlags RefTensorHandle::GetImportFlags() const +{ + if (m_IsImportEnabled) + { + return static_cast<MemorySourceFlags>(MemorySource::Malloc); + } + else + { + return static_cast<MemorySourceFlags>(MemorySource::Undefined); + } +} + bool RefTensorHandle::Import(void* memory, MemorySource source) { - if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + if (m_IsImportEnabled && source == MemorySource::Malloc) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + // Check memory alignment + if(!CanBeImported(memory, source)) { - // Check memory alignment - if(!CanBeImported(memory, source)) + if (m_Imported) { - if (m_Imported) - { - m_Imported = false; - m_UnmanagedMemory = nullptr; - } - return false; + m_Imported = false; + m_UnmanagedMemory = nullptr; } + return false; + } - // m_UnmanagedMemory not yet allocated. - if (!m_Imported && !m_UnmanagedMemory) - { - m_UnmanagedMemory = memory; - m_Imported = true; - return true; - } + // m_UnmanagedMemory not yet allocated. + if (!m_Imported && !m_UnmanagedMemory) + { + m_UnmanagedMemory = memory; + m_Imported = true; + return true; + } - // m_UnmanagedMemory initially allocated with Allocate(). - if (!m_Imported && m_UnmanagedMemory) - { - return false; - } + // m_UnmanagedMemory initially allocated with Allocate(). + if (!m_Imported && m_UnmanagedMemory) + { + return false; + } - // m_UnmanagedMemory previously imported. - if (m_Imported) - { - m_UnmanagedMemory = memory; - return true; - } + // m_UnmanagedMemory previously imported. + if (m_Imported) + { + m_UnmanagedMemory = memory; + return true; } } @@ -160,17 +166,14 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) bool RefTensorHandle::CanBeImported(void *memory, MemorySource source) { - if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + if (m_IsImportEnabled && source == MemorySource::Malloc) { - if (m_IsImportEnabled && source == MemorySource::Malloc) + uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); + if (reinterpret_cast<uintptr_t>(memory) % alignment) { - uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); - if (reinterpret_cast<uintptr_t>(memory) % alignment) - { - return false; - } - return true; + return false; } + return true; } return false; } |