aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLReductionOperationKernel.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-11-22 17:36:28 +0000
committerMichalis Spyrou <michalis.spyrou@arm.com>2018-11-30 15:46:49 +0000
commit7930db48e12dd3a14c1971f41f5b83527efea281 (patch)
treed17899ba82203423320bfa8d2dea1e07b045c898 /src/core/CL/kernels/CLReductionOperationKernel.cpp
parent95abfddfa08ab85d4f88c6f4d2e077969178f2d5 (diff)
downloadComputeLibrary-7930db48e12dd3a14c1971f41f5b83527efea281.tar.gz
COMPMID-1728 CL: Implement ArgMax/ArgMin
Change-Id: I7eae2e55cc0b0b7bbebb7617299daaca6f75f40c Reviewed-on: https://review.mlplatform.org/292 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLReductionOperationKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLReductionOperationKernel.cpp36
1 files changed, 27 insertions, 9 deletions
diff --git a/src/core/CL/kernels/CLReductionOperationKernel.cpp b/src/core/CL/kernels/CLReductionOperationKernel.cpp
index ef46325e4d..f6dc4a8806 100644
--- a/src/core/CL/kernels/CLReductionOperationKernel.cpp
+++ b/src/core/CL/kernels/CLReductionOperationKernel.cpp
@@ -53,19 +53,29 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
if(output->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
+ if(op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QASYMM8, "Not supported operation for QASYMM8");
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
}
return Status{};
}
-std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis)
+std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
{
// Output tensor auto initialization if not yet initialized
TensorShape output_shape{ input->tensor_shape() };
output_shape.set(axis, 1);
- auto_init_if_empty(*output, output_shape, 1, input->data_type());
+ const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
+ DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
+ auto_init_if_empty(*output, output_shape, 1, output_data_type);
const unsigned int num_elems_processed_per_iteration = (is_data_type_quantized(input->data_type()) && (axis == 0)) ? 1 : 16;
Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
@@ -136,7 +146,7 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
// Set build options
CLBuildOptions build_opts;
std::string data_type_promoted = get_cl_type_from_data_type(input->info()->data_type());
- if(is_data_type_quantized(input->info()->data_type()) && axis != 0)
+ if(is_data_type_quantized(input->info()->data_type()))
{
data_type_promoted = "uint";
}
@@ -144,6 +154,8 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
build_opts.add_option("-DDATA_TYPE_PROMOTED=" + data_type_promoted);
build_opts.add_option_if(op == ReductionOperation::SUM_SQUARE, "-DSUM_SQUARE=");
build_opts.add_option_if(op == ReductionOperation::MEAN_SUM, "-DMEAN");
+ build_opts.add_option_if(op == ReductionOperation::ARG_IDX_MAX, "-DARG_MAX");
+ build_opts.add_option_if(op == ReductionOperation::ARG_IDX_MIN, "-DARG_MIN");
switch(op)
{
@@ -154,6 +166,9 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
case ReductionOperation::MEAN_SUM:
build_opts.add_option(("-DOPERATION=sum"));
break;
+ case ReductionOperation::ARG_IDX_MAX:
+ case ReductionOperation::ARG_IDX_MIN:
+ break;
default:
ARM_COMPUTE_ERROR("Unsupported reduction operation");
}
@@ -161,11 +176,12 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
// Create kernel
cl::NDRange lws_hint = CLKernelLibrary::get().default_ndrange();
std::string kernel_axis_name;
+ const bool is_arg_op = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
switch(axis)
{
case 0:
{
- if(!is_data_type_quantized(input->info()->data_type()))
+ if(!is_data_type_quantized(input->info()->data_type()) && !is_arg_op)
{
build_opts.add_option_if(op == ReductionOperation::MEAN_SUM, "-DWIDTH=" + support::cpp11::to_string(width));
const unsigned int width_leftover = input->info()->dimension(0) % border_val;
@@ -181,7 +197,8 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
else
{
build_opts.add_option("-DWIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
- kernel_axis_name = "quantized_x";
+ build_opts.add_option_if_else(_input->info()->data_type() == DataType::F32, "-DCOND_DATA_TYPE=int", "-DCOND_DATA_TYPE=short");
+ kernel_axis_name = "non_parallel_x";
}
}
break;
@@ -204,7 +221,7 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("reduction_operation_" + kernel_axis_name, build_opts.options()));
// Configure kernel window
- auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis);
+ auto win_config = validate_and_configure_window(_input->info(), _output->info(), axis, op);
ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
@@ -214,7 +231,7 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
Status CLReductionOperationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op, unsigned int width)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, axis, op, width));
- ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis)));
+ ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), axis, op)));
return Status{};
}
@@ -224,12 +241,13 @@ void CLReductionOperationKernel::run(const Window &window, cl::CommandQueue &que
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
+ const bool is_arg_op = (_op == ReductionOperation::ARG_IDX_MAX || _op == ReductionOperation::ARG_IDX_MIN);
switch(_reduction_axis)
{
case 0:
{
// We use parallel reduction only in non quantized types
- if(!is_data_type_quantized(_input->info()->data_type()))
+ if(!is_data_type_quantized(_input->info()->data_type()) && !is_arg_op)
{
// Set out window
Window out_window(window);