aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLReductionOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLReductionOperationKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLReductionOperationKernel.cpp18
1 files changed, 8 insertions, 10 deletions
diff --git a/src/core/CL/kernels/CLReductionOperationKernel.cpp b/src/core/CL/kernels/CLReductionOperationKernel.cpp
index 8e92b591d1..a085ab1683 100644
--- a/src/core/CL/kernels/CLReductionOperationKernel.cpp
+++ b/src/core/CL/kernels/CLReductionOperationKernel.cpp
@@ -33,6 +33,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "support/ToolchainSupport.h"
@@ -80,17 +81,15 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u
std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
{
// Output tensor auto initialization if not yet initialized
- TensorShape output_shape{ input->tensor_shape() };
- output_shape.set(axis, 1);
- const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
- DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
+ const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
+ const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis, !is_arg_min_max);
+ const DataType output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
const unsigned int num_elems_processed_per_iteration = (is_data_type_quantized(input->data_type()) && (axis == 0)) ? 1 : 16;
Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
bool window_changed = false;
- const bool is_serial_op = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN
- || op == ReductionOperation::MAX || is_data_type_quantized(input->data_type()));
+ const bool is_serial_op = needs_serialized_reduction(op, input->data_type(), axis);
switch(axis)
{
@@ -198,8 +197,8 @@ void CLReductionOperationKernel::configure(const ICLTensor *input, ICLTensor *ou
// Create kernel
cl::NDRange lws_hint = CLKernelLibrary::get().default_ndrange();
std::string kernel_axis_name;
- const bool is_serial_op = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX
- || is_data_type_quantized(input->info()->data_type()));
+ const bool is_serial_op = needs_serialized_reduction(_op, _input->info()->data_type(), _reduction_axis);
+
switch(axis)
{
case 0:
@@ -264,8 +263,7 @@ void CLReductionOperationKernel::run(const Window &window, cl::CommandQueue &que
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
- const bool is_serial_op = (_op == ReductionOperation::ARG_IDX_MAX || _op == ReductionOperation::ARG_IDX_MIN || _op == ReductionOperation::MIN || _op == ReductionOperation::MAX
- || is_data_type_quantized(_input->info()->data_type()));
+ const bool is_serial_op = needs_serialized_reduction(_op, _input->info()->data_type(), _reduction_axis);
switch(_reduction_axis)
{
case 0: