aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLScaleKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLScaleKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLScaleKernel.cpp22
1 files changed, 9 insertions, 13 deletions
diff --git a/src/core/CL/kernels/CLScaleKernel.cpp b/src/core/CL/kernels/CLScaleKernel.cpp
index 5a7d5830fd..f3d2fa12d5 100644
--- a/src/core/CL/kernels/CLScaleKernel.cpp
+++ b/src/core/CL/kernels/CLScaleKernel.cpp
@@ -120,15 +120,8 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
break;
case DataLayout::NHWC:
{
- num_elems_processed_per_iteration = 1;
// Configure kernel window
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
- AccessWindowStatic input_access(input, -border.left, -border.top,
- input->dimension(0) + border.right,
- input->dimension(1) + border.bottom);
- AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
- window_changed = update_window_and_padding(win, input_access, output_access);
- output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+ win = calculate_max_window(*output, Steps());
}
break;
default:
@@ -142,14 +135,13 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
BorderSize CLScaleKernel::border_size() const
{
- return BorderSize(1);
+ return BorderSize(static_cast<size_t>(_data_layout == DataLayout::NCHW));
}
Status CLScaleKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ScaleKernelInfo &info)
{
- BorderSize border = BorderSize(1);
-
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info));
+ BorderSize border = BorderSize(static_cast<size_t>(input->data_layout() == DataLayout::NCHW));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), info, border).first);
return Status{};
@@ -173,6 +165,7 @@ void CLScaleKernel::configure(const ICLTensor *input, ICLTensor *output, const S
void CLScaleKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output, const ScaleKernelInfo &info)
{
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), info));
+ auto padding_info = get_padding_info({ input, output });
_input = input;
_output = output;
@@ -208,6 +201,7 @@ void CLScaleKernel::configure(const CLCompileContext &compile_context, const ICL
// Create kernel
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DCONSTANT_VALUE=" + string_from_pixel_value(info.constant_border_value, input->info()->data_type()));
build_opts.add_option("-DBORDER_SIZE=" + support::cpp11::to_string(border.right));
build_opts.add_option_if(info.border_mode == BorderMode::REPLICATE, "-DBORDER_MODE_REPLICATE");
build_opts.add_option_if(is_nhwc, "-DDEPTH_OUT=" + support::cpp11::to_string(output->info()->dimension(2)));
@@ -219,7 +213,6 @@ void CLScaleKernel::configure(const CLCompileContext &compile_context, const ICL
build_opts.add_option("-DSCALE=" + support::cpp11::to_string(qinfo.scale));
build_opts.add_option("-DOFFSET=" + support::cpp11::to_string(qinfo.offset));
}
-
std::string interpolation_name = string_from_interpolation_policy(interpolation_policy_to_use);
std::transform(interpolation_name.begin(), interpolation_name.end(), interpolation_name.begin(), ::tolower);
std::string kernel_name = "scale_" + interpolation_name;
@@ -250,13 +243,16 @@ void CLScaleKernel::configure(const CLCompileContext &compile_context, const ICL
_config_id += support::cpp11::to_string(output->info()->dimension(2));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(3));
+ if(is_nhwc)
+ {
+ ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info));
+ }
}
void CLScaleKernel::run(const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
-
switch(_data_layout)
{
case DataLayout::NCHW: