From 301e33f8f94be6427bf2377570388c379d8c8466 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Mon, 29 Apr 2024 17:00:14 +0100 Subject: Add fp16 and integer data type support for ScatterNd in Gpu Resolves: COMPMID-6899 Change-Id: I3743f2c9e5c21e1ec9f4c81d08c148666afad33a Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11505 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Jakub Sujak Reviewed-by: Sang Won Ha Comments-Addressed: Arm Jenkins --- tests/datasets/ScatterDataset.h | 14 ++++ tests/validation/CL/ScatterLayer.cpp | 96 ++++++++++++++++++++++++- tests/validation/fixtures/ScatterLayerFixture.h | 21 +++++- tests/validation/reference/ScatterLayer.cpp | 8 ++- 4 files changed, 135 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h index c0858941db..9dcf859a8f 100644 --- a/tests/datasets/ScatterDataset.h +++ b/tests/datasets/ScatterDataset.h @@ -185,6 +185,20 @@ public: add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(6U, 2U), TensorShape(5U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U)); } }; + +// This dataset is for data types that does not require full testing. It contains selected tests from the above. +class SmallScatterMixedDataset final : public ScatterDataset +{ +public: + SmallScatterMixedDataset() + { + add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U)); + add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U)); + add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U)); + add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U)); + // TODO: add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U)); + } +}; } // namespace datasets } // namespace test } // namespace arm_compute diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp index 4a2462c7d2..2970d82572 100644 --- a/tests/validation/CL/ScatterLayer.cpp +++ b/tests/validation/CL/ScatterLayer.cpp @@ -41,6 +41,8 @@ namespace validation namespace { RelativeTolerance tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */ +RelativeTolerance tolerance_f16(0.02f); /**< Tolerance value for comparing reference's output against implementation's output for fp16 data type */ +RelativeTolerance tolerance_int(0); /**< Tolerance value for comparing reference's output against implementation's output for integer data types */ } // namespace template @@ -53,6 +55,7 @@ TEST_SUITE(Scatter) DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid + TensorInfo(TensorShape(15U), 1, DataType::U8), // Valid TensorInfo(TensorShape(8U), 1, DataType::F32), TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims. TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims. @@ -63,6 +66,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( }), make("UpdatesInfo",{TensorInfo(TensorShape(3U), 1, DataType::F16), TensorInfo(TensorShape(15U), 1, DataType::F32), + TensorInfo(TensorShape(15U), 1, DataType::U8), TensorInfo(TensorShape(2U), 1, DataType::F32), TensorInfo(TensorShape(217U), 1, DataType::F32), TensorInfo(TensorShape(217U, 3U), 1, DataType::F32), @@ -72,6 +76,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( TensorInfo(TensorShape(1U), 1, DataType::F32), }), make("IndicesInfo",{TensorInfo(TensorShape(1U, 3U), 1, DataType::S32), + TensorInfo(TensorShape(1U, 15U), 1, DataType::S32), TensorInfo(TensorShape(1U, 15U), 1, DataType::S32), TensorInfo(TensorShape(1U, 2U), 1, DataType::S32), TensorInfo(TensorShape(1U, 271U), 1, DataType::S32), @@ -83,6 +88,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( }), make("OutputInfo",{TensorInfo(TensorShape(9U), 1, DataType::F16), TensorInfo(TensorShape(15U), 1, DataType::F32), + TensorInfo(TensorShape(15U), 1, DataType::U8), TensorInfo(TensorShape(8U), 1, DataType::F32), TensorInfo(TensorShape(271U, 3U), 1, DataType::F32), TensorInfo(TensorShape(271U), 1, DataType::F32), @@ -92,6 +98,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), }), make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false), + ScatterInfo(ScatterFunction::Max, false), ScatterInfo(ScatterFunction::Max, false), ScatterInfo(ScatterFunction::Min, false), ScatterInfo(ScatterFunction::Add, false), @@ -101,7 +108,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip( ScatterInfo(ScatterFunction::Update, false), ScatterInfo(ScatterFunction::Update, false), }), - make("Expected", { false, true, true, false, false, false, false, false, false })), + make("Expected", { false, true, true, true, false, false, false, false, false, false })), input_info, updates_info, indices_info, output_info, scatter_info, expected) { const Status status = CLScatter::validate(&input_info, &updates_info, &indices_info, &output_info, scatter_info); @@ -168,7 +175,94 @@ FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture } TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMixedDataset(), + make("DataType", {DataType::F16}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_f16); +} +TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float + +TEST_SUITE(Integer) +TEST_SUITE(S32) +FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMixedDataset(), + make("DataType", {DataType::S32}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_int); +} +TEST_SUITE_END() // S32 + +TEST_SUITE(S16) +FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMixedDataset(), + make("DataType", {DataType::S16}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_int); +} +TEST_SUITE_END() // S16 + +TEST_SUITE(S8) +FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMixedDataset(), + make("DataType", {DataType::S8}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_int); +} +TEST_SUITE_END() // S8 + +TEST_SUITE(U32) +FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMixedDataset(), + make("DataType", {DataType::U32}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_int); +} +TEST_SUITE_END() // U32 + +TEST_SUITE(U16) +FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMixedDataset(), + make("DataType", {DataType::U16}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_int); +} +TEST_SUITE_END() // U16 + +TEST_SUITE(U8) +FIXTURE_DATA_TEST_CASE(RunSmallMixed, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, + combine(datasets::SmallScatterMixedDataset(), + make("DataType", {DataType::U8}), + allScatterFunctions, + make("ZeroInit", {false}), + make("Inplace", {false}))) +{ + validate(CLAccessor(_target), _reference, tolerance_int); +} +TEST_SUITE_END() // U8 +TEST_SUITE_END() // Integer + TEST_SUITE_END() // Scatter TEST_SUITE_END() // CL } // namespace validation diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 4fb2d7f127..35e6b647f3 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -63,13 +63,30 @@ public: protected: template - void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f) + void fill(U &&tensor, int i) { switch(tensor.data_type()) { case DataType::F32: + case DataType::F16: { - std::uniform_real_distribution distribution(lo, hi); + std::uniform_real_distribution distribution(-10.f, 10.f); + library->fill(tensor, distribution, i); + break; + } + case DataType::S32: + case DataType::S16: + case DataType::S8: + { + std::uniform_int_distribution distribution(-100, 100); + library->fill(tensor, distribution, i); + break; + } + case DataType::U32: + case DataType::U16: + case DataType::U8: + { + std::uniform_int_distribution distribution(0, 200); library->fill(tensor, distribution, i); break; } diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp index 283022e8e2..c9e6035e14 100644 --- a/tests/validation/reference/ScatterLayer.cpp +++ b/tests/validation/reference/ScatterLayer.cpp @@ -138,7 +138,13 @@ SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor } template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); - +template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); +template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); +template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); +template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); +template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); +template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); +template SimpleTensor scatter_layer(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1