diff options
Diffstat (limited to 'src/core/CL')
-rw-r--r-- | src/core/CL/kernels/CLReductionOperationKernel.cpp | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/src/core/CL/kernels/CLReductionOperationKernel.cpp b/src/core/CL/kernels/CLReductionOperationKernel.cpp index 8e92b591d1..a085ab1683 100644 --- a/src/core/CL/kernels/CLReductionOperationKernel.cpp +++ b/src/core/CL/kernels/CLReductionOperationKernel.cpp @@ -33,6 +33,7 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "support/ToolchainSupport.h" @@ -80,17 +81,15 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op) { // Output tensor auto initialization if not yet initialized - TensorShape output_shape{ input->tensor_shape() }; - output_shape.set(axis, 1); - const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX); - DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type(); + const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX); + const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis, !is_arg_min_max); + const DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type(); auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true)); const unsigned int num_elems_processed_per_iteration = (is_data_type_quantized(input->data_type()) && (axis == 0)) ? 1 : 16; Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration)); bool window_changed = false; - const bool is_serial_op = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN - || op == ReductionOperation::MAX || is_data_type_quantized(input->data_type())); + const bool is_serial_op = needs_serialized_reduction(op, input->data_type(), axis); switch(axis) { @@ -198,8 +197,8 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou // Create kernel cl::NDRange lws_hint = CLKernelLibrary::get().default_ndrange(); std::string kernel_axis_name; - const bool is_serial_op = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX - || is_data_type_quantized(input->info()->data_type())); + const bool is_serial_op = needs_serialized_reduction(_op, _input->info()->data_type(), _reduction_axis); + switch(axis) { case 0: @@ -264,8 +263,7 @@ void CLReductionOperationKernel::run(const Window &window, cl::CommandQueue &que ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); - const bool is_serial_op = (_op == ReductionOperation::ARG_IDX_MAX || _op == ReductionOperation::ARG_IDX_MIN || _op == ReductionOperation::MIN || _op == ReductionOperation::MAX - || is_data_type_quantized(_input->info()->data_type())); + const bool is_serial_op = needs_serialized_reduction(_op, _input->info()->data_type(), _reduction_axis); switch(_reduction_axis) { case 0: |