aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefGatherNdWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefGatherNdWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefGatherNdWorkload.cpp9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefGatherNdWorkload.cpp b/src/backends/reference/workloads/RefGatherNdWorkload.cpp
index 6d98d54a77..fa820aeb82 100644
--- a/src/backends/reference/workloads/RefGatherNdWorkload.cpp
+++ b/src/backends/reference/workloads/RefGatherNdWorkload.cpp
@@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT
//
+#include <fmt/format.h>
#include "RefGatherNdWorkload.hpp"
#include "Gather.hpp"
@@ -36,6 +37,14 @@ void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vecto
const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map());
std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements());
+ // Check for negative indices, it could not be checked in validate as we do not have access to the values there
+ for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
+ {
+ if (indices[i] < 0)
+ {
+ throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i)));
+ }
+ }
std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());