aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/reference/workloads/RefGatherNdWorkload.cpp9
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.cpp9
2 files changed, 18 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefGatherNdWorkload.cpp b/src/backends/reference/workloads/RefGatherNdWorkload.cpp
index 6d98d54a77..fa820aeb82 100644
--- a/src/backends/reference/workloads/RefGatherNdWorkload.cpp
+++ b/src/backends/reference/workloads/RefGatherNdWorkload.cpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
+#include <fmt/format.h>
#include "RefGatherNdWorkload.hpp"
#include "Gather.hpp"
@@ -36,6 +37,14 @@ void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vecto
const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map());
std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements());
+ // Check for negative indices, it could not be checked in validate as we do not have access to the values there
+ for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
+ {
+ if (indices[i] < 0)
+ {
+ throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i)));
+ }
+ }
std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp
index 129dcf1b27..a5cc998b40 100644
--- a/src/backends/reference/workloads/RefGatherWorkload.cpp
+++ b/src/backends/reference/workloads/RefGatherWorkload.cpp
@@ -9,6 +9,7 @@
#include "Profiling.hpp"
#include "RefWorkloadUtils.hpp"
#include <ResolveType.hpp>
+#include <fmt/format.h>
namespace armnn
{
@@ -36,6 +37,14 @@ void RefGatherWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<
Decoder<float>& decoder = *decoderPtr;
const int32_t* indicesData = reinterpret_cast<int32_t*>(inputs[1]->Map());
+ // Check for negative indices, it could not be checked in validate as we do not have access to the values there
+ for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
+ {
+ if (indicesData[i] < 0)
+ {
+ throw InvalidArgumentException((fmt::format("Gather: indices[{}] < 0", i)));
+ }
+ }
std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
Encoder<float>& encoder = *encoderPtr;