diff options
-rw-r--r-- | src/core/CL/cl_kernels/arg_min_max.cl | 8 | ||||
-rw-r--r-- | tests/validation/CL/ArgMinMax.cpp | 4 |
2 files changed, 6 insertions, 6 deletions
diff --git a/src/core/CL/cl_kernels/arg_min_max.cl b/src/core/CL/cl_kernels/arg_min_max.cl index 6ef0a61ac5..6e57ed0af1 100644 --- a/src/core/CL/cl_kernels/arg_min_max.cl +++ b/src/core/CL/cl_kernels/arg_min_max.cl @@ -38,8 +38,8 @@ #define ISGREATER(x, y) (x > y) ? 1 : 0 #define ISLESS(x, y) (x < y) ? 1 : 0 #else // !defined(WIDTH) -#define ISGREATER(x, y) select((VEC_SIGNED_INT_IN)0, (VEC_SIGNED_INT_IN)-1, x > y) -#define ISLESS(x, y) select((VEC_SIGNED_INT_IN)0, (VEC_SIGNED_INT_IN)-1, x < y) +#define ISGREATER(x, y) select((VEC_SIGNED_INT_IN)0, (VEC_SIGNED_INT_IN)-1, (VEC_SIGNED_INT_IN)(x > y)) +#define ISLESS(x, y) select((VEC_SIGNED_INT_IN)0, (VEC_SIGNED_INT_IN)-1, (VEC_SIGNED_INT_IN)(x < y)) #endif // defined(WIDTH) #endif // defined(FLOAT_DATA_TYPE) @@ -342,7 +342,7 @@ __kernel void arg_min_max_y( } #endif // defined(HEIGHT) -#if defined(DEPTH) +#if defined(DEPTH) && !defined(BATCH) /** This kernel performs reduction on z-axis. * * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float @@ -390,7 +390,7 @@ __kernel void arg_min_max_z( // Store result STORE_VECTOR_SELECT(indx, DATA_TYPE_OUTPUT, output_addr, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0); } -#endif /* defined(DEPTH) */ +#endif /* defined(DEPTH) && !defined(BATCH) */ #if defined(BATCH) && defined(DEPTH) /** This kernel performs reduction on w-axis. diff --git a/tests/validation/CL/ArgMinMax.cpp b/tests/validation/CL/ArgMinMax.cpp index 2508c63524..1d849ed0c7 100644 --- a/tests/validation/CL/ArgMinMax.cpp +++ b/tests/validation/CL/ArgMinMax.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -45,7 +45,7 @@ namespace { const auto ArgMinMaxSmallDataset = framework::dataset::make("Shape", { - TensorShape{ 2U, 7U, 1U, 3U }, + TensorShape{ 1U, 7U, 1U, 3U }, TensorShape{ 149U, 5U, 1U, 2U }, TensorShape{ 166U, 5U, 1U, 2U }, TensorShape{ 322U, 5U, 1U, 2U }, |