From abd3c214fa6f32075b243eb9d2f58ee75b8979e0 Mon Sep 17 00:00:00 2001 From: Ciara Sookarry Date: Wed, 11 Oct 2023 17:04:04 +0100 Subject: IVGCVSW-7751 DTS: Fix Gather and GatherNd Tests in CpuRef * Report unsupported when indices have negative values Signed-off-by: Ciara Sookarry Change-Id: I9592dcd8c5556d57bedc0d2236f0338c83e597d2 --- src/backends/reference/workloads/RefGatherNdWorkload.cpp | 9 +++++++++ src/backends/reference/workloads/RefGatherWorkload.cpp | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/src/backends/reference/workloads/RefGatherNdWorkload.cpp b/src/backends/reference/workloads/RefGatherNdWorkload.cpp index 6d98d54a77..fa820aeb82 100644 --- a/src/backends/reference/workloads/RefGatherNdWorkload.cpp +++ b/src/backends/reference/workloads/RefGatherNdWorkload.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT // +#include #include "RefGatherNdWorkload.hpp" #include "Gather.hpp" @@ -36,6 +37,14 @@ void RefGatherNdWorkload::Execute(std::vector inputs, std::vecto const int32_t* indicesDataPtr = reinterpret_cast(inputs[1]->Map()); std::vector indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements()); + // Check for negative indices, it could not be checked in validate as we do not have access to the values there + for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i) + { + if (indices[i] < 0) + { + throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i))); + } + } std::unique_ptr> output_encoderPtr = MakeEncoder(outputInfo, outputs[0]->Map()); diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp index 129dcf1b27..a5cc998b40 100644 --- a/src/backends/reference/workloads/RefGatherWorkload.cpp +++ b/src/backends/reference/workloads/RefGatherWorkload.cpp @@ -9,6 +9,7 @@ #include "Profiling.hpp" #include "RefWorkloadUtils.hpp" #include +#include namespace armnn { @@ -36,6 +37,14 @@ void RefGatherWorkload::Execute(std::vector inputs, std::vector< Decoder& decoder = *decoderPtr; const int32_t* indicesData = reinterpret_cast(inputs[1]->Map()); + // Check for negative indices, it could not be checked in validate as we do not have access to the values there + for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i) + { + if (indicesData[i] < 0) + { + throw InvalidArgumentException((fmt::format("Gather: indices[{}] < 0", i))); + } + } std::unique_ptr> encoderPtr = MakeEncoder(outputInfo, outputs[0]->Map()); Encoder& encoder = *encoderPtr; -- cgit v1.2.1