From 623e2f51d2f4f445c5b99db8a4c600e1d155dc8e Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Thu, 9 May 2024 13:24:15 +0100 Subject: ScatterND fix for scalar cases - Padding with batched scalar cases is unsupported, adds checks. - Adds tests for scalar cases, without padding. Resolves: [COMPMID-7015] Change-Id: Ib9cf5db990420ff4b442d003ef9424e365bee86d Signed-off-by: Mohammed Suhail Munshi Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11536 Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Benchmark: Arm Jenkins Signed-off-by: Michael Kozlov --- tests/validation/CL/ScatterLayer.cpp | 52 +++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 12 deletions(-) (limited to 'tests/validation/CL') diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp index e327ff9522..b1531eb64a 100644 --- a/tests/validation/CL/ScatterLayer.cpp +++ b/tests/validation/CL/ScatterLayer.cpp @@ -125,7 +125,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture, framework::Datase make("DataType", {DataType::F32}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {true}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -136,7 +137,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture, framework make("DataType", {DataType::F32}), make("ScatterFunction", {ScatterFunction::Add}), make("ZeroInit", {true}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {true}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -147,7 +149,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture, framework make("DataType", {DataType::F32}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {true}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -158,7 +161,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture, frame make("DataType", {DataType::F32}), make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }), make("ZeroInit", {false}), - make("Inplace", {false, true}))) + make("Inplace", {false, true}), + make("Padding", {true}))) { validate(CLAccessor(_target), _reference, tolerance_f32); } @@ -169,20 +173,38 @@ FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture make("DataType", {DataType::F32}), make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}), make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {true}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f32); +} + +// m+k, k-1-D m+n-D case +FIXTURE_DATA_TEST_CASE(RunSmallScatterScalar, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterScalarDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}), + make("ZeroInit", {false}), + make("Inplace", {false}), + make("Padding", {false}))) // NOTE: Padding not supported in this datset { validate(CLAccessor(_target), _reference, tolerance_f32); } TEST_SUITE_END() // FP32 + +// NOTE: Padding is disabled for the SmallScatterMixedDataset due certain shapes not supporting padding. +// Padding is well tested in F32 Datatype test cases. + TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallScatterMixedDataset(), make("DataType", {DataType::F16}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_f16); } @@ -196,7 +218,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework: make("DataType", {DataType::S32}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -208,7 +231,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework: make("DataType", {DataType::S16}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -220,7 +244,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework:: make("DataType", {DataType::S8}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -232,7 +257,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework make("DataType", {DataType::U32}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -244,7 +270,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework make("DataType", {DataType::U16}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } @@ -256,7 +283,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework: make("DataType", {DataType::U8}), allScatterFunctions, make("ZeroInit", {false}), - make("Inplace", {false}))) + make("Inplace", {false}), + make("Padding", {false}))) { validate(CLAccessor(_target), _reference, tolerance_int); } -- cgit v1.2.1