From ada3200f5cec0b6a37f898d5d6f8e69395d7bcb1 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Wed, 24 Apr 2024 10:27:13 +0100 Subject: Add update/index/output (m+1)/2d/(m+n) support for CLScatter Resolves: COMPMID-6894, COMPMID-6896 Change-Id: I9d29fd3701a7e0f28d83f81a6c42a7234c2587c3 Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11477 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Ramy Elgammal Dynamic-Fusion: Ramy Elgammal Benchmark: Arm Jenkins --- tests/validation/fixtures/ScatterLayerFixture.h | 82 +++++++++++++++++++------ 1 file changed, 64 insertions(+), 18 deletions(-) (limited to 'tests/validation/fixtures/ScatterLayerFixture.h') diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 91e28b58f7..4fb2d7f127 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -29,6 +29,7 @@ #include "tests/Globals.h" #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" +#include "tests/validation/Helpers.h" #include "tests/validation/Validation.h" #include "tests/validation/reference/ScatterLayer.h" #include "tests/SimpleTensor.h" @@ -46,9 +47,17 @@ template fill_tensor_uniform(tensor, i, static_cast(-2), static_cast(max)); } - TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) + TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, + const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace, + QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // 1. Create relevant tensors using ScatterInfo data structure. // ---------------------------------------------------- @@ -94,14 +105,22 @@ protected: FunctionType scatter; // Configure operator - // When scatter_info.zero_initialization is true, pass nullptr to scatter function. + // When scatter_info.zero_initialization is true, pass nullptr for src + // because dst does not need to be initialized with src values. if(info.zero_initialization) { scatter.configure(nullptr, &updates, &indices, &dst, info); } else { - scatter.configure(&src, &updates, &indices, &dst, info); + if(inplace) + { + scatter.configure(&src, &updates, &indices, &src, info); + } + else + { + scatter.configure(&src, &updates, &indices, &dst, info); + } } // Assertions @@ -110,28 +129,51 @@ protected: ARM_COMPUTE_ASSERT(indices.info()->is_resizable()); ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + add_padding_x({ &src, &updates, &indices}); + + if(!inplace) + { + add_padding_x({ &dst }); + } + // Allocate tensors src.allocator()->allocate(); updates.allocator()->allocate(); indices.allocator()->allocate(); - dst.allocator()->allocate(); + + if(!inplace) + { + dst.allocator()->allocate(); + } ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); ARM_COMPUTE_ASSERT(!updates.info()->is_resizable()); ARM_COMPUTE_ASSERT(!indices.info()->is_resizable()); - ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + if(!inplace) + { + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + } // Fill update (a) and indices (b) tensors. - fill(AccessorType(src), 0); - fill(AccessorType(updates), 1); - fill_indices(AccessorType(indices), 2, out_shape); + fill(AccessorType(src), 0 + _hash); + fill(AccessorType(updates), 1+ _hash); + fill_indices(AccessorType(indices), 2 + _hash, out_shape); scatter.run(); - return dst; + + if(inplace) + { + return src; + } + else + { + return dst; + } } - SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type, - ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) + SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, + const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo) { // Output Quantization not currently in use - fixture should be extended to support this. ARM_COMPUTE_UNUSED(o_qinfo); @@ -158,9 +200,9 @@ protected: SimpleTensor indices{ c_shape, DataType::S32, 1, QuantizationInfo() }; // Fill reference - fill(src, 0); - fill(updates, 1); - fill_indices(indices, 2, out_shape); + fill(src, 0 + _hash); + fill(updates, 1 + _hash); + fill_indices(indices, 2 + _hash, out_shape); // Calculate individual reference. return reference::scatter_layer(src, updates, indices, out_shape, info); @@ -168,6 +210,7 @@ protected: TensorType _target{}; SimpleTensor _reference{}; + int32_t _hash{}; }; // This fixture will use the same shape for updates as indices. @@ -175,9 +218,12 @@ template { public: - void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init) + void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, + TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace) { - ScatterGenericValidationFixture::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo()); + ScatterGenericValidationFixture::setup(src_shape, update_shape, + indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace, + QuantizationInfo(), QuantizationInfo()); } }; -- cgit v1.2.1