diff options
author | Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> | 2024-04-08 14:38:31 +0100 |
---|---|---|
committer | Suhail M <MohammedSuhail.Munshi@arm.com> | 2024-04-22 15:35:41 +0000 |
commit | 0e2123695083df5fc1a98af22bbb51808c413350 (patch) | |
tree | 3606439df27480ab7a45097b491775a44c12d032 /tests/validation/fixtures/ScatterLayerFixture.h | |
parent | 7377107378d6c26439320fce78a551e85b5ad36a (diff) | |
download | ComputeLibrary-0e2123695083df5fc1a98af22bbb51808c413350.tar.gz |
Multi-Dimensional and Batched Scatter Reference and Dataset Implementation.
Resolves: [COMPMID-6893, COMPMID-6895, COMPMID-6898]
Change-Id: I355f46aeba2213cd8d067cac7643d8d96e713c93
Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11430
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/ScatterLayerFixture.h | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 451a1e1416..91e28b58f7 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -54,7 +54,7 @@ public: protected: template <typename U> - void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f) + void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f) { switch(tensor.data_type()) { @@ -135,6 +135,22 @@ protected: { // Output Quantization not currently in use - fixture should be extended to support this. ARM_COMPUTE_UNUSED(o_qinfo); + TensorShape src_shape = a_shape; + TensorShape updates_shape = b_shape; + TensorShape indices_shape = c_shape; + + // 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor. + if(c_shape.num_dimensions() >= 3) + { + indices_shape = indices_shape.collapsed_from(1); + updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims + } + + // 2. Collapse data dims into a single dim. + // Collapse all src dims into 2 dims. First one holding data, the other being the index we iterate over. + src_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse all data dims into single dim. + src_shape = src_shape.collapsed_from(1); // Collapse all index dims into a single dim + updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim) // Create reference tensors SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo }; |