diff options
Diffstat (limited to 'tests/validation/fixtures/ReduceMeanFixture.h')
-rw-r--r-- | tests/validation/fixtures/ReduceMeanFixture.h | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h index 304630e9f5..e61941435c 100644 --- a/tests/validation/fixtures/ReduceMeanFixture.h +++ b/tests/validation/fixtures/ReduceMeanFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -47,7 +47,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ReduceMeanValidationFixture : public framework::Fixture { public: - template <typename...> void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output) { _target = compute_target(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output); @@ -124,7 +123,13 @@ protected: { TensorShape output_shape = i == 0 ? src_shape : out.shape(); output_shape.set(axis[i], 1); - out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, quantization_info_output); + bool is_opencl = false; + +#ifdef ARM_COMPUTE_OPENCL_ENABLED + is_opencl = std::is_same<CLTensor, TensorType>::value; // Round down to zero on opencl to match kernel +#endif /* ARM_COMPUTE_OPENCL_ENABLED */ + out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, data_type, quantization_info_output, + is_opencl ? RoundingPolicy::TO_ZERO : RoundingPolicy::TO_NEAREST_UP); } if(!keep_dims) @@ -133,7 +138,7 @@ protected: std::sort(axis.begin(), axis.begin() + axis.num_dimensions()); for(unsigned int i = 0; i < axis.num_dimensions(); ++i) { - output_shape.remove_dimension(axis[i] - i); + output_shape.remove_dimension(axis[i] - i, false); } out = reference::reshape_layer(out, output_shape); @@ -149,7 +154,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ReduceMeanQuantizedFixture : public ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output) { ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output); @@ -160,7 +164,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class ReduceMeanFixture : public ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T> { public: - template <typename...> void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims) { ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, QuantizationInfo(), QuantizationInfo()); |