aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/runtime/CL/functions/CLConvolutionLayer.cpp23
1 files changed, 11 insertions, 12 deletions
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index aaabbe0cfc..1082a475b2 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -224,21 +224,20 @@ ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo *
{
return ConvolutionMethod::DIRECT;
}
- if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) >= output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
+ if(gpu_target == GPUTarget::G71)
{
- if(gpu_target == GPUTarget::G71)
+ if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) >= output->dimension(idx_c))
+ && (CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)))
{
- if(CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math))
- {
- return ConvolutionMethod::FFT;
- }
- else
- {
- return ConvolutionMethod::GEMM;
- }
+ return ConvolutionMethod::FFT;
+ }
+ }
+ else
+ {
+ if((weights->dimension(idx_h) >= 5) && (input->dimension(idx_c) >= output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
+ {
+ return ConvolutionMethod::DIRECT;
}
-
- return ConvolutionMethod::DIRECT;
}
if(input->dimension(idx_c) < 16)
{