diff options
Diffstat (limited to 'src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp | 17 |
1 files changed, 6 insertions, 11 deletions
diff --git a/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp b/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp index 71218f5b52..69e5eff213 100644 --- a/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp +++ b/src/core/CL/kernels/CLDeconvolutionReshapeOutputKernel.cpp @@ -40,8 +40,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output, input_info, weights_info); const DataLayout data_layout = input_info->data_layout(); - const unsigned int stride_x = deconv_info.stride().first; - const unsigned int stride_y = deconv_info.stride().second; const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); @@ -77,8 +75,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con if(output->total_size() != 0) { - auto out_dims = deconvolution_output_dimensions(input_info->dimension(idx_w), input_info->dimension(idx_h), weights_info->dimension(idx_w), weights_info->dimension(idx_h), - 0, 0, stride_x, stride_y); + const PadStrideInfo stride_info(deconv_info.stride().first, deconv_info.stride().second); + auto out_dims = deconvolution_output_dimensions(input_info->dimension(idx_w), input_info->dimension(idx_h), weights_info->dimension(idx_w), weights_info->dimension(idx_h), stride_info); const TensorShape output_shape = misc::shape_calculator::compute_deconvolution_output_shape(out_dims, *input_info, *weights_info); @@ -92,14 +90,11 @@ std::pair<Status, Window> validate_and_configure_window(const ITensorInfo *input ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); const DataLayout data_layout = input_info->data_layout(); + const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const PadStrideInfo stride_info(deconv_info.stride().first, deconv_info.stride().second); - const unsigned int stride_x = deconv_info.stride().first; - const unsigned int stride_y = deconv_info.stride().second; - const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const size_t idx_h = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - - auto out_dims = deconvolution_output_dimensions(input_info->dimension(idx_w), input_info->dimension(idx_h), weights_info->dimension(idx_w), weights_info->dimension(idx_h), - 0, 0, stride_x, stride_y); + auto out_dims = deconvolution_output_dimensions(input_info->dimension(idx_w), input_info->dimension(idx_h), weights_info->dimension(idx_w), weights_info->dimension(idx_h), stride_info); const TensorShape output_shape = misc::shape_calculator::compute_deconvolution_output_shape(out_dims, *input_info, *weights_info); |