aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLReductionOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLReductionOperationKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLReductionOperationKernel.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/src/core/CL/kernels/CLReductionOperationKernel.cpp b/src/core/CL/kernels/CLReductionOperationKernel.cpp
index db4850f14e..cb57070612 100644
--- a/src/core/CL/kernels/CLReductionOperationKernel.cpp
+++ b/src/core/CL/kernels/CLReductionOperationKernel.cpp
@@ -49,7 +49,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
if(input->num_channels() == 1)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
}
else
{
@@ -160,8 +160,10 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
{
data_type_promoted = "uint";
}
+
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
build_opts.add_option("-DDATA_TYPE_PROMOTED=" + data_type_promoted);
+ build_opts.add_option_if(is_data_type_float(input->info()->data_type()), "-DFLOAT_DATA_TYPE");
build_opts.add_option_if(op == ReductionOperation::SUM_SQUARE, "-DSUM_SQUARE");
build_opts.add_option_if(op == ReductionOperation::MEAN_SUM, "-DMEAN");
build_opts.add_option_if(op == ReductionOperation::ARG_IDX_MAX, "-DARG_MAX");
@@ -199,7 +201,7 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
if(is_serial_op)
{
build_opts.add_option("-DWIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
- build_opts.add_option_if_else(_input->info()->data_type() == DataType::F32, "-DCOND_DATA_TYPE=int", "-DCOND_DATA_TYPE=short");
+ build_opts.add_option_if_else(_input->info()->data_type() == DataType::F16, "-DCOND_DATA_TYPE=short", "-DCOND_DATA_TYPE=int");
kernel_axis_name = "non_parallel_x";
}
else