aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScaleFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ScaleFixture.h')
-rw-r--r--tests/validation/fixtures/ScaleFixture.h24
1 files changed, 13 insertions, 11 deletions
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h
index 476985e066..894260a02f 100644
--- a/tests/validation/fixtures/ScaleFixture.h
+++ b/tests/validation/fixtures/ScaleFixture.h
@@ -44,15 +44,16 @@ class ScaleValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode)
+ void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
{
constexpr float max_width = 8192.0f;
constexpr float max_height = 6384.0f;
- _shape = shape;
- _policy = policy;
- _border_mode = border_mode;
- _data_type = data_type;
+ _shape = shape;
+ _policy = policy;
+ _border_mode = border_mode;
+ _sampling_policy = sampling_policy;
+ _data_type = data_type;
std::mt19937 generator(library->seed());
std::uniform_real_distribution<float> distribution_float(0.25, 3);
@@ -65,8 +66,8 @@ public:
std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
T constant_border_value = static_cast<T>(distribution_u8(generator));
- _target = compute_target(shape, scale_x, scale_y, policy, border_mode, constant_border_value);
- _reference = compute_reference(shape, scale_x, scale_y, policy, border_mode, constant_border_value);
+ _target = compute_target(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
+ _reference = compute_reference(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
}
protected:
@@ -77,7 +78,7 @@ protected:
}
TensorType compute_target(const TensorShape &shape, const float scale_x, const float scale_y,
- InterpolationPolicy policy, BorderMode border_mode, T constant_border_value)
+ InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy)
{
// Create tensors
TensorType src = create_tensor<TensorType>(shape, _data_type);
@@ -89,7 +90,7 @@ protected:
// Create and configure function
FunctionType scale;
- scale.configure(&src, &dst, policy, border_mode, constant_border_value);
+ scale.configure(&src, &dst, policy, border_mode, constant_border_value, sampling_policy);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -110,7 +111,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &shape, const float scale_x, const float scale_y,
- InterpolationPolicy policy, BorderMode border_mode, T constant_border_value)
+ InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy)
{
// Create reference
SimpleTensor<T> src{ shape, _data_type };
@@ -118,7 +119,7 @@ protected:
// Fill reference
fill(src);
- return reference::scale<T>(src, scale_x, scale_y, policy, border_mode, constant_border_value);
+ return reference::scale<T>(src, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
}
TensorType _target{};
@@ -126,6 +127,7 @@ protected:
TensorShape _shape{};
InterpolationPolicy _policy{};
BorderMode _border_mode{};
+ SamplingPolicy _sampling_policy{};
DataType _data_type{};
};
} // namespace validation