From 473b8291a7dc126586d19b82d6c84b4c7a6e44a6 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Mon, 18 Mar 2024 12:13:30 +0000 Subject: Adds Tests and reference implementation for scatter operator with 1D tensors. Resolves: [COMPMID-6890] Change-Id: Ie4a8db24fc6387afa9ddf42b3607e040cdf8df67 Signed-off-by: Mohammed Suhail Munshi Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11339 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/gpu/cl/operators/ClScatter.cpp | 2 +- src/runtime/CL/functions/CLScatter.cpp | 11 +++- tests/datasets/ScatterDataset.h | 5 +- tests/validation/CL/ScatterLayer.cpp | 70 ++++++++++++++++++++++++- tests/validation/fixtures/ScatterLayerFixture.h | 31 +++++++++-- tests/validation/reference/ScatterLayer.cpp | 64 ++++++++++++++++++++-- 6 files changed, 166 insertions(+), 17 deletions(-) diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp index 74d747bc16..af5fbb86f3 100644 --- a/src/gpu/cl/operators/ClScatter.cpp +++ b/src/gpu/cl/operators/ClScatter.cpp @@ -59,7 +59,7 @@ void ClScatter::configure(const CLCompileContext &compile_context, ITensorInfo *dst, const ScatterInfo &info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(src, indices, dst); + ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, dst); ARM_COMPUTE_LOG_PARAMS(src, indices, dst, info); ARM_COMPUTE_UNUSED(src); ARM_COMPUTE_UNUSED(updates); diff --git a/src/runtime/CL/functions/CLScatter.cpp b/src/runtime/CL/functions/CLScatter.cpp index e1de92968a..e16fcc4ccc 100644 --- a/src/runtime/CL/functions/CLScatter.cpp +++ b/src/runtime/CL/functions/CLScatter.cpp @@ -62,10 +62,17 @@ void CLScatter::configure(const CLCompileContext &compile_context, ICLTensor *output, const ScatterInfo &info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(src, indices, output); + ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, output); _impl->op = std::make_unique(); - _impl->op->configure(compile_context, src->info(), updates->info(), indices->info(), output->info(), info); + if (src) + { // Src not nullptr. + _impl->op->configure(compile_context, src->info(), updates->info(), indices->info(), output->info(), info); + } + else + { + _impl->op->configure(compile_context, nullptr, updates->info(), indices->info(), output->info(), info); + } _impl->run_pack = {{ACL_SRC_0, src}, {ACL_SRC_1, updates}, {ACL_SRC_2, indices}, {ACL_DST, output}}; } diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h index 09f6338432..d204d17855 100644 --- a/tests/datasets/ScatterDataset.h +++ b/tests/datasets/ScatterDataset.h @@ -113,12 +113,13 @@ private: std::vector _dst_shapes{}; }; -class SmallScatterDataset final : public ScatterDataset +class Small1DScatterDataset final : public ScatterDataset { public: - SmallScatterDataset() + Small1DScatterDataset() { add_config(TensorShape(6U), TensorShape(6U), TensorShape(6U), TensorShape(6U)); + add_config(TensorShape(10U), TensorShape(2U), TensorShape(2U), TensorShape(10U)); } }; } // namespace datasets diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp index 66b71ef650..56338f489f 100644 --- a/tests/validation/CL/ScatterLayer.cpp +++ b/tests/validation/CL/ScatterLayer.cpp @@ -24,8 +24,13 @@ #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/functions/CLScatter.h" #include "tests/validation/fixtures/ScatterLayerFixture.h" +#include "tests/datasets/ScatterDataset.h" #include "tests/CL/CLAccessor.h" +#include "arm_compute/function_info/ScatterInfo.h" +#include "tests/framework/Asserts.h" #include "tests/framework/Macros.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Validation.h" namespace arm_compute { @@ -37,13 +42,74 @@ namespace validation template using CLScatterLayerFixture = ScatterValidationFixture; +using framework::dataset::make; + TEST_SUITE(CL) -TEST_SUITE(ScatterLayer) +TEST_SUITE(Scatter) +DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip( + make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types + TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid + TensorInfo(TensorShape(8U), 1, DataType::F32), + TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims. + TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims. + TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype. + }), + make("UpdatesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::F16), + TensorInfo(TensorShape(15U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(217U), 1, DataType::F32), + TensorInfo(TensorShape(217U, 3U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F32), + }), + make("IndicesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U32), + TensorInfo(TensorShape(15U), 1, DataType::U32), + TensorInfo(TensorShape(2U), 1, DataType::U32), + TensorInfo(TensorShape(271U), 1, DataType::U32), + TensorInfo(TensorShape(271U), 1, DataType::U32), + TensorInfo(TensorShape(2U), 1 , DataType::S32) + }), + make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16), + TensorInfo(TensorShape(15U), 1, DataType::F32), + TensorInfo(TensorShape(8U), 1, DataType::F32), + TensorInfo(TensorShape(271U, 3U), 1, DataType::F32), + TensorInfo(TensorShape(271U), 1, DataType::F32), + TensorInfo(TensorShape(12U), 1, DataType::F32) + }), + make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false), + }), + make("Expected", { false, true, true, false, false, false })), + input_info, updates_info, indices_info, output_info, scatter_info, expected) +{ + // TODO: Enable validation tests. + ARM_COMPUTE_UNUSED(input_info); + ARM_COMPUTE_UNUSED(updates_info); + ARM_COMPUTE_UNUSED(indices_info); + ARM_COMPUTE_UNUSED(output_info); + ARM_COMPUTE_UNUSED(scatter_info); + ARM_COMPUTE_UNUSED(expected); +} + TEST_SUITE(Float) TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max}), + make("ZeroInit", {false}))) +{ + // TODO: Add validate() here. +} + +// With this test, src should be passed as nullptr. +FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(), + make("DataType", {DataType::F32}), + make("ScatterFunction", {ScatterFunction::Add}), + make("ZeroInit", {true}))) +{ + // TODO: Add validate() here +} TEST_SUITE_END() // FP32 TEST_SUITE_END() // Float -TEST_SUITE_END() // ScatterLayer +TEST_SUITE_END() // Scatter TEST_SUITE_END() // CL } // namespace validation } // namespace test diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 750e272388..bda5532a51 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -25,13 +25,16 @@ #define ACL_TESTS_VALIDATION_FIXTURES_SCATTERLAYERFIXTURE_H #include "arm_compute/core/Utils.h" +#include "arm_compute/runtime/CL/CLTensorAllocator.h" #include "tests/Globals.h" #include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT #include "tests/framework/Fixture.h" #include "tests/validation/Validation.h" #include "tests/validation/reference/ScatterLayer.h" #include "tests/SimpleTensor.h" + #include +#include namespace arm_compute { @@ -68,6 +71,16 @@ protected: } } + // This is used to fill indices tensor with U32 datatype. + // Used to prevent ONLY having values that are out of bounds. + template + void fill_indices(U &&tensor, int i, const TensorShape &shape) + { + // Calculate max indices the shape should contain. Add an arbitrary constant to allow testing for some out of bounds values. + const uint32_t max = std::max({shape[0] , shape[1], shape[2]}) + 5; + library->fill_tensor_uniform(tensor, i, static_cast(0), static_cast(max)); + } + TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // 1. Create relevant tensors using ScatterInfo data structure. @@ -81,7 +94,15 @@ protected: FunctionType scatter; // Configure operator - scatter.configure(&src, &updates, &indices, &dst, info); + // When scatter_info.zero_initialization is true, pass nullptr to scatter function. + if(info.zero_initialization) + { + scatter.configure(nullptr, &updates, &indices, &dst, info); + } + else + { + scatter.configure(&src, &updates, &indices, &dst, info); + } // Assertions ARM_COMPUTE_ASSERT(src.info()->is_resizable()); @@ -103,7 +124,7 @@ protected: // Fill update (a) and indices (b) tensors. fill(AccessorType(src), 0); fill(AccessorType(updates), 1); - fill(AccessorType(indices), 2); + fill_indices(AccessorType(indices), 2, out_shape); scatter.run(); @@ -124,7 +145,7 @@ protected: // Fill reference fill(src, 0); fill(updates, 1); - fill(indices, 2); + fill_indices(indices, 2, out_shape); // Calculate individual reference. auto result = reference::scatter_layer(src, updates, indices, out_shape, info); @@ -141,9 +162,9 @@ template { public: - void setup(TensorShape src_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init) + void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init) { - ScatterGenericValidationFixture::setup(src_shape, indices_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo()); + ScatterGenericValidationFixture::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo()); } }; diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp index 188cce100b..920f2b9990 100644 --- a/tests/validation/reference/ScatterLayer.cpp +++ b/tests/validation/reference/ScatterLayer.cpp @@ -32,16 +32,70 @@ namespace validation { namespace reference { +namespace +{ +template +T reduce_op(const T ¤t,const T &update,const ScatterFunction func) +{ + switch(func) + { + case ScatterFunction::Update: + return update; + break; + case ScatterFunction::Add: + return current + update; + break; + case ScatterFunction::Sub: + return current - update; + break; + case ScatterFunction::Max: + return std::max(current, update); + break; + case ScatterFunction::Min: + return std::min(current, update); + break; + default: + ARM_COMPUTE_ERROR("Unsupported Scatter function"); + break; + } +} + +template float reduce_op(const float ¤t,const float &update,const ScatterFunction func); +} + +// Note : This function currently only supports 1D src, 1D updates, 2D indices, 1D output tensors. template SimpleTensor scatter_layer_internal(const SimpleTensor &src, const SimpleTensor &updates, const SimpleTensor &indices, const TensorShape &out_shape, const ScatterInfo &info) { - ARM_COMPUTE_UNUSED(src); - ARM_COMPUTE_UNUSED(updates); - ARM_COMPUTE_UNUSED(indices); - ARM_COMPUTE_UNUSED(info); - // Unimplemented reference. SimpleTensor dst{ out_shape, src.data_type(), 1 }; + + // 1. If zero initialization variable is true, fill dst with 0 values. Else copy src data to dst. + if(info.zero_initialization) + { + for (int i = 0; i < src.num_elements(); ++i) + { + dst[i] = static_cast(0); + } + } + else + { + std::copy_n(src.data(), src.num_elements(), dst.data()); + } + + // 2. Get max index of output tensor, then iterate over index tensor. + const auto x_bound = dst.shape().x(); + + + for(int i = 0; i < indices.num_elements(); ++i) + { + // 3. Check whether index is out of bounds for dst, if not then apply reduce op. + const auto index = indices[i]; + if (index < x_bound) // Note : index is always >= 0 as datatype is unsigned. + { + dst[index] = reduce_op(dst[index], updates[i], info.func); + } + } return dst; } -- cgit v1.2.1