aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScaleFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ScaleFixture.h')
-rw-r--r--tests/validation/fixtures/ScaleFixture.h80
1 files changed, 48 insertions, 32 deletions
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h
index b719a22fdf..86d89d71f7 100644
--- a/tests/validation/fixtures/ScaleFixture.h
+++ b/tests/validation/fixtures/ScaleFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,15 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_SCALE_FIXTURE
-#define ARM_COMPUTE_TEST_SCALE_FIXTURE
-
-#include "arm_compute/core/TensorShape.h"
-#include "arm_compute/core/Types.h"
-#include "tests/AssetsLibrary.h"
-#include "tests/Globals.h"
-#include "tests/IAccessor.h"
-#include "tests/framework/Asserts.h"
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H
+
+#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT
#include "tests/framework/Fixture.h"
#include "tests/validation/reference/Permute.h"
#include "tests/validation/reference/Scale.h"
@@ -44,23 +39,23 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScaleValidationGenericFixture : public framework::Fixture
{
public:
- template <typename...>
void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy,
- bool align_corners, bool mixed_layout)
+ bool align_corners, bool mixed_layout, QuantizationInfo output_quantization_info)
{
- _shape = shape;
- _policy = policy;
- _border_mode = border_mode;
- _sampling_policy = sampling_policy;
- _data_type = data_type;
- _quantization_info = quantization_info;
- _align_corners = align_corners;
- _mixed_layout = mixed_layout;
+ _shape = shape;
+ _policy = policy;
+ _border_mode = border_mode;
+ _sampling_policy = sampling_policy;
+ _data_type = data_type;
+ _input_quantization_info = quantization_info;
+ _output_quantization_info = output_quantization_info;
+ _align_corners = align_corners;
+ _mixed_layout = mixed_layout;
generate_scale(shape);
- std::mt19937 generator(library->seed());
- std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
+ std::mt19937 generator(library->seed());
+ std::uniform_int_distribution<uint32_t> distribution_u8(0, 255);
_constant_border_value = static_cast<T>(distribution_u8(generator));
_target = compute_target(shape, data_layout);
@@ -144,7 +139,7 @@ protected:
}
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, _data_type, 1, _quantization_info, data_layout);
+ TensorType src = create_tensor<TensorType>(shape, _data_type, 1, _input_quantization_info, data_layout);
const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
@@ -152,7 +147,7 @@ protected:
TensorShape shape_scaled(shape);
shape_scaled.set(idx_width, shape[idx_width] * _scale_x, /* apply_dim_correction = */ false);
shape_scaled.set(idx_height, shape[idx_height] * _scale_y, /* apply_dim_correction = */ false);
- TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, _quantization_info, data_layout);
+ TensorType dst = create_tensor<TensorType>(shape_scaled, _data_type, 1, _output_quantization_info, data_layout);
// Create and configure function
FunctionType scale;
@@ -188,12 +183,12 @@ protected:
SimpleTensor<T> compute_reference(const TensorShape &shape)
{
// Create reference
- SimpleTensor<T> src{ shape, _data_type, 1, _quantization_info };
+ SimpleTensor<T> src{ shape, _data_type, 1, _input_quantization_info };
// Fill reference
fill(src);
- return reference::scale<T>(src, _scale_x, _scale_y, _policy, _border_mode, _constant_border_value, _sampling_policy, /* ceil_policy_scale */ false, _align_corners);
+ return reference::scale<T>(src, _scale_x, _scale_y, _policy, _border_mode, _constant_border_value, _sampling_policy, /* ceil_policy_scale */ false, _align_corners, _output_quantization_info);
}
TensorType _target{};
@@ -204,7 +199,8 @@ protected:
T _constant_border_value{};
SamplingPolicy _sampling_policy{};
DataType _data_type{};
- QuantizationInfo _quantization_info{};
+ QuantizationInfo _input_quantization_info{};
+ QuantizationInfo _output_quantization_info{};
bool _align_corners{ false };
bool _mixed_layout{ false };
float _scale_x{ 1.f };
@@ -215,7 +211,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScaleValidationQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- template <typename...>
void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy,
bool align_corners)
{
@@ -227,14 +222,34 @@ public:
border_mode,
sampling_policy,
align_corners,
- mixed_layout);
+ mixed_layout,
+ quantization_info);
+ }
+};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
+class ScaleValidationDifferentOutputQuantizedFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ void setup(TensorShape shape, DataType data_type, QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, DataLayout data_layout, InterpolationPolicy policy,
+ BorderMode border_mode, SamplingPolicy sampling_policy,
+ bool align_corners)
+ {
+ ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
+ data_type,
+ input_quantization_info,
+ data_layout,
+ policy,
+ border_mode,
+ sampling_policy,
+ align_corners,
+ mixed_layout,
+ output_quantization_info);
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
class ScaleValidationFixture : public ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- template <typename...>
void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, bool align_corners)
{
ScaleValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
@@ -245,10 +260,11 @@ public:
border_mode,
sampling_policy,
align_corners,
- mixed_layout);
+ mixed_layout,
+ QuantizationInfo());
}
};
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_SCALE_FIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_SCALEFIXTURE_H