aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/TensorHandle.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/TensorHandle.cpp')
-rw-r--r--src/backends/backendsCommon/TensorHandle.cpp35
1 files changed, 31 insertions, 4 deletions
diff --git a/src/backends/backendsCommon/TensorHandle.cpp b/src/backends/backendsCommon/TensorHandle.cpp
index d55fca24d4..bc221adbe3 100644
--- a/src/backends/backendsCommon/TensorHandle.cpp
+++ b/src/backends/backendsCommon/TensorHandle.cpp
@@ -103,12 +103,30 @@ void ScopedTensorHandle::Allocate()
void ScopedTensorHandle::CopyOutTo(void* memory) const
{
- memcpy(memory, GetTensor<void>(), GetTensorInfo().GetNumBytes());
+ const void* src = GetTensor<void>();
+ if (src == nullptr)
+ {
+ throw NullPointerException("TensorHandle::CopyOutTo called with a null src pointer");
+ }
+ if (memory == nullptr)
+ {
+ throw NullPointerException("TensorHandle::CopyOutTo called with a null dest pointer");
+ }
+ memcpy(memory, src, GetTensorInfo().GetNumBytes());
}
void ScopedTensorHandle::CopyInFrom(const void* memory)
{
- memcpy(GetTensor<void>(), memory, GetTensorInfo().GetNumBytes());
+ void* dest = GetTensor<void>();
+ if (dest == nullptr)
+ {
+ throw NullPointerException("TensorHandle::CopyInFrom called with a null dest pointer");
+ }
+ if (memory == nullptr)
+ {
+ throw NullPointerException("TensorHandle::CopyInFrom called with a null src pointer");
+ }
+ memcpy(dest, memory, GetTensorInfo().GetNumBytes());
}
void ScopedTensorHandle::CopyFrom(const ScopedTensorHandle& other)
@@ -118,8 +136,17 @@ void ScopedTensorHandle::CopyFrom(const ScopedTensorHandle& other)
void ScopedTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes)
{
- ARMNN_ASSERT(GetTensor<void>() == nullptr);
- ARMNN_ASSERT(GetTensorInfo().GetNumBytes() == numBytes);
+ if (GetTensor<void>() != nullptr)
+ {
+ throw NullPointerException("TensorHandle::CopyFrom called on an already allocated TensorHandle");
+ }
+ if (GetTensorInfo().GetNumBytes() != numBytes)
+ {
+ std::stringstream msg;
+ msg << "TensorHandle:CopyFrom: Number of bytes in the tensor info (" << GetTensorInfo().GetNumBytes() <<
+ ") does not match the number of bytes being copied (" << numBytes << ")";
+ throw armnn::Exception(msg.str());
+ }
if (srcMemory)
{