aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefGatherWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefGatherWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.cpp9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp
index 129dcf1b27..a5cc998b40 100644
--- a/src/backends/reference/workloads/RefGatherWorkload.cpp
+++ b/src/backends/reference/workloads/RefGatherWorkload.cpp
@@ -9,6 +9,7 @@
#include "Profiling.hpp"
#include "RefWorkloadUtils.hpp"
#include <ResolveType.hpp>
+#include <fmt/format.h>
namespace armnn
{
@@ -36,6 +37,14 @@ void RefGatherWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<
Decoder<float>& decoder = *decoderPtr;
const int32_t* indicesData = reinterpret_cast<int32_t*>(inputs[1]->Map());
+ // Check for negative indices, it could not be checked in validate as we do not have access to the values there
+ for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i)
+ {
+ if (indicesData[i] < 0)
+ {
+ throw InvalidArgumentException((fmt::format("Gather: indices[{}] < 0", i)));
+ }
+ }
std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
Encoder<float>& encoder = *encoderPtr;