diff options
Diffstat (limited to 'tests/validation/reference/Gather.cpp')
-rw-r--r-- | tests/validation/reference/Gather.cpp | 53 |
1 files changed, 41 insertions, 12 deletions
diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 93ac09cf95..c90c04f8cc 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,27 +39,56 @@ namespace reference template <typename T> SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> &indices, uint32_t actual_axis) { - const auto *indices_ptr = static_cast<const uint32_t *>(indices.data()); const TensorShape dst_shape = arm_compute::misc::shape_calculator::compute_gather_shape(src.shape(), indices.shape(), actual_axis); SimpleTensor<T> dst(dst_shape, src.data_type()); + const auto src_ptr = static_cast<const T *>(src.data()); + const auto indices_ptr = static_cast<const uint32_t *>(indices.data()); + const auto dst_ptr = static_cast<T *>(dst.data()); + + const uint32_t index_limit = src.shape()[actual_axis]; + Window win; win.use_tensor_dimensions(dst_shape); - execute_window_loop(win, [&](const Coordinates & id) - { - Coordinates offset; - for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim) + + execute_window_loop(win, [&](const Coordinates &dst_coords) { + const auto dst_addr = coords2index(dst.shape(), dst_coords); + + // Calculate the coordinates of the index value. + Coordinates idx_coords; + + for(size_t i = 0; i < indices.shape().num_dimensions(); ++i) { - if(dim == actual_axis) + idx_coords.set(i, dst_coords[i + actual_axis]); + } + + const auto index = indices_ptr[coords2index(indices.shape(), idx_coords)]; + + if(index < index_limit) + { + // Calculate the coordinates of the source data. + Coordinates src_coords; + + for(size_t i = 0; i < actual_axis; ++i) { - offset.set(dim, indices_ptr[id[dim]]); + src_coords.set(i, dst_coords[i]); } - else + + src_coords.set(actual_axis, index); + + for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) { - offset.set(dim, id[dim]); + src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); } + + // Copy the data. + const auto src_addr = coords2index(src.shape(), src_coords); + dst_ptr[dst_addr] = src_ptr[src_addr]; + } + else + { + dst_ptr[dst_addr] = 0; } - *reinterpret_cast<T *>(dst(id)) = *reinterpret_cast<const T *>(src(offset)); }); return dst; @@ -72,4 +101,4 @@ template SimpleTensor<uint8_t> gather(const SimpleTensor<uint8_t> &src, const Si } // namespace reference } // namespace validation } // namespace test -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |