From 37c989a58a04985dfdc21089c7dacc7e1925a4d0 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Fri, 24 Feb 2023 15:52:21 +0000 Subject: Add support for arbitrary parameters for CPU Gather * The shape of input and indices tensors, and the gather axis can be any number, as long as these are valid and the output tensor doesn't have more dimensions than the library supports. * Update the reference code to be more generic and straightforward. * Add necessary test cases. Signed-off-by: Viet-Hoa Do Resolves: COMPMID-5919 Change-Id: Ic7e2032777aa97ecc147f61d5388528697508ab1 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9199 Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- tests/validation/reference/Gather.cpp | 74 +++++++++++++---------------------- 1 file changed, 28 insertions(+), 46 deletions(-) (limited to 'tests/validation/reference') diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 8de1a473eb..12d1a3cd3c 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019, 2022 Arm Limited. + * Copyright (c) 2018-2019, 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,61 +39,43 @@ namespace reference template SimpleTensor gather(const SimpleTensor &src, const SimpleTensor &indices, uint32_t actual_axis) { - const auto *indices_ptr = static_cast(indices.data()); const TensorShape dst_shape = arm_compute::misc::shape_calculator::compute_gather_shape(src.shape(), indices.shape(), actual_axis); SimpleTensor dst(dst_shape, src.data_type()); + const auto src_ptr = static_cast(src.data()); + const auto indices_ptr = static_cast(indices.data()); + const auto dst_ptr = static_cast(dst.data()); + Window win; win.use_tensor_dimensions(dst_shape); - if(indices.shape().num_dimensions() == 1u) - { - execute_window_loop(win, [&](const Coordinates & id) + + execute_window_loop(win, [&](const Coordinates &dst_coords) { + // Calculate the coordinates of the index value. + Coordinates idx_coords; + + for(size_t i = 0; i < indices.shape().num_dimensions(); ++i) { - Coordinates offset; - for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim) - { - if(dim == actual_axis) - { - offset.set(dim, indices_ptr[id[dim]]); - } - else - { - offset.set(dim, id[dim]); - } - } - *reinterpret_cast(dst(id)) = *reinterpret_cast(src(offset)); - }); - } - else - { - if(actual_axis == 1) + idx_coords.set(i, dst_coords[i + actual_axis]); + } + + // Calculate the coordinates of the source data. + Coordinates src_coords; + + for(size_t i = 0; i < actual_axis; ++i) { - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - execute_window_loop(win, [&](const Coordinates & id) - { - auto *dst_ptr = dst(id); - Coordinates index_offset; - for(uint32_t k = 0; k < indices.shape().num_dimensions(); ++k) - { - index_offset.set(k, id[k + 1]); - } - const uint32_t row = *reinterpret_cast(indices(index_offset)); - Coordinates src_offset; - src_offset.set(0, 0); - src_offset.set(1, row); - for(uint32_t j = 2; j < src.shape().num_dimensions(); ++j) - { - src_offset.set(j, id[1 + indices.shape().num_dimensions() + (j - 2)]); - } - const auto in_ptr_row = src(src_offset); - memcpy(dst_ptr, in_ptr_row, src.shape()[0] * src.element_size()); - }); + src_coords.set(i, dst_coords[i]); } - else + + src_coords.set(actual_axis, indices_ptr[coords2index(indices.shape(), idx_coords)]); + + for(size_t i = actual_axis + 1; i < src.shape().num_dimensions(); ++i) { - ARM_COMPUTE_ERROR("Not implemented."); + src_coords.set(i, dst_coords[i + indices.shape().num_dimensions() - 1]); } - } + + // Copy the data. + dst_ptr[coords2index(dst.shape(), dst_coords)] = src_ptr[coords2index(src.shape(), src_coords)]; + }); return dst; } -- cgit v1.2.1