diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEConvolutionLayer.cpp | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp index 5059162032..a62459b3e8 100644 --- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp @@ -73,6 +73,13 @@ void NEConvolutionLayer::configure(ITensor *input, const ITensor *weights, const _function = std::move(f); break; } + case ConvolutionMethod::FFT: + { + auto f = arm_compute::support::cpp14::make_unique<NEFFTConvolutionLayer>(_memory_manager); + f->configure(input, weights, biases, output, conv_info, act_info); + _function = std::move(f); + break; + } default: ARM_COMPUTE_ERROR("Not supported."); break; @@ -97,6 +104,10 @@ Status NEConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo case ConvolutionMethod::DIRECT: //Validate Gemm-based Convolution ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info)); + case ConvolutionMethod::FFT: + // Validate FFT-based convolution layer + ARM_COMPUTE_RETURN_ON_ERROR(NEFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)); + break; default: ARM_COMPUTE_ERROR("Not supported."); break; @@ -148,12 +159,22 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo * return (*found).second; } - if(dilation != Size2D(1U, 1U) || input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) <= 16) + if(dilation != Size2D(1U, 1U)) { return ConvolutionMethod::GEMM; } - - return bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) ? ConvolutionMethod::WINOGRAD : ConvolutionMethod::GEMM; + else + { + if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) > output->dimension(idx_c)) && (NEFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) + { + return ConvolutionMethod::FFT; + } + if(input->dimension(idx_c) < 16) + { + return ConvolutionMethod::GEMM; + } + return bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) ? ConvolutionMethod::WINOGRAD : ConvolutionMethod::GEMM; + } } void NEConvolutionLayer::run() |