diff options
Diffstat (limited to 'src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp | 10 |
1 files changed, 3 insertions, 7 deletions
diff --git a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp index 779df637f6..e6c713e5e7 100644 --- a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp +++ b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp @@ -25,7 +25,6 @@ #include "arm_compute/core/AccessWindowStatic.h" #include "arm_compute/core/CL/CLHelpers.h" -#include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/CLKernelLibrary.h" #include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/Helpers.h" @@ -54,12 +53,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH); const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd filter transform only supports 3x3 and 5x5 kernels"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC && output_tile_size != Size2D(4U, 4U), "Winograd filter transform only supports 4x4 output tile for NHWC data layout"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U) - && output_tile_size != Size2D(4U, 4U), - "Winograd filter transform only supports 2x2 or 4x4 output tile for 3x3 kernels"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U), "Winograd filter transform only supports 4x4 output tile for 5x5 kernels"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!cl_winograd_convolution_layer_supported(output_tile_size, kernel_size, input->data_layout()), "Winograd filter transform not supported"); ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(idx_w) != kernel_size.width || input->dimension(idx_h) != kernel_size.height); ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4); @@ -115,6 +109,8 @@ void CLWinogradFilterTransformKernel::configure(const ICLTensor *input, ICLTenso // Set build options CLBuildOptions build_opts; build_opts.add_option("-DSRC_DIM_Z=" + support::cpp11::to_string(input->info()->dimension(2))); + build_opts.add_option_if(winograd_info.kernel_size.height == 1, "-DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL"); + build_opts.add_option_if(winograd_info.kernel_size.width == 1, "-DWINOGRAD_FILTER_TRANSFORM_VERTICAL"); const Size2D kernel_size = winograd_info.kernel_size; const Size2D output_tile_size = winograd_info.output_tile_size; |