aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2024-04-29 22:53:58 +0100
committerSuhail M <MohammedSuhail.Munshi@arm.com>2024-05-08 12:07:31 +0000
commit2fea13593a4753316ae488edf489cb4b00150153 (patch)
tree423e6369a74c44b505dd8fd4d62bde0946ec2e32 /tests/validation/fixtures
parentc22e1263ba3a6945ceb1fdccb33eac512fd156fb (diff)
downloadComputeLibrary-2fea13593a4753316ae488edf489cb4b00150153.tar.gz
Add batched indices support to Scatter GPU Implementation
Resolves: [COMPMID-6897] Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Change-Id: I70b1c3c5f0de8484fcb6c3b0cc0d0d8c059b0f58 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11525 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h15
1 files changed, 8 insertions, 7 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index 35e6b647f3..5cd9b8115c 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -103,7 +103,7 @@ protected:
void fill_indices(U &&tensor, int i, const TensorShape &shape)
{
// Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension)
- const int32_t max = std::max({shape[0] , shape[1], shape[2]});
+ const int32_t max = std::min({shape[0] , shape[1], shape[2]}) + 1;
library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
}
@@ -197,12 +197,13 @@ protected:
TensorShape src_shape = a_shape;
TensorShape updates_shape = b_shape;
TensorShape indices_shape = c_shape;
+ const int num_ind_dims = c_shape.num_dimensions();
// 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor.
- if(c_shape.num_dimensions() >= 3)
+ if(num_ind_dims >= 3)
{
indices_shape = indices_shape.collapsed_from(1);
- updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims
+ updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - (num_ind_dims -1)); // Collapses batch dims
}
// 2. Collapse data dims into a single dim.
@@ -212,16 +213,16 @@ protected:
updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim)
// Create reference tensors
- SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };
- SimpleTensor<T> updates{b_shape, data_type, 1, QuantizationInfo() };
- SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
+ SimpleTensor<T> src{ src_shape, data_type, 1, a_qinfo };
+ SimpleTensor<T> updates{updates_shape, data_type, 1, QuantizationInfo() };
+ SimpleTensor<int32_t> indices{ indices_shape, DataType::S32, 1, QuantizationInfo() };
// Fill reference
fill(src, 0 + _hash);
fill(updates, 1 + _hash);
fill_indices(indices, 2 + _hash, out_shape);
- // Calculate individual reference.
+ // Calculate individual reference using collapsed shapes
return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
}