aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CL/Select.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/CL/Select.cpp')
-rw-r--r--tests/validation/CL/Select.cpp122
1 files changed, 31 insertions, 91 deletions
diff --git a/tests/validation/CL/Select.cpp b/tests/validation/CL/Select.cpp
index f366ce7d9a..d3540cae48 100644
--- a/tests/validation/CL/Select.cpp
+++ b/tests/validation/CL/Select.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -98,41 +98,21 @@ using CLSelectFixture = SelectValidationFixture<CLTensor, CLAccessor, CLSelect,
TEST_SUITE(Float)
TEST_SUITE(F16)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, run_small_dataset,
- shape, same_rank)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ CLSelectFixture<half>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(run_small_dataset, framework::dataset::make("DataType", DataType::F16)))
{
- const DataType dt = DataType::F16;
-
- // Create tensors
- CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8);
- CLTensor ref_x = create_tensor<CLTensor>(shape, dt);
- CLTensor ref_y = create_tensor<CLTensor>(shape, dt);
- CLTensor dst = create_tensor<CLTensor>(shape, dt);
-
- // Create and Configure function
- CLSelect select;
- select.configure(&ref_c, &ref_x, &ref_y, &dst);
-
- // Validate valid region
- const ValidRegion valid_region = shape_to_valid_region(shape);
- validate(dst.info()->valid_region(), valid_region);
-
- // Validate padding
- const int step = 16 / arm_compute::data_size_from_type(dt);
- const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding();
- if(same_rank)
- {
- validate(ref_c.info()->padding(), padding);
- }
- validate(ref_x.info()->padding(), padding);
- validate(ref_y.info()->padding(), padding);
- validate(dst.info()->padding(), padding);
+ // Validate output
+ validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunSmall,
+FIXTURE_DATA_TEST_CASE(RunOneDim,
CLSelectFixture<half>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::F16)))
+ combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)),
+ framework::dataset::make("has_same_rank", { false, true })),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -149,41 +129,21 @@ FIXTURE_DATA_TEST_CASE(RunLarge,
TEST_SUITE_END() // F16
TEST_SUITE(FP32)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, run_small_dataset,
- shape, same_rank)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ CLSelectFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(run_small_dataset, framework::dataset::make("DataType", DataType::F32)))
{
- const DataType dt = DataType::F32;
-
- // Create tensors
- CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8);
- CLTensor ref_x = create_tensor<CLTensor>(shape, dt);
- CLTensor ref_y = create_tensor<CLTensor>(shape, dt);
- CLTensor dst = create_tensor<CLTensor>(shape, dt);
-
- // Create and Configure function
- CLSelect select;
- select.configure(&ref_c, &ref_x, &ref_y, &dst);
-
- // Validate valid region
- const ValidRegion valid_region = shape_to_valid_region(shape);
- validate(dst.info()->valid_region(), valid_region);
-
- // Validate padding
- const int step = 16 / arm_compute::data_size_from_type(dt);
- const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding();
- if(same_rank)
- {
- validate(ref_c.info()->padding(), padding);
- }
- validate(ref_x.info()->padding(), padding);
- validate(ref_y.info()->padding(), padding);
- validate(dst.info()->padding(), padding);
+ // Validate output
+ validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunSmall,
+FIXTURE_DATA_TEST_CASE(RunOneDim,
CLSelectFixture<float>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::F32)))
+ combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)),
+ framework::dataset::make("has_same_rank", { false, true })),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -202,41 +162,21 @@ TEST_SUITE_END() // Float
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, run_small_dataset,
- shape, same_rank)
+FIXTURE_DATA_TEST_CASE(RunSmall,
+ CLSelectFixture<uint8_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
{
- const DataType dt = DataType::QASYMM8;
-
- // Create tensors
- CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8);
- CLTensor ref_x = create_tensor<CLTensor>(shape, dt);
- CLTensor ref_y = create_tensor<CLTensor>(shape, dt);
- CLTensor dst = create_tensor<CLTensor>(shape, dt);
-
- // Create and Configure function
- CLSelect select;
- select.configure(&ref_c, &ref_x, &ref_y, &dst);
-
- // Validate valid region
- const ValidRegion valid_region = shape_to_valid_region(shape);
- validate(dst.info()->valid_region(), valid_region);
-
- // Validate padding
- const int step = 16 / arm_compute::data_size_from_type(dt);
- const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding();
- if(same_rank)
- {
- validate(ref_c.info()->padding(), padding);
- }
- validate(ref_x.info()->padding(), padding);
- validate(ref_y.info()->padding(), padding);
- validate(dst.info()->padding(), padding);
+ // Validate output
+ validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunSmall,
+FIXTURE_DATA_TEST_CASE(RunOneDim,
CLSelectFixture<uint8_t>,
framework::DatasetMode::PRECOMMIT,
- combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8)))
+ combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)),
+ framework::dataset::make("has_same_rank", { false, true })),
+ framework::dataset::make("DataType", DataType::QASYMM8)))
{
// Validate output
validate(CLAccessor(_target), _reference);