diff options
-rw-r--r-- | src/core/CL/cl_kernels/select.cl | 8 | ||||
-rw-r--r-- | tests/validation/CL/Select.cpp | 35 |
2 files changed, 38 insertions, 5 deletions
diff --git a/src/core/CL/cl_kernels/select.cl b/src/core/CL/cl_kernels/select.cl index 4d22d5bf07..6fd4bd4ce3 100644 --- a/src/core/CL/cl_kernels/select.cl +++ b/src/core/CL/cl_kernels/select.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -86,7 +86,7 @@ __kernel void select_same_rank( // Calculate result VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) - res0 = select(in_y, in_x, in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0); + res0 = select(in_y, in_x, CONVERT(in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0, SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))); // Boundary-aware store STORE_VECTOR_SELECT(res, DATA_TYPE, (__global DATA_TYPE *)out_addr, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0); @@ -152,7 +152,7 @@ __kernel void select_different_rank_2( // Calculate result VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) - res0 = select(in_y, in_x, in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0); + res0 = select(in_y, in_x, CONVERT(in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0, SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))); // Boundary-aware store STORE_VECTOR_SELECT(res, DATA_TYPE, (__global DATA_TYPE *)out_addr, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0); @@ -220,7 +220,7 @@ __kernel void select_different_rank_n( // Calculate result VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) - res0 = select(in_y, in_x, in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0); + res0 = select(in_y, in_x, CONVERT(in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0, SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))); // Boundary-aware store STORE_VECTOR_SELECT(res, DATA_TYPE, (__global DATA_TYPE *)out_addr, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0); diff --git a/tests/validation/CL/Select.cpp b/tests/validation/CL/Select.cpp index 3d7c61aab5..d3540cae48 100644 --- a/tests/validation/CL/Select.cpp +++ b/tests/validation/CL/Select.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -107,6 +107,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall, validate(CLAccessor(_target), _reference); } +FIXTURE_DATA_TEST_CASE(RunOneDim, + CLSelectFixture<half>, + framework::DatasetMode::PRECOMMIT, + combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)), + framework::dataset::make("has_same_rank", { false, true })), + framework::dataset::make("DataType", DataType::F16))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + FIXTURE_DATA_TEST_CASE(RunLarge, CLSelectFixture<half>, framework::DatasetMode::NIGHTLY, @@ -127,6 +138,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall, validate(CLAccessor(_target), _reference); } +FIXTURE_DATA_TEST_CASE(RunOneDim, + CLSelectFixture<float>, + framework::DatasetMode::PRECOMMIT, + combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)), + framework::dataset::make("has_same_rank", { false, true })), + framework::dataset::make("DataType", DataType::F32))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + FIXTURE_DATA_TEST_CASE(RunLarge, CLSelectFixture<float>, framework::DatasetMode::NIGHTLY, @@ -149,6 +171,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall, validate(CLAccessor(_target), _reference); } +FIXTURE_DATA_TEST_CASE(RunOneDim, + CLSelectFixture<uint8_t>, + framework::DatasetMode::PRECOMMIT, + combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)), + framework::dataset::make("has_same_rank", { false, true })), + framework::dataset::make("DataType", DataType::QASYMM8))) +{ + // Validate output + validate(CLAccessor(_target), _reference); +} + FIXTURE_DATA_TEST_CASE(RunLarge, CLSelectFixture<uint8_t>, framework::DatasetMode::NIGHTLY, |