aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp17
1 files changed, 6 insertions, 11 deletions
diff --git a/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp b/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp
index 71218f5b52..69e5eff213 100644
--- a/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp
+++ b/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp
@@ -40,8 +40,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output, input_info, weights_info);
const DataLayout data_layout = input_info->data_layout();
- const unsigned int stride_x = deconv_info.stride().first;
- const unsigned int stride_y = deconv_info.stride().second;
const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
@@ -77,8 +75,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con
if(output->total_size() != 0)
{
- auto out_dims = deconvolution_output_dimensions(input_info->dimension(idx_w), input_info->dimension(idx_h), weights_info->dimension(idx_w), weights_info->dimension(idx_h),
- 0, 0, stride_x, stride_y);
+ const PadStrideInfo stride_info(deconv_info.stride().first, deconv_info.stride().second);
+ auto out_dims = deconvolution_output_dimensions(input_info->dimension(idx_w), input_info->dimension(idx_h), weights_info->dimension(idx_w), weights_info->dimension(idx_h), stride_info);
const TensorShape output_shape = misc::shape_calculator::compute_deconvolution_output_shape(out_dims, *input_info, *weights_info);
@@ -92,14 +90,11 @@ std::pair<Status, Window> validate_and_configure_window(const ITensorInfo *input
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
const DataLayout data_layout = input_info->data_layout();
+ const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const PadStrideInfo stride_info(deconv_info.stride().first, deconv_info.stride().second);
- const unsigned int stride_x = deconv_info.stride().first;
- const unsigned int stride_y = deconv_info.stride().second;
- const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
-
- auto out_dims = deconvolution_output_dimensions(input_info->dimension(idx_w), input_info->dimension(idx_h), weights_info->dimension(idx_w), weights_info->dimension(idx_h),
- 0, 0, stride_x, stride_y);
+ auto out_dims = deconvolution_output_dimensions(input_info->dimension(idx_w), input_info->dimension(idx_h), weights_info->dimension(idx_w), weights_info->dimension(idx_h), stride_info);
const TensorShape output_shape = misc::shape_calculator::compute_deconvolution_output_shape(out_dims, *input_info, *weights_info);