aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ReduceMeanFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ReduceMeanFixture.h')
-rw-r--r--tests/validation/fixtures/ReduceMeanFixture.h28
1 files changed, 15 insertions, 13 deletions
diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h
index d10292182f..72887616fe 100644
--- a/tests/validation/fixtures/ReduceMeanFixture.h
+++ b/tests/validation/fixtures/ReduceMeanFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,6 +26,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/Tensor.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -47,10 +48,10 @@ class ReduceMeanValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info)
+ void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
- _target = compute_target(shape, data_type, axis, keep_dims, quantization_info);
- _reference = compute_reference(shape, data_type, axis, keep_dims, quantization_info);
+ _target = compute_target(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output);
+ _reference = compute_reference(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output);
}
protected:
@@ -71,11 +72,12 @@ protected:
}
}
- TensorType compute_target(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info)
+ TensorType compute_target(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, quantization_info);
- TensorType dst;
+ TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, quantization_info_input);
+ TensorShape dst_shape = arm_compute::misc::shape_calculator::calculate_reduce_mean_shape(src.info(), axis, keep_dims);
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, quantization_info_output);
// Create and configure function
FunctionType reduction_mean;
@@ -100,10 +102,10 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info)
+ SimpleTensor<T> compute_reference(TensorShape &src_shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
// Create reference
- SimpleTensor<T> src{ src_shape, data_type, 1, quantization_info };
+ SimpleTensor<T> src{ src_shape, data_type, 1, quantization_info_input };
// Fill reference
fill(src);
@@ -113,7 +115,7 @@ protected:
{
TensorShape output_shape = i == 0 ? src_shape : out.shape();
output_shape.set(axis[i], 1);
- out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM);
+ out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, quantization_info_output);
}
if(!keep_dims)
@@ -139,9 +141,9 @@ class ReduceMeanQuantizedFixture : public ReduceMeanValidationFixture<TensorType
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info = QuantizationInfo())
+ void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims, QuantizationInfo quantization_info_input, QuantizationInfo quantization_info_output)
{
- ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, quantization_info);
+ ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, quantization_info_input, quantization_info_output);
}
};
@@ -152,7 +154,7 @@ public:
template <typename...>
void setup(TensorShape shape, DataType data_type, Coordinates axis, bool keep_dims)
{
- ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, QuantizationInfo());
+ ReduceMeanValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, keep_dims, QuantizationInfo(), QuantizationInfo());
}
};
} // namespace validation