aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets/ScatterDataset.h
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-05-09 13:24:15 +0100
committerSuhail M <MohammedSuhail.Munshi@arm.com>2024-05-10 11:01:30 +0000
commit05269f013cf2b7c4a53f5950cdd6bfea26367769 (patch)
treebc620cce08a1569c311ce0fd9d74833e5ff63382 /tests/datasets/ScatterDataset.h
parent48f120c64c21d983318c6e65f6d5609a8f8e92e6 (diff)
downloadComputeLibrary-05269f013cf2b7c4a53f5950cdd6bfea26367769.tar.gz
ScatterND fix for scalar cases
- Padding with batched scalar cases is unsupported, adds checks. - Adds tests for scalar cases, without padding. Resolves: [COMPMID-7015] Change-Id: Ib9cf5db990420ff4b442d003ef9424e365bee86d Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11536 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/datasets/ScatterDataset.h')
-rw-r--r--tests/datasets/ScatterDataset.h14
1 files changed, 13 insertions, 1 deletions
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index 4ad269ec85..8fd4448d2d 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -180,7 +180,6 @@ public:
// NOTE: Updates/Indices tensors are now batched.
// NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1
// k is the number of batch dimensions
-
// k = 2
add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U));
add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 5U, 6U, 2U), TensorShape(3U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
@@ -197,6 +196,18 @@ public:
}
};
+class SmallScatterScalarDataset final : public ScatterDataset
+{
+public:
+ // batched scalar case
+ SmallScatterScalarDataset()
+ {
+ add_config(TensorShape(6U, 5U), TensorShape(6U), TensorShape(2U, 6U), TensorShape(6U, 5U));
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U));
+ add_config(TensorShape(3U, 3U, 6U, 5U), TensorShape(6U, 6U), TensorShape(4U, 6U, 6U), TensorShape(3U, 3U, 6U, 5U));
+ }
+};
+
// This dataset is for data types that does not require full testing. It contains selected tests from the above.
class SmallScatterMixedDataset final : public ScatterDataset
{
@@ -205,6 +216,7 @@ public:
{
add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U));
add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U));
add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U));
add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U));