diff options
Diffstat (limited to 'tests/validation/reference/Gather.cpp')
-rw-r--r-- | tests/validation/reference/Gather.cpp | 74 |
1 files changed, 28 insertions, 46 deletions
diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 8de1a473eb..12d1a3cd3c 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019, 2022 Arm Limited. + * Copyright (c) 2018-2019, 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,61 +39,43 @@ namespace reference template <typename T> SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> &indices, uint32_t actual_axis) { - const auto *indices_ptr = static_cast<const uint32_t *>(indices.data()); const TensorShape dst_shape = arm_compute::misc::shape_calculator::compute_gather_shape(src.shape(), indices.shape(), actual_axis); SimpleTensor<T> dst(dst_shape, src.data_type()); + const auto src_ptr = static_cast<const T *>(src.data()); + const auto indices_ptr = static_cast<const uint32_t *>(indices.data()); + const auto dst_ptr = static_cast<T *>(dst.data()); + Window win; win.use_tensor_dimensions(dst_shape); - if(indices.shape().num_dimensions() == 1u) - { - execute_window_loop(win, [&](const Coordinates & id) + + execute_window_loop(win, [&](const Coordinates &dst_coords) { + // Calculate the coordinates of the index value. + Coordinates idx_coords; + + for(size_t i = 0; i < indices.shape().num_dimensions(); ++i) { - Coordinates offset; - for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim) - { - if(dim == actual_axis) - { - offset.set(dim, indices_ptr[id[dim]]); - } - else - { - offset.set(dim, id[dim]); - } - } - *reinterpret_cast<T *>(dst(id)) = *reinterpret_cast<const T *>(src(offset)); - }); - } - else - { - if(actual_axis == 1) + idx_coords.set(i, dst_coords[i + actual_axis]); + } + + // Calculate the coordinates of the source data. + Coordinates src_coords; + + for(size_t i = 0; i < actual_axis; ++i) { - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - execute_window_loop(win, [&](const Coordinates & id) - { - auto *dst_ptr = dst(id); - Coordinates index_offset; - for(uint32_t k = 0; k < indices.shape().num_dimensions(); ++k) - { - index_offset.set(k, id[k + 1]); - } - const uint32_t row = *reinterpret_cast<const uint32_t *>(indices(index_offset)); - Coordinates src_offset; - src_offset.set(0, 0); - src_offset.set(1, row); - for(uint32_t j = 2; j < src.shape().num_dimensions(); ++j) - { - src_offset.set(j, id[1 + indices.shape().num_dimensions() + (j - 2)]); - } - const auto in_ptr_row = src(src_offset); - memcpy(dst_ptr, in_ptr_row, src.shape()[0] * src.element_size()); - }); + src_coords.set(i, dst_coords[i]); } - else + + src_coords.set(actual_axis, indices_ptr[coords2index(indices.shape(), idx_coords)]); + + for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) { - ARM_COMPUTE_ERROR("Not implemented."); + src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); } - } + + // Copy the data. + dst_ptr[coords2index(dst.shape(), dst_coords)] = src_ptr[coords2index(src.shape(), src_coords)]; + }); return dst; } |