aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2022-11-23 12:11:32 +0000
committerTeresaARM <teresa.charlinreyes@arm.com>2022-12-14 12:53:00 +0000
commitc30abd843e68dfbd186ca25e8a8ecaefcf95776f (patch)
treee2f17d59df2dbef75cc056dd355560df22e93d78
parent6d2647df4ce2e15bff8548e74993aa4b12ea8f34 (diff)
downloadarmnn-c30abd843e68dfbd186ca25e8a8ecaefcf95776f.tar.gz
Refactor: Remove m_ImportFlags from RefTensorHandle
The import flags for a RefTensorHandle shouldn't be a data member, as RefTensorHandle can only import from MemorySource::Malloc. Instead, use m_ImportEnabled to determine what to return from GetImportFlags(). Simplifies the code in Import and CanBeImported. Signed-off-by: Matthew Bentham <matthew.bentham@arm.com> Change-Id: Ic629858920f7dd32f99ee27f150b81d8b67144cf
-rw-r--r--src/backends/reference/RefTensorHandle.cpp83
-rw-r--r--src/backends/reference/RefTensorHandle.hpp8
-rw-r--r--src/backends/reference/RefTensorHandleFactory.cpp6
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp4
-rw-r--r--src/backends/reference/test/RefTensorHandleTests.cpp12
5 files changed, 56 insertions, 57 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index e196b61ccd..eccdc26542 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -12,19 +12,16 @@ RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<R
m_MemoryManager(memoryManager),
m_Pool(nullptr),
m_UnmanagedMemory(nullptr),
- m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
m_Imported(false),
m_IsImportEnabled(false)
{
}
-RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo,
- MemorySourceFlags importFlags)
+RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo)
: m_TensorInfo(tensorInfo),
m_Pool(nullptr),
m_UnmanagedMemory(nullptr),
- m_ImportFlags(importFlags),
m_Imported(false),
m_IsImportEnabled(true)
{
@@ -115,43 +112,52 @@ void RefTensorHandle::CopyInFrom(const void* src)
memcpy(dest, src, m_TensorInfo.GetNumBytes());
}
+MemorySourceFlags RefTensorHandle::GetImportFlags() const
+{
+ if (m_IsImportEnabled)
+ {
+ return static_cast<MemorySourceFlags>(MemorySource::Malloc);
+ }
+ else
+ {
+ return static_cast<MemorySourceFlags>(MemorySource::Undefined);
+ }
+}
+
bool RefTensorHandle::Import(void* memory, MemorySource source)
{
- if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ if (m_IsImportEnabled && source == MemorySource::Malloc)
{
- if (m_IsImportEnabled && source == MemorySource::Malloc)
+ // Check memory alignment
+ if(!CanBeImported(memory, source))
{
- // Check memory alignment
- if(!CanBeImported(memory, source))
+ if (m_Imported)
{
- if (m_Imported)
- {
- m_Imported = false;
- m_UnmanagedMemory = nullptr;
- }
- return false;
+ m_Imported = false;
+ m_UnmanagedMemory = nullptr;
}
+ return false;
+ }
- // m_UnmanagedMemory not yet allocated.
- if (!m_Imported && !m_UnmanagedMemory)
- {
- m_UnmanagedMemory = memory;
- m_Imported = true;
- return true;
- }
+ // m_UnmanagedMemory not yet allocated.
+ if (!m_Imported && !m_UnmanagedMemory)
+ {
+ m_UnmanagedMemory = memory;
+ m_Imported = true;
+ return true;
+ }
- // m_UnmanagedMemory initially allocated with Allocate().
- if (!m_Imported && m_UnmanagedMemory)
- {
- return false;
- }
+ // m_UnmanagedMemory initially allocated with Allocate().
+ if (!m_Imported && m_UnmanagedMemory)
+ {
+ return false;
+ }
- // m_UnmanagedMemory previously imported.
- if (m_Imported)
- {
- m_UnmanagedMemory = memory;
- return true;
- }
+ // m_UnmanagedMemory previously imported.
+ if (m_Imported)
+ {
+ m_UnmanagedMemory = memory;
+ return true;
}
}
@@ -160,17 +166,14 @@ bool RefTensorHandle::Import(void* memory, MemorySource source)
bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
{
- if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ if (m_IsImportEnabled && source == MemorySource::Malloc)
{
- if (m_IsImportEnabled && source == MemorySource::Malloc)
+ uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
+ if (reinterpret_cast<uintptr_t>(memory) % alignment)
{
- uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
- if (reinterpret_cast<uintptr_t>(memory) % alignment)
- {
- return false;
- }
- return true;
+ return false;
}
+ return true;
}
return false;
}
diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp
index a7eab034b2..d916b39ed9 100644
--- a/src/backends/reference/RefTensorHandle.hpp
+++ b/src/backends/reference/RefTensorHandle.hpp
@@ -17,7 +17,7 @@ class RefTensorHandle : public ITensorHandle
public:
RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager);
- RefTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags);
+ RefTensorHandle(const TensorInfo& tensorInfo);
~RefTensorHandle();
@@ -51,10 +51,7 @@ public:
return m_TensorInfo;
}
- virtual MemorySourceFlags GetImportFlags() const override
- {
- return m_ImportFlags;
- }
+ virtual MemorySourceFlags GetImportFlags() const override;
virtual bool Import(void* memory, MemorySource source) override;
virtual bool CanBeImported(void* memory, MemorySource source) override;
@@ -74,7 +71,6 @@ private:
std::shared_ptr<RefMemoryManager> m_MemoryManager;
RefMemoryManager::Pool* m_Pool;
mutable void* m_UnmanagedMemory;
- MemorySourceFlags m_ImportFlags;
bool m_Imported;
bool m_IsImportEnabled;
};
diff --git a/src/backends/reference/RefTensorHandleFactory.cpp b/src/backends/reference/RefTensorHandleFactory.cpp
index ade27dd733..da3b798d3d 100644
--- a/src/backends/reference/RefTensorHandleFactory.cpp
+++ b/src/backends/reference/RefTensorHandleFactory.cpp
@@ -48,7 +48,7 @@ std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const
}
else
{
- return std::make_unique<RefTensorHandle>(tensorInfo, m_ImportFlags);
+ return std::make_unique<RefTensorHandle>(tensorInfo);
}
}
@@ -63,7 +63,7 @@ std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const
}
else
{
- return std::make_unique<RefTensorHandle>(tensorInfo, m_ImportFlags);
+ return std::make_unique<RefTensorHandle>(tensorInfo);
}
}
@@ -87,4 +87,4 @@ MemorySourceFlags RefTensorHandleFactory::GetImportFlags() const
return m_ImportFlags;
}
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 69f75cae8a..bfe37d7bf5 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -119,7 +119,7 @@ std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const Tens
}
else
{
- return std::make_unique<RefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
+ return std::make_unique<RefTensorHandle>(tensorInfo);
}
}
@@ -137,7 +137,7 @@ std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const Tens
}
else
{
- return std::make_unique<RefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
+ return std::make_unique<RefTensorHandle>(tensorInfo);
}
}
diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp
index 6f608e8541..b5fcc212a9 100644
--- a/src/backends/reference/test/RefTensorHandleTests.cpp
+++ b/src/backends/reference/test/RefTensorHandleTests.cpp
@@ -137,7 +137,7 @@ TEST_CASE("RefTensorHandleFactoryImport")
TEST_CASE("RefTensorHandleImport")
{
TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
- RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+ RefTensorHandle handle(info);
handle.Manage();
handle.Allocate();
@@ -224,7 +224,7 @@ TEST_CASE("TestManagedConstTensorHandle")
TEST_CASE("CheckSourceType")
{
TensorInfo info({1}, DataType::Float32);
- RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+ RefTensorHandle handle(info);
int* testPtr = new int(4);
@@ -243,7 +243,7 @@ TEST_CASE("CheckSourceType")
TEST_CASE("ReusePointer")
{
TensorInfo info({1}, DataType::Float32);
- RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+ RefTensorHandle handle(info);
int* testPtr = new int(4);
@@ -258,7 +258,7 @@ TEST_CASE("ReusePointer")
TEST_CASE("MisalignedPointer")
{
TensorInfo info({2}, DataType::Float32);
- RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+ RefTensorHandle handle(info);
// Allocate a 2 int array
int* testPtr = new int[2];
@@ -274,7 +274,7 @@ TEST_CASE("MisalignedPointer")
TEST_CASE("CheckCanBeImported")
{
TensorInfo info({1}, DataType::Float32);
- RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+ RefTensorHandle handle(info);
int* testPtr = new int(4);
@@ -291,7 +291,7 @@ TEST_CASE("CheckCanBeImported")
TEST_CASE("MisalignedCanBeImported")
{
TensorInfo info({2}, DataType::Float32);
- RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+ RefTensorHandle handle(info);
// Allocate a 2 int array
int* testPtr = new int[2];