aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp194
-rw-r--r--tests/datasets/LargeConvolutionLayerDataset.h30
2 files changed, 119 insertions, 105 deletions
diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
index e41b0be860..44ea3a0881 100644
--- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp
@@ -275,111 +275,93 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor *
int n_gemms = 0;
int N_BLOCK = 0; // Size of block used by GEMM.
- const bool square_kernel = kernel_size.width == kernel_size.height;
-
- if(square_kernel)
- {
- switch(kernel_size.width)
- {
- case 3:
- {
- if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4)
- {
- using config = NEWinogradLayerConfiguration<float, float, 4, 4, 3, 3>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else
- {
- using config = NEWinogradLayerConfiguration<float, float, 2, 2, 3, 3>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- break;
- }
- case 5:
- {
- using config = NEWinogradLayerConfiguration<float, float, 2, 2, 5, 5>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR("Not supported.");
- break;
- }
- }
- }
- else
+ if(kernel_size == Size2D(3, 3))
{
- if(kernel_size == Size2D(1, 3))
- {
- using config = NEWinogradLayerConfiguration<float, float, 6, 1, 3, 1>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else if(kernel_size == Size2D(3, 1))
- {
- using config = NEWinogradLayerConfiguration<float, float, 1, 6, 1, 3>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else if(kernel_size == Size2D(1, 5))
- {
- using config = NEWinogradLayerConfiguration<float, float, 4, 1, 5, 1>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else if(kernel_size == Size2D(5, 1))
- {
- using config = NEWinogradLayerConfiguration<float, float, 1, 4, 1, 5>;
- transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
- transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
- transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
- n_gemms = config::WinogradBase::N_GEMMS;
- N_BLOCK = config::WinogradConv::N_BLOCK;
- }
- else if(kernel_size == Size2D(1, 7))
+ if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4)
{
- using config = NEWinogradLayerConfiguration<float, float, 2, 1, 7, 1>;
+ using config = NEWinogradLayerConfiguration<float, float, 4, 4, 3, 3>;
transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
n_gemms = config::WinogradBase::N_GEMMS;
N_BLOCK = config::WinogradConv::N_BLOCK;
}
- else if(kernel_size == Size2D(7, 1))
+ else
{
- using config = NEWinogradLayerConfiguration<float, float, 1, 2, 1, 7>;
+ using config = NEWinogradLayerConfiguration<float, float, 2, 2, 3, 3>;
transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
n_gemms = config::WinogradBase::N_GEMMS;
N_BLOCK = config::WinogradConv::N_BLOCK;
}
- else
- {
- ARM_COMPUTE_ERROR("Not supported.");
- }
+ }
+ else if(kernel_size == Size2D(5, 5))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 2, 2, 5, 5>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(1, 3))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 6, 1, 3, 1>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(3, 1))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 1, 6, 1, 3>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(1, 5))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 4, 1, 5, 1>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(5, 1))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 1, 4, 1, 5>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(1, 7))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 2, 1, 7, 1>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else if(kernel_size == Size2D(7, 1))
+ {
+ using config = NEWinogradLayerConfiguration<float, float, 1, 2, 1, 7>;
+ transform_input_kernel = support::cpp14::make_unique<config::TransformInputKernel>();
+ transform_weights_kernel = support::cpp14::make_unique<config::TransformWeightsKernel>();
+ transform_output_kernel = support::cpp14::make_unique<config::TransformOutputKernel>();
+ n_gemms = config::WinogradBase::N_GEMMS;
+ N_BLOCK = config::WinogradConv::N_BLOCK;
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported.");
}
const PaddingType use_padding_type = (conv_info.pad_top() != 0u || conv_info.pad_left() != 0) ? PADDING_SAME : PADDING_VALID;
@@ -612,34 +594,66 @@ Status NEWinogradConvolutionLayer::validate(const ITensorInfo *input, const ITen
if(kernel_size == Size2D(3, 3))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != conv_info.pad_left(), "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_bottom(), "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_left(), "Only SAME or VALID padding supported");
return validate_kernel_3x3(input_dims, input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(5, 5))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != conv_info.pad_left(), "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_bottom(), "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_left(), "Only SAME or VALID padding supported");
return validate_kernel_5x5(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
if(kernel_size == Size2D(3, 1))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported");
return validate_kernel_3x1(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(1, 3))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 1, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported");
return validate_kernel_1x3(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(5, 1))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported");
return validate_kernel_5x1(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(1, 5))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 2, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported");
return validate_kernel_1x5(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(7, 1))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 3, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 3, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported");
return validate_kernel_7x1(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else if(kernel_size == Size2D(1, 7))
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 3, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 3, "Only SAME or VALID padding supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported");
return validate_kernel_1x7(input, &input0, &input1, &batched_mm_output, weights, biases, output, winograd_info, act_info);
}
else
diff --git a/tests/datasets/LargeConvolutionLayerDataset.h b/tests/datasets/LargeConvolutionLayerDataset.h
index 170d562f6c..ca38abb43a 100644
--- a/tests/datasets/LargeConvolutionLayerDataset.h
+++ b/tests/datasets/LargeConvolutionLayerDataset.h
@@ -46,15 +46,15 @@ public:
// Batch size 1
add_config(TensorShape(224U, 222U, 64U), TensorShape(3U, 3U, 64U, 64U), TensorShape(64U), TensorShape(224U, 222U, 64U), PadStrideInfo(1, 1, 1, 1));
add_config(TensorShape(112U, 113U, 64U), TensorShape(3U, 3U, 64U, 128U), TensorShape(128U), TensorShape(112U, 113U, 128U), PadStrideInfo(1, 1, 1, 1));
- add_config(TensorShape(112U, 112U, 128U), TensorShape(3U, 3U, 128U, 129U), TensorShape(129U), TensorShape(112U, 110U, 129U), PadStrideInfo(1, 1, 1, 0));
- add_config(TensorShape(53U, 56U, 125U), TensorShape(3U, 3U, 125U, 256U), TensorShape(256U), TensorShape(51U, 56U, 256U), PadStrideInfo(1, 1, 0, 1));
- add_config(TensorShape(56U, 56U, 256U), TensorShape(3U, 3U, 256U, 256U), TensorShape(256U), TensorShape(56U, 54U, 256U), PadStrideInfo(1, 1, 1, 0));
- add_config(TensorShape(28U, 28U, 257U), TensorShape(3U, 3U, 257U, 512U), TensorShape(512U), TensorShape(26U, 28U, 512U), PadStrideInfo(1, 1, 0, 1));
+ add_config(TensorShape(112U, 112U, 128U), TensorShape(3U, 3U, 128U, 129U), TensorShape(129U), TensorShape(112U, 112U, 129U), PadStrideInfo(1, 1, 1, 1));
+ add_config(TensorShape(53U, 56U, 125U), TensorShape(3U, 3U, 125U, 256U), TensorShape(256U), TensorShape(51U, 54U, 256U), PadStrideInfo(1, 1, 0, 0));
+ add_config(TensorShape(56U, 56U, 256U), TensorShape(3U, 3U, 256U, 256U), TensorShape(256U), TensorShape(54U, 54U, 256U), PadStrideInfo(1, 1, 0, 0));
+ add_config(TensorShape(28U, 28U, 257U), TensorShape(3U, 3U, 257U, 512U), TensorShape(512U), TensorShape(28U, 28U, 512U), PadStrideInfo(1, 1, 1, 1));
add_config(TensorShape(28U, 28U, 512U), TensorShape(3U, 3U, 512U, 512U), TensorShape(512U), TensorShape(28U, 28U, 512U), PadStrideInfo(1, 1, 1, 1));
add_config(TensorShape(14U, 14U, 512U), TensorShape(3U, 3U, 512U, 512U), TensorShape(512U), TensorShape(12U, 12U, 512U), PadStrideInfo(1, 1, 0, 0));
// Batch size 3, 2 and 4
add_config(TensorShape(224U, 222U, 64U, 3U), TensorShape(3U, 3U, 64U, 64U), TensorShape(64U), TensorShape(224U, 222U, 64U, 3U), PadStrideInfo(1, 1, 1, 1));
- add_config(TensorShape(112U, 113U, 64U, 2U), TensorShape(3U, 3U, 64U, 128U), TensorShape(128U), TensorShape(110U, 113U, 128U, 2U), PadStrideInfo(1, 1, 0, 1));
+ add_config(TensorShape(112U, 113U, 64U, 2U), TensorShape(3U, 3U, 64U, 128U), TensorShape(128U), TensorShape(110U, 111U, 128U, 2U), PadStrideInfo(1, 1, 0, 0));
add_config(TensorShape(111U, 112U, 127U, 4U), TensorShape(3U, 3U, 127U, 128U), TensorShape(128U), TensorShape(111U, 112U, 128U, 4U), PadStrideInfo(1, 1, 1, 1));
}
};
@@ -110,14 +110,14 @@ public:
{
// Kernel size 5
// Batch size 1
- add_config(TensorShape(224U, 224U, 3U), TensorShape(5U, 5U, 3U, 64U), TensorShape(64U), TensorShape(222U, 222U, 64U), PadStrideInfo(1, 1, 1, 1));
- add_config(TensorShape(123U, 134U, 16U), TensorShape(5U, 5U, 16U, 7U), TensorShape(7U), TensorShape(121U, 130U, 7U), PadStrideInfo(1, 1, 1, 0));
+ add_config(TensorShape(224U, 224U, 3U), TensorShape(5U, 5U, 3U, 64U), TensorShape(64U), TensorShape(220U, 220U, 64U), PadStrideInfo(1, 1, 0, 0));
+ add_config(TensorShape(123U, 134U, 16U), TensorShape(5U, 5U, 16U, 7U), TensorShape(7U), TensorShape(123U, 134U, 7U), PadStrideInfo(1, 1, 2, 2));
add_config(TensorShape(181U, 152U, 42U), TensorShape(5U, 5U, 42U, 100U), TensorShape(100U), TensorShape(177U, 148U, 100U), PadStrideInfo(1, 1, 0, 0));
add_config(TensorShape(200U, 201U, 24U), TensorShape(5U, 5U, 24U, 61), TensorShape(61U), TensorShape(200U, 201U, 61), PadStrideInfo(1, 1, 2, 2));
// Batch size 2, 3 and 4
- add_config(TensorShape(224U, 224U, 3U, 2U), TensorShape(5U, 5U, 3U, 64U), TensorShape(64U), TensorShape(222U, 222U, 64U, 2U), PadStrideInfo(1, 1, 1, 1));
- add_config(TensorShape(123U, 134U, 16U, 3U), TensorShape(5U, 5U, 16U, 7U), TensorShape(7U), TensorShape(121U, 130U, 7U, 3U), PadStrideInfo(1, 1, 1, 0));
+ add_config(TensorShape(224U, 224U, 3U, 2U), TensorShape(5U, 5U, 3U, 64U), TensorShape(64U), TensorShape(220U, 220U, 64U, 2U), PadStrideInfo(1, 1, 0, 0));
+ add_config(TensorShape(123U, 134U, 16U, 3U), TensorShape(5U, 5U, 16U, 7U), TensorShape(7U), TensorShape(123U, 134U, 7U, 3U), PadStrideInfo(1, 1, 2, 2));
add_config(TensorShape(181U, 152U, 42U, 4U), TensorShape(5U, 5U, 42U, 100U), TensorShape(100U), TensorShape(177U, 148U, 100U, 4U), PadStrideInfo(1, 1, 0, 0));
}
};
@@ -128,14 +128,14 @@ public:
LargeWinogradConvolutionLayer5x1Dataset()
{
// Batch size 1
- add_config(TensorShape(224U, 224U, 3U), TensorShape(5U, 1U, 3U, 64U), TensorShape(64U), TensorShape(222U, 224U, 64U), PadStrideInfo(1, 1, 1, 0));
- add_config(TensorShape(123U, 134U, 16U), TensorShape(5U, 1U, 16U, 7U), TensorShape(7U), TensorShape(121U, 134U, 7U), PadStrideInfo(1, 1, 1, 0));
+ add_config(TensorShape(224U, 224U, 3U), TensorShape(5U, 1U, 3U, 64U), TensorShape(64U), TensorShape(224U, 224U, 64U), PadStrideInfo(1, 1, 2, 0));
+ add_config(TensorShape(123U, 134U, 16U), TensorShape(5U, 1U, 16U, 7U), TensorShape(7U), TensorShape(123U, 134U, 7U), PadStrideInfo(1, 1, 2, 0));
add_config(TensorShape(181U, 152U, 42U), TensorShape(5U, 1U, 42U, 100U), TensorShape(100U), TensorShape(177U, 152U, 100U), PadStrideInfo(1, 1, 0, 0));
add_config(TensorShape(200U, 201U, 24U), TensorShape(5U, 1U, 24U, 61), TensorShape(61U), TensorShape(200U, 201U, 61), PadStrideInfo(1, 1, 2, 0));
// Batch size 2, 3 and 4
- add_config(TensorShape(224U, 224U, 3U, 2U), TensorShape(5U, 1U, 3U, 64U), TensorShape(64U), TensorShape(222U, 224U, 64U, 2U), PadStrideInfo(1, 1, 1, 0));
- add_config(TensorShape(123U, 134U, 16U, 3U), TensorShape(5U, 1U, 16U, 7U), TensorShape(7U), TensorShape(121U, 134U, 7U, 3U), PadStrideInfo(1, 1, 1, 0));
+ add_config(TensorShape(224U, 224U, 3U, 2U), TensorShape(5U, 1U, 3U, 64U), TensorShape(64U), TensorShape(224U, 224U, 64U, 2U), PadStrideInfo(1, 1, 2, 0));
+ add_config(TensorShape(123U, 134U, 16U, 3U), TensorShape(5U, 1U, 16U, 7U), TensorShape(7U), TensorShape(123U, 134U, 7U, 3U), PadStrideInfo(1, 1, 2, 0));
add_config(TensorShape(181U, 152U, 42U, 4U), TensorShape(5U, 1U, 42U, 100U), TensorShape(100U), TensorShape(177U, 152U, 100U, 4U), PadStrideInfo(1, 1, 0, 0));
}
};
@@ -146,13 +146,13 @@ public:
LargeWinogradConvolutionLayer1x5Dataset()
{
// Batch size 1
- add_config(TensorShape(224U, 224U, 3U), TensorShape(1U, 5U, 3U, 64U), TensorShape(64U), TensorShape(224U, 222U, 64U), PadStrideInfo(1, 1, 0, 1));
+ add_config(TensorShape(224U, 224U, 3U), TensorShape(1U, 5U, 3U, 64U), TensorShape(64U), TensorShape(224U, 224U, 64U), PadStrideInfo(1, 1, 0, 2));
add_config(TensorShape(123U, 134U, 16U), TensorShape(1U, 5U, 16U, 7U), TensorShape(7U), TensorShape(123U, 130U, 7U), PadStrideInfo(1, 1, 0, 0));
add_config(TensorShape(181U, 152U, 42U), TensorShape(1U, 5U, 42U, 100U), TensorShape(100U), TensorShape(181U, 148U, 100U), PadStrideInfo(1, 1, 0, 0));
add_config(TensorShape(200U, 201U, 24U), TensorShape(1U, 5U, 24U, 61), TensorShape(61U), TensorShape(200U, 201U, 61), PadStrideInfo(1, 1, 0, 2));
// Batch size 2, 3 and 4
- add_config(TensorShape(224U, 224U, 3U, 2U), TensorShape(1U, 5U, 3U, 64U), TensorShape(64U), TensorShape(224U, 222U, 64U, 2U), PadStrideInfo(1, 1, 0, 1));
+ add_config(TensorShape(224U, 224U, 3U, 2U), TensorShape(1U, 5U, 3U, 64U), TensorShape(64U), TensorShape(224U, 224U, 64U, 2U), PadStrideInfo(1, 1, 0, 2));
add_config(TensorShape(123U, 134U, 16U, 3U), TensorShape(1U, 5U, 16U, 7U), TensorShape(7U), TensorShape(123U, 130U, 7U, 3U), PadStrideInfo(1, 1, 0, 0));
add_config(TensorShape(181U, 152U, 42U, 4U), TensorShape(1U, 5U, 42U, 100U), TensorShape(100U), TensorShape(181U, 148U, 100U, 4U), PadStrideInfo(1, 1, 0, 0));
}