diff options
author | Gunes Bayir <gunes.bayir@arm.com> | 2024-05-09 13:24:15 +0100 |
---|---|---|
committer | Suhail M <MohammedSuhail.Munshi@arm.com> | 2024-05-10 11:01:30 +0000 |
commit | 05269f013cf2b7c4a53f5950cdd6bfea26367769 (patch) | |
tree | bc620cce08a1569c311ce0fd9d74833e5ff63382 /tests/datasets/ScatterDataset.h | |
parent | 48f120c64c21d983318c6e65f6d5609a8f8e92e6 (diff) | |
download | ComputeLibrary-05269f013cf2b7c4a53f5950cdd6bfea26367769.tar.gz |
ScatterND fix for scalar cases
- Padding with batched scalar cases is unsupported, adds checks.
- Adds tests for scalar cases, without padding.
Resolves: [COMPMID-7015]
Change-Id: Ib9cf5db990420ff4b442d003ef9424e365bee86d
Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11536
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/datasets/ScatterDataset.h')
-rw-r--r-- | tests/datasets/ScatterDataset.h | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h index 4ad269ec85..8fd4448d2d 100644 --- a/tests/datasets/ScatterDataset.h +++ b/tests/datasets/ScatterDataset.h @@ -180,7 +180,6 @@ public: // NOTE: Updates/Indices tensors are now batched. // NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1 // k is the number of batch dimensions - // k = 2 add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U)); add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(5U, 5U, 6U, 2U), TensorShape(3U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U)); @@ -197,6 +196,18 @@ public: } }; +class SmallScatterScalarDataset final : public ScatterDataset +{ +public: + // batched scalar case + SmallScatterScalarDataset() + { + add_config(TensorShape(6U, 5U), TensorShape(6U), TensorShape(2U, 6U), TensorShape(6U, 5U)); + add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U)); + add_config(TensorShape(3U, 3U, 6U, 5U), TensorShape(6U, 6U), TensorShape(4U, 6U, 6U), TensorShape(3U, 3U, 6U, 5U)); + } +}; + // This dataset is for data types that does not require full testing. It contains selected tests from the above. class SmallScatterMixedDataset final : public ScatterDataset { @@ -205,6 +216,7 @@ public: { add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U)); add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U)); + add_config(TensorShape(6U, 5U), TensorShape(6U, 6U), TensorShape(2U, 6U, 6U), TensorShape(6U, 5U)); add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U)); add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U)); add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U)); |