From 05269f013cf2b7c4a53f5950cdd6bfea26367769 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Thu, 9 May 2024 13:24:15 +0100 Subject: ScatterND fix for scalar cases - Padding with batched scalar cases is unsupported, adds checks. - Adds tests for scalar cases, without padding. Resolves: [COMPMID-7015] Change-Id: Ib9cf5db990420ff4b442d003ef9424e365bee86d Signed-off-by: Mohammed Suhail Munshi Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11536 Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- src/gpu/cl/kernels/ClScatterKernel.cpp | 47 ++++++++++++++-------- tests/datasets/ScatterDataset.h | 14 ++++++- tests/validation/CL/ScatterLayer.cpp | 52 +++++++++++++++++++------ tests/validation/fixtures/ScatterLayerFixture.h | 23 ++++++----- 4 files changed, 97 insertions(+), 39 deletions(-) diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp index f76a674b27..19adc1ef34 100644 --- a/src/gpu/cl/kernels/ClScatterKernel.cpp +++ b/src/gpu/cl/kernels/ClScatterKernel.cpp @@ -69,7 +69,10 @@ Status ClScatterKernel::validate(const ITensorInfo *updates, const int32_t data_dim = upt_dims - (ind_dims - 1); // Number of batch dims is the number of indices dims - 1 const int32_t index_len = ind_shape[0]; + bool unsupported_padding_config = + (dst_dims == index_len) && index_len > 1 && (dst->has_padding() || updates->has_padding()); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(unsupported_padding_config, "Padding is not supported with these shapes."); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16, @@ -99,9 +102,8 @@ Status ClScatterKernel::validate(const ITensorInfo *updates, ARM_COMPUTE_RETURN_ERROR_ON_MSG((ind_dims < 2), "Shape of Indices tensor must be at least 2D"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - index_len >= dst_dims && dst_dims != 1, - "Index length should be smaller than number of output dims (or equal to with 1D output)"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > dst_dims && dst_dims != 1, + "Index length should be smaller than or equal to number of output dims"); return Status{}; } @@ -116,25 +118,31 @@ void ClScatterKernel::configure(const ClCompileContext &compile_context, ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info); const TensorShape &dst_shape = dst->tensor_shape(); + const int index_len = indices->dimension(0); - const bool is_scalar_block = updates->num_dimensions() == 1; // Checks for replacing only a single element. - const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0)); + // Check for single element data block + const bool is_scalar_block = (dst->num_dimensions() == static_cast(index_len)); + const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0)); const int partial_n0 = updates->dimension(0) % n0; // The GWS will be 2D [x, y] // x-dimension refers to the x coordinate of the dst tensor // y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor - Window win = calculate_max_window(dst_shape, Steps(n0)); - const int index_len = indices->dimension(0); + Window win; - // Collapse the dimensions corresponding to indices in the execution window - for (int i = 0; i < index_len; ++i) + if (!is_scalar_block) { - win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1)); - } + win = calculate_max_window(dst_shape, Steps(n0)); + + // Collapse the dimensions corresponding to indices in the execution window + for (int i = 0; i < index_len; ++i) + { + win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1)); + } - win = win.collapse(win, 1); + win = win.collapse(win, 1); + } // Set build options CLBuildOptions build_opts; @@ -206,11 +214,18 @@ void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::Com utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); auto dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); - const ITensorInfo *dst_info = dst->info(); - const int num_dims = dst_info->num_dimensions(); - const int ind_dims = indices->info()->num_dimensions(); + const ITensorInfo *dst_info = dst->info(); + const ITensorInfo *upd_info = updates->info(); + const int num_dims = dst_info->num_dimensions(); + const int ind_dims = indices->info()->num_dimensions(); + const int index_len = indices->info()->dimension(0); - const int index_len = indices->info()->dimension(0); + bool unsupported_padding_config = + num_dims == index_len && index_len > 1 && (dst_info->has_padding() || upd_info->has_padding()); + if (unsupported_padding_config) + { + ARM_COMPUTE_ERROR("Unsupported Configuration! Padding not supported with these shapes."); + } // calculate m-dimensional data block strides in updates and destination tensors const int upt_block_stride = diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h index 4ad269ec85..8fd4448d2d 100644 --- a/tests/datasets/ScatterDataset.h +++ b/tests/datasets/ScatterDataset.h @@ -180,7 +180,6 @@ public: // NOTE: Updates/Indices tensors are now batched. // NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1 // k is the number of batch dimensions - // k = 2 add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U)); add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 5U, 6U, 2U), TensorShape(3U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U)); @@ -197,6 +196,18 @@ public: } }; +class SmallScatterScalarDataset final : public ScatterDataset +{ +public: + // batched scalar case + SmallScatterScalarDataset() + { + add_config(TensorShape(6U, 5U), TensorShape(6U), TensorShape(2U, 6U), TensorShape(6U, 5U)); + add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U)); + add_config(TensorShape(3U, 3U, 6U, 5U), TensorShape(6U, 6U), TensorShape(4U, 6U, 6U), TensorShape(3U, 3U, 6U, 5U)); + } +}; + // This dataset is for data types that does not require full testing. It contains selected tests from the above. class SmallScatterMixedDataset final : public ScatterDataset { @@ -205,6 +216,7 @@ public: { add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U)); add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U)); + add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U)); add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U)); add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U)); add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U)); diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp index e327ff9522..b1531eb64a 100644 --- a/tests/validation/CL/ScatterLayer.cpp +++ b/tests/validation/CL/ScatterLayer.cpp @@ -125,7 +125,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture, framework::Datase make("DataType", {DataType::F32}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {true}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -136,7 +137,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture, framework make("DataType", {DataType::F32}), make("ScatterFunction", {ScatterFunction::Add}), make("ZeroInit", {true}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {true}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -147,7 +149,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture, framework make("DataType", {DataType::F32}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {true}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -158,7 +161,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture, frame make("DataType", {DataType::F32}), make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }), make("ZeroInit", {false}), - make("Inplace", {false, true}))) + make("Inplace", {false, true}), + make("Padding", {true}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -169,20 +173,38 @@ FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture make("DataType", {DataType::F32}), make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}), make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {true}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f32); +} + +// m+k, k-1-D m+n-D case +FIXTURE_DATA_TEST_CASE(RunSmallScatterScalar, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterScalarDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}), + make("ZeroInit", {false}), + make("Inplace", {false}), + make("Padding", {false}))) // NOTE: Padding not supported in this datset { validate(CLAccessor(_target), _reference, tolerance_f32); } TEST_SUITE_END() // FP32 + +// NOTE: Padding is disabled for the SmallScatterMixedDataset due certain shapes not supporting padding. +// Padding is well tested in F32 Datatype test cases. + TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallScatterMixedDataset(), make("DataType", {DataType::F16}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_f16); } @@ -196,7 +218,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework: make("DataType", {DataType::S32}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -208,7 +231,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework: make("DataType", {DataType::S16}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -220,7 +244,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework:: make("DataType", {DataType::S8}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -232,7 +257,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework make("DataType", {DataType::U32}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -244,7 +270,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework make("DataType", {DataType::U16}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -256,7 +283,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework: make("DataType", {DataType::U8}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 5cd9b8115c..af161ef98b 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -48,7 +48,7 @@ class ScatterGenericValidationFixture : public framework::Fixture { public: void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, - TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace, + TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace, bool padding, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo()) { // this is for improving randomness across tests @@ -57,7 +57,7 @@ public: + updates_shape[4] + updates_shape[5] + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3]; - _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo); + _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, padding, src_qinfo, o_qinfo); _reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo); } @@ -104,11 +104,11 @@ protected: { // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension) const int32_t max = std::min({shape[0] , shape[1], shape[2]}) + 1; - library->fill_tensor_uniform(tensor, i, static_cast(-2), static_cast(max)); + library->fill_tensor_uniform(tensor, i, static_cast(0), static_cast(max)); } TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, - const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, + const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, bool padding, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // 1. Create relevant tensors using ScatterInfo data structure. @@ -146,11 +146,14 @@ protected: ARM_COMPUTE_ASSERT(indices.info()->is_resizable()); ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); - add_padding_x({ &src, &updates, &indices}); - - if(!inplace) + if(padding) { - add_padding_x({ &dst }); + add_padding_x({ &src, &updates, &indices}); + + if(!inplace) + { + add_padding_x({ &dst }); + } } // Allocate tensors @@ -237,10 +240,10 @@ class ScatterValidationFixture : public ScatterGenericValidationFixture::setup(src_shape, update_shape, - indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, + indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, padding, QuantizationInfo(), QuantizationInfo()); } }; -- cgit v1.2.1