aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL
diff options
context:
space:
mode:
authorOmar Al Khatib <omar.alkhatib@arm.com>2023-04-26 11:31:45 +0100
committerOmar Al Khatib <omar.alkhatib@arm.com>2023-05-03 13:22:48 +0000
commitcdd1e039ad598aec10d8c1b81e08de9412324bf2 (patch)
tree344bfa6dc1e30604c6e67533eccb08a71e235fde /tests/validation/CL
parent911d5728fccdabbdf41549c58f0266e49c2aeaf0 (diff)
downloadComputeLibrary-cdd1e039ad598aec10d8c1b81e08de9412324bf2.tar.gz
Support multi-dimensional indices in the CL Gather Layer up to four-dimensional output tensors
Resolves [COMPMID-5775] Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com> Change-Id: I6f6c12ac08f0b0ad070ca5d715c531c2c3762c30 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9498 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/CL')
-rw-r--r--tests/validation/CL/Gather.cpp61
1 files changed, 51 insertions, 10 deletions
diff --git a/tests/validation/CL/Gather.cpp b/tests/validation/CL/Gather.cpp
index f0b87d7d9f..7619baae1e 100644
--- a/tests/validation/CL/Gather.cpp
+++ b/tests/validation/CL/Gather.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -48,19 +48,21 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 27U), 1, DataType::F16),
TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices data type
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices dimensionality
- TensorInfo(TensorShape(5U, 5U, 5U, 5U, 5U), 1, DataType::F32), // Invalid Input dimensionality
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Mismatching data type input/output
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid positive axis value
- TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Invalid negative axis value
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Output shape
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices data type
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid Indices dimensionality
+ TensorInfo(TensorShape(5U, 5U, 5U, 5U, 5U), 1, DataType::F32), // Invalid Input dimensionality
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Mismatching data type input/output
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F32), // Invalid positive axis value
+ TensorInfo(TensorShape(27U, 27U), 1, DataType::F16), // Invalid negative axis value
}),
framework::dataset::make("IndicesInfo", {
TensorInfo(TensorShape(10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
- TensorInfo(TensorShape(10U), 1, DataType::U8),
TensorInfo(TensorShape(10U, 10U), 1, DataType::U32),
+ TensorInfo(TensorShape(10U), 1, DataType::U8),
+ TensorInfo(TensorShape(10U, 10U, 10U, 10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
TensorInfo(TensorShape(10U), 1, DataType::U32),
@@ -71,7 +73,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
- TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
+ TensorInfo(TensorShape(10U, 27U), 1, DataType::F32),
+ TensorInfo(TensorShape(27U, 10U, 10U, 10U, 10U), 1, DataType::F32),
TensorInfo(TensorShape(10U, 5U, 5U, 5U, 5U), 1, DataType::F32),
TensorInfo(TensorShape(27U, 10U), 1, DataType::F32),
TensorInfo(TensorShape(27U, 27U), 1, DataType::F32),
@@ -82,13 +85,14 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
1,
-2,
0,
+ 0,
1,
0,
1,
2,
-3,
})),
- framework::dataset::make("Expected", { true, true, true, false, false, false, false, false, false })),
+ framework::dataset::make("Expected", { true, true, true, false, false, false, false, false, false, false })),
input_info, indices_info, output_info, axis, expected)
{
const Status status = CLGather::validate(&input_info.clone()->set_is_resizable(true), &indices_info.clone()->set_is_resizable(true), &output_info.clone()->set_is_resizable(true), axis);
@@ -111,6 +115,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+ CLGatherFixture<half>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLGatherFixture<half>,
framework::DatasetMode::NIGHTLY,
@@ -131,6 +144,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+ CLGatherFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLGatherFixture<float>,
framework::DatasetMode::NIGHTLY,
@@ -152,6 +174,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+ CLGatherFixture<uint8_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::U8)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLGatherFixture<uint8_t>,
framework::DatasetMode::NIGHTLY,
@@ -172,6 +203,16 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDimIndices,
+ CLGatherFixture<uint16_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(datasets::CLSmallGatherMultiDimIndicesDataset(), framework::dataset::make("DataType", DataType::U16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLGatherFixture<uint16_t>,
framework::DatasetMode::NIGHTLY,