From 0e2123695083df5fc1a98af22bbb51808c413350 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Mon, 8 Apr 2024 14:38:31 +0100 Subject: Multi-Dimensional and Batched Scatter Reference and Dataset Implementation. Resolves: [COMPMID-6893, COMPMID-6895, COMPMID-6898] Change-Id: I355f46aeba2213cd8d067cac7643d8d96e713c93 Signed-off-by: Mohammed Suhail Munshi Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11430 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- tests/validation/fixtures/ScatterLayerFixture.h | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'tests/validation/fixtures/ScatterLayerFixture.h') diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 451a1e1416..91e28b58f7 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -54,7 +54,7 @@ public: protected: template - void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f) + void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f) { switch(tensor.data_type()) { @@ -135,6 +135,22 @@ protected: { // Output Quantization not currently in use - fixture should be extended to support this. ARM_COMPUTE_UNUSED(o_qinfo); + TensorShape src_shape = a_shape; + TensorShape updates_shape = b_shape; + TensorShape indices_shape = c_shape; + + // 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor. + if(c_shape.num_dimensions() >= 3) + { + indices_shape = indices_shape.collapsed_from(1); + updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims + } + + // 2. Collapse data dims into a single dim. + // Collapse all src dims into 2 dims. First one holding data, the other being the index we iterate over. + src_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse all data dims into single dim. + src_shape = src_shape.collapsed_from(1); // Collapse all index dims into a single dim + updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim) // Create reference tensors SimpleTensor src{ a_shape, data_type, 1, a_qinfo }; -- cgit v1.2.1