aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuConv2d.cpp')
-rw-r--r--src/cpu/operators/CpuConv2d.cpp22
1 files changed, 17 insertions, 5 deletions
diff --git a/src/cpu/operators/CpuConv2d.cpp b/src/cpu/operators/CpuConv2d.cpp
index 19311733db..26ca2ee783 100644
--- a/src/cpu/operators/CpuConv2d.cpp
+++ b/src/cpu/operators/CpuConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021, 2023 Arm Limited.
+ * Copyright (c) 2017-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -209,12 +209,24 @@ ConvolutionMethod CpuConv2d::get_convolution_method(const ITensorInfo *i
}
else
{
+ const bool gemmDirectConv2d_validates =
+ bool(CpuGemmDirectConv2d::validate(input, weights, nullptr, output, info));
+
// SRGAN
// Output might not be initialized when it is an internal tensor of the layer using the convolution
- if (input->total_size() > 1e7 && (weights->dimension(idx_h) > 7) &&
- (CpuDirectConv2d::validate(input, weights, nullptr, output, conv_info, act_info)))
+ if (input->total_size() > 1e7 && weights->dimension(idx_h) > 7)
{
- return ConvolutionMethod::DIRECT;
+ // This configuration is memory demanding for GEMM method. GEMM_CONV2D which uses indirect convolution
+ // kernels underneath is the best option.
+ if (gemmDirectConv2d_validates)
+ {
+ return ConvolutionMethod::GEMM_CONV2D;
+ }
+ else if (bool(CpuDirectConv2d::validate(input, weights, nullptr, output, conv_info, act_info)))
+ {
+ // NCHW data layout is not supported by GEMM_CONV2D
+ return ConvolutionMethod::DIRECT;
+ }
}
if (input->dimension(idx_c) < 16)
{
@@ -270,7 +282,7 @@ ConvolutionMethod CpuConv2d::get_convolution_method(const ITensorInfo *i
{
return ConvolutionMethod::WINOGRAD;
}
- if (bool(CpuGemmDirectConv2d::validate(input, weights, nullptr, output, info)))
+ if (gemmDirectConv2d_validates)
{
return ConvolutionMethod::GEMM_CONV2D;
}