From 154bc1c3e6a0182e2130c7966af3944ee6ca20b3 Mon Sep 17 00:00:00 2001 From: giuros01 Date: Tue, 26 Mar 2019 17:44:40 +0000 Subject: COMPMID-1973: Implement FFTConvolutionLayer on NEON Change-Id: I2e667c0411bda0164a616ffe44473a78de6752c9 Signed-off-by: giuros01 Reviewed-on: https://review.mlplatform.org/c/1066 Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- src/runtime/NEON/functions/NEFFT1D.cpp | 39 ++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 9 deletions(-) (limited to 'src/runtime/NEON/functions/NEFFT1D.cpp') diff --git a/src/runtime/NEON/functions/NEFFT1D.cpp b/src/runtime/NEON/functions/NEFFT1D.cpp index 665efeb440..25ba1c8391 100644 --- a/src/runtime/NEON/functions/NEFFT1D.cpp +++ b/src/runtime/NEON/functions/NEFFT1D.cpp @@ -37,6 +37,9 @@ NEFFT1D::NEFFT1D(std::shared_ptr memory_manager) void NEFFT1D::configure(const ITensor *input, ITensor *output, const FFT1DInfo &config) { + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_ERROR_THROW_ON(NEFFT1D::validate(input->info(), output->info(), config)); + // Decompose size to radix factors const auto supported_radix = NEFFTRadixStageKernel::supported_radix(); const unsigned int N = input->info()->tensor_shape()[config.axis]; @@ -44,21 +47,25 @@ void NEFFT1D::configure(const ITensor *input, ITensor *output, const FFT1DInfo & ARM_COMPUTE_ERROR_ON(decomposed_vector.empty()); // Flags - _run_scale = config.direction == FFTDirection::Inverse; - _axis = config.axis; + _run_scale = config.direction == FFTDirection::Inverse; + const bool is_c2r = input->info()->num_channels() == 2 && output->info()->num_channels() == 1; // Configure digit reverse + FFTDigitReverseKernelInfo digit_reverse_config; + digit_reverse_config.axis = config.axis; + digit_reverse_config.conjugate = config.direction == FFTDirection::Inverse; TensorInfo digit_reverse_indices_info(TensorShape(input->info()->tensor_shape()[config.axis]), 1, DataType::U32); _digit_reverse_indices.allocator()->init(digit_reverse_indices_info); _memory_group.manage(&_digit_reversed_input); - _digit_reverse_kernel.configure(input, &_digit_reversed_input, &_digit_reverse_indices, config.axis); + _digit_reverse_kernel.configure(input, &_digit_reversed_input, &_digit_reverse_indices, digit_reverse_config); // Create and configure FFT kernels unsigned int Nx = 1; - - _num_ffts = decomposed_vector.size(); + _num_ffts = decomposed_vector.size(); _fft_kernels.resize(_num_ffts); + _axis = config.axis; + for(unsigned int i = 0; i < _num_ffts; ++i) { const unsigned int radix_for_stage = decomposed_vector.at(i); @@ -68,10 +75,20 @@ void NEFFT1D::configure(const ITensor *input, ITensor *output, const FFT1DInfo & fft_kernel_info.radix = radix_for_stage; fft_kernel_info.Nx = Nx; fft_kernel_info.is_first_stage = (i == 0); - _fft_kernels[i].configure(&_digit_reversed_input, i == (_num_ffts - 1) && !is_c2r ? output : nullptr, fft_kernel_info); + _fft_kernels[i].configure(&_digit_reversed_input, ((i == (_num_ffts - 1)) && !is_c2r) ? output : nullptr, fft_kernel_info); + Nx *= radix_for_stage; } + // Configure scale kernel + if(_run_scale) + { + FFTScaleKernelInfo scale_config; + scale_config.scale = static_cast(N); + scale_config.conjugate = config.direction == FFTDirection::Inverse; + is_c2r ? _scale_kernel.configure(&_digit_reversed_input, output, scale_config) : _scale_kernel.configure(output, nullptr, scale_config); + } + // Allocate tensors _digit_reversed_input.allocator()->allocate(); _digit_reverse_indices.allocator()->allocate(); @@ -84,8 +101,9 @@ void NEFFT1D::configure(const ITensor *input, ITensor *output, const FFT1DInfo & Status NEFFT1D::validate(const ITensorInfo *input, const ITensorInfo *output, const FFT1DInfo &config) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(std::set({ 0, 1 }).count(config.axis) == 0); // Check if FFT is decomposable const auto supported_radix = NEFFTRadixStageKernel::supported_radix(); @@ -96,6 +114,9 @@ Status NEFFT1D::validate(const ITensorInfo *input, const ITensorInfo *output, co // Checks performed when output is configured if((output != nullptr) && (output->total_size() != 0)) { + // All combinations are supported except real input with real output (i.e., both input channels set to 1) + ARM_COMPUTE_RETURN_ERROR_ON(output->num_channels() == 1 && input->num_channels() == 1); + ARM_COMPUTE_RETURN_ERROR_ON(output->num_channels() > 2); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); } @@ -107,7 +128,7 @@ void NEFFT1D::run() { MemoryGroupResourceScope scope_mg(_memory_group); - NEScheduler::get().schedule(&_digit_reverse_kernel, (_axis == 0 ? Window::DimY : Window::DimX)); + NEScheduler::get().schedule(&_digit_reverse_kernel, (_axis == 0 ? Window::DimY : Window::DimZ)); for(unsigned int i = 0; i < _num_ffts; ++i) { -- cgit v1.2.1