diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-06-04 17:23:04 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-06-25 16:01:33 +0000 |
commit | 26dcbc7ec604eefce46d728d946878e16a470274 (patch) | |
tree | 4138c25390af53c4d9b3c1eb288391edd2953967 /src/runtime/CL/functions/CLConvolutionLayer.cpp | |
parent | bc415af5ee9517fd113e9ea0f01fdc84f9693dc4 (diff) | |
download | ComputeLibrary-26dcbc7ec604eefce46d728d946878e16a470274.tar.gz |
COMPMID-2158 Add SRGAN shapes in ConvolutionLayer heurists to dispatch direct convolution
Change-Id: I94b853a6ade6f027b7a404174ebca6c600050c28
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1400
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLConvolutionLayer.cpp | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp index 165d523100..d794cde1f4 100644 --- a/src/runtime/CL/functions/CLConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp @@ -188,11 +188,17 @@ ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo * } else { - if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) > output->dimension(idx_c)) && ( CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) + // SRGAN + if((input->dimension(idx_h) > 720U) && (output->dimension(idx_h) > 720U) && (weights->dimension(idx_h) == 9) && (conv_info.pad_top() < 3) + && (CLDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) + { + return ConvolutionMethod::DIRECT; + } + if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) > output->dimension(idx_c)) && (CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) { return ConvolutionMethod::FFT; } - if (input->dimension(idx_c) < 16) + if(input->dimension(idx_c) < 16) { return ConvolutionMethod::GEMM; } |