aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp194
1 files changed, 104 insertions, 90 deletions
diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
index e41b0be860..44ea3a0881 100644
--- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
@@ -275,111 +275,93 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor *
int n_gemms = 0;
int N_BLOCK = 0; // Size of block used by GEMM.
- const bool square_kernel = kernel_size.width == kernel_size.height;
-
- if(square_kernel)
- {
- switch(kernel_size.width)
- {
- case 3:
- {
- if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4)
- {
- using config = NEWinogradLayerConfiguration<float, float, 4, 4, 3, 3>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else
- {
- using config = NEWinogradLayerConfiguration<float, float, 2, 2, 3, 3>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- break;
- }
- case 5:
- {
- using config = NEWinogradLayerConfiguration<float, float, 2, 2, 5, 5>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR("Not supported.");
- break;
- }
- }
- }
- else
+ if(kernel_size == Size2D(3, 3))
{
- if(kernel_size == Size2D(1, 3))
- {
- using config = NEWinogradLayerConfiguration<float, float, 6, 1, 3, 1>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else if(kernel_size == Size2D(3, 1))
- {
- using config = NEWinogradLayerConfiguration<float, float, 1, 6, 1, 3>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else if(kernel_size == Size2D(1, 5))
- {
- using config = NEWinogradLayerConfiguration<float, float, 4, 1, 5, 1>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else if(kernel_size == Size2D(5, 1))
- {
- using config = NEWinogradLayerConfiguration<float, float, 1, 4, 1, 5>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else if(kernel_size == Size2D(1, 7))
+ if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4)
{
- using config = NEWinogradLayerConfiguration<float, float, 2, 1, 7, 1>;
+ using config = NEWinogradLayerConfiguration<float, float, 4, 4, 3, 3>;
transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
n_gemms = config::WinogradBase::N_GEMMS;
N_BLOCK = config::WinogradConv::N_BLOCK;
}
- else if(kernel_size == Size2D(7, 1))
+ else
{
- using config = NEWinogradLayerConfiguration<float, float, 1, 2, 1, 7>;
+ using config = NEWinogradLayerConfiguration<float, float, 2, 2, 3, 3>;
transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
n_gemms = config::WinogradBase::N_GEMMS;
N_BLOCK = config::WinogradConv::N_BLOCK;
}
- else
- {
- ARM_COMPUTE_ERROR("Not supported.");
- }
+ }
+ else if(kernel_size == Size2D(5, 5))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 2, 2, 5, 5>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(1, 3))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 6, 1, 3, 1>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(3, 1))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 1, 6, 1, 3>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(1, 5))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 4, 1, 5, 1>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(5, 1))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 1, 4, 1, 5>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(1, 7))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 2, 1, 7, 1>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(7, 1))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 1, 2, 1, 7>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported.");
}
const PaddingType use_padding_type = (conv_info.pad_top() != 0u || conv_info.pad_left() != 0) ? PADDING_SAME : PADDING_VALID;
@@ -612,34 +594,66 @@ Status NEWinogradConvolutionLayer::validate(const ITensorInfo *input, const ITen
if(kernel_size == Size2D(3, 3))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != conv_info.pad_left(), "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_bottom(), "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_left(), "Only SAME or VALID padding supported");
return validate_kernel_3x3(input_dims, input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(5, 5))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != conv_info.pad_left(), "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_bottom(), "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_left(), "Only SAME or VALID padding supported");
return validate_kernel_5x5(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
if(kernel_size == Size2D(3, 1))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported");
return validate_kernel_3x1(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(1, 3))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported");
return validate_kernel_1x3(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(5, 1))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported");
return validate_kernel_5x1(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(1, 5))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported");
return validate_kernel_1x5(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(7, 1))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 3, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 3, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported");
return validate_kernel_7x1(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(1, 7))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 3, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 3, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported");
return validate_kernel_1x7(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else