diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefTensorHandle.cpp | 26 | ||||
-rw-r--r-- | src/backends/reference/workloads/Slice.cpp | 67 | ||||
-rw-r--r-- | src/backends/reference/workloads/StridedSlice.cpp | 9 |
3 files changed, 87 insertions, 15 deletions
diff --git a/src/backends/reference/RefTensorHandle.cpp b/src/backends/reference/RefTensorHandle.cpp index cce992c947..07f497c54e 100644 --- a/src/backends/reference/RefTensorHandle.cpp +++ b/src/backends/reference/RefTensorHandle.cpp @@ -101,16 +101,30 @@ void* RefTensorHandle::GetPointer() const void RefTensorHandle::CopyOutTo(void* dest) const { - const void *src = GetPointer(); - ARMNN_ASSERT(src); - memcpy(dest, src, m_TensorInfo.GetNumBytes()); + const void* src = GetPointer(); + if (src == nullptr) + { + throw NullPointerException("TensorHandle::CopyOutTo called with a null src pointer"); + } + if (dest == nullptr) + { + throw NullPointerException("TensorHandle::CopyOutTo called with a null dest pointer"); + } + memcpy(dest, src, GetTensorInfo().GetNumBytes()); } void RefTensorHandle::CopyInFrom(const void* src) { - void *dest = GetPointer(); - ARMNN_ASSERT(dest); - memcpy(dest, src, m_TensorInfo.GetNumBytes()); + void* dest = GetPointer(); + if (dest == nullptr) + { + throw NullPointerException("RefTensorHandle::CopyInFrom called with a null dest pointer"); + } + if (src == nullptr) + { + throw NullPointerException("RefTensorHandle::CopyInFrom called with a null src pointer"); + } + memcpy(dest, src, GetTensorInfo().GetNumBytes()); } MemorySourceFlags RefTensorHandle::GetImportFlags() const diff --git a/src/backends/reference/workloads/Slice.cpp b/src/backends/reference/workloads/Slice.cpp index d6836c6933..534a063ed5 100644 --- a/src/backends/reference/workloads/Slice.cpp +++ b/src/backends/reference/workloads/Slice.cpp @@ -20,11 +20,28 @@ void Slice(const TensorInfo& inputInfo, const TensorShape& inputShape = inputInfo.GetShape(); const unsigned int numDims = inputShape.GetNumDimensions(); - ARMNN_ASSERT(descriptor.m_Begin.size() == numDims); - ARMNN_ASSERT(descriptor.m_Size.size() == numDims); - constexpr unsigned int maxNumDims = 4; - ARMNN_ASSERT(numDims <= maxNumDims); + if (descriptor.m_Begin.size() != numDims) + { + std::stringstream msg; + msg << "Slice: Number of dimensions (" << numDims << + ") does not match the Begin vector in the descriptor (" << descriptor.m_Begin.size() << ")"; + throw InvalidArgumentException(msg.str()); + } + if (descriptor.m_Size.size() != numDims) + { + std::stringstream msg; + msg << "Slice: Number of dimensions (" << numDims << + ") does not match the Size vector in the descriptor (" << descriptor.m_Size.size() << ")"; + throw InvalidArgumentException(msg.str()); + } + if (numDims > maxNumDims) + { + std::stringstream msg; + msg << "Slice: Number of dimensions (" << numDims << + ") is greater than the maximum supported (" << maxNumDims << ")"; + throw InvalidArgumentException(msg.str()); + } std::vector<unsigned int> paddedInput(4); std::vector<unsigned int> paddedBegin(4); @@ -63,15 +80,47 @@ void Slice(const TensorInfo& inputInfo, unsigned int size2 = paddedSize[2]; unsigned int size3 = paddedSize[3]; - ARMNN_ASSERT(begin0 + size0 <= dim0); - ARMNN_ASSERT(begin1 + size1 <= dim1); - ARMNN_ASSERT(begin2 + size2 <= dim2); - ARMNN_ASSERT(begin3 + size3 <= dim3); + if (begin0 + size0 > dim0) + { + std::stringstream msg; + msg << "Slice: begin0 + size0 (" << (begin0 + size0) << + ") exceeds dim0 (" << dim0 << ")"; + throw InvalidArgumentException(msg.str()); + } + if (begin1 + size1 > dim1) + { + std::stringstream msg; + msg << "Slice: begin1 + size1 (" << (begin1 + size1) << + ") exceeds dim2 (" << dim1 << ")"; + throw InvalidArgumentException(msg.str()); + } + if (begin2 + size2 > dim2) + { + std::stringstream msg; + msg << "Slice: begin2 + size2 (" << (begin2 + size2) << + ") exceeds dim2 (" << dim2 << ")"; + throw InvalidArgumentException(msg.str()); + } + if (begin3 + size3 > dim3) + { + std::stringstream msg; + msg << "Slice: begin3 + size3 (" << (begin3 + size3) << + ") exceeds dim3 (" << dim3 << ")"; + throw InvalidArgumentException(msg.str()); + } + + if (inputData == nullptr) + { + throw armnn::NullPointerException("Slice: Null inputData pointer"); + } + if (outputData == nullptr) + { + throw armnn::NullPointerException("Slice: Null outputData pointer"); + } const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData); unsigned char* output = reinterpret_cast<unsigned char*>(outputData); - IgnoreUnused(dim0); for (unsigned int idx0 = begin0; idx0 < begin0 + size0; ++idx0) { for (unsigned int idx1 = begin1; idx1 < begin1 + size1; ++idx1) diff --git a/src/backends/reference/workloads/StridedSlice.cpp b/src/backends/reference/workloads/StridedSlice.cpp index c5fb121cb3..68600c9a95 100644 --- a/src/backends/reference/workloads/StridedSlice.cpp +++ b/src/backends/reference/workloads/StridedSlice.cpp @@ -93,6 +93,15 @@ void StridedSlice(const TensorInfo& inputInfo, void* outputData, unsigned int dataTypeSize) { + if (inputData == nullptr) + { + throw armnn::InvalidArgumentException("Slice: Null inputData pointer"); + } + if (outputData == nullptr) + { + throw armnn::InvalidArgumentException("Slice: Null outputData pointer"); + } + const unsigned char* input = reinterpret_cast<const unsigned char*>(inputData); unsigned char* output = reinterpret_cast<unsigned char*>(outputData); |