diff options
-rw-r--r-- | src/backends/reference/workloads/RefGatherNdWorkload.cpp | 9 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefGatherWorkload.cpp | 9 |
2 files changed, 18 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefGatherNdWorkload.cpp b/src/backends/reference/workloads/RefGatherNdWorkload.cpp index 6d98d54a77..fa820aeb82 100644 --- a/src/backends/reference/workloads/RefGatherNdWorkload.cpp +++ b/src/backends/reference/workloads/RefGatherNdWorkload.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT // +#include <fmt/format.h> #include "RefGatherNdWorkload.hpp" #include "Gather.hpp" @@ -36,6 +37,14 @@ void RefGatherNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vecto const int32_t* indicesDataPtr = reinterpret_cast<int32_t*>(inputs[1]->Map()); std::vector<int32_t> indices(indicesDataPtr, indicesDataPtr + inputInfo1.GetNumElements()); + // Check for negative indices, it could not be checked in validate as we do not have access to the values there + for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i) + { + if (indices[i] < 0) + { + throw InvalidArgumentException((fmt::format("GatherNd: indices[{}] < 0", i))); + } + } std::unique_ptr<Encoder<float>> output_encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map()); diff --git a/src/backends/reference/workloads/RefGatherWorkload.cpp b/src/backends/reference/workloads/RefGatherWorkload.cpp index 129dcf1b27..a5cc998b40 100644 --- a/src/backends/reference/workloads/RefGatherWorkload.cpp +++ b/src/backends/reference/workloads/RefGatherWorkload.cpp @@ -9,6 +9,7 @@ #include "Profiling.hpp" #include "RefWorkloadUtils.hpp" #include <ResolveType.hpp> +#include <fmt/format.h> namespace armnn { @@ -36,6 +37,14 @@ void RefGatherWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector< Decoder<float>& decoder = *decoderPtr; const int32_t* indicesData = reinterpret_cast<int32_t*>(inputs[1]->Map()); + // Check for negative indices, it could not be checked in validate as we do not have access to the values there + for (unsigned int i = 0; i < inputInfo1.GetNumElements(); ++i) + { + if (indicesData[i] < 0) + { + throw InvalidArgumentException((fmt::format("Gather: indices[{}] < 0", i))); + } + } std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map()); Encoder<float>& encoder = *encoderPtr; |