aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ScaleFixture.h
diff options
context:
space:
mode:
authorDaniil Efremov <daniil.efremov@xored.com>2017-11-22 00:26:51 +0700
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:41:04 +0000
commit02bf80d4554cfc824a76008905921cb564bee999 (patch)
treeb86ebbed4d330af69c1107c10ce5e765705e88dd /tests/validation/fixtures/ScaleFixture.h
parent6194145681232bf59e0455434f15aba42956145b (diff)
downloadComputeLibrary-02bf80d4554cfc824a76008905921cb564bee999.tar.gz
COMPMID-661: Fix scale border issue (#38)
Change-Id: If1dcca724e5e5f5ab363ffc16b0ef8c943e0b657 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110105 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ScaleFixture.h')
-rw-r--r--tests/validation/fixtures/ScaleFixture.h24
1 files changed, 13 insertions, 11 deletions
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h
index 476985e066..894260a02f 100644
--- a/tests/validation/fixtures/ScaleFixture.h
+++ b/tests/validation/fixtures/ScaleFixture.h
@@ -44,15 +44,16 @@ class ScaleValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode)
+ void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy)
{
constexpr float max_width = 8192.0f;
constexpr float max_height = 6384.0f;
- _shape = shape;
- _policy = policy;
- _border_mode = border_mode;
- _data_type = data_type;
+ _shape = shape;
+ _policy = policy;
+ _border_mode = border_mode;
+ _sampling_policy = sampling_policy;
+ _data_type = data_type;
std::mt19937 generator(library->seed());
std::uniform_real_distribution<float> distribution_float(0.25, 3);
@@ -65,8 +66,8 @@ public:
std::uniform_int_distribution<uint8_t> distribution_u8(0, 255);
T constant_border_value = static_cast<T>(distribution_u8(generator));
- _target = compute_target(shape, scale_x, scale_y, policy, border_mode, constant_border_value);
- _reference = compute_reference(shape, scale_x, scale_y, policy, border_mode, constant_border_value);
+ _target = compute_target(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
+ _reference = compute_reference(shape, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
}
protected:
@@ -77,7 +78,7 @@ protected:
}
TensorType compute_target(const TensorShape &shape, const float scale_x, const float scale_y,
- InterpolationPolicy policy, BorderMode border_mode, T constant_border_value)
+ InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy)
{
// Create tensors
TensorType src = create_tensor<TensorType>(shape, _data_type);
@@ -89,7 +90,7 @@ protected:
// Create and configure function
FunctionType scale;
- scale.configure(&src, &dst, policy, border_mode, constant_border_value);
+ scale.configure(&src, &dst, policy, border_mode, constant_border_value, sampling_policy);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -110,7 +111,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &shape, const float scale_x, const float scale_y,
- InterpolationPolicy policy, BorderMode border_mode, T constant_border_value)
+ InterpolationPolicy policy, BorderMode border_mode, T constant_border_value, SamplingPolicy sampling_policy)
{
// Create reference
SimpleTensor<T> src{ shape, _data_type };
@@ -118,7 +119,7 @@ protected:
// Fill reference
fill(src);
- return reference::scale<T>(src, scale_x, scale_y, policy, border_mode, constant_border_value);
+ return reference::scale<T>(src, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy);
}
TensorType _target{};
@@ -126,6 +127,7 @@ protected:
TensorShape _shape{};
InterpolationPolicy _policy{};
BorderMode _border_mode{};
+ SamplingPolicy _sampling_policy{};
DataType _data_type{};
};
} // namespace validation