From e55a013bfdd8238addad8449daa1fb91378eadae Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 26 Oct 2018 10:48:56 +0100 Subject: COMPMID-1451: Fix validation issue in CLReduceMean Change-Id: Ie1bcdd9dca2dc3b26003790a19cc80bb953385b2 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/155373 Reviewed-by: Georgios Pinitas Tested-by: bsgcomp --- src/runtime/CL/functions/CLReduceMean.cpp | 2 +- src/runtime/CL/functions/CLReductionOperation.cpp | 23 ++++++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) (limited to 'src/runtime/CL') diff --git a/src/runtime/CL/functions/CLReduceMean.cpp b/src/runtime/CL/functions/CLReduceMean.cpp index 02e341a35c..1016ff76ea 100644 --- a/src/runtime/CL/functions/CLReduceMean.cpp +++ b/src/runtime/CL/functions/CLReduceMean.cpp @@ -103,7 +103,7 @@ Status CLReduceMean::validate(const ITensorInfo *input, const Coordinates &reduc ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(reduction_axis[i]) != 1); } - ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM, 0)); + ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperation::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM)); } return Status{}; diff --git a/src/runtime/CL/functions/CLReductionOperation.cpp b/src/runtime/CL/functions/CLReductionOperation.cpp index 52a5d91cb8..c5447ffd6b 100644 --- a/src/runtime/CL/functions/CLReductionOperation.cpp +++ b/src/runtime/CL/functions/CLReductionOperation.cpp @@ -80,18 +80,35 @@ Status CLReductionOperation::validate(const ITensorInfo *input, const ITensorInf sums_vector[i].set_num_channels(input->num_channels()); } + ReductionOperation first_kernel_op; + ReductionOperation last_kernel_op; + switch(op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + first_kernel_op = ReductionOperation::SUM; + last_kernel_op = op; + break; + case ReductionOperation::SUM_SQUARE: + first_kernel_op = ReductionOperation::SUM_SQUARE; + last_kernel_op = ReductionOperation::SUM; + break; + default: + ARM_COMPUTE_ERROR("Not supported"); + } + // Validate ReductionOperation only on first kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, op)); + ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, first_kernel_op)); // Validate ReductionOperation on intermediate stages for(unsigned int i = 1; i < num_of_stages - 1; ++i) { - ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, op)); + ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, ReductionOperation::SUM)); } // Validate ReductionOperation on the last stage const unsigned int last_stage = num_of_stages - 1; - ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, op)); + ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, last_kernel_op, input->dimension(0))); } else { -- cgit v1.2.1