diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/datasets/GatherDataset.h | 9 | ||||
-rw-r--r-- | tests/validation/reference/Gather.cpp | 74 |
2 files changed, 36 insertions, 47 deletions
diff --git a/tests/datasets/GatherDataset.h b/tests/datasets/GatherDataset.h index 8fec5441b1..487ce19bc7 100644 --- a/tests/datasets/GatherDataset.h +++ b/tests/datasets/GatherDataset.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019, 2022 Arm Limited. + * Copyright (c) 2018-2019, 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -116,6 +116,13 @@ public: add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1); add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1); add_config(TensorShape(1U, 5U, 3U), TensorShape(1U, 7U, 3U), 1); + + add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 0); + add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 0); + add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 0); + + add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U), 2); + add_config(TensorShape(8U, 2U, 3U), TensorShape(4U, 2U, 5U), 2); } }; diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 8de1a473eb..12d1a3cd3c 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019, 2022 Arm Limited. + * Copyright (c) 2018-2019, 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,61 +39,43 @@ namespace reference template <typename T> SimpleTensor<T> gather(const SimpleTensor<T> &src, const SimpleTensor<uint32_t> &indices, uint32_t actual_axis) { - const auto *indices_ptr = static_cast<const uint32_t *>(indices.data()); const TensorShape dst_shape = arm_compute::misc::shape_calculator::compute_gather_shape(src.shape(), indices.shape(), actual_axis); SimpleTensor<T> dst(dst_shape, src.data_type()); + const auto src_ptr = static_cast<const T *>(src.data()); + const auto indices_ptr = static_cast<const uint32_t *>(indices.data()); + const auto dst_ptr = static_cast<T *>(dst.data()); + Window win; win.use_tensor_dimensions(dst_shape); - if(indices.shape().num_dimensions() == 1u) - { - execute_window_loop(win, [&](const Coordinates & id) + + execute_window_loop(win, [&](const Coordinates &dst_coords) { + // Calculate the coordinates of the index value. + Coordinates idx_coords; + + for(size_t i = 0; i < indices.shape().num_dimensions(); ++i) { - Coordinates offset; - for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim) - { - if(dim == actual_axis) - { - offset.set(dim, indices_ptr[id[dim]]); - } - else - { - offset.set(dim, id[dim]); - } - } - *reinterpret_cast<T *>(dst(id)) = *reinterpret_cast<const T *>(src(offset)); - }); - } - else - { - if(actual_axis == 1) + idx_coords.set(i, dst_coords[i + actual_axis]); + } + + // Calculate the coordinates of the source data. + Coordinates src_coords; + + for(size_t i = 0; i < actual_axis; ++i) { - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - execute_window_loop(win, [&](const Coordinates & id) - { - auto *dst_ptr = dst(id); - Coordinates index_offset; - for(uint32_t k = 0; k < indices.shape().num_dimensions(); ++k) - { - index_offset.set(k, id[k + 1]); - } - const uint32_t row = *reinterpret_cast<const uint32_t *>(indices(index_offset)); - Coordinates src_offset; - src_offset.set(0, 0); - src_offset.set(1, row); - for(uint32_t j = 2; j < src.shape().num_dimensions(); ++j) - { - src_offset.set(j, id[1 + indices.shape().num_dimensions() + (j - 2)]); - } - const auto in_ptr_row = src(src_offset); - memcpy(dst_ptr, in_ptr_row, src.shape()[0] * src.element_size()); - }); + src_coords.set(i, dst_coords[i]); } - else + + src_coords.set(actual_axis, indices_ptr[coords2index(indices.shape(), idx_coords)]); + + for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) { - ARM_COMPUTE_ERROR("Not implemented."); + src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); } - } + + // Copy the data. + dst_ptr[coords2index(dst.shape(), dst_coords)] = src_ptr[coords2index(src.shape(), src_coords)]; + }); return dst; } |