diff options
author | Ciara Sookarry <ciara.sookarry@arm.com> | 2023-10-11 17:04:04 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2023-10-25 15:04:40 +0000 |
commit | abd3c214fa6f32075b243eb9d2f58ee75b8979e0 (patch) | |
tree | fa20b16084e5177221b8ae3441c878b4b2d2f955 /src/backends | |
parent | 3e4b60897bde2ad7ab5b730c7c5d727e41cc0eef (diff) | |
download | armnn-abd3c214fa6f32075b243eb9d2f58ee75b8979e0.tar.gz |
IVGCVSW-7751 DTS: Fix Gather and GatherNd Tests in CpuRef
* Report unsupported when indices have negative values
Signed-off-by: Ciara Sookarry <ciara.sookarry@arm.com>
Change-Id: I9592dcd8c5556d57bedc0d2236f0338c83e597d2
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/reference/workloads/RefGatherNdWorkload.cpp | 9 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefGatherWorkload.cpp | 9 |
2 files changed, 18 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefGatherNdWorkload.cpp b/src/backends/reference/workloads/RefGatherNdWorkload.cpp index 6d98d54a77..fa820aeb82 100644 --- a/src/backends/reference/workloads/RefGatherNdWorkload.cpp +++ b/src/backends/reference/workloads/RefGatherNdWorkload.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT // +#include <fmt/format.h> #include "RefGatherNdWorkload.hpp" #include "Gather.hpp" @@ -36,6 +37,14 @@ void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vecto const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map()); std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements()); + // Check for negative indices, it could not be checked in validate as we do not have access to the values there + for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i) + { + if (indices[i] < 0) + { + throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i))); + } + } std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map()); diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp index 129dcf1b27..a5cc998b40 100644 --- a/src/backends/reference/workloads/RefGatherWorkload.cpp +++ b/src/backends/reference/workloads/RefGatherWorkload.cpp @@ -9,6 +9,7 @@ #include "Profiling.hpp" #include "RefWorkloadUtils.hpp" #include <ResolveType.hpp> +#include <fmt/format.h> namespace armnn { @@ -36,6 +37,14 @@ void RefGatherWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector< Decoder<float>& decoder = *decoderPtr; const int32_t* indicesData = reinterpret_cast<int32_t*>(inputs[1]->Map()); + // Check for negative indices, it could not be checked in validate as we do not have access to the values there + for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i) + { + if (indicesData[i] < 0) + { + throw InvalidArgumentException((fmt::format("Gather: indices[{}] < 0", i))); + } + } std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map()); Encoder<float>& encoder = *encoderPtr; |