aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
index 7e82dc4ecd..672684d14f 100644
--- a/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEWinogradConvolutionLayerKernel.cpp
@@ -173,8 +173,10 @@ Status validate_arguments_winograd_output_trans(const ITensorInfo *input, const
const Size2D kernel_dims = winograd_info.kernel_size;
// Number of tiles along the X and Y direction
- const unsigned int num_tiles_x = std::ceil((winograd_info.input_dimensions.x() - (kernel_dims.width - 1) + conv_info.pad_left() + conv_info.pad_right()) / 2.f);
- const unsigned int num_tiles_y = std::ceil((winograd_info.input_dimensions.y() - (kernel_dims.height - 1) + conv_info.pad_top() + conv_info.pad_bottom()) / 2.f);
+ const unsigned int num_tiles_x = std::ceil((winograd_info.input_dimensions.x() - (kernel_dims.width - 1) + conv_info.pad_left() + conv_info.pad_right()) / static_cast<float>
+ (winograd_info.output_tile_size.width));
+ const unsigned int num_tiles_y = std::ceil((winograd_info.input_dimensions.y() - (kernel_dims.height - 1) + conv_info.pad_top() + conv_info.pad_bottom()) / static_cast<float>
+ (winograd_info.output_tile_size.height));
const Size2D num_tiles = Size2D(num_tiles_x, num_tiles_y);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
@@ -301,6 +303,7 @@ Status NEWinogradLayerBatchedGEMMKernel<TIn, TOut, OutputTileRows, OutputTileCol
}
template class NEWinogradLayerBatchedGEMMKernel<float, float, 2, 2, 3, 3>;
+template class NEWinogradLayerBatchedGEMMKernel<float, float, 4, 4, 3, 3>;
template class NEWinogradLayerBatchedGEMMKernel<float, float, 2, 2, 5, 5>;
// Weights transform
@@ -369,6 +372,7 @@ Status NEWinogradLayerTransformWeightsKernel<T, OutputTileRows, OutputTileCols,
}
template class NEWinogradLayerTransformWeightsKernel<float, 2, 2, 3, 3>;
+template class NEWinogradLayerTransformWeightsKernel<float, 4, 4, 3, 3>;
template class NEWinogradLayerTransformWeightsKernel<float, 2, 2, 5, 5>;
// Input transform
@@ -442,6 +446,7 @@ Status NEWinogradLayerTransformInputKernel<T, OutputTileRows, OutputTileCols, Ke
}
template class NEWinogradLayerTransformInputKernel<float, 2, 2, 3, 3>;
+template class NEWinogradLayerTransformInputKernel<float, 4, 4, 3, 3>;
template class NEWinogradLayerTransformInputKernel<float, 2, 2, 5, 5>;
// Output transform
@@ -544,6 +549,7 @@ Status NEWinogradLayerTransformOutputKernel<T, OutputTileRows, OutputTileCols, K
}
template class NEWinogradLayerTransformOutputKernel<float, 2, 2, 3, 3>;
+template class NEWinogradLayerTransformOutputKernel<float, 4, 4, 3, 3>;
template class NEWinogradLayerTransformOutputKernel<float, 2, 2, 5, 5>;
} // namespace arm_compute