diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/fixtures/GatherFixture.h | 7 | ||||
-rw-r--r-- | tests/validation/reference/Gather.cpp | 36 |
2 files changed, 30 insertions, 13 deletions
diff --git a/tests/validation/fixtures/GatherFixture.h b/tests/validation/fixtures/GatherFixture.h index 452a201f82..f6f70023b9 100644 --- a/tests/validation/fixtures/GatherFixture.h +++ b/tests/validation/fixtures/GatherFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -67,7 +67,10 @@ protected: std::mt19937 gen(library->seed()); uint32_t *indices_ptr = static_cast<uint32_t *>(indices.data()); - std::uniform_int_distribution<uint32_t> dist_index(0, input_shape[actual_axis] - 1); + // 10% of the time the index is out-of-range. + uint32_t max_index = input_shape[actual_axis] + input_shape[actual_axis] / 9 + 1; + + std::uniform_int_distribution<uint32_t> dist_index(0, max_index - 1); //Let's consider 1D indices for(unsigned int ind = 0; ind < indices_shape[0]; ind++) { diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 12d1a3cd3c..c90c04f8cc 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -46,10 +46,14 @@ SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> const auto indices_ptr = static_cast<const uint32_t *>(indices.data()); const auto dst_ptr = static_cast<T *>(dst.data()); + const uint32_t index_limit = src.shape()[actual_axis]; + Window win; win.use_tensor_dimensions(dst_shape); execute_window_loop(win, [&](const Coordinates &dst_coords) { + const auto dst_addr = coords2index(dst.shape(), dst_coords); + // Calculate the coordinates of the index value. Coordinates idx_coords; @@ -58,23 +62,33 @@ SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> idx_coords.set(i, dst_coords[i + actual_axis]); } - // Calculate the coordinates of the source data. - Coordinates src_coords; + const auto index = indices_ptr[coords2index(indices.shape(), idx_coords)]; - for(size_t i = 0; i < actual_axis; ++i) + if(index < index_limit) { - src_coords.set(i, dst_coords[i]); - } + // Calculate the coordinates of the source data. + Coordinates src_coords; + + for(size_t i = 0; i < actual_axis; ++i) + { + src_coords.set(i, dst_coords[i]); + } - src_coords.set(actual_axis, indices_ptr[coords2index(indices.shape(), idx_coords)]); + src_coords.set(actual_axis, index); - for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) + for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) + { + src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); + } + + // Copy the data. + const auto src_addr = coords2index(src.shape(), src_coords); + dst_ptr[dst_addr] = src_ptr[src_addr]; + } + else { - src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); + dst_ptr[dst_addr] = 0; } - - // Copy the data. - dst_ptr[coords2index(dst.shape(), dst_coords)] = src_ptr[coords2index(src.shape(), src_coords)]; }); return dst; |