diff options
-rw-r--r-- | src/runtime/CL/functions/CLConvolutionLayer.cpp | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp index ac18b966af..aaabbe0cfc 100644 --- a/src/runtime/CL/functions/CLConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp @@ -226,6 +226,18 @@ ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo * } if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) >= output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) { + if(gpu_target == GPUTarget::G71) + { + if(CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) + { + return ConvolutionMethod::FFT; + } + else + { + return ConvolutionMethod::GEMM; + } + } + return ConvolutionMethod::DIRECT; } if(input->dimension(idx_c) < 16) |