aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/ScatterLayer.cpp
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-05-09 13:24:15 +0100
committerSuhail M <MohammedSuhail.Munshi@arm.com>2024-05-10 11:01:30 +0000
commit05269f013cf2b7c4a53f5950cdd6bfea26367769 (patch)
treebc620cce08a1569c311ce0fd9d74833e5ff63382 /tests/validation/CL/ScatterLayer.cpp
parent48f120c64c21d983318c6e65f6d5609a8f8e92e6 (diff)
downloadComputeLibrary-05269f013cf2b7c4a53f5950cdd6bfea26367769.tar.gz
ScatterND fix for scalar cases
- Padding with batched scalar cases is unsupported, adds checks. - Adds tests for scalar cases, without padding. Resolves: [COMPMID-7015] Change-Id: Ib9cf5db990420ff4b442d003ef9424e365bee86d Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11536 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL/ScatterLayer.cpp')
-rw-r--r--tests/validation/CL/ScatterLayer.cpp52
1 files changed, 40 insertions, 12 deletions
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
index e327ff9522..b1531eb64a 100644
--- a/tests/validation/CL/ScatterLayer.cpp
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -125,7 +125,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::Datase
make("DataType", {DataType::F32}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {true})))
{
validate(CLAccessor(_target), _reference, tolerance_f32);
}
@@ -136,7 +137,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework
make("DataType", {DataType::F32}),
make("ScatterFunction", {ScatterFunction::Add}),
make("ZeroInit", {true}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {true})))
{
validate(CLAccessor(_target), _reference, tolerance_f32);
}
@@ -147,7 +149,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture<float>, framework
make("DataType", {DataType::F32}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {true})))
{
validate(CLAccessor(_target), _reference, tolerance_f32);
}
@@ -158,7 +161,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture<float>, frame
make("DataType", {DataType::F32}),
make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
make("ZeroInit", {false}),
- make("Inplace", {false, true})))
+ make("Inplace", {false, true}),
+ make("Padding", {true})))
{
validate(CLAccessor(_target), _reference, tolerance_f32);
}
@@ -169,20 +173,38 @@ FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture<float>
make("DataType", {DataType::F32}),
make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}),
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {true})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+// m+k, k-1-D m+n-D case
+FIXTURE_DATA_TEST_CASE(RunSmallScatterScalar, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterScalarDataset(),
+ make("DataType", {DataType::F32}),
+ make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add}),
+ make("ZeroInit", {false}),
+ make("Inplace", {false}),
+ make("Padding", {false}))) // NOTE: Padding not supported in this datset
{
validate(CLAccessor(_target), _reference, tolerance_f32);
}
TEST_SUITE_END() // FP32
+
+// NOTE: Padding is disabled for the SmallScatterMixedDataset due certain shapes not supporting padding.
+// Padding is well tested in F32 Datatype test cases.
+
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
combine(datasets::SmallScatterMixedDataset(),
make("DataType", {DataType::F16}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {false})))
{
validate(CLAccessor(_target), _reference, tolerance_f16);
}
@@ -196,7 +218,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int32_t>, framework:
make("DataType", {DataType::S32}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {false})))
{
validate(CLAccessor(_target), _reference, tolerance_int);
}
@@ -208,7 +231,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int16_t>, framework:
make("DataType", {DataType::S16}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {false})))
{
validate(CLAccessor(_target), _reference, tolerance_int);
}
@@ -220,7 +244,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<int8_t>, framework::
make("DataType", {DataType::S8}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {false})))
{
validate(CLAccessor(_target), _reference, tolerance_int);
}
@@ -232,7 +257,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint32_t>, framework
make("DataType", {DataType::U32}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {false})))
{
validate(CLAccessor(_target), _reference, tolerance_int);
}
@@ -244,7 +270,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint16_t>, framework
make("DataType", {DataType::U16}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {false})))
{
validate(CLAccessor(_target), _reference, tolerance_int);
}
@@ -256,7 +283,8 @@ FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture<uint8_t>, framework:
make("DataType", {DataType::U8}),
allScatterFunctions,
make("ZeroInit", {false}),
- make("Inplace", {false})))
+ make("Inplace", {false}),
+ make("Padding", {false})))
{
validate(CLAccessor(_target), _reference, tolerance_int);
}