diff options
Diffstat (limited to 'tests/datasets/GatherDataset.h')
-rw-r--r-- | tests/datasets/GatherDataset.h | 60 |
1 files changed, 59 insertions, 1 deletions
diff --git a/tests/datasets/GatherDataset.h b/tests/datasets/GatherDataset.h index 29a99d5239..74ea3b4a06 100644 --- a/tests/datasets/GatherDataset.h +++ b/tests/datasets/GatherDataset.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2022-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -106,6 +106,64 @@ private: std::vector<int> _axis{}; }; +class SmallGatherMultiDimIndicesDataset final : public GatherDataset +{ +public: + SmallGatherMultiDimIndicesDataset() + { + add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 1); + add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 1); + add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1); + add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1); + add_config(TensorShape(1U, 5U, 3U), TensorShape(1U, 7U, 3U), 1); + + add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 0); + add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 0); + add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 0); + + add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U), 2); + add_config(TensorShape(8U, 2U, 3U), TensorShape(4U, 2U, 5U), 2); + } +}; + +class CLSmallGatherMultiDimIndicesDataset final : public GatherDataset +{ +public: + CLSmallGatherMultiDimIndicesDataset() + { + add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 0); + add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 0); + add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 0); + add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 0); + + add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 0); + add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 0); + add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 0); + + add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),0); + + add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 1); + add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 1); + add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1); + add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1); + + add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 1); + add_config(TensorShape(9U), TensorShape(3U, 2U, 4U), 1); + add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 1); + + add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),1); + + add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 2); + add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 2); + add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 2); + + add_config(TensorShape(3U, 5U), TensorShape(2U, 3U), 2); + add_config(TensorShape(5U, 3U, 4U), TensorShape(5U, 6U), 2); + + add_config(TensorShape(7U, 4U, 5U), TensorShape(2U, 3U),2); + } +}; + class SmallGatherDataset final : public GatherDataset { public: |