aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ReduceMeanFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ReduceMeanFixture.h')
-rw-r--r--tests/validation/fixtures/ReduceMeanFixture.h62
1 files changed, 38 insertions, 24 deletions
diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h
index 44bb9fca6a..e61941435c 100644
--- a/tests/validation/fixtures/ReduceMeanFixture.h
+++ b/tests/validation/fixtures/ReduceMeanFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,6 +26,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/Tensor.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -46,50 +47,59 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ReduceMeanValidationFixture : public framework::Fixture
{
public:
- template <typename...>
- void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info)
+ void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
- _target = compute_target(shape, data_type, axis, keep_dims, quantization_info);
- _reference = compute_reference(shape, data_type, axis, keep_dims, quantization_info);
+ _target = compute_target(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output);
+ _reference = compute_reference(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output);
}
protected:
template <typename U>
void fill(U &&tensor)
{
- if(!is_data_type_quantized(tensor.data_type()))
+ if(tensor.data_type() == DataType::F32)
{
- std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+ std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
library->fill(tensor, distribution, 0);
}
- else
+ else if(tensor.data_type() == DataType::F16)
+ {
+ arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+ library->fill(tensor, distribution, 0);
+ }
+ else if(is_data_type_quantized(tensor.data_type()))
{
std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, 0);
}
+ else
+ {
+ library->fill_tensor_uniform(tensor, 0);
+ }
}
- TensorType compute_target(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info)
+ TensorType compute_target(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, quantization_info);
- TensorType dst;
+ TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, quantization_info_input);
+ TensorShape dst_shape = arm_compute::misc::shape_calculator::calculate_reduce_mean_shape(src.info(), axis, keep_dims);
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, quantization_info_output);
// Create and configure function
FunctionType reduction_mean;
reduction_mean.configure(&src, axis, keep_dims, &dst);
- ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(src.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
// Allocate tensors
src.allocator()->allocate();
dst.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
// Fill tensors
fill(AccessorType(src));
@@ -100,10 +110,10 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info)
+ SimpleTensor<T> compute_reference(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
// Create reference
- SimpleTensor<T> src{ src_shape, data_type, 1, quantization_info };
+ SimpleTensor<T> src{ src_shape, data_type, 1, quantization_info_input };
// Fill reference
fill(src);
@@ -113,7 +123,13 @@ protected:
{
TensorShape output_shape = i == 0 ? src_shape : out.shape();
output_shape.set(axis[i], 1);
- out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM);
+ bool is_opencl = false;
+
+#ifdef ARM_COMPUTE_OPENCL_ENABLED
+ is_opencl = std::is_same<CLTensor, TensorType>::value; // Round down to zero on opencl to match kernel
+#endif /* ARM_COMPUTE_OPENCL_ENABLED */
+ out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, data_type, quantization_info_output,
+ is_opencl ? RoundingPolicy::TO_ZERO : RoundingPolicy::TO_NEAREST_UP);
}
if(!keep_dims)
@@ -122,7 +138,7 @@ protected:
std::sort(axis.begin(), axis.begin() + axis.num_dimensions());
for(unsigned int i = 0; i < axis.num_dimensions(); ++i)
{
- output_shape.remove_dimension(axis[i] - i);
+ output_shape.remove_dimension(axis[i] - i, false);
}
out = reference::reshape_layer(out, output_shape);
@@ -138,10 +154,9 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ReduceMeanQuantizedFixture : public ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- template <typename...>
- void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info = QuantizationInfo())
+ void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
- ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, quantization_info);
+ ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output);
}
};
@@ -149,10 +164,9 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ReduceMeanFixture : public ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- template <typename...>
void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims)
{
- ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, QuantizationInfo());
+ ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, QuantizationInfo(), QuantizationInfo());
}
};
} // namespace validation