diff options
Diffstat (limited to 'tests/validation/reference/Gather.cpp')
-rw-r--r-- | tests/validation/reference/Gather.cpp | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 12d1a3cd3c..c90c04f8cc 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -46,10 +46,14 @@ SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> const auto indices_ptr = static_cast<const uint32_t *>(indices.data()); const auto dst_ptr = static_cast<T *>(dst.data()); + const uint32_t index_limit = src.shape()[actual_axis]; + Window win; win.use_tensor_dimensions(dst_shape); execute_window_loop(win, [&](const Coordinates &dst_coords) { + const auto dst_addr = coords2index(dst.shape(), dst_coords); + // Calculate the coordinates of the index value. Coordinates idx_coords; @@ -58,23 +62,33 @@ SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> idx_coords.set(i, dst_coords[i + actual_axis]); } - // Calculate the coordinates of the source data. - Coordinates src_coords; + const auto index = indices_ptr[coords2index(indices.shape(), idx_coords)]; - for(size_t i = 0; i < actual_axis; ++i) + if(index < index_limit) { - src_coords.set(i, dst_coords[i]); - } + // Calculate the coordinates of the source data. + Coordinates src_coords; + + for(size_t i = 0; i < actual_axis; ++i) + { + src_coords.set(i, dst_coords[i]); + } - src_coords.set(actual_axis, indices_ptr[coords2index(indices.shape(), idx_coords)]); + src_coords.set(actual_axis, index); - for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) + for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) + { + src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); + } + + // Copy the data. + const auto src_addr = coords2index(src.shape(), src_coords); + dst_ptr[dst_addr] = src_ptr[src_addr]; + } + else { - src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); + dst_ptr[dst_addr] = 0; } - - // Copy the data. - dst_ptr[coords2index(dst.shape(), dst_coords)] = src_ptr[coords2index(src.shape(), src_coords)]; }); return dst; |