diff options
author | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-03-15 16:52:05 +0000 |
---|---|---|
committer | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-04-28 13:16:39 +0000 |
commit | a25582c4e7dddd26419e0a3316614e8309928934 (patch) | |
tree | 3e71a83870f561d2abb8df802c56009224628152 /tests/validation/reference | |
parent | eaae8999ac8027a5fb96162061ad8ccc490515cb (diff) | |
download | ComputeLibrary-a25582c4e7dddd26419e0a3316614e8309928934.tar.gz |
Fix the gather layer indices check
* If the index is out-of-bound, both CPU and GPU implementations
of the gather layer will output 0.
Resolves: COMPMID-5964
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: Ib029b3acfb31452f2097c8c75448fb2697cfa332
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9487
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/Gather.cpp | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 12d1a3cd3c..c90c04f8cc 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -46,10 +46,14 @@ SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> const auto indices_ptr = static_cast<const uint32_t *>(indices.data()); const auto dst_ptr = static_cast<T *>(dst.data()); + const uint32_t index_limit = src.shape()[actual_axis]; + Window win; win.use_tensor_dimensions(dst_shape); execute_window_loop(win, [&](const Coordinates &dst_coords) { + const auto dst_addr = coords2index(dst.shape(), dst_coords); + // Calculate the coordinates of the index value. Coordinates idx_coords; @@ -58,23 +62,33 @@ SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> idx_coords.set(i, dst_coords[i + actual_axis]); } - // Calculate the coordinates of the source data. - Coordinates src_coords; + const auto index = indices_ptr[coords2index(indices.shape(), idx_coords)]; - for(size_t i = 0; i < actual_axis; ++i) + if(index < index_limit) { - src_coords.set(i, dst_coords[i]); - } + // Calculate the coordinates of the source data. + Coordinates src_coords; + + for(size_t i = 0; i < actual_axis; ++i) + { + src_coords.set(i, dst_coords[i]); + } - src_coords.set(actual_axis, indices_ptr[coords2index(indices.shape(), idx_coords)]); + src_coords.set(actual_axis, index); - for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) + for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) + { + src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); + } + + // Copy the data. + const auto src_addr = coords2index(src.shape(), src_coords); + dst_ptr[dst_addr] = src_ptr[src_addr]; + } + else { - src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); + dst_ptr[dst_addr] = 0; } - - // Copy the data. - dst_ptr[coords2index(dst.shape(), dst_coords)] = src_ptr[coords2index(src.shape(), src_coords)]; }); return dst; |