aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScatterLayerFixture.h
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2024-04-08 14:38:31 +0100
committerSuhail M <MohammedSuhail.Munshi@arm.com>2024-04-22 15:35:41 +0000
commit0e2123695083df5fc1a98af22bbb51808c413350 (patch)
tree3606439df27480ab7a45097b491775a44c12d032 /tests/validation/fixtures/ScatterLayerFixture.h
parent7377107378d6c26439320fce78a551e85b5ad36a (diff)
downloadComputeLibrary-0e2123695083df5fc1a98af22bbb51808c413350.tar.gz
Multi-Dimensional and Batched Scatter Reference and Dataset Implementation.
Resolves: [COMPMID-6893, COMPMID-6895, COMPMID-6898] Change-Id: I355f46aeba2213cd8d067cac7643d8d96e713c93 Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11430 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h18
1 files changed, 17 insertions, 1 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 451a1e1416..91e28b58f7 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -54,7 +54,7 @@ public:
protected:
template <typename U>
- void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
+ void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f)
{
switch(tensor.data_type())
{
@@ -135,6 +135,22 @@ protected:
{
// Output Quantization not currently in use - fixture should be extended to support this.
ARM_COMPUTE_UNUSED(o_qinfo);
+ TensorShape src_shape = a_shape;
+ TensorShape updates_shape = b_shape;
+ TensorShape indices_shape = c_shape;
+
+ // 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor.
+ if(c_shape.num_dimensions() >= 3)
+ {
+ indices_shape = indices_shape.collapsed_from(1);
+ updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims
+ }
+
+ // 2. Collapse data dims into a single dim.
+ // Collapse all src dims into 2 dims. First one holding data, the other being the index we iterate over.
+ src_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse all data dims into single dim.
+ src_shape = src_shape.collapsed_from(1); // Collapse all index dims into a single dim
+ updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim)
// Create reference tensors
SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };