aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCiara Sookarry <ciara.sookarry@arm.com>2023-10-11 17:04:04 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2023-10-25 15:04:40 +0000
commitabd3c214fa6f32075b243eb9d2f58ee75b8979e0 (patch)
treefa20b16084e5177221b8ae3441c878b4b2d2f955
parent3e4b60897bde2ad7ab5b730c7c5d727e41cc0eef (diff)
downloadarmnn-abd3c214fa6f32075b243eb9d2f58ee75b8979e0.tar.gz
IVGCVSW-7751 DTS: Fix Gather and GatherNd Tests in CpuRef
* Report unsupported when indices have negative values Signed-off-by: Ciara Sookarry <ciara.sookarry@arm.com> Change-Id: I9592dcd8c5556d57bedc0d2236f0338c83e597d2
-rw-r--r--src/backends/reference/workloads/RefGatherNdWorkload.cpp9
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.cpp9
2 files changed, 18 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefGatherNdWorkload.cpp b/src/backends/reference/workloads/RefGatherNdWorkload.cpp
index 6d98d54a77..fa820aeb82 100644
--- a/src/backends/reference/workloads/RefGatherNdWorkload.cpp
+++ b/src/backends/reference/workloads/RefGatherNdWorkload.cpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
+#include <fmt/format.h>
#include "RefGatherNdWorkload.hpp"
#include "Gather.hpp"
@@ -36,6 +37,14 @@ void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vecto
const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map());
std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements());
+ // Check for negative indices, it could not be checked in validate as we do not have access to the values there
+ for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
+ {
+ if (indices[i] < 0)
+ {
+ throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i)));
+ }
+ }
std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp
index 129dcf1b27..a5cc998b40 100644
--- a/src/backends/reference/workloads/RefGatherWorkload.cpp
+++ b/src/backends/reference/workloads/RefGatherWorkload.cpp
@@ -9,6 +9,7 @@
#include "Profiling.hpp"
#include "RefWorkloadUtils.hpp"
#include <ResolveType.hpp>
+#include <fmt/format.h>
namespace armnn
{
@@ -36,6 +37,14 @@ void RefGatherWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<
Decoder<float>& decoder = *decoderPtr;
const int32_t* indicesData = reinterpret_cast<int32_t*>(inputs[1]->Map());
+ // Check for negative indices, it could not be checked in validate as we do not have access to the values there
+ for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
+ {
+ if (indicesData[i] < 0)
+ {
+ throw InvalidArgumentException((fmt::format("Gather: indices[{}] < 0", i)));
+ }
+ }
std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
Encoder<float>& encoder = *encoderPtr;