aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ReduceMeanFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ReduceMeanFixture.h')
-rw-r--r--tests/validation/fixtures/ReduceMeanFixture.h23
1 files changed, 13 insertions, 10 deletions
diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h
index 36bf14b27e..e61941435c 100644
--- a/tests/validation/fixtures/ReduceMeanFixture.h
+++ b/tests/validation/fixtures/ReduceMeanFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,7 +47,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ReduceMeanValidationFixture : public framework::Fixture
{
public:
- template <typename...>
void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
_target = compute_target(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output);
@@ -92,15 +91,15 @@ protected:
FunctionType reduction_mean;
reduction_mean.configure(&src, axis, keep_dims, &dst);
- ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(src.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
// Allocate tensors
src.allocator()->allocate();
dst.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
// Fill tensors
fill(AccessorType(src));
@@ -124,7 +123,13 @@ protected:
{
TensorShape output_shape = i == 0 ? src_shape : out.shape();
output_shape.set(axis[i], 1);
- out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, quantization_info_output);
+ bool is_opencl = false;
+
+#ifdef ARM_COMPUTE_OPENCL_ENABLED
+ is_opencl = std::is_same<CLTensor, TensorType>::value; // Round down to zero on opencl to match kernel
+#endif /* ARM_COMPUTE_OPENCL_ENABLED */
+ out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, data_type, quantization_info_output,
+ is_opencl ? RoundingPolicy::TO_ZERO : RoundingPolicy::TO_NEAREST_UP);
}
if(!keep_dims)
@@ -133,7 +138,7 @@ protected:
std::sort(axis.begin(), axis.begin() + axis.num_dimensions());
for(unsigned int i = 0; i < axis.num_dimensions(); ++i)
{
- output_shape.remove_dimension(axis[i] - i);
+ output_shape.remove_dimension(axis[i] - i, false);
}
out = reference::reshape_layer(out, output_shape);
@@ -149,7 +154,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ReduceMeanQuantizedFixture : public ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- template <typename...>
void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output);
@@ -160,7 +164,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ReduceMeanFixture : public ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- template <typename...>
void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims)
{
ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, QuantizationInfo(), QuantizationInfo());