aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2023-06-01 16:42:05 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2023-06-09 10:55:17 +0000
commit49f8d6aeb645595289f04061cdcfefc404dd1652 (patch)
treebc560a5d6178e556bf6063e937b0bbbd28b07275
parente6b0e900322a92028b2d10225cd35e5346636bc1 (diff)
downloadarmnn-49f8d6aeb645595289f04061cdcfefc404dd1652.tar.gz
IVGCVSW-7691 Replace asserts with exceptions in Ref Gather
Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: If6731b4757257d983c09210b50315cd5d9837e20
-rw-r--r--src/backends/reference/workloads/Gather.cpp27
1 files changed, 21 insertions, 6 deletions
diff --git a/src/backends/reference/workloads/Gather.cpp b/src/backends/reference/workloads/Gather.cpp
index 48039432a5..585fc81b00 100644
--- a/src/backends/reference/workloads/Gather.cpp
+++ b/src/backends/reference/workloads/Gather.cpp
@@ -1,12 +1,13 @@
//
-// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "Gather.hpp"
#include <armnn/backends/WorkloadData.hpp>
-#include <armnn/utility/NumericCast.hpp>
+
+#include <fmt/format.h>
namespace armnn
{
@@ -22,7 +23,11 @@ void Gather(const TensorInfo& paramsInfo,
IgnoreUnused(outputInfo);
const int paramsRank = static_cast<int>(paramsInfo.GetNumDimensions());
- ARMNN_ASSERT(-1 * paramsRank <= axis_int && axis_int < paramsRank);
+ if((axis_int < -1 * paramsRank) || (paramsRank <= axis_int))
+ {
+ throw InvalidArgumentException((fmt::format("Gather: Axis {} is not within [-{}, {}) range",
+ axis_int, paramsRank, paramsRank)));
+ }
const unsigned int axis = (axis_int < 0) ? static_cast<unsigned int>(paramsRank + axis_int)
: static_cast<unsigned int>(axis_int);
@@ -47,8 +52,15 @@ void Gather(const TensorInfo& paramsInfo,
{
for (unsigned int j = 0; j < indicesInfo.GetNumElements(); ++j)
{
- unsigned int index = armnn::numeric_cast<unsigned int>(indices[j]);
- ARMNN_ASSERT(indices[j] >= 0 && index < paramsShape[axis]);
+ unsigned int index =
+ (indices[j] < 0) ? static_cast<unsigned int>(static_cast<int>(paramsShape[axis]) + indices[j])
+ : static_cast<unsigned int>(indices[j]);
+
+ if (index >= paramsShape[axis])
+ {
+ throw InvalidArgumentException((fmt::format("Gather: index >= paramsShape[axis]: {} >= {}",
+ index, paramsShape[axis] )));
+ }
unsigned int startOffset = (paramsInnerProduct * index) + offset;
unsigned int endOffset = startOffset + paramsInnerProduct;
@@ -65,7 +77,10 @@ void Gather(const TensorInfo& paramsInfo,
offset += paramsShape[axis] * paramsInnerProduct;
}
- ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
+ if (outIndex != outputInfo.GetNumElements())
+ {
+ throw InvalidArgumentException((fmt::format("Gather: Invalid outIndex {} ", outIndex)));
+ }
}
} //namespace armnn \ No newline at end of file