diff options
Diffstat (limited to 'src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp index 8f990712e8..f5609b6f5c 100644 --- a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp +++ b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp @@ -42,7 +42,7 @@ namespace { inline bool is_kernel_size_supported(Size2D size) { - const std::array<Size2D, 4> supported_input_sizes = { { Size2D(1, 3), Size2D(3, 1), Size2D(5, 5), Size2D(3, 3) } }; + const std::array<Size2D, 8> supported_input_sizes = { { Size2D(1, 3), Size2D(3, 1), Size2D(5, 5), Size2D(3, 3), Size2D(1, 5), Size2D(5, 1), Size2D(7, 1), Size2D(1, 7) } }; return std::end(supported_input_sizes) != std::find(std::begin(supported_input_sizes), std::end(supported_input_sizes), size); } @@ -56,10 +56,10 @@ Status validate_arguments_winograd_weight_trans(const ITensorInfo *input, const const size_t idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); const auto input_width = input->dimension(idx_width); const auto input_height = input->dimension(idx_height); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(Size2D(input_width, input_height)), "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(Size2D(input_width, input_height)), "Only 1x3, 3x1, 1x5, 5x1, 7x1, 1x7, 3x3 and 5x5 kernels are supported"); ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4); const Size2D &output_tile = winograd_info.output_tile_size; - const std::array<Size2D, 4> supported_tile_sizes = { { Size2D(2U, 2U), Size2D(4U, 4U), Size2D(1U, 6U), Size2D(6U, 1U) } }; + const std::array<Size2D, 8> supported_tile_sizes = { { Size2D(2U, 2U), Size2D(4U, 4U), Size2D(1U, 6U), Size2D(6U, 1U), Size2D(4, 1), Size2D(1, 4), Size2D(2, 1), Size2D(1, 2) } }; ARM_COMPUTE_RETURN_ERROR_ON(std::end(supported_tile_sizes) == std::find(std::begin(supported_tile_sizes), std::end(supported_tile_sizes), output_tile)); // Checks performed when output is configured @@ -305,6 +305,10 @@ template class NEWinogradLayerTransformWeightsKernel<float, 2, 2, 5, 5>; template class NEWinogradLayerTransformWeightsKernel<float, 1, 6, 1, 3>; template class NEWinogradLayerTransformWeightsKernel<float, 6, 1, 3, 1>; +template class NEWinogradLayerTransformWeightsKernel<float, 1, 4, 1, 5>; +template class NEWinogradLayerTransformWeightsKernel<float, 4, 1, 5, 1>; +template class NEWinogradLayerTransformWeightsKernel<float, 1, 2, 1, 7>; +template class NEWinogradLayerTransformWeightsKernel<float, 2, 1, 7, 1>; // Input transform template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols> @@ -401,6 +405,11 @@ template class NEWinogradLayerTransformInputKernel<float, 2, 2, 5, 5>; template class NEWinogradLayerTransformInputKernel<float, 1, 6, 1, 3>; template class NEWinogradLayerTransformInputKernel<float, 6, 1, 3, 1>; +template class NEWinogradLayerTransformInputKernel<float, 1, 4, 1, 5>; +template class NEWinogradLayerTransformInputKernel<float, 4, 1, 5, 1>; +template class NEWinogradLayerTransformInputKernel<float, 1, 2, 1, 7>; +template class NEWinogradLayerTransformInputKernel<float, 2, 1, 7, 1>; + // Output transform template <typename T, int OutputTileRows, int OutputTileCols, int KernelRows, int KernelCols> @@ -513,4 +522,9 @@ template class NEWinogradLayerTransformOutputKernel<float, 2, 2, 5, 5>; template class NEWinogradLayerTransformOutputKernel<float, 1, 6, 1, 3>; template class NEWinogradLayerTransformOutputKernel<float, 6, 1, 3, 1>; +template class NEWinogradLayerTransformOutputKernel<float, 1, 4, 1, 5>; +template class NEWinogradLayerTransformOutputKernel<float, 4, 1, 5, 1>; +template class NEWinogradLayerTransformOutputKernel<float, 1, 2, 1, 7>; +template class NEWinogradLayerTransformOutputKernel<float, 2, 1, 7, 1>; + } // namespace arm_compute |