aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefTensorHandle.cpp26
-rw-r--r--src/backends/reference/workloads/Slice.cpp67
-rw-r--r--src/backends/reference/workloads/StridedSlice.cpp9
3 files changed, 87 insertions, 15 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp
index cce992c947..07f497c54e 100644
--- a/src/backends/reference/RefTensorHandle.cpp
+++ b/src/backends/reference/RefTensorHandle.cpp
@@ -101,16 +101,30 @@ void* RefTensorHandle::GetPointer() const
void RefTensorHandle::CopyOutTo(void* dest) const
{
- const void *src = GetPointer();
- ARMNN_ASSERT(src);
- memcpy(dest, src, m_TensorInfo.GetNumBytes());
+ const void* src = GetPointer();
+ if (src == nullptr)
+ {
+ throw NullPointerException("TensorHandle::CopyOutTo called with a null src pointer");
+ }
+ if (dest == nullptr)
+ {
+ throw NullPointerException("TensorHandle::CopyOutTo called with a null dest pointer");
+ }
+ memcpy(dest, src, GetTensorInfo().GetNumBytes());
}
void RefTensorHandle::CopyInFrom(const void* src)
{
- void *dest = GetPointer();
- ARMNN_ASSERT(dest);
- memcpy(dest, src, m_TensorInfo.GetNumBytes());
+ void* dest = GetPointer();
+ if (dest == nullptr)
+ {
+ throw NullPointerException("RefTensorHandle::CopyInFrom called with a null dest pointer");
+ }
+ if (src == nullptr)
+ {
+ throw NullPointerException("RefTensorHandle::CopyInFrom called with a null src pointer");
+ }
+ memcpy(dest, src, GetTensorInfo().GetNumBytes());
}
MemorySourceFlags RefTensorHandle::GetImportFlags() const
diff --git a/src/backends/reference/workloads/Slice.cpp b/src/backends/reference/workloads/Slice.cpp
index d6836c6933..534a063ed5 100644
--- a/src/backends/reference/workloads/Slice.cpp
+++ b/src/backends/reference/workloads/Slice.cpp
@@ -20,11 +20,28 @@ void Slice(const TensorInfo& inputInfo,
const TensorShape& inputShape = inputInfo.GetShape();
const unsigned int numDims = inputShape.GetNumDimensions();
- ARMNN_ASSERT(descriptor.m_Begin.size() == numDims);
- ARMNN_ASSERT(descriptor.m_Size.size() == numDims);
-
constexpr unsigned int maxNumDims = 4;
- ARMNN_ASSERT(numDims <= maxNumDims);
+ if (descriptor.m_Begin.size() != numDims)
+ {
+ std::stringstream msg;
+ msg << "Slice: Number of dimensions (" << numDims <<
+ ") does not match the Begin vector in the descriptor (" << descriptor.m_Begin.size() << ")";
+ throw InvalidArgumentException(msg.str());
+ }
+ if (descriptor.m_Size.size() != numDims)
+ {
+ std::stringstream msg;
+ msg << "Slice: Number of dimensions (" << numDims <<
+ ") does not match the Size vector in the descriptor (" << descriptor.m_Size.size() << ")";
+ throw InvalidArgumentException(msg.str());
+ }
+ if (numDims > maxNumDims)
+ {
+ std::stringstream msg;
+ msg << "Slice: Number of dimensions (" << numDims <<
+ ") is greater than the maximum supported (" << maxNumDims << ")";
+ throw InvalidArgumentException(msg.str());
+ }
std::vector<unsigned int> paddedInput(4);
std::vector<unsigned int> paddedBegin(4);
@@ -63,15 +80,47 @@ void Slice(const TensorInfo& inputInfo,
unsigned int size2 = paddedSize[2];
unsigned int size3 = paddedSize[3];
- ARMNN_ASSERT(begin0 + size0 <= dim0);
- ARMNN_ASSERT(begin1 + size1 <= dim1);
- ARMNN_ASSERT(begin2 + size2 <= dim2);
- ARMNN_ASSERT(begin3 + size3 <= dim3);
+ if (begin0 + size0 > dim0)
+ {
+ std::stringstream msg;
+ msg << "Slice: begin0 + size0 (" << (begin0 + size0) <<
+ ") exceeds dim0 (" << dim0 << ")";
+ throw InvalidArgumentException(msg.str());
+ }
+ if (begin1 + size1 > dim1)
+ {
+ std::stringstream msg;
+ msg << "Slice: begin1 + size1 (" << (begin1 + size1) <<
+ ") exceeds dim2 (" << dim1 << ")";
+ throw InvalidArgumentException(msg.str());
+ }
+ if (begin2 + size2 > dim2)
+ {
+ std::stringstream msg;
+ msg << "Slice: begin2 + size2 (" << (begin2 + size2) <<
+ ") exceeds dim2 (" << dim2 << ")";
+ throw InvalidArgumentException(msg.str());
+ }
+ if (begin3 + size3 > dim3)
+ {
+ std::stringstream msg;
+ msg << "Slice: begin3 + size3 (" << (begin3 + size3) <<
+ ") exceeds dim3 (" << dim3 << ")";
+ throw InvalidArgumentException(msg.str());
+ }
+
+ if (inputData == nullptr)
+ {
+ throw armnn::NullPointerException("Slice: Null inputData pointer");
+ }
+ if (outputData == nullptr)
+ {
+ throw armnn::NullPointerException("Slice: Null outputData pointer");
+ }
const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData);
unsigned char* output = reinterpret_cast<unsigned char*>(outputData);
- IgnoreUnused(dim0);
for (unsigned int idx0 = begin0; idx0 < begin0 + size0; ++idx0)
{
for (unsigned int idx1 = begin1; idx1 < begin1 + size1; ++idx1)
diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp
index c5fb121cb3..68600c9a95 100644
--- a/src/backends/reference/workloads/StridedSlice.cpp
+++ b/src/backends/reference/workloads/StridedSlice.cpp
@@ -93,6 +93,15 @@ void StridedSlice(const TensorInfo& inputInfo,
void* outputData,
unsigned int dataTypeSize)
{
+ if (inputData == nullptr)
+ {
+ throw armnn::InvalidArgumentException("Slice: Null inputData pointer");
+ }
+ if (outputData == nullptr)
+ {
+ throw armnn::InvalidArgumentException("Slice: Null outputData pointer");
+ }
+
const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData);
unsigned char* output = reinterpret_cast<unsigned char*>(outputData);