diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEConvolutionLayer.cpp | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp index 1755e9a774..dcd26fc1cd 100644 --- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp @@ -102,7 +102,7 @@ Status NEConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info)); break; case ConvolutionMethod::DIRECT: - //Validate Gemm-based Convolution + //Validate Direct Convolution ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info)); break; case ConvolutionMethod::FFT: @@ -167,7 +167,8 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo * else { // SRGAN - if((input->dimension(idx_h) > 720U) && (output->dimension(idx_h) > 720U) && (weights->dimension(idx_h) == 9) + // Output might not be initialized when it is an internal tensor of the layer using the convolution + if(input->total_size() > 1e7 && (weights->dimension(idx_h) > 7) && (NEDirectConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) { return ConvolutionMethod::DIRECT; |