diff options
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/Gather.cpp | 27 |
1 files changed, 21 insertions, 6 deletions
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp index 48039432a5..585fc81b00 100644 --- a/src/backends/reference/workloads/Gather.cpp +++ b/src/backends/reference/workloads/Gather.cpp @@ -1,12 +1,13 @@ // -// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "Gather.hpp" #include <armnn/backends/WorkloadData.hpp> -#include <armnn/utility/NumericCast.hpp> + +#include <fmt/format.h> namespace armnn { @@ -22,7 +23,11 @@ void Gather(const TensorInfo& paramsInfo, IgnoreUnused(outputInfo); const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions()); - ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank); + if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int)) + { + throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range", + axis_int, paramsRank, paramsRank))); + } const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int) : static_cast<unsigned int>(axis_int); @@ -47,8 +52,15 @@ void Gather(const TensorInfo& paramsInfo, { for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j) { - unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]); - ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]); + unsigned int index = + (indices[j] < 0) ? static_cast<unsigned int>(static_cast<int>(paramsShape[axis]) + indices[j]) + : static_cast<unsigned int>(indices[j]); + + if (index >= paramsShape[axis]) + { + throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}", + index, paramsShape[axis] ))); + } unsigned int startOffset = (paramsInnerProduct * index) + offset; unsigned int endOffset = startOffset + paramsInnerProduct; @@ -65,7 +77,10 @@ void Gather(const TensorInfo& paramsInfo, offset += paramsShape[axis] * paramsInnerProduct; } - ARMNN_ASSERT(outIndex == outputInfo.GetNumElements()); + if (outIndex != outputInfo.GetNumElements()) + { + throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex))); + } } } //namespace armnn
\ No newline at end of file |