From a25582c4e7dddd26419e0a3316614e8309928934 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Wed, 15 Mar 2023 16:52:05 +0000 Subject: Fix the gather layer indices check * If the index is out-of-bound, both CPU and GPU implementations of the gather layer will output 0. Resolves: COMPMID-5964 Signed-off-by: Viet-Hoa Do Change-Id: Ib029b3acfb31452f2097c8c75448fb2697cfa332 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9487 Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez Tello Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- tests/validation/reference/Gather.cpp | 36 ++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) (limited to 'tests/validation/reference') diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 12d1a3cd3c..c90c04f8cc 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -46,10 +46,14 @@ SimpleTensor gather(const SimpleTensor &src, const SimpleTensor const auto indices_ptr = static_cast(indices.data()); const auto dst_ptr = static_cast(dst.data()); + const uint32_t index_limit = src.shape()[actual_axis]; + Window win; win.use_tensor_dimensions(dst_shape); execute_window_loop(win, [&](const Coordinates &dst_coords) { + const auto dst_addr = coords2index(dst.shape(), dst_coords); + // Calculate the coordinates of the index value. Coordinates idx_coords; @@ -58,23 +62,33 @@ SimpleTensor gather(const SimpleTensor &src, const SimpleTensor idx_coords.set(i, dst_coords[i + actual_axis]); } - // Calculate the coordinates of the source data. - Coordinates src_coords; + const auto index = indices_ptr[coords2index(indices.shape(), idx_coords)]; - for(size_t i = 0; i < actual_axis; ++i) + if(index < index_limit) { - src_coords.set(i, dst_coords[i]); - } + // Calculate the coordinates of the source data. + Coordinates src_coords; + + for(size_t i = 0; i < actual_axis; ++i) + { + src_coords.set(i, dst_coords[i]); + } - src_coords.set(actual_axis, indices_ptr[coords2index(indices.shape(), idx_coords)]); + src_coords.set(actual_axis, index); - for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) + for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) + { + src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); + } + + // Copy the data. + const auto src_addr = coords2index(src.shape(), src_coords); + dst_ptr[dst_addr] = src_ptr[src_addr]; + } + else { - src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); + dst_ptr[dst_addr] = 0; } - - // Copy the data. - dst_ptr[coords2index(dst.shape(), dst_coords)] = src_ptr[coords2index(src.shape(), src_coords)]; }); return dst; -- cgit v1.2.1