diff options
author | Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> | 2024-03-25 15:55:42 +0000 |
---|---|---|
committer | Suhail M <MohammedSuhail.Munshi@arm.com> | 2024-04-22 14:44:09 +0000 |
commit | 7377107378d6c26439320fce78a551e85b5ad36a (patch) | |
tree | 3aa9c74c59993f9d51924fc123eefa17e3376a79 /tests/validation/CL/ScatterLayer.cpp | |
parent | 5057ce9e1866ffa0388543d81af32083b5b1c684 (diff) | |
download | ComputeLibrary-7377107378d6c26439320fce78a551e85b5ad36a.tar.gz |
Scatter GPU Kernel Implementation for 1D tensors.
Resolves: [COMPMID-6891, COMPMID-6892]
Change-Id: I5b094fff1bff4c4c59cc44f7d6beab0e40133d8e
Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11394
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL/ScatterLayer.cpp')
-rw-r--r-- | tests/validation/CL/ScatterLayer.cpp | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp index 56338f489f..9711671841 100644 --- a/tests/validation/CL/ScatterLayer.cpp +++ b/tests/validation/CL/ScatterLayer.cpp @@ -38,6 +38,10 @@ namespace test { namespace validation { +namespace +{ +RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */ +} // namespace template <typename T> using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>; @@ -46,7 +50,7 @@ using framework::dataset::make; TEST_SUITE(CL) TEST_SUITE(Scatter) -DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip( +DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid TensorInfo(TensorShape(8U), 1, DataType::F32), @@ -61,12 +65,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip( TensorInfo(TensorShape(217U, 3U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::F32), }), - make("IndicesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U32), - TensorInfo(TensorShape(15U), 1, DataType::U32), - TensorInfo(TensorShape(2U), 1, DataType::U32), - TensorInfo(TensorShape(271U), 1, DataType::U32), - TensorInfo(TensorShape(271U), 1, DataType::U32), - TensorInfo(TensorShape(2U), 1 , DataType::S32) + make("IndicesInfo",{ TensorInfo(TensorShape(1U, 3U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 15U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 2U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32) }), make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16), TensorInfo(TensorShape(15U), 1, DataType::F32), @@ -76,27 +80,27 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip( TensorInfo(TensorShape(12U), 1, DataType::F32) }), make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false), + ScatterInfo(ScatterFunction::Max, false), + ScatterInfo(ScatterFunction::Min, false), + ScatterInfo(ScatterFunction::Add, false), + ScatterInfo(ScatterFunction::Update, false), + ScatterInfo(ScatterFunction::Sub, false), }), make("Expected", { false, true, true, false, false, false })), input_info, updates_info, indices_info, output_info, scatter_info, expected) { - // TODO: Enable validation tests. - ARM_COMPUTE_UNUSED(input_info); - ARM_COMPUTE_UNUSED(updates_info); - ARM_COMPUTE_UNUSED(indices_info); - ARM_COMPUTE_UNUSED(output_info); - ARM_COMPUTE_UNUSED(scatter_info); - ARM_COMPUTE_UNUSED(expected); + const Status status = CLScatter::validate(&input_info.clone()->set_is_resizable(true), &updates_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), scatter_info); + ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS); } TEST_SUITE(Float) TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(), make("DataType", {DataType::F32}), - make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max}), + make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max }), make("ZeroInit", {false}))) { - // TODO: Add validate() here. + validate(CLAccessor(_target), _reference, tolerance_f32); } // With this test, src should be passed as nullptr. @@ -105,7 +109,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework make("ScatterFunction", {ScatterFunction::Add}), make("ZeroInit", {true}))) { - // TODO: Add validate() here + validate(CLAccessor(_target), _reference, tolerance_f32); } TEST_SUITE_END() // FP32 TEST_SUITE_END() // Float |