From 05fb448bf48e31d723dfd9f4bbf3899ff65f0fba Mon Sep 17 00:00:00 2001 From: giuros01 Date: Tue, 26 Mar 2019 17:44:40 +0000 Subject: COMPMID-1963: Implement FFT (2D) on NEON Change-Id: I3b564be8d7949e00c6544071ef62dd51de838c96 Signed-off-by: giuros01 Reviewed-on: https://review.mlplatform.org/c/1048 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas --- arm_compute/core/NEON/NEKernels.h | 1 + .../core/NEON/kernels/NEFFTRadixStageKernel.h | 19 +- arm_compute/core/NEON/kernels/NEFFTScaleKernel.h | 84 +++ arm_compute/runtime/NEON/NEFunctions.h | 1 + arm_compute/runtime/NEON/functions/NEFFT1D.h | 19 +- arm_compute/runtime/NEON/functions/NEFFT2D.h | 76 +++ src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp | 21 +- src/core/NEON/kernels/NEFFTRadixStageKernel.cpp | 579 ++++++++++++++------- src/core/NEON/kernels/NEFFTScaleKernel.cpp | 136 +++++ src/runtime/NEON/functions/NEFFT1D.cpp | 41 +- src/runtime/NEON/functions/NEFFT2D.cpp | 95 ++++ tests/validation/NEON/FFT.cpp | 74 ++- 12 files changed, 928 insertions(+), 218 deletions(-) create mode 100644 arm_compute/core/NEON/kernels/NEFFTScaleKernel.h create mode 100644 arm_compute/runtime/NEON/functions/NEFFT2D.h create mode 100644 src/core/NEON/kernels/NEFFTScaleKernel.cpp create mode 100644 src/runtime/NEON/functions/NEFFT2D.cpp diff --git a/arm_compute/core/NEON/NEKernels.h b/arm_compute/core/NEON/NEKernels.h index b8ae467c6d..b9716b1e43 100644 --- a/arm_compute/core/NEON/NEKernels.h +++ b/arm_compute/core/NEON/NEKernels.h @@ -64,6 +64,7 @@ #include "arm_compute/core/NEON/kernels/NEErodeKernel.h" #include "arm_compute/core/NEON/kernels/NEFFTDigitReverseKernel.h" #include "arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h" +#include "arm_compute/core/NEON/kernels/NEFFTScaleKernel.h" #include "arm_compute/core/NEON/kernels/NEFastCornersKernel.h" #include "arm_compute/core/NEON/kernels/NEFillArrayKernel.h" #include "arm_compute/core/NEON/kernels/NEFillBorderKernel.h" diff --git a/arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h b/arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h index a4c4be6f35..8498d3c613 100644 --- a/arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h +++ b/arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h @@ -24,10 +24,10 @@ #ifndef __ARM_COMPUTE_NEFFTRADIXSTAGEKERNEL_H__ #define __ARM_COMPUTE_NEFFTRADIXSTAGEKERNEL_H__ -#include "arm_compute/core/NEON/INEKernel.h" - #include "arm_compute/core/KernelDescriptors.h" +#include "arm_compute/core/NEON/INEKernel.h" +#include #include namespace arm_compute @@ -87,12 +87,17 @@ private: ITensor *_output; bool _run_in_place; unsigned int _Nx; + unsigned int _axis; + unsigned int _radix; + + void set_radix_stage_axis0(const FFTRadixStageKernelInfo &config); + void set_radix_stage_axis1(const FFTRadixStageKernelInfo &config); - template - void set_radix_stage_fun(unsigned int radix); + using FFTFunctionPointerAxis0 = std::function; + using FFTFunctionPointerAxis1 = std::function; - using FFTFunctionPointerInPlace = std::function; - FFTFunctionPointerInPlace _func; + FFTFunctionPointerAxis0 _func_0; + FFTFunctionPointerAxis1 _func_1; }; } // namespace arm_compute -#endif /*__ARM_COMPUTE_NEFFTKERNEL_H__ */ +#endif /*__ARM_COMPUTE_NEFFTRADIXSTAGEKERNEL_H__ */ diff --git a/arm_compute/core/NEON/kernels/NEFFTScaleKernel.h b/arm_compute/core/NEON/kernels/NEFFTScaleKernel.h new file mode 100644 index 0000000000..5a19af7e62 --- /dev/null +++ b/arm_compute/core/NEON/kernels/NEFFTScaleKernel.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_NEFFTSCALEKERNEL_H__ +#define __ARM_COMPUTE_NEFFTSCALEKERNEL_H__ + +#include "arm_compute/core/NEON/INEKernel.h" + +#include "arm_compute/core/KernelDescriptors.h" + +namespace arm_compute +{ +// Forward declarations +class ITensor; + +/** Interface for the inverse fft scale kernel. */ +class NEFFTScaleKernel : public INEKernel +{ +public: + const char *name() const override + { + return "NEFFTScaleKernel"; + } + /** Constructor */ + NEFFTScaleKernel(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + NEFFTScaleKernel(const NEFFTScaleKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + NEFFTScaleKernel &operator=(const NEFFTScaleKernel &) = delete; + /** Default Move Constructor. */ + NEFFTScaleKernel(NEFFTScaleKernel &&) = default; + /** Default move assignment operator */ + NEFFTScaleKernel &operator=(NEFFTScaleKernel &&) = default; + /** Default destructor */ + ~NEFFTScaleKernel() = default; + /** Set the input and output tensors. + * + * @param[in,out] input Source tensor. Data types supported: F32. + * @param[out] output Destination tensor. Data type supported: same as @p input + * @param[in] config Kernel configuration + */ + void configure(ITensor *input, ITensor *output, const FFTScaleKernelInfo &config); + /** Static function to check if given info will lead to a valid configuration of @ref NEFFTScaleKernel + * + * @param[in] input Source tensor info. Data types supported: F32. + * @param[in] output Destination tensor info. Data type supported: same as @p input + * @param[in] config Kernel configuration + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *output, const FFTScaleKernelInfo &config); + + // Inherited methods overridden: + void run(const Window &window, const ThreadInfo &info) override; + +private: + ITensor *_input; + ITensor *_output; + float _scale; + bool _run_in_place; + bool _is_conj; +}; +} // namespace arm_compute +#endif /*__ARM_COMPUTE_NEFFTSCALEKERNEL_H__ */ diff --git a/arm_compute/runtime/NEON/NEFunctions.h b/arm_compute/runtime/NEON/NEFunctions.h index d8f54ea231..869eb523dd 100644 --- a/arm_compute/runtime/NEON/NEFunctions.h +++ b/arm_compute/runtime/NEON/NEFunctions.h @@ -64,6 +64,7 @@ #include "arm_compute/runtime/NEON/functions/NEEqualizeHistogram.h" #include "arm_compute/runtime/NEON/functions/NEErode.h" #include "arm_compute/runtime/NEON/functions/NEFFT1D.h" +#include "arm_compute/runtime/NEON/functions/NEFFT2D.h" #include "arm_compute/runtime/NEON/functions/NEFastCorners.h" #include "arm_compute/runtime/NEON/functions/NEFillBorder.h" #include "arm_compute/runtime/NEON/functions/NEFlattenLayer.h" diff --git a/arm_compute/runtime/NEON/functions/NEFFT1D.h b/arm_compute/runtime/NEON/functions/NEFFT1D.h index 9b5ada746a..c706936f77 100644 --- a/arm_compute/runtime/NEON/functions/NEFFT1D.h +++ b/arm_compute/runtime/NEON/functions/NEFFT1D.h @@ -26,6 +26,7 @@ #include "arm_compute/core/NEON/kernels/NEFFTDigitReverseKernel.h" #include "arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h" +#include "arm_compute/core/NEON/kernels/NEFFTScaleKernel.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/FunctionDescriptors.h" @@ -37,24 +38,25 @@ namespace arm_compute // Forward declaration class ITensor; -/** Basic function to execute one dimensional FFT. This function calls the following OpenCL kernels: +/** Basic function to execute one dimensional FFT. This function calls the following NEON kernels: * - * -# @ref CLFFTDigitReverseKernel Performs digit reverse + * -# @ref NEFFTDigitReverseKernel Performs digit reverse * -# @ref NEFFTRadixStageKernel A list of FFT kernels depending on the radix decomposition + * -# @ref NEFFTScaleKernel Performs output scaling in case of in inverse FFT */ class NEFFT1D : public IFunction { public: /** Default Constructor */ NEFFT1D(std::shared_ptr memory_manager = nullptr); - /** Initialise the function's source, destinations and border mode. + /** Initialise the function's source and destinations. * * @param[in] input Source tensor. Data types supported: F32. * @param[out] output Destination tensor. Data types and data layouts supported: Same as @p input. * @param[in] config FFT related configuration */ void configure(const ITensor *input, ITensor *output, const FFT1DInfo &config); - /** Static function to check if given info will lead to a valid configuration of @ref CLFFT1D. + /** Static function to check if given info will lead to a valid configuration of @ref NEFFT1D. * * @param[in] input Source tensor info. Data types supported: F32. * @param[in] output Destination tensor info. Data types and data layouts supported: Same as @p input. @@ -69,11 +71,14 @@ public: protected: MemoryGroup _memory_group; - Tensor _digit_reversed_input; - Tensor _digit_reverse_indices; NEFFTDigitReverseKernel _digit_reverse_kernel; std::vector _fft_kernels; - unsigned int _n_ffts; + NEFFTScaleKernel _scale_kernel; + Tensor _digit_reversed_input; + Tensor _digit_reverse_indices; + unsigned int _num_ffts; + unsigned int _axis; + bool _run_scale; }; } // namespace arm_compute #endif /*__ARM_COMPUTE_NEFFT1D_H__ */ diff --git a/arm_compute/runtime/NEON/functions/NEFFT2D.h b/arm_compute/runtime/NEON/functions/NEFFT2D.h new file mode 100644 index 0000000000..9911cea290 --- /dev/null +++ b/arm_compute/runtime/NEON/functions/NEFFT2D.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_NEFFT2D_H__ +#define __ARM_COMPUTE_NEFFT2D_H__ + +#include "arm_compute/runtime/IFunction.h" + +#include "arm_compute/runtime/FunctionDescriptors.h" +#include "arm_compute/runtime/MemoryGroup.h" +#include "arm_compute/runtime/NEON/functions/NEFFT1D.h" +#include "arm_compute/runtime/Tensor.h" + +namespace arm_compute +{ +// Forward declaration +class ITensor; + +/** Basic function to execute two dimensional FFT. This function calls the following NEON kernels: + * + * -# @ref NEFFT1D 1D FFT is performed on the first given axis + * -# @ref NEFFT1D 1D FFT is performed on the second given axis + */ +class NEFFT2D : public IFunction +{ +public: + /** Default Constructor */ + NEFFT2D(std::shared_ptr memory_manager = nullptr); + /** Initialise the function's source and destinations + * + * @param[in] input Source tensor. Data types supported: F32. + * @param[out] output Destination tensor. Data types and data layouts supported: Same as @p input. + * @param[in] config FFT related configuration + */ + void configure(const ITensor *input, ITensor *output, const FFT2DInfo &config); + /** Static function to check if given info will lead to a valid configuration of @ref NEFFT2D. + * + * @param[in] input Source tensor info. Data types supported: F32. + * @param[in] output Destination tensor info. Data types and data layouts supported: Same as @p input. + * @param[in] config FFT related configuration + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *output, const FFT2DInfo &config); + + // Inherited methods overridden: + void run() override; + +protected: + MemoryGroup _memory_group; + NEFFT1D _first_pass_func; + NEFFT1D _second_pass_func; + Tensor _first_pass_tensor; +}; +} // namespace arm_compute +#endif /*__ARM_COMPUTE_NEFFT2D_H__ */ diff --git a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp index 845fcef4f3..b2ffb01e99 100644 --- a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp +++ b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp @@ -37,7 +37,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(idx, 1, DataType::U32); - ARM_COMPUTE_RETURN_ERROR_ON(axis != 0); + ARM_COMPUTE_RETURN_ERROR_ON(axis > 1); // Checks performed when output is configured if((output != nullptr) && (output->total_size() != 0)) @@ -96,15 +96,24 @@ void NEFFTDigitReverseKernel::run(const Window &window, const ThreadInfo &info) Iterator out(_output, window); const size_t element_size = _input->info()->element_size(); + // Pointers to the buffers + const size_t offset = _input->info()->offset_first_element_in_bytes(); + auto *idx_ptr = reinterpret_cast(_idx->buffer()); + uint8_t *input_ptr = offset + _input->buffer(); + + // Strides + const size_t stride_x = _input->info()->strides_in_bytes()[0]; + const size_t stride_y = _input->info()->strides_in_bytes()[1]; + const size_t stride_z = _input->info()->strides_in_bytes()[2]; + const size_t stride_w = _input->info()->strides_in_bytes()[3]; + execute_window_loop(window, [&](const Coordinates & id) { - unsigned int in_index_1d = *reinterpret_cast(_idx->ptr_to_element(Coordinates(id.x()))); - - auto reverse_id = id; + unsigned int in_index_1d = idx_ptr[id[_axis]]; + auto reverse_id = id; reverse_id.set(_axis, in_index_1d); - memcpy(out.ptr(), _input->ptr_to_element(reverse_id), 2 * element_size); - + memcpy(out.ptr(), input_ptr + reverse_id.x() * stride_x + reverse_id.y() * stride_y + reverse_id.z() * stride_z + reverse_id[3] * stride_w, element_size); }, out); diff --git a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp index b264590791..148bbe915a 100644 --- a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp +++ b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp @@ -24,8 +24,6 @@ #include "arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h" #include "arm_compute/core/ITensor.h" -#include "arm_compute/core/NEON/wrapper/traits.h" -#include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" @@ -34,28 +32,53 @@ #include #include #include +#include + +#include "arm_compute/core/NEON/wrapper/traits.h" +#include "arm_compute/core/NEON/wrapper/wrapper.h" namespace arm_compute { namespace { -constexpr float PI = 3.141592653589793f; +// PI constant (from cmath) +constexpr float kPi = float(M_PI); + +// Constant used in the fft_3 kernel +constexpr float kSqrt3Div2 = 0.866025403784438; + +// Constants used in the fft_5 kernel +constexpr float kW5_0 = 0.30901699437494f; +constexpr float kW5_1 = 0.95105651629515f; +constexpr float kW5_2 = 0.80901699437494f; +constexpr float kW5_3 = 0.58778525229247f; + +// Constants used in the fft_7 kernel +constexpr float kW7_0 = 0.62348980185873f; +constexpr float kW7_1 = 0.78183148246802f; +constexpr float kW7_2 = 0.22252093395631f; +constexpr float kW7_3 = 0.97492791218182f; +constexpr float kW7_4 = 0.90096886790241f; +constexpr float kW7_5 = 0.43388373911755f; + +// Constant used in the fft_8 kernel +constexpr float kSqrt2Div2 = 0.707106781186548; float32x2_t c_mul_neon(float32x2_t a, float32x2_t b) { - float32x2_t tmp = wrapper::vmul(a, b); + using ExactTagType = typename wrapper::traits::neon_vector::tag_type; - const float P1 = wrapper::vgetlane(tmp, 0); - const float P2 = wrapper::vgetlane(tmp, 1); + const float32x2_t mask = { -1.0, 1.0 }; + const float32x2_t tmp0 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{}); + const float32x2_t tmp1 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{}); - const float a_r = wrapper::vgetlane(a, 0); - const float a_i = wrapper::vgetlane(a, 1); - const float b_r = wrapper::vgetlane(b, 0); - const float b_i = wrapper::vgetlane(b, 1); + float32x2_t res = wrapper::vmul(tmp0, b); - const float P3 = (a_r + a_i) * (b_r + b_i); - float32x2_t out = { P1 - P2, P3 - P2 - P1 }; - return out; + b = wrapper::vrev64(b); + b = wrapper::vmul(b, mask); + res = wrapper::vmla(res, tmp1, b); + + return res; } float32x2_t c_mul_neon_img(float32x2_t a, float img_constant) @@ -107,7 +130,6 @@ void fft_2(float32x2_t &x, float32x2_t &y, float32x2_t &w) y = wrapper::vsub(a, b); } -constexpr float sqrt3div2 = 0.866025403784438; void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, const float32x2_t &w2) { float32x2_t a = x; @@ -118,7 +140,7 @@ void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, x = wrapper::vadd(x, c); const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c)); - const auto v2 = c_mul_neon(float32x2_t{ 0.f, -sqrt3div2 }, wrapper::vsub(b, c)); + const auto v2 = c_mul_neon(float32x2_t{ 0.f, -kSqrt3Div2 }, wrapper::vsub(b, c)); y = z = wrapper::vsub(a, v1); y = wrapper::vadd(y, v2); @@ -149,10 +171,6 @@ void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, c x4 = wrapper::vadd(x41, x42); } -constexpr float W5_0 = 0.30901699437494f; -constexpr float W5_1 = 0.95105651629515f; -constexpr float W5_2 = 0.80901699437494f; -constexpr float W5_3 = 0.58778525229247f; void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4) { const auto a = x1; @@ -161,25 +179,25 @@ void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f const auto d = c_mul_neon(w3, x4); const auto e = c_mul_neon(w4, x5); - const auto b0 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, b); - const auto b1 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, b); - const auto b2 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, b); - const auto b3 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, b); + const auto b0 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, b); + const auto b1 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, b); + const auto b2 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, b); + const auto b3 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, b); - const auto c0 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, c); - const auto c1 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, c); - const auto c2 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, c); - const auto c3 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, c); + const auto c0 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, c); + const auto c1 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, c); + const auto c2 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, c); + const auto c3 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, c); - const auto d0 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, d); - const auto d1 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, d); - const auto d2 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, d); - const auto d3 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, d); + const auto d0 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, d); + const auto d1 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, d); + const auto d2 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, d); + const auto d3 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, d); - const auto e0 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, e); - const auto e1 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, e); - const auto e2 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, e); - const auto e3 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, e); + const auto e0 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, e); + const auto e1 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, e); + const auto e2 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, e); + const auto e3 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, e); x1 = reduce_sum_5(a, b, c, d, e); x2 = reduce_sum_5(a, b0, c0, d0, e0); @@ -188,12 +206,6 @@ void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f x5 = reduce_sum_5(a, b3, c3, d3, e3); } -constexpr float W7_0 = 0.62348980185873f; -constexpr float W7_1 = 0.78183148246802f; -constexpr float W7_2 = 0.22252093395631f; -constexpr float W7_3 = 0.97492791218182f; -constexpr float W7_4 = 0.90096886790241f; -constexpr float W7_5 = 0.43388373911755f; void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6) @@ -206,47 +218,47 @@ void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f const auto f = c_mul_neon(w5, x6); const auto g = c_mul_neon(w6, x7); - const auto b0 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, b); - const auto b1 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, b); - const auto b2 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, b); - const auto b3 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, b); - const auto b4 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, b); - const auto b5 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, b); - - const auto c0 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, c); - const auto c1 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, c); - const auto c2 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, c); - const auto c3 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, c); - const auto c4 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, c); - const auto c5 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, c); - - const auto d0 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, d); - const auto d1 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, d); - const auto d2 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, d); - const auto d3 = c_mul_neon(float32x2_t{ -W7_2, +W7_3 }, d); - const auto d4 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, d); - const auto d5 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, d); - - const auto e0 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, e); - const auto e1 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, e); - const auto e2 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, e); - const auto e3 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, e); - const auto e4 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, e); - const auto e5 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, e); - - const auto f0 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, f); - const auto f1 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, f); - const auto f2 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, f); - const auto f3 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, f); - const auto f4 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, f); - const auto f5 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, f); - - const auto g0 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, g); - const auto g1 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, g); - const auto g2 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, g); - const auto g3 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, g); - const auto g4 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, g); - const auto g5 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, g); + const auto b0 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, b); + const auto b1 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, b); + const auto b2 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, b); + const auto b3 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, b); + const auto b4 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, b); + const auto b5 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, b); + + const auto c0 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, c); + const auto c1 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, c); + const auto c2 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, c); + const auto c3 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, c); + const auto c4 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, c); + const auto c5 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, c); + + const auto d0 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, d); + const auto d1 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, d); + const auto d2 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, d); + const auto d3 = c_mul_neon(float32x2_t{ -kW7_2, +kW7_3 }, d); + const auto d4 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, d); + const auto d5 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, d); + + const auto e0 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, e); + const auto e1 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, e); + const auto e2 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, e); + const auto e3 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, e); + const auto e4 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, e); + const auto e5 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, e); + + const auto f0 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, f); + const auto f1 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, f); + const auto f2 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, f); + const auto f3 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, f); + const auto f4 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, f); + const auto f5 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, f); + + const auto g0 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, g); + const auto g1 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, g); + const auto g2 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, g); + const auto g3 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, g); + const auto g4 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, g); + const auto g5 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, g); x1 = reduce_sum_7(a, b, c, d, e, f, g); x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0); @@ -257,7 +269,6 @@ void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5); } -constexpr float sqrt2div2 = 0.707106781186548; void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6, @@ -272,13 +283,13 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f const auto g = c_mul_neon(w6, x7); const auto h = c_mul_neon(w7, x8); - const auto b0 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, b); + const auto b0 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, b); const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b); - const auto b2 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, b); + const auto b2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, b); const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b); - const auto b4 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, b); + const auto b4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, b); const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b); - const auto b6 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, b); + const auto b6 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, b); const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c); const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c); @@ -288,13 +299,13 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c); const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c); - const auto d0 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, d); + const auto d0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, d); const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d); - const auto d2 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, d); + const auto d2 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, d); const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d); - const auto d4 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, d); + const auto d4 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, d); const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d); - const auto d6 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, d); + const auto d6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, d); const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e); const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e); @@ -304,13 +315,13 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e); const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e); - const auto f0 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, f); + const auto f0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, f); const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f); - const auto f2 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, f); + const auto f2 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, f); const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f); - const auto f4 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, f); + const auto f4 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, f); const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f); - const auto f6 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, f); + const auto f6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, f); const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g); const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g); @@ -320,13 +331,13 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g); const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g); - const auto h0 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, h); + const auto h0 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, h); const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h); - const auto h2 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, h); + const auto h2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, h); const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h); - const auto h4 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, h); + const auto h4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, h); const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h); - const auto h6 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, h); + const auto h6 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, h); x1 = reduce_sum_8(a, b, c, d, e, f, g, h); x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0); @@ -339,17 +350,12 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f } template -void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) +void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) { - unsigned int Nx2 = 2 * Nx; - float alpha = 2 * PI / Nx2; - - float32x2_t w{ 1, 0 }; - const float32x2_t w_m{ cosf(alpha), -sinf(alpha) }; - + float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) { - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx2) + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) { auto a = float32x2_t{ 0, 0 }; auto b = float32x2_t{ 0, 0 }; @@ -386,19 +392,38 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) } } -template -void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) +void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) { - const unsigned int Nx3 = 3 * Nx; - const float alpha = 2 * PI / float(Nx3); - float32x2_t w{ 1, 0 }; - const float32x2_t w_m{ cosf(alpha), -sinf(alpha) }; + float32x2_t w{ 1.0f, 0.0f }; + for(unsigned int j = 0; j < Nx; j++) + { + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + { + // Load inputs + float32x2_t a = wrapper::vload(x + M * k); + float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); + // Base-case prime transform + fft_2(a, b, w); + + // Write outputs + wrapper::vstore(X + M * k, a); + wrapper::vstore(X + M * (k + 2 * Nx), b); + } + + w = c_mul_neon(w, w_m); + } +} + +template +void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +{ + float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) { const auto w2 = c_mul_neon(w, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx3) + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) { // Load inputs float32x2_t a = { 0, 0 }; @@ -435,21 +460,42 @@ void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) } } -template -void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) +void fft_radix_3_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) { - unsigned int Nx4 = 4 * Nx; - const float alpha = 2 * PI / float(Nx4); + float32x2_t w{ 1.0f, 0.0f }; + for(unsigned int j = 0; j < Nx; j++) + { + const auto w2 = c_mul_neon(w, w); + + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + { + // Load inputs + float32x2_t a = wrapper::vload(x + M * k); + float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); + + // Base-case prime transform + fft_3(a, b, c, w, w2); - float32x2_t w{ 1, 0 }; - float32x2_t w_m{ cosf(alpha), -sinf(alpha) }; + // Store the output + wrapper::vstore(X + M * k, a); + wrapper::vstore(X + M * (k + 2 * Nx), b); + wrapper::vstore(X + M * (k + 4 * Nx), c); + } + w = c_mul_neon(w, w_m); + } +} +template +void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +{ + float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) { const auto w2 = c_mul_neon(w, w); const auto w3 = c_mul_neon(w2, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx4) + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) { float32x2_t a = { 0, 0 }; float32x2_t b = { 0, 0 }; @@ -494,22 +540,46 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) } } -template -void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) +void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) { - unsigned int Nx5 = 5 * Nx; - const float alpha = 2 * PI / float(Nx5); + float32x2_t w{ 1.0f, 0.0f }; + for(unsigned int j = 0; j < Nx; j++) + { + const auto w2 = c_mul_neon(w, w); + const auto w3 = c_mul_neon(w2, w); + + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + { + // Load inputs + float32x2_t a = wrapper::vload(x + M * k); + float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); + float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx)); + + // Base-case prime transform + fft_4(a, b, c, d, w, w2, w3); - float32x2_t w{ 1, 0 }; - float32x2_t w_m{ cosf(alpha), -sinf(alpha) }; + wrapper::vstore(X + M * k, a); + wrapper::vstore(X + M * (k + 2 * Nx), b); + wrapper::vstore(X + M * (k + 4 * Nx), c); + wrapper::vstore(X + M * (k + 6 * Nx), d); + } + + w = c_mul_neon(w, w_m); + } +} +template +void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +{ + float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) { const float32x2_t w2 = c_mul_neon(w, w); const float32x2_t w3 = c_mul_neon(w2, w); const float32x2_t w4 = c_mul_neon(w3, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx5) + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) { float32x2_t a = { 0, 0 }; float32x2_t b = { 0, 0 }; @@ -560,15 +630,43 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) } } -template -void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) +void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) { - unsigned int Nx7 = 7 * Nx; - const float alpha = 2 * PI / float(Nx7); + float32x2_t w{ 1.0f, 0.0f }; + for(unsigned int j = 0; j < Nx; j++) + { + const float32x2_t w2 = c_mul_neon(w, w); + const float32x2_t w3 = c_mul_neon(w2, w); + const float32x2_t w4 = c_mul_neon(w3, w); - float32x2_t w{ 1, 0 }; - float32x2_t w_m{ cosf(alpha), -sinf(alpha) }; + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + { + // Load inputs + float32x2_t a = wrapper::vload(x + M * k); + float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); + float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx)); + float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx)); + + // Base-case prime transform + fft_5(a, b, c, d, e, w, w2, w3, w4); + + // Store outputs + wrapper::vstore(X + M * k, a); + wrapper::vstore(X + M * (k + 2 * Nx), b); + wrapper::vstore(X + M * (k + 4 * Nx), c); + wrapper::vstore(X + M * (k + 6 * Nx), d); + wrapper::vstore(X + M * (k + 8 * Nx), e); + } + + w = c_mul_neon(w, w_m); + } +} +template +void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +{ + float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) { const float32x2_t w2 = c_mul_neon(w, w); @@ -577,7 +675,7 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) const float32x2_t w5 = c_mul_neon(w4, w); const float32x2_t w6 = c_mul_neon(w5, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx7) + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) { float32x2_t a = { 0, 0 }; float32x2_t b = { 0, 0 }; @@ -637,15 +735,49 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) } } -template -void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) +void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) { - unsigned int Nx8 = 8 * Nx; - const float alpha = 2 * PI / float(Nx8); + float32x2_t w{ 1.0f, 0.0f }; + for(unsigned int j = 0; j < Nx; j++) + { + const float32x2_t w2 = c_mul_neon(w, w); + const float32x2_t w3 = c_mul_neon(w2, w); + const float32x2_t w4 = c_mul_neon(w3, w); + const float32x2_t w5 = c_mul_neon(w4, w); + const float32x2_t w6 = c_mul_neon(w5, w); - float32x2_t w{ 1, 0 }; - const float32x2_t w_m{ cosf(alpha), -sinf(alpha) }; + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + { + // Load inputs + float32x2_t a = wrapper::vload(x + M * k); + float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); + float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx)); + float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx)); + float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx)); + float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx)); + + // Base-case prime transform + fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6); + + // Store outputs + wrapper::vstore(X + M * k, a); + wrapper::vstore(X + M * (k + 2 * Nx), b); + wrapper::vstore(X + M * (k + 4 * Nx), c); + wrapper::vstore(X + M * (k + 6 * Nx), d); + wrapper::vstore(X + M * (k + 8 * Nx), e); + wrapper::vstore(X + M * (k + 10 * Nx), f); + wrapper::vstore(X + M * (k + 12 * Nx), g); + } + + w = c_mul_neon(w, w_m); + } +} +template +void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +{ + float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) { const float32x2_t w2 = c_mul_neon(w, w); @@ -655,7 +787,7 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) const float32x2_t w6 = c_mul_neon(w5, w); const float32x2_t w7 = c_mul_neon(w6, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx8) + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) { // Load inputs float32x2_t a = { 0, 0 }; @@ -724,11 +856,54 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int N) } } +void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) +{ + float32x2_t w{ 1.0f, 0.0f }; + for(unsigned int j = 0; j < Nx; j++) + { + const float32x2_t w2 = c_mul_neon(w, w); + const float32x2_t w3 = c_mul_neon(w2, w); + const float32x2_t w4 = c_mul_neon(w3, w); + const float32x2_t w5 = c_mul_neon(w4, w); + const float32x2_t w6 = c_mul_neon(w5, w); + const float32x2_t w7 = c_mul_neon(w6, w); + + for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + { + // Load inputs + float32x2_t a = wrapper::vload(x + M * k); + float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); + float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx)); + float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx)); + float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx)); + float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx)); + float32x2_t h = wrapper::vload(x + M * (k + 14 * Nx)); + + // Base-case prime transform + fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7); + + // Store outputs + wrapper::vstore(X + M * k, a); + wrapper::vstore(X + M * (k + 2 * Nx), b); + wrapper::vstore(X + M * (k + 4 * Nx), c); + wrapper::vstore(X + M * (k + 6 * Nx), d); + wrapper::vstore(X + M * (k + 8 * Nx), e); + wrapper::vstore(X + M * (k + 10 * Nx), f); + wrapper::vstore(X + M * (k + 12 * Nx), g); + wrapper::vstore(X + M * (k + 14 * Nx), h); + } + + w = c_mul_neon(w, w_m); + } +} + Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON(config.axis != 0); + ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1); ARM_COMPUTE_RETURN_ERROR_ON(NEFFTRadixStageKernel::supported_radix().count(config.radix) == 0); + ARM_COMPUTE_UNUSED(config); // Checks performed when output is configured if((output != nullptr) && (output->total_size() != 0)) @@ -742,12 +917,14 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config) { + ARM_COMPUTE_UNUSED(config); + if(output != nullptr) { auto_init_if_empty(*output, *input); } - Window win = calculate_max_window(*input, Steps(config.radix)); + Window win = calculate_max_window(*input, Steps()); if(output != nullptr) { output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape())); @@ -758,36 +935,51 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen } // namespace NEFFTRadixStageKernel::NEFFTRadixStageKernel() - : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _func() + : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1() { } -template -void NEFFTRadixStageKernel::set_radix_stage_fun(unsigned int radix) +void NEFFTRadixStageKernel::set_radix_stage_axis0(const FFTRadixStageKernelInfo &config) { - switch(radix) + // FFT table axis 0: [radix, first_stage] + static std::map> fft_table_axis0; + + if(fft_table_axis0.empty()) { - case 2: - _func = &fft_radix_2_axes_0; - break; - case 3: - _func = &fft_radix_3_axes_0; - break; - case 4: - _func = &fft_radix_4_axes_0; - break; - case 5: - _func = &fft_radix_5_axes_0; - break; - case 7: - _func = &fft_radix_7_axes_0; - break; - case 8: - _func = &fft_radix_8_axes_0; - break; - default: - ARM_COMPUTE_ERROR("Radix not supported"); + fft_table_axis0[2][false] = &fft_radix_2_axes_0; + fft_table_axis0[3][false] = &fft_radix_3_axes_0; + fft_table_axis0[4][false] = &fft_radix_4_axes_0; + fft_table_axis0[5][false] = &fft_radix_5_axes_0; + fft_table_axis0[7][false] = &fft_radix_7_axes_0; + fft_table_axis0[8][false] = &fft_radix_8_axes_0; + + fft_table_axis0[2][true] = &fft_radix_2_axes_0; + fft_table_axis0[3][true] = &fft_radix_3_axes_0; + fft_table_axis0[4][true] = &fft_radix_4_axes_0; + fft_table_axis0[5][true] = &fft_radix_5_axes_0; + fft_table_axis0[7][true] = &fft_radix_7_axes_0; + fft_table_axis0[8][true] = &fft_radix_8_axes_0; + } + + _func_0 = fft_table_axis0[config.radix][config.is_first_stage]; +} + +void NEFFTRadixStageKernel::set_radix_stage_axis1(const FFTRadixStageKernelInfo &config) +{ + // FFT table axis 1: [radix, first_stage] + static std::map fft_table_axis1; + + if(fft_table_axis1.empty()) + { + fft_table_axis1[2] = &fft_radix_2_axes_1; + fft_table_axis1[3] = &fft_radix_3_axes_1; + fft_table_axis1[4] = &fft_radix_4_axes_1; + fft_table_axis1[5] = &fft_radix_5_axes_1; + fft_table_axis1[7] = &fft_radix_7_axes_1; + fft_table_axis1[8] = &fft_radix_8_axes_1; } + + _func_1 = fft_table_axis1[config.radix]; } void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFTRadixStageKernelInfo &config) @@ -806,14 +998,20 @@ void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFT _output = output; _run_in_place = (output == nullptr) || (output == input); _Nx = config.Nx; + _axis = config.axis; + _radix = config.radix; - if(config.is_first_stage) - { - set_radix_stage_fun(config.radix); - } - else + switch(config.axis) { - set_radix_stage_fun(config.radix); + case 0: + set_radix_stage_axis0(config); + break; + case 1: + set_radix_stage_axis1(config); + break; + default: + ARM_COMPUTE_ERROR("Axis not supported"); + break; } // Configure kernel window @@ -841,23 +1039,40 @@ std::set NEFFTRadixStageKernel::supported_radix() void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info) { - ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); + ARM_COMPUTE_UNUSED(info); Window input_window = window; - input_window.set(Window::DimX, 0); - - unsigned int N = _input->info()->dimension(0); + input_window.set(_axis, 0); Iterator in(_input, input_window); Iterator out(_run_in_place ? _input : _output, input_window); - execute_window_loop(input_window, [&](const Coordinates &) + // Precompute FFT constants + const unsigned int NxRadix = _radix * _Nx; + const float alpha = 2.0f * kPi / float(NxRadix); + const float32x2_t w_m{ cosf(alpha), -sinf(alpha) }; + + if(_axis == 0) + { + const unsigned int N = _input->info()->dimension(0); + execute_window_loop(input_window, [&](const Coordinates &) + { + _func_0(reinterpret_cast(out.ptr()), reinterpret_cast(in.ptr()), _Nx, NxRadix, w_m, N); + }, + in, out); + } + else { - _func(reinterpret_cast(out.ptr()), reinterpret_cast(in.ptr()), _Nx, N); - }, - in, out); + const unsigned int N = _input->info()->dimension(0); + const unsigned int M = _input->info()->dimension(1); + execute_window_loop(input_window, [&](const Coordinates &) + { + _func_1(reinterpret_cast(out.ptr()), reinterpret_cast(in.ptr()), _Nx, NxRadix, w_m, N, M); + }, + in, out); + } ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); diff --git a/src/core/NEON/kernels/NEFFTScaleKernel.cpp b/src/core/NEON/kernels/NEFFTScaleKernel.cpp new file mode 100644 index 0000000000..6568755e5d --- /dev/null +++ b/src/core/NEON/kernels/NEFFTScaleKernel.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/NEON/kernels/NEFFTScaleKernel.h" + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/NEON/wrapper/wrapper.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/Window.h" + +#include + +namespace arm_compute +{ +namespace +{ +void scale_complex(float *c_in, float *c_out, bool is_conjugate, float scale) +{ + const auto a = wrapper::vload(c_in); + auto b = wrapper::vdiv(a, float32x2_t{ scale, scale }); + if(is_conjugate) + { + const float img_part = wrapper::vgetlane(b, 1); + b = wrapper::vsetlane(-img_part, b, 1); + } + + wrapper::vstore(c_out, b); +} + +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32); + + // Checks performed when output is configured + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON(output->num_channels() != 1 && output->num_channels() != 2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + } + + return Status{}; +} + +std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) +{ + // Configure kernel window + Window win = calculate_max_window(*input, Steps()); + + if(output != nullptr) + { + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*output, *input->clone()); + + // NEFFTScaleKernel doesn't need padding so update_window_and_padding() can be skipped + Coordinates coord; + coord.set_num_dimensions(output->num_dimensions()); + output->set_valid_region(ValidRegion(coord, output->tensor_shape())); + } + + return std::make_pair(Status{}, win); +} +} // namespace + +NEFFTScaleKernel::NEFFTScaleKernel() + : _input(nullptr), _output(nullptr), _scale(), _run_in_place(false), _is_conj(false) +{ +} + +void NEFFTScaleKernel::configure(ITensor *input, ITensor *output, const FFTScaleKernelInfo &config) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr)); + + _input = input; + _output = output; + _run_in_place = (output == nullptr) || (output == input); + _is_conj = config.conjugate; + _scale = config.scale; + + // Configure kernel window + auto win_config = validate_and_configure_window(input->info(), _run_in_place ? nullptr : output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + INEKernel::configure(win_config.second); +} + +Status NEFFTScaleKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTScaleKernelInfo &config) +{ + ARM_COMPUTE_UNUSED(config); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get()).first); + + return Status{}; +} + +void NEFFTScaleKernel::run(const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); + ARM_COMPUTE_UNUSED(info); + + Window input_window = window; + input_window.set(Window::DimX, 0); + + Iterator in(_input, input_window); + Iterator out(_run_in_place ? _input : _output, input_window); + + execute_window_loop(window, [&](const Coordinates &) + { + scale_complex(reinterpret_cast(out.ptr()), reinterpret_cast(in.ptr()), _is_conj, _scale); + }, + in, out); +} +} // namespace arm_compute diff --git a/src/runtime/NEON/functions/NEFFT1D.cpp b/src/runtime/NEON/functions/NEFFT1D.cpp index d3ff674a2a..665efeb440 100644 --- a/src/runtime/NEON/functions/NEFFT1D.cpp +++ b/src/runtime/NEON/functions/NEFFT1D.cpp @@ -31,7 +31,7 @@ namespace arm_compute { NEFFT1D::NEFFT1D(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _digit_reversed_input(), _digit_reverse_indices(), _digit_reverse_kernel(), _fft_kernels(), _n_ffts(0) + : _memory_group(std::move(memory_manager)), _digit_reverse_kernel(), _fft_kernels(), _scale_kernel(), _digit_reversed_input(), _digit_reverse_indices(), _num_ffts(0), _axis(0), _run_scale(false) { } @@ -43,6 +43,11 @@ void NEFFT1D::configure(const ITensor *input, ITensor *output, const FFT1DInfo & const auto decomposed_vector = arm_compute::helpers::fft::decompose_stages(N, supported_radix); ARM_COMPUTE_ERROR_ON(decomposed_vector.empty()); + // Flags + _run_scale = config.direction == FFTDirection::Inverse; + _axis = config.axis; + const bool is_c2r = input->info()->num_channels() == 2 && output->info()->num_channels() == 1; + // Configure digit reverse TensorInfo digit_reverse_indices_info(TensorShape(input->info()->tensor_shape()[config.axis]), 1, DataType::U32); _digit_reverse_indices.allocator()->init(digit_reverse_indices_info); @@ -51,19 +56,19 @@ void NEFFT1D::configure(const ITensor *input, ITensor *output, const FFT1DInfo & // Create and configure FFT kernels unsigned int Nx = 1; - _n_ffts = decomposed_vector.size(); - _fft_kernels.resize(_n_ffts); - for(unsigned int i = 0; i < _n_ffts; ++i) + + _num_ffts = decomposed_vector.size(); + _fft_kernels.resize(_num_ffts); + for(unsigned int i = 0; i < _num_ffts; ++i) { const unsigned int radix_for_stage = decomposed_vector.at(i); - FFTRadixStageKernelInfo fft_kernel_desc; - fft_kernel_desc.axis = config.axis; - fft_kernel_desc.radix = radix_for_stage; - fft_kernel_desc.Nx = Nx; - fft_kernel_desc.is_first_stage = (i == 0); - _fft_kernels[i].configure(&_digit_reversed_input, i == (_n_ffts - 1) ? output : nullptr, fft_kernel_desc); - + FFTRadixStageKernelInfo fft_kernel_info; + fft_kernel_info.axis = config.axis; + fft_kernel_info.radix = radix_for_stage; + fft_kernel_info.Nx = Nx; + fft_kernel_info.is_first_stage = (i == 0); + _fft_kernels[i].configure(&_digit_reversed_input, i == (_num_ffts - 1) && !is_c2r ? output : nullptr, fft_kernel_info); Nx *= radix_for_stage; } @@ -80,7 +85,7 @@ Status NEFFT1D::validate(const ITensorInfo *input, const ITensorInfo *output, co { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON(config.axis != 0); + ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1); // Check if FFT is decomposable const auto supported_radix = NEFFTRadixStageKernel::supported_radix(); @@ -102,11 +107,17 @@ void NEFFT1D::run() { MemoryGroupResourceScope scope_mg(_memory_group); - NEScheduler::get().schedule(&_digit_reverse_kernel, Window::DimY); + NEScheduler::get().schedule(&_digit_reverse_kernel, (_axis == 0 ? Window::DimY : Window::DimX)); + + for(unsigned int i = 0; i < _num_ffts; ++i) + { + NEScheduler::get().schedule(&_fft_kernels[i], (_axis == 0 ? Window::DimY : Window::DimX)); + } - for(unsigned int i = 0; i < _n_ffts; ++i) + // Run output scaling + if(_run_scale) { - NEScheduler::get().schedule(&_fft_kernels[i], Window::DimY); + NEScheduler::get().schedule(&_scale_kernel, Window::DimY); } } } // namespace arm_compute diff --git a/src/runtime/NEON/functions/NEFFT2D.cpp b/src/runtime/NEON/functions/NEFFT2D.cpp new file mode 100644 index 0000000000..9210ecfa2e --- /dev/null +++ b/src/runtime/NEON/functions/NEFFT2D.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/runtime/NEON/functions/NEFFT2D.h" + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/Scheduler.h" + +namespace arm_compute +{ +NEFFT2D::NEFFT2D(std::shared_ptr memory_manager) + : _memory_group(memory_manager), _first_pass_func(memory_manager), _second_pass_func(memory_manager), _first_pass_tensor() +{ +} + +void NEFFT2D::configure(const ITensor *input, ITensor *output, const FFT2DInfo &config) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_ERROR_THROW_ON(NEFFT2D::validate(input->info(), output->info(), config)); + + // Setup first pass + FFT1DInfo first_pass_config; + first_pass_config.axis = config.axes.first; + first_pass_config.direction = config.direction; + _memory_group.manage(&_first_pass_tensor); + _first_pass_func.configure(input, &_first_pass_tensor, first_pass_config); + + // Setup second pass + FFT1DInfo second_pass_config; + second_pass_config.axis = config.axes.second; + second_pass_config.direction = config.direction; + _second_pass_func.configure(&_first_pass_tensor, output, second_pass_config); + _first_pass_tensor.allocator()->allocate(); +} + +Status NEFFT2D::validate(const ITensorInfo *input, const ITensorInfo *output, const FFT2DInfo &config) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); + + // Create intermediate tensor info + TensorInfo first_pass_tensor(input->clone()->set_is_resizable(true).reset_padding().set_num_channels(2)); + + // Validate first pass + FFT1DInfo first_pass_config; + first_pass_config.axis = config.axes.first; + first_pass_config.direction = config.direction; + ARM_COMPUTE_RETURN_ON_ERROR(NEFFT1D::validate(input, &first_pass_tensor, first_pass_config)); + + // Validate second pass + FFT1DInfo second_pass_config; + second_pass_config.axis = config.axes.second; + second_pass_config.direction = config.direction; + ARM_COMPUTE_RETURN_ON_ERROR(NEFFT1D::validate(&first_pass_tensor, output, second_pass_config)); + + // Checks performed when output is configured + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + } + + return Status{}; +} + +void NEFFT2D::run() +{ + _memory_group.acquire(); + + _first_pass_func.run(); + _second_pass_func.run(); + + _memory_group.release(); +} +} // namespace arm_compute diff --git a/tests/validation/NEON/FFT.cpp b/tests/validation/NEON/FFT.cpp index 14fb7e2518..598c8bb10d 100644 --- a/tests/validation/NEON/FFT.cpp +++ b/tests/validation/NEON/FFT.cpp @@ -23,6 +23,7 @@ */ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEFFT1D.h" +#include "arm_compute/runtime/NEON/functions/NEFFT2D.h" #include "arm_compute/runtime/Tensor.h" #include "tests/NEON/Accessor.h" #include "tests/framework/Asserts.h" @@ -49,6 +50,13 @@ const auto shapes_1d = framework::dataset::make("TensorShape", { TensorShape(2U TensorShape(96U, 2U, 2U) }); +const auto shapes_2d = framework::dataset::make("TensorShape", { TensorShape(2U, 2U, 3U), TensorShape(3U, 6U, 3U), + TensorShape(4U, 5U, 3U), TensorShape(5U, 7U, 3U), + TensorShape(7U, 25U, 3U), TensorShape(8U, 2U, 3U), + TensorShape(9U, 16U, 3U), TensorShape(25U, 32U, 3U), + TensorShape(192U, 128U, 2U) + }); + const auto ActivationFunctionsSmallDataset = framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), @@ -127,8 +135,72 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEFFT1DFixture, framework::DatasetMode:: } TEST_SUITE_END() // FP32 TEST_SUITE_END() // Float - TEST_SUITE_END() // FFT1D + +TEST_SUITE(FFT2D) + +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(shapes_2d, data_types), + shape, data_type) +{ + // Create tensors + Tensor src = create_tensor(shape, data_type, 2); + Tensor dst = create_tensor(shape, data_type, 2); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Create and configure function + NEFFT2D fft2d; + fft2d.configure(&src, &dst, FFT2DInfo()); + + // Validate valid region + const ValidRegion valid_region = shape_to_valid_region(shape); + validate(src.info()->valid_region(), valid_region); + validate(dst.info()->valid_region(), valid_region); + + // Validate padding + validate(src.info()->padding(), PaddingSize()); + validate(dst.info()->padding(), PaddingSize()); +} + +// *INDENT-OFF* +// clang-format off +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( + framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F32), // Mismatching data types + TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F32), // Mismatching shapes + TensorInfo(TensorShape(32U, 25U, 2U), 3, DataType::F32), // Invalid channels + TensorInfo(TensorShape(32U, 13U, 2U), 2, DataType::F32), // Undecomposable FFT + TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F32), + }), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F16), + TensorInfo(TensorShape(16U, 25U, 2U), 2, DataType::F32), + TensorInfo(TensorShape(32U, 25U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 2, DataType::F32), + TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F32), + })), + framework::dataset::make("Expected", { false, false, false, false, true })), + input_info, output_info, expected) +{ + const Status s = NEFFT2D::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), FFT2DInfo()); + ARM_COMPUTE_EXPECT(bool(s) == expected, framework::LogLevel::ERRORS); +} +// clang-format on +// *INDENT-ON* + +template +using NEFFT2DFixture = FFTValidationFixture; + +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, NEFFT2DFixture, framework::DatasetMode::ALL, combine(shapes_2d, framework::dataset::make("DataType", DataType::F32))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_f32, tolerance_num); +} +TEST_SUITE_END() // FP32 +TEST_SUITE_END() // Float +TEST_SUITE_END() // FFT2D + TEST_SUITE_END() // NEON } // namespace validation } // namespace test -- cgit v1.2.1