From ada3200f5cec0b6a37f898d5d6f8e69395d7bcb1 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Wed, 24 Apr 2024 10:27:13 +0100 Subject: Add update/index/output (m+1)/2d/(m+n) support for CLScatter Resolves: COMPMID-6894, COMPMID-6896 Change-Id: I9d29fd3701a7e0f28d83f81a6c42a7234c2587c3 Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11477 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Ramy Elgammal Dynamic-Fusion: Ramy Elgammal Benchmark: Arm Jenkins --- tests/validation/CL/ScatterLayer.cpp | 131 +++++++++++++++++------- tests/validation/fixtures/ScatterLayerFixture.h | 82 +++++++++++---- 2 files changed, 157 insertions(+), 56 deletions(-) (limited to 'tests/validation') diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp index 5b1d5afe92..4a2462c7d2 100644 --- a/tests/validation/CL/ScatterLayer.cpp +++ b/tests/validation/CL/ScatterLayer.cpp @@ -52,62 +52,117 @@ TEST_SUITE(CL) TEST_SUITE(Scatter) DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types - TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid - TensorInfo(TensorShape(8U), 1, DataType::F32), - TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims. - TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims. - TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype. - }), - make("UpdatesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::F16), - TensorInfo(TensorShape(15U), 1, DataType::F32), - TensorInfo(TensorShape(2U), 1, DataType::F32), - TensorInfo(TensorShape(217U), 1, DataType::F32), - TensorInfo(TensorShape(217U, 3U), 1, DataType::F32), - TensorInfo(TensorShape(2U), 1, DataType::F32), - }), - make("IndicesInfo",{ TensorInfo(TensorShape(1U, 3U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 15U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 2U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), - TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32) - }), - make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16), - TensorInfo(TensorShape(15U), 1, DataType::F32), - TensorInfo(TensorShape(8U), 1, DataType::F32), - TensorInfo(TensorShape(271U, 3U), 1, DataType::F32), - TensorInfo(TensorShape(271U), 1, DataType::F32), - TensorInfo(TensorShape(12U), 1, DataType::F32) - }), + TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid + TensorInfo(TensorShape(8U), 1, DataType::F32), + TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims. + TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims. + TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype. + TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32), // Number of updates != number of indices + TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32), // index_len != (dst_dims - upt_dims + 1) + TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), // index_len > 5 + }), + make("UpdatesInfo",{TensorInfo(TensorShape(3U), 1, DataType::F16), + TensorInfo(TensorShape(15U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(217U), 1, DataType::F32), + TensorInfo(TensorShape(217U, 3U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(9U, 3U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(17U, 3U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(1U), 1, DataType::F32), + }), + make("IndicesInfo",{TensorInfo(TensorShape(1U, 3U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 15U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 2U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32), + TensorInfo(TensorShape(1U, 4U), 1, DataType::S32), + TensorInfo(TensorShape(3U, 2U), 1, DataType::S32), + TensorInfo(TensorShape(6U, 2U), 1, DataType::S32), + }), + make("OutputInfo",{TensorInfo(TensorShape(9U), 1, DataType::F16), + TensorInfo(TensorShape(15U), 1, DataType::F32), + TensorInfo(TensorShape(8U), 1, DataType::F32), + TensorInfo(TensorShape(271U, 3U), 1, DataType::F32), + TensorInfo(TensorShape(271U), 1, DataType::F32), + TensorInfo(TensorShape(12U), 1, DataType::F32), + TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32), + TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), + }), make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false), ScatterInfo(ScatterFunction::Max, false), ScatterInfo(ScatterFunction::Min, false), ScatterInfo(ScatterFunction::Add, false), ScatterInfo(ScatterFunction::Update, false), ScatterInfo(ScatterFunction::Sub, false), - }), - make("Expected", { false, true, true, false, false, false })), + ScatterInfo(ScatterFunction::Sub, false), + ScatterInfo(ScatterFunction::Update, false), + ScatterInfo(ScatterFunction::Update, false), + }), + make("Expected", { false, true, true, false, false, false, false, false, false })), input_info, updates_info, indices_info, output_info, scatter_info, expected) { - const Status status = CLScatter::validate(&input_info.clone()->set_is_resizable(true), &updates_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), scatter_info); + const Status status = CLScatter::validate(&input_info, &updates_info, &indices_info, &output_info, scatter_info); ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS); } +const auto allScatterFunctions = make("ScatterFunction", + {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }); + TEST_SUITE(Float) TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(), - make("DataType", {DataType::F32}), - make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }), - make("ZeroInit", {false}))) +FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::Small1DScatterDataset(), + make("DataType", {DataType::F32}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } // With this test, src should be passed as nullptr. -FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(), - make("DataType", {DataType::F32}), - make("ScatterFunction", {ScatterFunction::Add}), - make("ZeroInit", {true}))) +FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::Small1DScatterDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Add}), + make("ZeroInit", {true}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f32); +} + +// Updates/src/dst have same no. dims. +FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMultiDimDataset(), + make("DataType", {DataType::F32}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f32); +} + +// m+1-D to m+n-D cases +FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMultiIndicesDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }), + make("ZeroInit", {false}), + make("Inplace", {false, true}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f32); +} + +// m+k, k-1-D m+n-D case +FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture, framework::DatasetMode::DISABLED, + combine(datasets::SmallScatterBatchedDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }), + make("ZeroInit", {false}), + make("Inplace", {false}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 91e28b58f7..4fb2d7f127 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -29,6 +29,7 @@ #include "tests/Globals.h" #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" +#include "tests/validation/Helpers.h" #include "tests/validation/Validation.h" #include "tests/validation/reference/ScatterLayer.h" #include "tests/SimpleTensor.h" @@ -46,9 +47,17 @@ template fill_tensor_uniform(tensor, i, static_cast(-2), static_cast(max)); } - TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) + TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, + const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, + QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // 1. Create relevant tensors using ScatterInfo data structure. // ---------------------------------------------------- @@ -94,14 +105,22 @@ protected: FunctionType scatter; // Configure operator - // When scatter_info.zero_initialization is true, pass nullptr to scatter function. + // When scatter_info.zero_initialization is true, pass nullptr for src + // because dst does not need to be initialized with src values. if(info.zero_initialization) { scatter.configure(nullptr, &updates, &indices, &dst, info); } else { - scatter.configure(&src, &updates, &indices, &dst, info); + if(inplace) + { + scatter.configure(&src, &updates, &indices, &src, info); + } + else + { + scatter.configure(&src, &updates, &indices, &dst, info); + } } // Assertions @@ -110,28 +129,51 @@ protected: ARM_COMPUTE_ASSERT(indices.info()->is_resizable()); ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + add_padding_x({ &src, &updates, &indices}); + + if(!inplace) + { + add_padding_x({ &dst }); + } + // Allocate tensors src.allocator()->allocate(); updates.allocator()->allocate(); indices.allocator()->allocate(); - dst.allocator()->allocate(); + + if(!inplace) + { + dst.allocator()->allocate(); + } ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); ARM_COMPUTE_ASSERT(!updates.info()->is_resizable()); ARM_COMPUTE_ASSERT(!indices.info()->is_resizable()); - ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + if(!inplace) + { + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + } // Fill update (a) and indices (b) tensors. - fill(AccessorType(src), 0); - fill(AccessorType(updates), 1); - fill_indices(AccessorType(indices), 2, out_shape); + fill(AccessorType(src), 0 + _hash); + fill(AccessorType(updates), 1+ _hash); + fill_indices(AccessorType(indices), 2 + _hash, out_shape); scatter.run(); - return dst; + + if(inplace) + { + return src; + } + else + { + return dst; + } } - SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type, - ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) + SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, + const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // Output Quantization not currently in use - fixture should be extended to support this. ARM_COMPUTE_UNUSED(o_qinfo); @@ -158,9 +200,9 @@ protected: SimpleTensor indices{ c_shape, DataType::S32, 1, QuantizationInfo() }; // Fill reference - fill(src, 0); - fill(updates, 1); - fill_indices(indices, 2, out_shape); + fill(src, 0 + _hash); + fill(updates, 1 + _hash); + fill_indices(indices, 2 + _hash, out_shape); // Calculate individual reference. return reference::scatter_layer(src, updates, indices, out_shape, info); @@ -168,6 +210,7 @@ protected: TensorType _target{}; SimpleTensor _reference{}; + int32_t _hash{}; }; // This fixture will use the same shape for updates as indices. @@ -175,9 +218,12 @@ template { public: - void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init) + void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, + TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace) { - ScatterGenericValidationFixture::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo()); + ScatterGenericValidationFixture::setup(src_shape, update_shape, + indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, + QuantizationInfo(), QuantizationInfo()); } }; -- cgit v1.2.1