diff options
Diffstat (limited to 'tests/validation/CL/Select.cpp')
-rw-r--r-- | tests/validation/CL/Select.cpp | 122 |
1 files changed, 31 insertions, 91 deletions
diff --git a/tests/validation/CL/Select.cpp b/tests/validation/CL/Select.cpp index f366ce7d9a..d3540cae48 100644 --- a/tests/validation/CL/Select.cpp +++ b/tests/validation/CL/Select.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -98,41 +98,21 @@ using CLSelectFixture = SelectValidationFixture<CLTensor, CLAccessor, CLSelect, TEST_SUITE(Float) TEST_SUITE(F16) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, run_small_dataset, - shape, same_rank) +FIXTURE_DATA_TEST_CASE(RunSmall, + CLSelectFixture<half>, + framework::DatasetMode::PRECOMMIT, + combine(run_small_dataset, framework::dataset::make("DataType", DataType::F16))) { - const DataType dt = DataType::F16; - - // Create tensors - CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8); - CLTensor ref_x = create_tensor<CLTensor>(shape, dt); - CLTensor ref_y = create_tensor<CLTensor>(shape, dt); - CLTensor dst = create_tensor<CLTensor>(shape, dt); - - // Create and Configure function - CLSelect select; - select.configure(&ref_c, &ref_x, &ref_y, &dst); - - // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(shape); - validate(dst.info()->valid_region(), valid_region); - - // Validate padding - const int step = 16 / arm_compute::data_size_from_type(dt); - const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding(); - if(same_rank) - { - validate(ref_c.info()->padding(), padding); - } - validate(ref_x.info()->padding(), padding); - validate(ref_y.info()->padding(), padding); - validate(dst.info()->padding(), padding); + // Validate output + validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunSmall, +FIXTURE_DATA_TEST_CASE(RunOneDim, CLSelectFixture<half>, framework::DatasetMode::PRECOMMIT, - combine(run_small_dataset, framework::dataset::make("DataType", DataType::F16))) + combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)), + framework::dataset::make("has_same_rank", { false, true })), + framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(CLAccessor(_target), _reference); @@ -149,41 +129,21 @@ FIXTURE_DATA_TEST_CASE(RunLarge, TEST_SUITE_END() // F16 TEST_SUITE(FP32) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, run_small_dataset, - shape, same_rank) +FIXTURE_DATA_TEST_CASE(RunSmall, + CLSelectFixture<float>, + framework::DatasetMode::PRECOMMIT, + combine(run_small_dataset, framework::dataset::make("DataType", DataType::F32))) { - const DataType dt = DataType::F32; - - // Create tensors - CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8); - CLTensor ref_x = create_tensor<CLTensor>(shape, dt); - CLTensor ref_y = create_tensor<CLTensor>(shape, dt); - CLTensor dst = create_tensor<CLTensor>(shape, dt); - - // Create and Configure function - CLSelect select; - select.configure(&ref_c, &ref_x, &ref_y, &dst); - - // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(shape); - validate(dst.info()->valid_region(), valid_region); - - // Validate padding - const int step = 16 / arm_compute::data_size_from_type(dt); - const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding(); - if(same_rank) - { - validate(ref_c.info()->padding(), padding); - } - validate(ref_x.info()->padding(), padding); - validate(ref_y.info()->padding(), padding); - validate(dst.info()->padding(), padding); + // Validate output + validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunSmall, +FIXTURE_DATA_TEST_CASE(RunOneDim, CLSelectFixture<float>, framework::DatasetMode::PRECOMMIT, - combine(run_small_dataset, framework::dataset::make("DataType", DataType::F32))) + combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)), + framework::dataset::make("has_same_rank", { false, true })), + framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(CLAccessor(_target), _reference); @@ -202,41 +162,21 @@ TEST_SUITE_END() // Float TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, run_small_dataset, - shape, same_rank) +FIXTURE_DATA_TEST_CASE(RunSmall, + CLSelectFixture<uint8_t>, + framework::DatasetMode::PRECOMMIT, + combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8))) { - const DataType dt = DataType::QASYMM8; - - // Create tensors - CLTensor ref_c = create_tensor<CLTensor>(detail::select_condition_shape(shape, same_rank), DataType::U8); - CLTensor ref_x = create_tensor<CLTensor>(shape, dt); - CLTensor ref_y = create_tensor<CLTensor>(shape, dt); - CLTensor dst = create_tensor<CLTensor>(shape, dt); - - // Create and Configure function - CLSelect select; - select.configure(&ref_c, &ref_x, &ref_y, &dst); - - // Validate valid region - const ValidRegion valid_region = shape_to_valid_region(shape); - validate(dst.info()->valid_region(), valid_region); - - // Validate padding - const int step = 16 / arm_compute::data_size_from_type(dt); - const PaddingSize padding = PaddingCalculator(shape.x(), step).required_padding(); - if(same_rank) - { - validate(ref_c.info()->padding(), padding); - } - validate(ref_x.info()->padding(), padding); - validate(ref_y.info()->padding(), padding); - validate(dst.info()->padding(), padding); + // Validate output + validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunSmall, +FIXTURE_DATA_TEST_CASE(RunOneDim, CLSelectFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, - combine(run_small_dataset, framework::dataset::make("DataType", DataType::QASYMM8))) + combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)), + framework::dataset::make("has_same_rank", { false, true })), + framework::dataset::make("DataType", DataType::QASYMM8))) { // Validate output validate(CLAccessor(_target), _reference); |