aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-10-26 10:48:56 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:55:45 +0000
commite55a013bfdd8238addad8449daa1fb91378eadae (patch)
treef745cc4831e837914449010749dd8bd49cd2e579
parentd775cd796e9b74323047992003d8acd4e8bb5047 (diff)
downloadComputeLibrary-e55a013bfdd8238addad8449daa1fb91378eadae.tar.gz
COMPMID-1451: Fix validation issue in CLReduceMean
Change-Id: Ie1bcdd9dca2dc3b26003790a19cc80bb953385b2 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/155373 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: bsgcomp <bsgcomp@arm.com>
-rw-r--r--src/runtime/CL/functions/CLReduceMean.cpp2
-rw-r--r--src/runtime/CL/functions/CLReductionOperation.cpp23
2 files changed, 21 insertions, 4 deletions
diff --git a/src/runtime/CL/functions/CLReduceMean.cpp b/src/runtime/CL/functions/CLReduceMean.cpp
index 02e341a35c..1016ff76ea 100644
--- a/src/runtime/CL/functions/CLReduceMean.cpp
+++ b/src/runtime/CL/functions/CLReduceMean.cpp
@@ -103,7 +103,7 @@ Status CLReduceMean::validate(const ITensorInfo *input, const Coordinates &reduc
ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(reduction_axis[i]) != 1);
}
- ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM, 0));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperation::validate(input, output, reduction_axis[i], ReductionOperation::MEAN_SUM));
}
return Status{};
diff --git a/src/runtime/CL/functions/CLReductionOperation.cpp b/src/runtime/CL/functions/CLReductionOperation.cpp
index 52a5d91cb8..c5447ffd6b 100644
--- a/src/runtime/CL/functions/CLReductionOperation.cpp
+++ b/src/runtime/CL/functions/CLReductionOperation.cpp
@@ -80,18 +80,35 @@ Status CLReductionOperation::validate(const ITensorInfo *input, const ITensorInf
sums_vector[i].set_num_channels(input->num_channels());
}
+ ReductionOperation first_kernel_op;
+ ReductionOperation last_kernel_op;
+ switch(op)
+ {
+ case ReductionOperation::SUM:
+ case ReductionOperation::MEAN_SUM:
+ first_kernel_op = ReductionOperation::SUM;
+ last_kernel_op = op;
+ break;
+ case ReductionOperation::SUM_SQUARE:
+ first_kernel_op = ReductionOperation::SUM_SQUARE;
+ last_kernel_op = ReductionOperation::SUM;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ }
+
// Validate ReductionOperation only on first kernel
- ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, op));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, sums_vector.get(), axis, first_kernel_op));
// Validate ReductionOperation on intermediate stages
for(unsigned int i = 1; i < num_of_stages - 1; ++i)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, op));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + i - 1, sums_vector.get() + i, axis, ReductionOperation::SUM));
}
// Validate ReductionOperation on the last stage
const unsigned int last_stage = num_of_stages - 1;
- ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, op));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(sums_vector.get() + last_stage - 1, output, axis, last_kernel_op, input->dimension(0)));
}
else
{