diff options
Diffstat (limited to 'src/core/CL/kernels/CLWinogradInputTransformKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLWinogradInputTransformKernel.cpp | 19 |
1 files changed, 2 insertions, 17 deletions
diff --git a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp index 695e1cbbf1..392edda615 100644 --- a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp +++ b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp @@ -126,23 +126,6 @@ void CLWinogradInputTransformKernel::configure(const CLCompileContext &compile_c const size_t idx_w = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); const size_t idx_h = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); - // Compute number of elements to process in the X and Y direction - const int num_elements_x = input->info()->dimension(idx_w) - (kernel_size.width - 1) + conv_info.pad_left() + conv_info.pad_right(); - const int num_elements_y = input->info()->dimension(idx_h) - (kernel_size.height - 1) + conv_info.pad_top() + conv_info.pad_bottom(); - - if(_data_layout == DataLayout::NCHW) - { - // Check if we need to extend the right or bottom border - const unsigned int extra_border_right = ((num_elements_x % output_tile_size.width) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.width - 1); - const unsigned int extra_border_bottom = ((num_elements_y % output_tile_size.height) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.height - 1); - - _border_size = BorderSize(conv_info.pad_top(), conv_info.pad_right() + extra_border_right, conv_info.pad_bottom() + extra_border_bottom, conv_info.pad_left()); - } - else - { - _border_size = BorderSize(); - } - // Compute the number of output tiles along the x and y direction of size "output_tile_size" const Size2D num_tiles = compute_winograd_convolution_tiles(Size2D(input->info()->dimension(idx_w), input->info()->dimension(idx_h)), kernel_size, @@ -206,6 +189,8 @@ void CLWinogradInputTransformKernel::configure(const CLCompileContext &compile_c ARM_COMPUTE_ERROR_THROW_ON(win_config.first); ICLKernel::configure_internal(win_config.second, cl::NDRange(1, 1, 8)); + _border_size = BorderSize(_input->info()->padding()); + ARM_COMPUTE_ERROR_ON((input->info()->data_layout() == DataLayout::NHWC) && has_padding_changed(padding_info)); _config_id = kernel_name; |