aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScatterLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h18
1 files changed, 17 insertions, 1 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 451a1e1416..91e28b58f7 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -54,7 +54,7 @@ public:
protected:
template <typename U>
- void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
+ void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f)
{
switch(tensor.data_type())
{
@@ -135,6 +135,22 @@ protected:
{
// Output Quantization not currently in use - fixture should be extended to support this.
ARM_COMPUTE_UNUSED(o_qinfo);
+ TensorShape src_shape = a_shape;
+ TensorShape updates_shape = b_shape;
+ TensorShape indices_shape = c_shape;
+
+ // 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor.
+ if(c_shape.num_dimensions() >= 3)
+ {
+ indices_shape = indices_shape.collapsed_from(1);
+ updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims
+ }
+
+ // 2. Collapse data dims into a single dim.
+ // Collapse all src dims into 2 dims. First one holding data, the other being the index we iterate over.
+ src_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse all data dims into single dim.
+ src_shape = src_shape.collapsed_from(1); // Collapse all index dims into a single dim
+ updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim)
// Create reference tensors
SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };