aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/runtime/CL/functions/CLConvolutionLayer.cpp10
-rw-r--r--src/runtime/NEON/functions/NEConvolutionLayer.cpp6
2 files changed, 14 insertions, 2 deletions
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index 165d523100..d794cde1f4 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -188,11 +188,17 @@ ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo *
}
else
{
- if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) > output->dimension(idx_c)) && ( CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
+ // SRGAN
+ if((input->dimension(idx_h) > 720U) && (output->dimension(idx_h) > 720U) && (weights->dimension(idx_h) == 9) && (conv_info.pad_top() < 3)
+ && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
+ {
+ return ConvolutionMethod::DIRECT;
+ }
+ if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) > output->dimension(idx_c)) && (CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
{
return ConvolutionMethod::FFT;
}
- if (input->dimension(idx_c) < 16)
+ if(input->dimension(idx_c) < 16)
{
return ConvolutionMethod::GEMM;
}
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
index a62459b3e8..6379367dd8 100644
--- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
@@ -165,6 +165,12 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo *
}
else
{
+ // SRGAN
+ if((input->dimension(idx_h) > 720U) && (output->dimension(idx_h) > 720U) && (weights->dimension(idx_h) == 9)
+ && (NEDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
+ {
+ return ConvolutionMethod::DIRECT;
+ }
if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) > output->dimension(idx_c)) && (NEFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
{
return ConvolutionMethod::FFT;