aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLScaleKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLScaleKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLScaleKernel.cpp30
1 files changed, 8 insertions, 22 deletions
diff --git a/src/core/CL/kernels/CLScaleKernel.cpp b/src/core/CL/kernels/CLScaleKernel.cpp
index 872ba5b6cc..f3acc3b31c 100644
--- a/src/core/CL/kernels/CLScaleKernel.cpp
+++ b/src/core/CL/kernels/CLScaleKernel.cpp
@@ -35,6 +35,8 @@
#include "arm_compute/core/TensorInfo.h"
#include "support/StringSupport.h"
+#include "src/core/utils/ScaleUtils.h"
+
#include <set>
#include <string>
@@ -54,8 +56,8 @@ inline std::pair<float, float> calculate_scale_factors(const ITensorInfo &input,
const unsigned int output_width = output.dimension(idx_width);
const unsigned int output_height = output.dimension(idx_height);
- float wr = arm_compute::calculate_resize_ratio(input_width, output_width, align_corners);
- float hr = arm_compute::calculate_resize_ratio(input_height, output_height, align_corners);
+ float wr = arm_compute::scale_utils::calculate_resize_ratio(input_width, output_width, align_corners);
+ float hr = arm_compute::scale_utils::calculate_resize_ratio(input_height, output_height, align_corners);
return std::make_pair(wr, hr);
}
@@ -68,29 +70,13 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
ARM_COMPUTE_RETURN_ERROR_ON(output == input);
- ARM_COMPUTE_RETURN_ERROR_ON(info.align_corners && !is_align_corners_allowed(info.sampling_policy));
+ ARM_COMPUTE_RETURN_ERROR_ON(info.align_corners && !arm_compute::scale_utils::is_align_corners_allowed_sampling_policy(info.sampling_policy));
- if(info.align_corners)
- {
- // For bilinear method with aligned corners, the resize ratio will
- // be calculated by (input_size - 1)/(output_size - 1). Belows are
- // checking possible overflows.
- const auto data_layout = input->data_layout();
- const auto width_index = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const auto height_index = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
-
- const auto input_width = input->dimension(width_index);
- const auto input_height = input->dimension(height_index);
- const auto output_width = output->dimension(width_index);
- const auto output_height = output->dimension(height_index);
-
- ARM_COMPUTE_RETURN_ERROR_ON(input_width == 0 || input_height == 0 || output_width == 0 || output_height == 0);
- ARM_COMPUTE_RETURN_ERROR_ON((output_width - 1 == 0) || (output_height - 1 == 0));
- }
+ const bool will_use_align_corners = info.align_corners && arm_compute::scale_utils::is_align_corners_allowed_output_shape(output->tensor_shape(), output->data_layout());
float wr = 0.f;
float hr = 0.f;
- std::tie(wr, hr) = calculate_scale_factors(*input, *output, info.align_corners);
+ std::tie(wr, hr) = calculate_scale_factors(*input, *output, will_use_align_corners);
ARM_COMPUTE_RETURN_ERROR_ON(info.interpolation_policy == InterpolationPolicy::AREA && (wr > 1.f || hr > 1.f));
@@ -191,7 +177,7 @@ void CLScaleKernel::configure(const CLCompileContext &compile_context, const ICL
_output = output;
_interpolation_policy = info.interpolation_policy;
_data_layout = input->info()->data_layout();
- _align_corners = info.align_corners;
+ _align_corners = info.align_corners && arm_compute::scale_utils::is_align_corners_allowed_output_shape(output->info()->tensor_shape(), _data_layout);
float wr = 0.f;
float hr = 0.f;