From bb123bd6f64444141161562aad06cb406762d47a Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Fri, 3 Jan 2020 10:57:30 +0000 Subject: MLCE-139 add align_corners parameter to NEScale Change-Id: I497ceb54c5fd8af1af8c529f90fd5a00a45263c8 Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/2538 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Pablo Marquez --- tests/validation/fixtures/ScaleFixture.h | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) (limited to 'tests/validation/fixtures/ScaleFixture.h') diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h index 2be02ec4d6..8d851ce574 100644 --- a/tests/validation/fixtures/ScaleFixture.h +++ b/tests/validation/fixtures/ScaleFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -45,7 +45,8 @@ class ScaleValidationGenericFixture : public framework::Fixture { public: template - void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy) + void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, + bool align_corners) { constexpr float max_width = 8192.0f; constexpr float max_height = 6384.0f; @@ -56,6 +57,7 @@ public: _sampling_policy = sampling_policy; _data_type = data_type; _quantization_info = quantization_info; + _align_corners = align_corners; std::mt19937 generator(library->seed()); std::uniform_real_distribution distribution_float(0.25, 3); @@ -120,7 +122,7 @@ protected: // Create and configure function FunctionType scale; - scale.configure(&src, &dst, policy, border_mode, constant_border_value, sampling_policy); + scale.configure(&src, &dst, policy, border_mode, constant_border_value, sampling_policy, /* use_padding */ true, _align_corners); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -150,7 +152,7 @@ protected: // Fill reference fill(src); - return reference::scale(src, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy); + return reference::scale(src, scale_x, scale_y, policy, border_mode, constant_border_value, sampling_policy, /* ceil_policy_scale */ false, _align_corners); } TensorType _target{}; @@ -161,6 +163,7 @@ protected: SamplingPolicy _sampling_policy{}; DataType _data_type{}; QuantizationInfo _quantization_info{}; + bool _align_corners{ false }; }; template @@ -168,7 +171,8 @@ class ScaleValidationQuantizedFixture : public ScaleValidationGenericFixture - void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy) + void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, + bool align_corners) { ScaleValidationGenericFixture::setup(shape, data_type, @@ -176,7 +180,8 @@ public: data_layout, policy, border_mode, - sampling_policy); + sampling_policy, + align_corners); } }; template @@ -184,7 +189,7 @@ class ScaleValidationFixture : public ScaleValidationGenericFixture - void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy) + void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, BorderMode border_mode, SamplingPolicy sampling_policy, bool align_corners) { ScaleValidationGenericFixture::setup(shape, data_type, @@ -192,7 +197,8 @@ public: data_layout, policy, border_mode, - sampling_policy); + sampling_policy, + align_corners); } }; } // namespace validation -- cgit v1.2.1