diff options
author | Nikhil Raj <nikhil.raj@arm.com> | 2022-01-05 16:04:08 +0000 |
---|---|---|
committer | Nikhil Raj <nikhil.raj@arm.com> | 2022-01-17 11:32:48 +0000 |
commit | 53e06599a3af44db90c37d1cda34fc85ec9c27fa (patch) | |
tree | 913616009ed42cbd0d27a18a433325971e8fd091 /src | |
parent | 56ccf68c7858560f2ba00f19076b3cb112970881 (diff) | |
download | armnn-53e06599a3af44db90c37d1cda34fc85ec9c27fa.tar.gz |
IVGCVSW-6672 Implement CanBeImported function to RefTensorHandle
Signed-off-by: Nikhil Raj <nikhil.raj@arm.com>
Change-Id: Icaa3aa7ef3e5cc3984941d095edfe8f0b2137879
Diffstat (limited to 'src')
-rw-r--r-- | src/backends/reference/RefTensorHandle.cpp | 23 | ||||
-rw-r--r-- | src/backends/reference/RefTensorHandle.hpp | 1 | ||||
-rw-r--r-- | src/backends/reference/test/RefTensorHandleTests.cpp | 33 |
3 files changed, 55 insertions, 2 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index 5229e9d62b..0be9708cff 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -122,8 +122,7 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) if (m_IsImportEnabled && source == MemorySource::Malloc) { // Check memory alignment - uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); - if (reinterpret_cast<uintptr_t>(memory) % alignment) + if(!CanBeImported(memory, source)) { if (m_Imported) { @@ -160,4 +159,24 @@ bool RefTensorHandle::Import(void* memory, MemorySource source) return false; } +bool RefTensorHandle::CanBeImported(void *memory, MemorySource source) +{ + if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) + { + if (m_IsImportEnabled && source == MemorySource::Malloc) + { + uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType()); + if (reinterpret_cast<uintptr_t>(memory) % alignment) + { + return false; + } + + return true; + + } + + } + return false; +} + } diff --git a/src/backends/reference/RefTensorHandle.hpp b/src/backends/reference/RefTensorHandle.hpp index a3264f55ef..a7eab034b2 100644 --- a/src/backends/reference/RefTensorHandle.hpp +++ b/src/backends/reference/RefTensorHandle.hpp @@ -57,6 +57,7 @@ public: } virtual bool Import(void* memory, MemorySource source) override; + virtual bool CanBeImported(void* memory, MemorySource source) override; private: // Only used for testing diff --git a/src/backends/reference/test/RefTensorHandleTests.cpp b/src/backends/reference/test/RefTensorHandleTests.cpp index 39f5a2aeed..3504f53bc7 100644 --- a/src/backends/reference/test/RefTensorHandleTests.cpp +++ b/src/backends/reference/test/RefTensorHandleTests.cpp @@ -253,6 +253,39 @@ TEST_CASE("MisalignedPointer") delete[] testPtr; } +TEST_CASE("CheckCanBeImported") +{ + TensorInfo info({1}, DataType::Float32); + RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc)); + + int* testPtr = new int(4); + + // Not supported + CHECK(!handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::DmaBuf)); + + // Supported + CHECK(handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::Malloc)); + + delete testPtr; + +} + +TEST_CASE("MisalignedCanBeImported") +{ + TensorInfo info({2}, DataType::Float32); + RefTensorHandle handle(info, static_cast<unsigned int>(MemorySource::Malloc)); + + // Allocate a 2 int array + int* testPtr = new int[2]; + + // Increment pointer by 1 byte + void* misalignedPtr = static_cast<void*>(reinterpret_cast<char*>(testPtr) + 1); + + CHECK(!handle.Import(misalignedPtr, MemorySource::Malloc)); + + delete[] testPtr; +} + #endif } |