diff options
-rw-r--r-- | src/runtime/CL/functions/CLConvolutionLayer.cpp | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp index aaabbe0cfc..1082a475b2 100644 --- a/src/runtime/CL/functions/CLConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp @@ -224,21 +224,20 @@ ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo * { return ConvolutionMethod::DIRECT; } - if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) >= output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) + if(gpu_target == GPUTarget::G71) { - if(gpu_target == GPUTarget::G71) + if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) >= output->dimension(idx_c)) + && (CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math))) { - if(CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) - { - return ConvolutionMethod::FFT; - } - else - { - return ConvolutionMethod::GEMM; - } + return ConvolutionMethod::FFT; + } + } + else + { + if((weights->dimension(idx_h) >= 5) && (input->dimension(idx_c) >= output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) + { + return ConvolutionMethod::DIRECT; } - - return ConvolutionMethod::DIRECT; } if(input->dimension(idx_c) < 16) { |