aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLScaleKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLScaleKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLScaleKernel.cpp38
1 files changed, 19 insertions, 19 deletions
diff --git a/src/core/CL/kernels/CLScaleKernel.cpp b/src/core/CL/kernels/CLScaleKernel.cpp
index f3d2fa12d5..c2c78c8f6d 100644
--- a/src/core/CL/kernels/CLScaleKernel.cpp
+++ b/src/core/CL/kernels/CLScaleKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,15 +41,15 @@
#include <set>
#include <string>
-using namespace arm_compute;
-
+namespace arm_compute
+{
namespace
{
-inline std::pair<float, float> calculate_scale_factors(const ITensorInfo &input, const ITensorInfo &output, bool align_corners)
+inline std::pair<float, float> calculate_scale_factors(const ITensorInfo &input, const ITensorInfo &output, const ScaleKernelInfo &info)
{
- DataLayout data_layout = input.data_layout();
- const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const DataLayout data_layout = info.data_layout == DataLayout::UNKNOWN ? input.data_layout() : info.data_layout;
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
// Compute the ratio between source width/height and destination width/height
const unsigned int input_width = input.dimension(idx_width);
@@ -57,8 +57,8 @@ inline std::pair<float, float> calculate_scale_factors(const ITensorInfo &input,
const unsigned int output_width = output.dimension(idx_width);
const unsigned int output_height = output.dimension(idx_height);
- float wr = arm_compute::scale_utils::calculate_resize_ratio(input_width, output_width, align_corners);
- float hr = arm_compute::scale_utils::calculate_resize_ratio(input_height, output_height, align_corners);
+ float wr = arm_compute::scale_utils::calculate_resize_ratio(input_width, output_width, info.align_corners);
+ float hr = arm_compute::scale_utils::calculate_resize_ratio(input_height, output_height, info.align_corners);
return std::make_pair(wr, hr);
}
@@ -73,11 +73,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
ARM_COMPUTE_RETURN_ERROR_ON(output == input);
ARM_COMPUTE_RETURN_ERROR_ON(info.align_corners && !arm_compute::scale_utils::is_align_corners_allowed_sampling_policy(info.sampling_policy));
- const bool will_use_align_corners = info.align_corners;
-
float wr = 0.f;
float hr = 0.f;
- std::tie(wr, hr) = calculate_scale_factors(*input, *output, will_use_align_corners);
+ std::tie(wr, hr) = calculate_scale_factors(*input, *output, info);
ARM_COMPUTE_RETURN_ERROR_ON(info.interpolation_policy == InterpolationPolicy::AREA && (wr > 1.f || hr > 1.f));
@@ -86,10 +84,10 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const ScaleKernelInfo &info, BorderSize &border)
{
- Window win{};
- bool window_changed{};
- unsigned int num_elems_processed_per_iteration = 0;
- DataLayout data_layout = input->data_layout();
+ Window win{};
+ bool window_changed{};
+ unsigned int num_elems_processed_per_iteration = 0;
+ const DataLayout data_layout = info.data_layout == DataLayout::UNKNOWN ? input->data_layout() : info.data_layout;
switch(data_layout)
{
@@ -141,7 +139,8 @@ BorderSize CLScaleKernel::border_size() const
Status CLScaleKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ScaleKernelInfo &info)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info));
- BorderSize border = BorderSize(static_cast<size_t>(input->data_layout() == DataLayout::NCHW));
+ const DataLayout data_layout = info.data_layout == DataLayout::UNKNOWN ? input->data_layout() : info.data_layout;
+ BorderSize border = BorderSize(static_cast<size_t>(data_layout == DataLayout::NCHW));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), info, border).first);
return Status{};
@@ -170,12 +169,12 @@ void CLScaleKernel::configure(const CLCompileContext &compile_context, const ICL
_input = input;
_output = output;
_interpolation_policy = info.interpolation_policy;
- _data_layout = input->info()->data_layout();
+ _data_layout = info.data_layout == DataLayout::UNKNOWN ? input->info()->data_layout() : info.data_layout;
_align_corners = info.align_corners;
float wr = 0.f;
float hr = 0.f;
- std::tie(wr, hr) = calculate_scale_factors(*input->info(), *output->info(), _align_corners);
+ std::tie(wr, hr) = calculate_scale_factors(*input->info(), *output->info(), info);
const bool call_quantized_kernel = is_data_type_quantized_asymmetric(input->info()->data_type()) && _interpolation_policy == InterpolationPolicy::BILINEAR;
@@ -284,3 +283,4 @@ void CLScaleKernel::run(const Window &window, cl::CommandQueue &queue)
ARM_COMPUTE_ERROR("Data layout not supported");
}
}
+} // namespace arm_compute