aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikhil Raj <nikhil.raj@arm.com>2022-01-05 16:04:08 +0000
committerNikhil Raj <nikhil.raj@arm.com>2022-01-17 11:32:48 +0000
commit53e06599a3af44db90c37d1cda34fc85ec9c27fa (patch)
tree913616009ed42cbd0d27a18a433325971e8fd091
parent56ccf68c7858560f2ba00f19076b3cb112970881 (diff)
downloadarmnn-53e06599a3af44db90c37d1cda34fc85ec9c27fa.tar.gz
IVGCVSW-6672 Implement CanBeImported function to RefTensorHandle
Signed-off-by: Nikhil Raj <nikhil.raj@arm.com> Change-Id: Icaa3aa7ef3e5cc3984941d095edfe8f0b2137879
-rw-r--r--src/backends/reference/RefTensorHandle.cpp23
-rw-r--r--src/backends/reference/RefTensorHandle.hpp1
-rw-r--r--src/backends/reference/test/RefTensorHandleTests.cpp33
3 files changed, 55 insertions, 2 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index 5229e9d62b..0be9708cff 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -122,8 +122,7 @@ bool RefTensorHandle::Import(void* memory, MemorySource source)
if (m_IsImportEnabled && source == MemorySource::Malloc)
{
// Check memory alignment
- uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
- if (reinterpret_cast<uintptr_t>(memory) % alignment)
+ if(!CanBeImported(memory, source))
{
if (m_Imported)
{
@@ -160,4 +159,24 @@ bool RefTensorHandle::Import(void* memory, MemorySource source)
return false;
}
+bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
+{
+ if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ {
+ if (m_IsImportEnabled && source == MemorySource::Malloc)
+ {
+ uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
+ if (reinterpret_cast<uintptr_t>(memory) % alignment)
+ {
+ return false;
+ }
+
+ return true;
+
+ }
+
+ }
+ return false;
+}
+
}
diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp
index a3264f55ef..a7eab034b2 100644
--- a/src/backends/reference/RefTensorHandle.hpp
+++ b/src/backends/reference/RefTensorHandle.hpp
@@ -57,6 +57,7 @@ public:
}
virtual bool Import(void* memory, MemorySource source) override;
+ virtual bool CanBeImported(void* memory, MemorySource source) override;
private:
// Only used for testing
diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp
index 39f5a2aeed..3504f53bc7 100644
--- a/src/backends/reference/test/RefTensorHandleTests.cpp
+++ b/src/backends/reference/test/RefTensorHandleTests.cpp
@@ -253,6 +253,39 @@ TEST_CASE("MisalignedPointer")
delete[] testPtr;
}
+TEST_CASE("CheckCanBeImported")
+{
+ TensorInfo info({1}, DataType::Float32);
+ RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+
+ int* testPtr = new int(4);
+
+ // Not supported
+ CHECK(!handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::DmaBuf));
+
+ // Supported
+ CHECK(handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::Malloc));
+
+ delete testPtr;
+
+}
+
+TEST_CASE("MisalignedCanBeImported")
+{
+ TensorInfo info({2}, DataType::Float32);
+ RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc));
+
+ // Allocate a 2 int array
+ int* testPtr = new int[2];
+
+ // Increment pointer by 1 byte
+ void* misalignedPtr = static_cast<void*>(reinterpret_cast<char*>(testPtr) + 1);
+
+ CHECK(!handle.Import(misalignedPtr, MemorySource::Malloc));
+
+ delete[] testPtr;
+}
+
#endif
}