From ada3200f5cec0b6a37f898d5d6f8e69395d7bcb1 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Wed, 24 Apr 2024 10:27:13 +0100 Subject: Add update/index/output (m+1)/2d/(m+n) support for CLScatter Resolves: COMPMID-6894, COMPMID-6896 Change-Id: I9d29fd3701a7e0f28d83f81a6c42a7234c2587c3 Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11477 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Ramy Elgammal Dynamic-Fusion: Ramy Elgammal Benchmark: Arm Jenkins --- docs/user_guide/release_version_and_change_log.dox | 3 + src/core/CL/cl_kernels/common/scatter.cl | 147 ++++++++++++++------- src/gpu/cl/ClKernelLibrary.cpp | 1 + src/gpu/cl/kernels/ClScatterKernel.cpp | 114 ++++++++++++---- src/gpu/cl/kernels/ClScatterKernel.h | 1 + src/gpu/cl/operators/ClScatter.cpp | 2 - tests/datasets/ScatterDataset.h | 2 +- tests/validation/CL/ScatterLayer.cpp | 131 ++++++++++++------ tests/validation/fixtures/ScatterLayerFixture.h | 82 +++++++++--- 9 files changed, 350 insertions(+), 133 deletions(-) diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox index b29b81580d..952753effb 100644 --- a/docs/user_guide/release_version_and_change_log.dox +++ b/docs/user_guide/release_version_and_change_log.dox @@ -41,6 +41,9 @@ If there is more than one release in a month then an extra sequential number is @section S2_2_changelog Changelog +v24.05 Public major release + - Add @ref CLScatter operator for FP32 data type + v24.04 Public major release - Add Bfloat16 data type support for @ref NEMatMul. - Add support for SoftMax in SME2 for FP32 and FP16. diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl index 73b714e042..ac9f828df2 100644 --- a/src/core/CL/cl_kernels/common/scatter.cl +++ b/src/core/CL/cl_kernels/common/scatter.cl @@ -22,8 +22,7 @@ * SOFTWARE. */ #include "helpers.h" - -#if defined(INDICES_SHAPE_Y) && defined(DATA_TYPE) && defined(OUT_SHAPE_X) && defined(SCATTER_FUNCTION) +#include "tile_helpers.h" // The below defines the various reduce operations for our purposes. // Where a corresponds to the existing value, and b the new value. @@ -33,64 +32,114 @@ #define MIN_OP(a, b) fmin(a, b) #define UPDATE_OP(a, b) (b) -/** Performs the ScatterND operation - * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short - * @note the size of the dst tensor in the "x" dimension should be passed using -DOUT_SHAPE_X at compile time. - * @note the number of values in the indices tensor in the y-dim should be passed with -DINDICES_SHAPE_Y at compile time. - * @note Negative indices are treated as out of bounds. +#ifdef SCATTER_MP1D_2D_MPND + +/** This kernel performs scatter operation + * + * @note Datatype should be given as a compile-time argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short + * @note Number of indices should be given as a compile-time argument using -DNUM_INDICES, e.g. -DNUM_INDICES=3 + * @note Index length should be given as a compile-time argument using -DINDEX_LENGTH, e.g. -DINDEX_LENGTH=2 + * @note Outermost output shapes should be given as a compile-time argument using -DOUT_SHAPE_N_MINUS_X, where + * X must be 1,2,3,4,5, e.g. -DOUT_SHAPE_N_MINUS_1=3, ... + * @note Number of elements to copy in a row should be given as a compile-time argument using -DN0, e.g. -DN0=4 + * @note Number of partial elements at the edge to copy in a row should be given as a compile-time argument using + * -DPARTIAL_N0, e.g. -DPARTIAL_N0=2 + * @note Scatter function should be given as a compile-time argument using -DSCATTER_FUNCTION, e.g. -DSCATTER_FUNCTION=ADD + * @note If the kernel should skip reading the output tensor, -DSKIP_OUTPUT_READ option should be provided. + * @note Kernel name in uppercase letters should be provided as a compile-time argument, e.g. -DSCATTER_MP1D_2D_MPND * - * @param[in] updates_ptr Pointer to the source tensor. Supported data types: All - * @param[in] updates_stride_x Stride of the source tensor in X dimension (in bytes) - * @param[in] updates_step_x updates_stride_x * number of elements along X processed per work item (in bytes) - * @param[in] updates_stride_y Stride of the source tensor in Y dimension (in bytes) - * @param[in] updates_step_y updates_stride_y * number of elements along Y processed per work item (in bytes) - * @param[in] updates_stride_z Stride of the source tensor in Y dimension (in bytes) - * @param[in] updates_step_z updates_stride_z * number of elements along Z processed per work item (in bytes) - * @param[in] updates_stride_w Stride of the source tensor in Z dimension (in bytes) - * @param[in] updates_step_w updates_stride_w * number of elements along W processed per work item (in bytes) - * @param[in] updates_offset_first_element_in_bytes Offset of the first element in the source tensor - * @param[in] indices_ptr Pointer to the indices vector. Supported data types: S32. - * @param[in] indices_stride_x Stride of the indices vector in X dimension (in bytes) - * @param[in] indices_step_x updates_stride_x * number of elements along X processed per work item (in bytes) - * @param[in] indices_offset_first_element_in_bytes Offset of the first element in the indices vector - * @param[out] output_ptr Pointer to the destination tensor. Supported data types: same as @p updates_ptr + * @param[in] updates_ptr Pointer to the updates tensor. Data Types: F32 + * @param[in] updates_stride_x Stride of the updates tensor in X dimension (in bytes) + * @param[in] updates_step_x updates_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] updates_stride_y Stride of the updates tensor in Y dimension (in bytes) + * @param[in] updates_step_y updates_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] updates_offset_first_element_in_bytes The offset of the first element in the updates tensor + * @param[in] indices_ptr Pointer to the indices tensor. Data Types: S32 + * @param[in] indices_stride_x Stride of the indices tensor in X dimension (in bytes) + * @param[in] indices_step_x indices_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] indices_stride_y Stride of the indices tensor in Y dimension (in bytes) + * @param[in] indices_step_y indices_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] indices_offset_first_element_in_bytes The offset of the first element in the indices tensor + * @param[out] output_ptr Pointer to the destination tensor. Same as @p upt_ptr * @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes) - * @param[in] output_step_x output_stride_x * number of elements along X processed per work item (in bytes) + * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes) - * @param[in] output_step_y output_stride_y * number of elements along Y processed per work item (in bytes) - * @param[in] output_stride_z Stride of the destination tensor in Z dimension (in bytes) - * @param[in] output_step_z output_stride_z * number of elements along Z processed per work item (in bytes) - * @param[in] output_stride_w Stride of the destination tensor in W dimension (in bytes) - * @param[in] output_step_w output_stride_w * number of elements along W processed per work item (in bytes) - * @param[in] output_offset_first_element_in_bytes Offset of the first element in the destination tensor + * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] output_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] upt_block_stride Update tensor data block stride in bytes + * @param[in] out_block_stride Output tensor data block stride in bytes */ -// The below kernel code is expected to be excecuted sequentially with a single thread to ensure a deterministic outcome. -__kernel void scatter1D( - TENSOR4D_DECLARATION(updates), - TENSOR4D_DECLARATION(indices), - TENSOR4D_DECLARATION(output)) +__kernel void scatter_mp1d_2d_mpnd( + IMAGE_DECLARATION(updates), + IMAGE_DECLARATION(indices), + IMAGE_DECLARATION(output), + int upt_block_stride, + int out_block_stride + ) { - // Currently 1D - only iterate through y dimension of indices. - unsigned int* indices_start_offset = (unsigned int*)(indices_ptr + indices_offset_first_element_in_bytes); - DATA_TYPE* updates_start_offset = (DATA_TYPE*)(updates_ptr + updates_offset_first_element_in_bytes); - DATA_TYPE* out_start_offset = (DATA_TYPE*)(output_ptr + output_offset_first_element_in_bytes); - for (int px = 0; px < INDICES_SHAPE_Y; px++) + const int out_shape[5] = {OUT_SHAPE_N_MINUS_1, OUT_SHAPE_N_MINUS_2, OUT_SHAPE_N_MINUS_3, + OUT_SHAPE_N_MINUS_4, OUT_SHAPE_N_MINUS_5}; + + const int x = GET_SPATIAL_IDX(0, N0, PARTIAL_N0); // x-coordinate in the tensor + const int y = get_global_id(1); // collapsed y-coordinate (ignoring the outermost dimensions) + + const bool x_cond = (PARTIAL_N0 != 0 && get_global_id(0) == 0); + + uchar *ind_ptr_raw = indices_ptr + indices_offset_first_element_in_bytes; + const uchar *out_ptr_raw = output_ptr + output_offset_first_element_in_bytes + + x * sizeof(DATA_TYPE) + y * output_stride_y; + + const uchar *upt_ptr_raw = updates_ptr + updates_offset_first_element_in_bytes + + x * sizeof(DATA_TYPE) + y * updates_stride_y; + + for(int index_element = 0; index_element < NUM_INDICES; ++index_element) { - const int index_value = *(indices_start_offset); - DATA_TYPE* out_addr = out_start_offset + index_value; - if((index_value < OUT_SHAPE_X) && (index_value >= 0)) + const int *ind_ptr = (const int *) (ind_ptr_raw); + + // Out of bounds check + bool out_of_bounds = false; + LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH, + { + if(ind_ptr[i] >= out_shape[i] || ind_ptr[i] < 0) + { + out_of_bounds = true; + } + }); + + ind_ptr_raw += indices_stride_y; + + if(out_of_bounds) { - *(__global DATA_TYPE *)(out_addr) = SCATTER_FUNCTION(*(out_addr), *updates_start_offset); + continue; } - // Increment pointers. - indices_start_offset++; - updates_start_offset++; + + // Index calculation + int index = 0; + LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH, + { + index = index * out_shape[i] + ind_ptr[i]; + }); + + DATA_TYPE *out_ptr = (DATA_TYPE *) (out_ptr_raw + index * out_block_stride); + + const DATA_TYPE *upt_ptr = (const DATA_TYPE *) (upt_ptr_raw + index_element * upt_block_stride); + + VEC_DATA_TYPE(DATA_TYPE, N0) data_in0 = VLOAD(N0)(0, (__global DATA_TYPE *) upt_ptr); + +#ifdef SKIP_OUTPUT_READ + STORE_VECTOR_SELECT(data_in, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond); +#else // ifdef SKIP_OUTPUT_READ + VEC_DATA_TYPE(DATA_TYPE, N0) data_out0 = VLOAD(N0)(0, (__global DATA_TYPE *) out_ptr); + data_out0 = SCATTER_FUNCTION(data_out0, data_in0); + + STORE_VECTOR_SELECT(data_out, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond); +#endif // ifdef SKIP_OUTPUT_READ } } -#endif //defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && defined(INDICES_SHAPE_Y) +#endif // SCATTER_MP1D_2D_MPND -#if defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && !defined(INDICES_SHAPE_Y) +#ifdef SCATTER1D_PARALLEL // NOTE : This code is non-deterministic and can only be excecuted with the "update" ScatterFunction // This code is currently unusued as it requires changes to the existing test suite. @@ -114,4 +163,4 @@ __kernel void scatter1D_parallel( } } -#endif //defined(DATA_TYPE) && defined(SCATTER_FUNCTION) && defined(OUT_SHAPE_X) && !defined(INDICES_SHAPE_Y) +#endif // SCATTER1D_PARALLEL diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp index 3e32a27d03..c4117b8a1a 100644 --- a/src/gpu/cl/ClKernelLibrary.cpp +++ b/src/gpu/cl/ClKernelLibrary.cpp @@ -441,6 +441,7 @@ const std::map ClKernelLibrary::_kernel_program_map = {"reorg_layer_nhwc", "nhwc/reorg_layer.cl"}, {"scale_nearest_neighbour_nhwc", "nhwc/scale.cl"}, {"scale_bilinear_nhwc", "nhwc/scale.cl"}, + {"scatter_mp1d_2d_mpnd", "common/scatter.cl"}, {"scatter1D", "common/scatter.cl"}, {"space_to_batch_nhwc", "nhwc/space_to_batch.cl"}, {"space_to_batch_static_nhwc", "nhwc/space_to_batch.cl"}, diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp index c95e156679..9c25b63c72 100644 --- a/src/gpu/cl/kernels/ClScatterKernel.cpp +++ b/src/gpu/cl/kernels/ClScatterKernel.cpp @@ -27,17 +27,26 @@ #include "arm_compute/core/ITensorPack.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/helpers/AdjustVecSize.h" #include "src/common/utils/Log.h" #include "src/core/helpers/WindowHelpers.h" #include "support/Cast.h" +#include + namespace arm_compute { namespace opencl { namespace kernels { + +namespace +{ +constexpr int max_index_length = 5; +} // namespace + ClScatterKernel::ClScatterKernel() { } @@ -47,21 +56,33 @@ Status ClScatterKernel::validate(const ITensorInfo *updates, const ITensorInfo *dst, const ScatterInfo &info) { - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->num_dimensions() > 1, "Only 1D output tensors are currently supported."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(indices->num_dimensions() > 2, "Only 2D indices tensors are currently supported."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(updates->num_dimensions() > 1, "Only 1D update tensors are currently supported."); + ARM_COMPUTE_UNUSED(info); + + const TensorShape &ind_shape = indices->tensor_shape(); + const TensorShape &upt_shape = updates->tensor_shape(); + const TensorShape &dst_shape = dst->tensor_shape(); + + const int32_t upt_dims = upt_shape.num_dimensions(); + const int32_t dst_dims = dst_shape.num_dimensions(); + const int32_t ind_dims = ind_shape.num_dimensions(); + + const int32_t index_len = ind_shape[0]; + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_dims > 2, "Only 2D indices tensors are currently supported."); ARM_COMPUTE_RETURN_ERROR_ON_MSG( - indices->tensor_shape().y() != updates->tensor_shape()[updates->num_dimensions() - 1], + ind_shape[1] != upt_shape[upt_dims - 1], "Height of indices tensor should match size of highest dimension in updates tensor."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(updates->num_dimensions() > dst->num_dimensions(), - "Update tensor cannot have more dims than output tensor."); - ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_dims > dst_dims, "Update tensor cannot have more dims than output tensor."); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!"); + ARM_COMPUTE_RETURN_ERROR_ON(index_len != dst_dims - upt_dims + 1); return Status{}; } + void ClScatterKernel::configure(const ClCompileContext &compile_context, const ITensorInfo *updates, const ITensorInfo *indices, @@ -71,22 +92,51 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context, ARM_COMPUTE_ERROR_ON_NULLPTR(updates, dst, indices); ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info); - // Configure kernel window - const auto indices_shape = indices->tensor_shape(); - Window win = calculate_max_window( - *indices, Steps(indices_shape.x(), indices_shape.y())); // Ensures single thread for deterministic output. + const TensorShape &dst_shape = dst->tensor_shape(); + + const bool is_scalar_block = updates->num_dimensions() == 1; + const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0)); + + const int partial_n0 = updates->dimension(0) % n0; + + // The GWS will be 2D [x, y] + // x-dimension refers to the x coordinate of the dst tensor + // y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor + Window win = calculate_max_window(dst_shape, Steps(n0)); + const int index_len = indices->dimension(0); + + // Collapse the dimensions corresponding to indices in the execution window + for (int i = 0; i < index_len; ++i) + { + win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1)); + } + + win = win.collapse(win, 1); // Set build options CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type())); - build_opts.add_option("-DINDICES_DIMS=" + support::cpp11::to_string(indices->num_dimensions())); - build_opts.add_option("-DINDICES_SHAPE_Y=" + support::cpp11::to_string(indices_shape.y())); - build_opts.add_option("-DOUT_SHAPE_X=" + support::cpp11::to_string(dst->tensor_shape().x())); + + const int num_dims = dst->num_dimensions(); + + build_opts.add_option("-DNUM_INDICES=" + support::cpp11::to_string(indices->dimension(1))); + build_opts.add_option("-DINDEX_LENGTH=" + support::cpp11::to_string(index_len)); + + // We provide 5 variables to use in a constant array + for (int i = 1; i <= max_index_length; i++) + { + build_opts.add_option("-DOUT_SHAPE_N_MINUS_" + support::cpp11::to_string(i) + "=" + + support::cpp11::to_string(dst_shape[std::max(num_dims - i, 0)])); + } + + build_opts.add_option("-DN0=" + support::cpp11::to_string(n0)); + build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_n0)); switch (info.func) { case ScatterFunction::Update: build_opts.add_option("-DSCATTER_FUNCTION=UPDATE_OP"); + build_opts.add_option("-DSKIP_OUTPUT_READ"); break; case ScatterFunction::Add: build_opts.add_option("-DSCATTER_FUNCTION=ADD_OP"); @@ -105,9 +155,12 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context, } // Create kernel - std::string kernel_name("scatter1D"); + std::string kernel_name = "scatter_mp1d_2d_mpnd"; + build_opts.add_option("-D" + upper_string(kernel_name)); + ICLKernel::configure_internal(win); _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); + // Set config_id for enabling LWS tuning _config_id = kernel_name; _config_id += "_"; @@ -123,18 +176,29 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context, void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) { - unsigned int idx = 0; - - Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); - const auto updates = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); const auto indices = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); auto dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); - add_4D_tensor_argument(idx, updates, window_collapsed); - add_4D_tensor_argument(idx, indices, window_collapsed); - add_4D_tensor_argument(idx, dst, window_collapsed); + + const ITensorInfo *dst_info = dst->info(); + const int num_dims = dst_info->num_dimensions(); + + const int index_len = indices->info()->dimension(0); + + // calculate m-dimensional data block strides in updates and destination tensors + const int upt_block_stride = updates->info()->strides_in_bytes()[updates->info()->num_dimensions() - 1]; + const int out_block_stride = dst_info->strides_in_bytes()[num_dims - index_len]; + + unsigned int idx = 0; + + add_2D_tensor_argument(idx, updates, window); + add_2D_tensor_argument(idx, indices, window); + add_2D_tensor_argument(idx, dst, window); + + _kernel.setArg(idx++, upt_block_stride); + _kernel.setArg(idx++, out_block_stride); enqueue(queue, *this, window, lws_hint()); } diff --git a/src/gpu/cl/kernels/ClScatterKernel.h b/src/gpu/cl/kernels/ClScatterKernel.h index d2a41adde9..e1b469c88e 100644 --- a/src/gpu/cl/kernels/ClScatterKernel.h +++ b/src/gpu/cl/kernels/ClScatterKernel.h @@ -37,6 +37,7 @@ namespace opencl { namespace kernels { + class ClScatterKernel : public IClKernel { public: diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp index 62711ddfe8..a11ecd7e6a 100644 --- a/src/gpu/cl/operators/ClScatter.cpp +++ b/src/gpu/cl/operators/ClScatter.cpp @@ -48,8 +48,6 @@ Status ClScatter::validate(const ITensorInfo *src, const ScatterInfo &info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::S32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32); // Currently, other datatypes are not suppported. if (src != nullptr) { // Check dst/src are same shape and datatype. diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h index 8b0972f99a..c0858941db 100644 --- a/tests/datasets/ScatterDataset.h +++ b/tests/datasets/ScatterDataset.h @@ -137,7 +137,7 @@ public: // - src/updates/dst should all have same number of dims. Indices should be 2D. add_config(TensorShape(6U, 5U), TensorShape(6U, 2U), TensorShape(1U, 2U), TensorShape(6U, 5U)); add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U)); - add_config(TensorShape(3U, 2U, 4U, 2U), TensorShape(3U, 2U, 4U, 2U), TensorShape(1U, 2U), TensorShape(3U, 2U, 4U, 2U)); + add_config(TensorShape(17U, 3U, 2U, 4U), TensorShape(17U, 3U, 2U, 7U), TensorShape(1U, 7U), TensorShape(17U, 3U, 2U, 4U)); } }; diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp index 5b1d5afe92..4a2462c7d2 100644 --- a/tests/validation/CL/ScatterLayer.cpp +++ b/tests/validation/CL/ScatterLayer.cpp @@ -52,62 +52,117 @@ TEST_SUITE(CL) TEST_SUITE(Scatter) DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types - TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid - TensorInfo(TensorShape(8U), 1, DataType::F32), - TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims. - TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims. - TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype. - }), - make("UpdatesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::F16), - TensorInfo(TensorShape(15U), 1, DataType::F32), - TensorInfo(TensorShape(2U), 1, DataType::F32), - TensorInfo(TensorShape(217U), 1, DataType::F32), - TensorInfo(TensorShape(217U, 3U), 1, DataType::F32), - TensorInfo(TensorShape(2U), 1, DataType::F32), - }), - make("IndicesInfo",{ TensorInfo(TensorShape(1U, 3U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 15U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 2U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32) - }), - make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16), - TensorInfo(TensorShape(15U), 1, DataType::F32), - TensorInfo(TensorShape(8U), 1, DataType::F32), - TensorInfo(TensorShape(271U, 3U), 1, DataType::F32), - TensorInfo(TensorShape(271U), 1, DataType::F32), - TensorInfo(TensorShape(12U), 1, DataType::F32) - }), + TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid + TensorInfo(TensorShape(8U), 1, DataType::F32), + TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims. + TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims. + TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype. + TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32), // Number of updates != number of indices + TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32), // index_len != (dst_dims - upt_dims + 1) + TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), // index_len > 5 + }), + make("UpdatesInfo",{TensorInfo(TensorShape(3U), 1, DataType::F16), + TensorInfo(TensorShape(15U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(217U), 1, DataType::F32), + TensorInfo(TensorShape(217U, 3U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(9U, 3U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(17U, 3U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(1U), 1, DataType::F32), + }), + make("IndicesInfo",{TensorInfo(TensorShape(1U, 3U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 15U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 2U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32), + TensorInfo(TensorShape(1U, 4U), 1, DataType::S32), + TensorInfo(TensorShape(3U, 2U), 1, DataType::S32), + TensorInfo(TensorShape(6U, 2U), 1, DataType::S32), + }), + make("OutputInfo",{TensorInfo(TensorShape(9U), 1, DataType::F16), + TensorInfo(TensorShape(15U), 1, DataType::F32), + TensorInfo(TensorShape(8U), 1, DataType::F32), + TensorInfo(TensorShape(271U, 3U), 1, DataType::F32), + TensorInfo(TensorShape(271U), 1, DataType::F32), + TensorInfo(TensorShape(12U), 1, DataType::F32), + TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32), + TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), + }), make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false), ScatterInfo(ScatterFunction::Max, false), ScatterInfo(ScatterFunction::Min, false), ScatterInfo(ScatterFunction::Add, false), ScatterInfo(ScatterFunction::Update, false), ScatterInfo(ScatterFunction::Sub, false), - }), - make("Expected", { false, true, true, false, false, false })), + ScatterInfo(ScatterFunction::Sub, false), + ScatterInfo(ScatterFunction::Update, false), + ScatterInfo(ScatterFunction::Update, false), + }), + make("Expected", { false, true, true, false, false, false, false, false, false })), input_info, updates_info, indices_info, output_info, scatter_info, expected) { - const Status status = CLScatter::validate(&input_info.clone()->set_is_resizable(true), &updates_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), scatter_info); + const Status status = CLScatter::validate(&input_info, &updates_info, &indices_info, &output_info, scatter_info); ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS); } +const auto allScatterFunctions = make("ScatterFunction", + {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }); + TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(), - make("DataType", {DataType::F32}), - make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }), - make("ZeroInit", {false}))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::Small1DScatterDataset(), + make("DataType", {DataType::F32}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } // With this test, src should be passed as nullptr. -FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(), - make("DataType", {DataType::F32}), - make("ScatterFunction", {ScatterFunction::Add}), - make("ZeroInit", {true}))) +FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::Small1DScatterDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Add}), + make("ZeroInit", {true}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f32); +} + +// Updates/src/dst have same no. dims. +FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMultiDimDataset(), + make("DataType", {DataType::F32}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f32); +} + +// m+1-D to m+n-D cases +FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMultiIndicesDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }), + make("ZeroInit", {false}), + make("Inplace", {false, true}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f32); +} + +// m+k, k-1-D m+n-D case +FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture, framework::DatasetMode::DISABLED, + combine(datasets::SmallScatterBatchedDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }), + make("ZeroInit", {false}), + make("Inplace", {false}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 91e28b58f7..4fb2d7f127 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -29,6 +29,7 @@ #include "tests/Globals.h" #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" +#include "tests/validation/Helpers.h" #include "tests/validation/Validation.h" #include "tests/validation/reference/ScatterLayer.h" #include "tests/SimpleTensor.h" @@ -46,9 +47,17 @@ template fill_tensor_uniform(tensor, i, static_cast(-2), static_cast(max)); } - TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) + TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, + const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, + QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // 1. Create relevant tensors using ScatterInfo data structure. // ---------------------------------------------------- @@ -94,14 +105,22 @@ protected: FunctionType scatter; // Configure operator - // When scatter_info.zero_initialization is true, pass nullptr to scatter function. + // When scatter_info.zero_initialization is true, pass nullptr for src + // because dst does not need to be initialized with src values. if(info.zero_initialization) { scatter.configure(nullptr, &updates, &indices, &dst, info); } else { - scatter.configure(&src, &updates, &indices, &dst, info); + if(inplace) + { + scatter.configure(&src, &updates, &indices, &src, info); + } + else + { + scatter.configure(&src, &updates, &indices, &dst, info); + } } // Assertions @@ -110,28 +129,51 @@ protected: ARM_COMPUTE_ASSERT(indices.info()->is_resizable()); ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + add_padding_x({ &src, &updates, &indices}); + + if(!inplace) + { + add_padding_x({ &dst }); + } + // Allocate tensors src.allocator()->allocate(); updates.allocator()->allocate(); indices.allocator()->allocate(); - dst.allocator()->allocate(); + + if(!inplace) + { + dst.allocator()->allocate(); + } ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); ARM_COMPUTE_ASSERT(!updates.info()->is_resizable()); ARM_COMPUTE_ASSERT(!indices.info()->is_resizable()); - ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + if(!inplace) + { + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + } // Fill update (a) and indices (b) tensors. - fill(AccessorType(src), 0); - fill(AccessorType(updates), 1); - fill_indices(AccessorType(indices), 2, out_shape); + fill(AccessorType(src), 0 + _hash); + fill(AccessorType(updates), 1+ _hash); + fill_indices(AccessorType(indices), 2 + _hash, out_shape); scatter.run(); - return dst; + + if(inplace) + { + return src; + } + else + { + return dst; + } } - SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type, - ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) + SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, + const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // Output Quantization not currently in use - fixture should be extended to support this. ARM_COMPUTE_UNUSED(o_qinfo); @@ -158,9 +200,9 @@ protected: SimpleTensor indices{ c_shape, DataType::S32, 1, QuantizationInfo() }; // Fill reference - fill(src, 0); - fill(updates, 1); - fill_indices(indices, 2, out_shape); + fill(src, 0 + _hash); + fill(updates, 1 + _hash); + fill_indices(indices, 2 + _hash, out_shape); // Calculate individual reference. return reference::scatter_layer(src, updates, indices, out_shape, info); @@ -168,6 +210,7 @@ protected: TensorType _target{}; SimpleTensor _reference{}; + int32_t _hash{}; }; // This fixture will use the same shape for updates as indices. @@ -175,9 +218,12 @@ template { public: - void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init) + void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, + TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace) { - ScatterGenericValidationFixture::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo()); + ScatterGenericValidationFixture::setup(src_shape, update_shape, + indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, + QuantizationInfo(), QuantizationInfo()); } }; -- cgit v1.2.1