From 49f8d6aeb645595289f04061cdcfefc404dd1652 Mon Sep 17 00:00:00 2001 From: Kevin May Date: Thu, 1 Jun 2023 16:42:05 +0100 Subject: IVGCVSW-7691 Replace asserts with exceptions in Ref Gather Signed-off-by: Kevin May Change-Id: If6731b4757257d983c09210b50315cd5d9837e20 --- src/backends/reference/workloads/Gather.cpp | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp index 48039432a5..585fc81b00 100644 --- a/src/backends/reference/workloads/Gather.cpp +++ b/src/backends/reference/workloads/Gather.cpp @@ -1,12 +1,13 @@ // -// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "Gather.hpp" #include -#include + +#include namespace armnn { @@ -22,7 +23,11 @@ void Gather(const TensorInfo& paramsInfo, IgnoreUnused(outputInfo); const int paramsRank = static_cast(paramsInfo.GetNumDimensions()); - ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank); + if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int)) + { + throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range", + axis_int, paramsRank, paramsRank))); + } const unsigned int axis = (axis_int < 0) ? static_cast(paramsRank + axis_int) : static_cast(axis_int); @@ -47,8 +52,15 @@ void Gather(const TensorInfo& paramsInfo, { for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j) { - unsigned int index = armnn::numeric_cast(indices[j]); - ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]); + unsigned int index = + (indices[j] < 0) ? static_cast(static_cast(paramsShape[axis]) + indices[j]) + : static_cast(indices[j]); + + if (index >= paramsShape[axis]) + { + throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}", + index, paramsShape[axis] ))); + } unsigned int startOffset = (paramsInnerProduct * index) + offset; unsigned int endOffset = startOffset + paramsInnerProduct; @@ -65,7 +77,10 @@ void Gather(const TensorInfo& paramsInfo, offset += paramsShape[axis] * paramsInnerProduct; } - ARMNN_ASSERT(outIndex == outputInfo.GetNumElements()); + if (outIndex != outputInfo.GetNumElements()) + { + throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex))); + } } } //namespace armnn \ No newline at end of file -- cgit v1.2.1