diff options
Diffstat (limited to 'src/cpu')
-rw-r--r-- | src/cpu/operators/CpuConv2d.cpp | 22 |
1 files changed, 17 insertions, 5 deletions
diff --git a/src/cpu/operators/CpuConv2d.cpp b/src/cpu/operators/CpuConv2d.cpp index 19311733db..26ca2ee783 100644 --- a/src/cpu/operators/CpuConv2d.cpp +++ b/src/cpu/operators/CpuConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021, 2023 Arm Limited. + * Copyright (c) 2017-2021, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -209,12 +209,24 @@ ConvolutionMethod CpuConv2d::get_convolution_method(const ITensorInfo *i } else { + const bool gemmDirectConv2d_validates = + bool(CpuGemmDirectConv2d::validate(input, weights, nullptr, output, info)); + // SRGAN // Output might not be initialized when it is an internal tensor of the layer using the convolution - if (input->total_size() > 1e7 && (weights->dimension(idx_h) > 7) && - (CpuDirectConv2d::validate(input, weights, nullptr, output, conv_info, act_info))) + if (input->total_size() > 1e7 && weights->dimension(idx_h) > 7) { - return ConvolutionMethod::DIRECT; + // This configuration is memory demanding for GEMM method. GEMM_CONV2D which uses indirect convolution + // kernels underneath is the best option. + if (gemmDirectConv2d_validates) + { + return ConvolutionMethod::GEMM_CONV2D; + } + else if (bool(CpuDirectConv2d::validate(input, weights, nullptr, output, conv_info, act_info))) + { + // NCHW data layout is not supported by GEMM_CONV2D + return ConvolutionMethod::DIRECT; + } } if (input->dimension(idx_c) < 16) { @@ -270,7 +282,7 @@ ConvolutionMethod CpuConv2d::get_convolution_method(const ITensorInfo *i { return ConvolutionMethod::WINOGRAD; } - if (bool(CpuGemmDirectConv2d::validate(input, weights, nullptr, output, info))) + if (gemmDirectConv2d_validates) { return ConvolutionMethod::GEMM_CONV2D; } |