diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2018-10-26 10:48:56 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:55:45 +0000 |
commit | e55a013bfdd8238addad8449daa1fb91378eadae (patch) | |
tree | f745cc4831e837914449010749dd8bd49cd2e579 /src/runtime/CL/functions | |
parent | d775cd796e9b74323047992003d8acd4e8bb5047 (diff) | |
download | ComputeLibrary-e55a013bfdd8238addad8449daa1fb91378eadae.tar.gz |
COMPMID-1451: Fix validation issue in CLReduceMean
Change-Id: Ie1bcdd9dca2dc3b26003790a19cc80bb953385b2
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/155373
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: bsgcomp <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions')
-rw-r--r-- | src/runtime/CL/functions/CLReduceMean.cpp | 2 | ||||
-rw-r--r-- | src/runtime/CL/functions/CLReductionOperation.cpp | 23 |
2 files changed, 21 insertions, 4 deletions
diff --git a/src/runtime/CL/functions/CLReduceMean.cpp b/src/runtime/CL/functions/CLReduceMean.cpp index 02e341a35c..1016ff76ea 100644 --- a/src/runtime/CL/functions/CLReduceMean.cpp +++ b/src/runtime/CL/functions/CLReduceMean.cpp @@ -103,7 +103,7 @@ Status CLReduceMean::validate(const ITensorInfo *input, const Coordinates &reduc ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(reduction_axis[i]) != 1); } - ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM, 0)); + ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperation::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM)); } return Status{}; diff --git a/src/runtime/CL/functions/CLReductionOperation.cpp b/src/runtime/CL/functions/CLReductionOperation.cpp index 52a5d91cb8..c5447ffd6b 100644 --- a/src/runtime/CL/functions/CLReductionOperation.cpp +++ b/src/runtime/CL/functions/CLReductionOperation.cpp @@ -80,18 +80,35 @@ Status CLReductionOperation::validate(const ITensorInfo *input, const ITensorInf sums_vector[i].set_num_channels(input->num_channels()); } + ReductionOperation first_kernel_op; + ReductionOperation last_kernel_op; + switch(op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + first_kernel_op = ReductionOperation::SUM; + last_kernel_op = op; + break; + case ReductionOperation::SUM_SQUARE: + first_kernel_op = ReductionOperation::SUM_SQUARE; + last_kernel_op = ReductionOperation::SUM; + break; + default: + ARM_COMPUTE_ERROR("Not supported"); + } + // Validate ReductionOperation only on first kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, op)); + ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, first_kernel_op)); // Validate ReductionOperation on intermediate stages for(unsigned int i = 1; i < num_of_stages - 1; ++i) { - ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, op)); + ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, ReductionOperation::SUM)); } // Validate ReductionOperation on the last stage const unsigned int last_stage = num_of_stages - 1; - ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, op)); + ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, last_kernel_op, input->dimension(0))); } else { |