diff options
author | giuros01 <giuseppe.rossini@arm.com> | 2019-03-26 17:44:40 +0000 |
---|---|---|
committer | Giuseppe Rossini <giuseppe.rossini@arm.com> | 2019-05-09 12:38:22 +0000 |
commit | 154bc1c3e6a0182e2130c7966af3944ee6ca20b3 (patch) | |
tree | 6cf717250870f311c99a4fbb6cdae4dfa84d5aae /src/runtime/NEON/functions/NEConvolutionLayer.cpp | |
parent | ae1a89ed670956b9722fe57c2dc36c75e5f948ec (diff) | |
download | ComputeLibrary-154bc1c3e6a0182e2130c7966af3944ee6ca20b3.tar.gz |
COMPMID-1973: Implement FFTConvolutionLayer on NEON
Change-Id: I2e667c0411bda0164a616ffe44473a78de6752c9
Signed-off-by: giuros01 <giuseppe.rossini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1066
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEConvolutionLayer.cpp | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp index 5059162032..a62459b3e8 100644 --- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp @@ -73,6 +73,13 @@ void NEConvolutionLayer::configure(ITensor *input, const ITensor *weights, const _function = std::move(f); break; } + case ConvolutionMethod::FFT: + { + auto f = arm_compute::support::cpp14::make_unique<NEFFTConvolutionLayer>(_memory_manager); + f->configure(input, weights, biases, output, conv_info, act_info); + _function = std::move(f); + break; + } default: ARM_COMPUTE_ERROR("Not supported."); break; @@ -97,6 +104,10 @@ Status NEConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo case ConvolutionMethod::DIRECT: //Validate Gemm-based Convolution ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info)); + case ConvolutionMethod::FFT: + // Validate FFT-based convolution layer + ARM_COMPUTE_RETURN_ON_ERROR(NEFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info)); + break; default: ARM_COMPUTE_ERROR("Not supported."); break; @@ -148,12 +159,22 @@ ConvolutionMethod NEConvolutionLayer::get_convolution_method(const ITensorInfo * return (*found).second; } - if(dilation != Size2D(1U, 1U) || input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)) <= 16) + if(dilation != Size2D(1U, 1U)) { return ConvolutionMethod::GEMM; } - - return bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) ? ConvolutionMethod::WINOGRAD : ConvolutionMethod::GEMM; + else + { + if((weights->dimension(idx_h) > 7) && (input->dimension(idx_c) > output->dimension(idx_c)) && (NEFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info))) + { + return ConvolutionMethod::FFT; + } + if(input->dimension(idx_c) < 16) + { + return ConvolutionMethod::GEMM; + } + return bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) ? ConvolutionMethod::WINOGRAD : ConvolutionMethod::GEMM; + } } void NEConvolutionLayer::run() |