From 3ae6580d26cee273b750e69b5c06f6efa4caf3fb Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 30 Apr 2021 09:55:26 +0100 Subject: Update heuristic for CLConvolutionLayer - Call direct convolution when filter size height is greater than or equal to 5 Resolves COMPMID-4439 Change-Id: Ie8ccccc0629eb4c74bd62c4bb4ced47f6898a945 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5589 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- src/runtime/CL/functions/CLConvolutionLayer.cpp | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp index aaabbe0cfc..1082a475b2 100644 --- a/src/runtime/CL/functions/CLConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp @@ -224,21 +224,20 @@ ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo * { return ConvolutionMethod::DIRECT; } - if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) >= output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) + if(gpu_target == GPUTarget::G71) { - if(gpu_target == GPUTarget::G71) + if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) >= output->dimension(idx_c)) + && (CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math))) { - if(CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) - { - return ConvolutionMethod::FFT; - } - else - { - return ConvolutionMethod::GEMM; - } + return ConvolutionMethod::FFT; + } + } + else + { + if((weights->dimension(idx_h) >= 5) && (input->dimension(idx_c) >= output->dimension(idx_c)) && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) + { + return ConvolutionMethod::DIRECT; } - - return ConvolutionMethod::DIRECT; } if(input->dimension(idx_c) < 16) { -- cgit v1.2.1