aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/CL/cl_kernels/select.cl8
-rw-r--r--tests/validation/CL/Select.cpp35
2 files changed, 38 insertions, 5 deletions
diff --git a/src/core/CL/cl_kernels/select.cl b/src/core/CL/cl_kernels/select.cl
index 4d22d5bf07..6fd4bd4ce3 100644
--- a/src/core/CL/cl_kernels/select.cl
+++ b/src/core/CL/cl_kernels/select.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -86,7 +86,7 @@ __kernel void select_same_rank(
// Calculate result
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
- res0 = select(in_y, in_x, in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0);
+ res0 = select(in_y, in_x, CONVERT(in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0, SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)));
// Boundary-aware store
STORE_VECTOR_SELECT(res, DATA_TYPE, (__global DATA_TYPE *)out_addr, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0);
@@ -152,7 +152,7 @@ __kernel void select_different_rank_2(
// Calculate result
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
- res0 = select(in_y, in_x, in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0);
+ res0 = select(in_y, in_x, CONVERT(in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0, SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)));
// Boundary-aware store
STORE_VECTOR_SELECT(res, DATA_TYPE, (__global DATA_TYPE *)out_addr, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0);
@@ -220,7 +220,7 @@ __kernel void select_different_rank_n(
// Calculate result
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
- res0 = select(in_y, in_x, in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0);
+ res0 = select(in_y, in_x, CONVERT(in_c > (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))0, SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)));
// Boundary-aware store
STORE_VECTOR_SELECT(res, DATA_TYPE, (__global DATA_TYPE *)out_addr, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0);
diff --git a/tests/validation/CL/Select.cpp b/tests/validation/CL/Select.cpp
index 3d7c61aab5..d3540cae48 100644
--- a/tests/validation/CL/Select.cpp
+++ b/tests/validation/CL/Select.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -107,6 +107,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunOneDim,
+ CLSelectFixture<half>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)),
+ framework::dataset::make("has_same_rank", { false, true })),
+ framework::dataset::make("DataType", DataType::F16)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLSelectFixture<half>,
framework::DatasetMode::NIGHTLY,
@@ -127,6 +138,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunOneDim,
+ CLSelectFixture<float>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)),
+ framework::dataset::make("has_same_rank", { false, true })),
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLSelectFixture<float>,
framework::DatasetMode::NIGHTLY,
@@ -149,6 +171,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
validate(CLAccessor(_target), _reference);
}
+FIXTURE_DATA_TEST_CASE(RunOneDim,
+ CLSelectFixture<uint8_t>,
+ framework::DatasetMode::PRECOMMIT,
+ combine(combine(framework::dataset::make("Shape", TensorShape(1U, 16U)),
+ framework::dataset::make("has_same_rank", { false, true })),
+ framework::dataset::make("DataType", DataType::QASYMM8)))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
FIXTURE_DATA_TEST_CASE(RunLarge,
CLSelectFixture<uint8_t>,
framework::DatasetMode::NIGHTLY,