aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEConvolutionLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEConvolutionLayer.cpp5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
index 1755e9a774..dcd26fc1cd 100644
--- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
@@ -102,7 +102,7 @@ Status NEConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info));
break;
case ConvolutionMethod::DIRECT:
- //Validate Gemm-based Convolution
+ //Validate Direct Convolution
ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info));
break;
case ConvolutionMethod::FFT:
@@ -167,7 +167,8 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo *
else
{
// SRGAN
- if((input->dimension(idx_h) > 720U) && (output->dimension(idx_h) > 720U) && (weights->dimension(idx_h) == 9)
+ // Output might not be initialized when it is an internal tensor of the layer using the convolution
+ if(input->total_size() > 1e7 && (weights->dimension(idx_h) > 7)
&& (NEDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)))
{
return ConvolutionMethod::DIRECT;