aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-05-09 13:24:15 +0100
committerMichael Kozlov <michael.kozlov@arm.com>2024-05-13 16:13:57 +0100
commit623e2f51d2f4f445c5b99db8a4c600e1d155dc8e (patch)
treefed682eb794d05695ddd315109ec372f85957ba5 /tests/datasets
parentc1575b2c12a4cee3a60c711fe6521025a814b159 (diff)
downloadComputeLibrary-623e2f51d2f4f445c5b99db8a4c600e1d155dc8e.tar.gz
ScatterND fix for scalar cases
- Padding with batched scalar cases is unsupported, adds checks. - Adds tests for scalar cases, without padding. Resolves: [COMPMID-7015] Change-Id: Ib9cf5db990420ff4b442d003ef9424e365bee86d Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11536 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com> Signed-off-by: Michael Kozlov <michael.kozlov@arm.com>
Diffstat (limited to 'tests/datasets')
-rw-r--r--tests/datasets/ScatterDataset.h14
1 files changed, 13 insertions, 1 deletions
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index 4ad269ec85..8fd4448d2d 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -180,7 +180,6 @@ public:
// NOTE: Updates/Indices tensors are now batched.
// NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1
// k is the number of batch dimensions
-
// k = 2
add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U));
add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 5U, 6U, 2U), TensorShape(3U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
@@ -197,6 +196,18 @@ public:
}
};
+class SmallScatterScalarDataset final : public ScatterDataset
+{
+public:
+ // batched scalar case
+ SmallScatterScalarDataset()
+ {
+ add_config(TensorShape(6U, 5U), TensorShape(6U), TensorShape(2U, 6U), TensorShape(6U, 5U));
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U));
+ add_config(TensorShape(3U, 3U, 6U, 5U), TensorShape(6U, 6U), TensorShape(4U, 6U, 6U), TensorShape(3U, 3U, 6U, 5U));
+ }
+};
+
// This dataset is for data types that does not require full testing. It contains selected tests from the above.
class SmallScatterMixedDataset final : public ScatterDataset
{
@@ -205,6 +216,7 @@ public:
{
add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U));
add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U));
add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U));
add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U));