diff options
author | Kevin May <kevin.may@arm.com> | 2023-06-01 16:42:05 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2023-06-09 10:55:17 +0000 |
commit | 49f8d6aeb645595289f04061cdcfefc404dd1652 (patch) | |
tree | bc560a5d6178e556bf6063e937b0bbbd28b07275 | |
parent | e6b0e900322a92028b2d10225cd35e5346636bc1 (diff) | |
download | armnn-49f8d6aeb645595289f04061cdcfefc404dd1652.tar.gz |
IVGCVSW-7691 Replace asserts with exceptions in Ref Gather
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: If6731b4757257d983c09210b50315cd5d9837e20
-rw-r--r-- | src/backends/reference/workloads/Gather.cpp | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp index 48039432a5..585fc81b00 100644 --- a/src/backends/reference/workloads/Gather.cpp +++ b/src/backends/reference/workloads/Gather.cpp @@ -1,12 +1,13 @@ // -// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "Gather.hpp" #include <armnn/backends/WorkloadData.hpp> -#include <armnn/utility/NumericCast.hpp> + +#include <fmt/format.h> namespace armnn { @@ -22,7 +23,11 @@ void Gather(const TensorInfo& paramsInfo, IgnoreUnused(outputInfo); const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions()); - ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank); + if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int)) + { + throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range", + axis_int, paramsRank, paramsRank))); + } const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int) : static_cast<unsigned int>(axis_int); @@ -47,8 +52,15 @@ void Gather(const TensorInfo& paramsInfo, { for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j) { - unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]); - ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]); + unsigned int index = + (indices[j] < 0) ? static_cast<unsigned int>(static_cast<int>(paramsShape[axis]) + indices[j]) + : static_cast<unsigned int>(indices[j]); + + if (index >= paramsShape[axis]) + { + throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}", + index, paramsShape[axis] ))); + } unsigned int startOffset = (paramsInnerProduct * index) + offset; unsigned int endOffset = startOffset + paramsInnerProduct; @@ -65,7 +77,10 @@ void Gather(const TensorInfo& paramsInfo, offset += paramsShape[axis] * paramsInnerProduct; } - ARMNN_ASSERT(outIndex == outputInfo.GetNumElements()); + if (outIndex != outputInfo.GetNumElements()) + { + throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex))); + } } } //namespace armnn
\ No newline at end of file |