aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorgiuros01 <giuseppe.rossini@arm.com>2019-03-26 17:44:40 +0000
committerGiuseppe Rossini <giuseppe.rossini@arm.com>2019-05-01 14:00:38 +0000
commit05fb448bf48e31d723dfd9f4bbf3899ff65f0fba (patch)
tree610576f2df7f1fc616a165c516f2a9475981f819 /src
parenta4f378dcd39addd4a63db1c0848f2c120804f4eb (diff)
downloadComputeLibrary-05fb448bf48e31d723dfd9f4bbf3899ff65f0fba.tar.gz
COMPMID-1963: Implement FFT (2D) on NEON
Change-Id: I3b564be8d7949e00c6544071ef62dd51de838c96 Signed-off-by: giuros01 <giuseppe.rossini@arm.com> Reviewed-on: https://review.mlplatform.org/c/1048 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp21
-rw-r--r--src/core/NEON/kernels/NEFFTRadixStageKernel.cpp579
-rw-r--r--src/core/NEON/kernels/NEFFTScaleKernel.cpp136
-rw-r--r--src/runtime/NEON/functions/NEFFT1D.cpp41
-rw-r--r--src/runtime/NEON/functions/NEFFT2D.cpp95
5 files changed, 669 insertions, 203 deletions
diff --git a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
index 845fcef4f3..b2ffb01e99 100644
--- a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
+++ b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
@@ -37,7 +37,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(idx, 1, DataType::U32);
- ARM_COMPUTE_RETURN_ERROR_ON(axis != 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(axis > 1);
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
@@ -96,15 +96,24 @@ void NEFFTDigitReverseKernel::run(const Window &window, const ThreadInfo &info)
Iterator out(_output, window);
const size_t element_size = _input->info()->element_size();
+ // Pointers to the buffers
+ const size_t offset = _input->info()->offset_first_element_in_bytes();
+ auto *idx_ptr = reinterpret_cast<unsigned int *>(_idx->buffer());
+ uint8_t *input_ptr = offset + _input->buffer();
+
+ // Strides
+ const size_t stride_x = _input->info()->strides_in_bytes()[0];
+ const size_t stride_y = _input->info()->strides_in_bytes()[1];
+ const size_t stride_z = _input->info()->strides_in_bytes()[2];
+ const size_t stride_w = _input->info()->strides_in_bytes()[3];
+
execute_window_loop(window, [&](const Coordinates & id)
{
- unsigned int in_index_1d = *reinterpret_cast<unsigned int *>(_idx->ptr_to_element(Coordinates(id.x())));
-
- auto reverse_id = id;
+ unsigned int in_index_1d = idx_ptr[id[_axis]];
+ auto reverse_id = id;
reverse_id.set(_axis, in_index_1d);
- memcpy(out.ptr(), _input->ptr_to_element(reverse_id), 2 * element_size);
-
+ memcpy(out.ptr(), input_ptr + reverse_id.x() * stride_x + reverse_id.y() * stride_y + reverse_id.z() * stride_z + reverse_id[3] * stride_w, element_size);
},
out);
diff --git a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
index b264590791..148bbe915a 100644
--- a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
+++ b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
@@ -24,8 +24,6 @@
#include "arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h"
#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/NEON/wrapper/traits.h"
-#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
@@ -34,28 +32,53 @@
#include <arm_neon.h>
#include <cmath>
#include <complex>
+#include <map>
+
+#include "arm_compute/core/NEON/wrapper/traits.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
namespace arm_compute
{
namespace
{
-constexpr float PI = 3.141592653589793f;
+// PI constant (from cmath)
+constexpr float kPi = float(M_PI);
+
+// Constant used in the fft_3 kernel
+constexpr float kSqrt3Div2 = 0.866025403784438;
+
+// Constants used in the fft_5 kernel
+constexpr float kW5_0 = 0.30901699437494f;
+constexpr float kW5_1 = 0.95105651629515f;
+constexpr float kW5_2 = 0.80901699437494f;
+constexpr float kW5_3 = 0.58778525229247f;
+
+// Constants used in the fft_7 kernel
+constexpr float kW7_0 = 0.62348980185873f;
+constexpr float kW7_1 = 0.78183148246802f;
+constexpr float kW7_2 = 0.22252093395631f;
+constexpr float kW7_3 = 0.97492791218182f;
+constexpr float kW7_4 = 0.90096886790241f;
+constexpr float kW7_5 = 0.43388373911755f;
+
+// Constant used in the fft_8 kernel
+constexpr float kSqrt2Div2 = 0.707106781186548;
float32x2_t c_mul_neon(float32x2_t a, float32x2_t b)
{
- float32x2_t tmp = wrapper::vmul(a, b);
+ using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
- const float P1 = wrapper::vgetlane(tmp, 0);
- const float P2 = wrapper::vgetlane(tmp, 1);
+ const float32x2_t mask = { -1.0, 1.0 };
+ const float32x2_t tmp0 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
+ const float32x2_t tmp1 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
- const float a_r = wrapper::vgetlane(a, 0);
- const float a_i = wrapper::vgetlane(a, 1);
- const float b_r = wrapper::vgetlane(b, 0);
- const float b_i = wrapper::vgetlane(b, 1);
+ float32x2_t res = wrapper::vmul(tmp0, b);
- const float P3 = (a_r + a_i) * (b_r + b_i);
- float32x2_t out = { P1 - P2, P3 - P2 - P1 };
- return out;
+ b = wrapper::vrev64(b);
+ b = wrapper::vmul(b, mask);
+ res = wrapper::vmla(res, tmp1, b);
+
+ return res;
}
float32x2_t c_mul_neon_img(float32x2_t a, float img_constant)
@@ -107,7 +130,6 @@ void fft_2(float32x2_t &x, float32x2_t &y, float32x2_t &w)
y = wrapper::vsub(a, b);
}
-constexpr float sqrt3div2 = 0.866025403784438;
void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, const float32x2_t &w2)
{
float32x2_t a = x;
@@ -118,7 +140,7 @@ void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w,
x = wrapper::vadd(x, c);
const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c));
- const auto v2 = c_mul_neon(float32x2_t{ 0.f, -sqrt3div2 }, wrapper::vsub(b, c));
+ const auto v2 = c_mul_neon(float32x2_t{ 0.f, -kSqrt3Div2 }, wrapper::vsub(b, c));
y = z = wrapper::vsub(a, v1);
y = wrapper::vadd(y, v2);
@@ -149,10 +171,6 @@ void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, c
x4 = wrapper::vadd(x41, x42);
}
-constexpr float W5_0 = 0.30901699437494f;
-constexpr float W5_1 = 0.95105651629515f;
-constexpr float W5_2 = 0.80901699437494f;
-constexpr float W5_3 = 0.58778525229247f;
void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4)
{
const auto a = x1;
@@ -161,25 +179,25 @@ void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto d = c_mul_neon(w3, x4);
const auto e = c_mul_neon(w4, x5);
- const auto b0 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, b);
- const auto b1 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, b);
- const auto b2 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, b);
- const auto b3 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, b);
+ const auto b0 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, b);
+ const auto b1 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, b);
+ const auto b2 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, b);
+ const auto b3 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, b);
- const auto c0 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, c);
- const auto c1 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, c);
- const auto c2 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, c);
- const auto c3 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, c);
+ const auto c0 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, c);
+ const auto c1 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, c);
+ const auto c2 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, c);
+ const auto c3 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, c);
- const auto d0 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, d);
- const auto d1 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, d);
- const auto d2 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, d);
- const auto d3 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, d);
+ const auto d0 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, d);
+ const auto d1 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, d);
+ const auto d2 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, d);
+ const auto d3 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, d);
- const auto e0 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, e);
- const auto e1 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, e);
- const auto e2 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, e);
- const auto e3 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, e);
+ const auto e0 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, e);
+ const auto e1 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, e);
+ const auto e2 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, e);
+ const auto e3 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, e);
x1 = reduce_sum_5(a, b, c, d, e);
x2 = reduce_sum_5(a, b0, c0, d0, e0);
@@ -188,12 +206,6 @@ void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
x5 = reduce_sum_5(a, b3, c3, d3, e3);
}
-constexpr float W7_0 = 0.62348980185873f;
-constexpr float W7_1 = 0.78183148246802f;
-constexpr float W7_2 = 0.22252093395631f;
-constexpr float W7_3 = 0.97492791218182f;
-constexpr float W7_4 = 0.90096886790241f;
-constexpr float W7_5 = 0.43388373911755f;
void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3,
const float32x2_t &w4,
const float32x2_t &w5, const float32x2_t &w6)
@@ -206,47 +218,47 @@ void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto f = c_mul_neon(w5, x6);
const auto g = c_mul_neon(w6, x7);
- const auto b0 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, b);
- const auto b1 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, b);
- const auto b2 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, b);
- const auto b3 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, b);
- const auto b4 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, b);
- const auto b5 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, b);
-
- const auto c0 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, c);
- const auto c1 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, c);
- const auto c2 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, c);
- const auto c3 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, c);
- const auto c4 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, c);
- const auto c5 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, c);
-
- const auto d0 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, d);
- const auto d1 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, d);
- const auto d2 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, d);
- const auto d3 = c_mul_neon(float32x2_t{ -W7_2, +W7_3 }, d);
- const auto d4 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, d);
- const auto d5 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, d);
-
- const auto e0 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, e);
- const auto e1 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, e);
- const auto e2 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, e);
- const auto e3 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, e);
- const auto e4 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, e);
- const auto e5 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, e);
-
- const auto f0 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, f);
- const auto f1 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, f);
- const auto f2 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, f);
- const auto f3 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, f);
- const auto f4 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, f);
- const auto f5 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, f);
-
- const auto g0 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, g);
- const auto g1 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, g);
- const auto g2 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, g);
- const auto g3 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, g);
- const auto g4 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, g);
- const auto g5 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, g);
+ const auto b0 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, b);
+ const auto b1 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, b);
+ const auto b2 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, b);
+ const auto b3 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, b);
+ const auto b4 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, b);
+ const auto b5 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, b);
+
+ const auto c0 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, c);
+ const auto c1 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, c);
+ const auto c2 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, c);
+ const auto c3 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, c);
+ const auto c4 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, c);
+ const auto c5 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, c);
+
+ const auto d0 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, d);
+ const auto d1 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, d);
+ const auto d2 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, d);
+ const auto d3 = c_mul_neon(float32x2_t{ -kW7_2, +kW7_3 }, d);
+ const auto d4 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, d);
+ const auto d5 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, d);
+
+ const auto e0 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, e);
+ const auto e1 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, e);
+ const auto e2 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, e);
+ const auto e3 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, e);
+ const auto e4 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, e);
+ const auto e5 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, e);
+
+ const auto f0 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, f);
+ const auto f1 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, f);
+ const auto f2 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, f);
+ const auto f3 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, f);
+ const auto f4 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, f);
+ const auto f5 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, f);
+
+ const auto g0 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, g);
+ const auto g1 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, g);
+ const auto g2 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, g);
+ const auto g3 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, g);
+ const auto g4 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, g);
+ const auto g5 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, g);
x1 = reduce_sum_7(a, b, c, d, e, f, g);
x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0);
@@ -257,7 +269,6 @@ void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5);
}
-constexpr float sqrt2div2 = 0.707106781186548;
void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2,
const float32x2_t &w3,
const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6,
@@ -272,13 +283,13 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto g = c_mul_neon(w6, x7);
const auto h = c_mul_neon(w7, x8);
- const auto b0 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, b);
+ const auto b0 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, b);
const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b);
- const auto b2 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, b);
+ const auto b2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, b);
const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b);
- const auto b4 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, b);
+ const auto b4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, b);
const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b);
- const auto b6 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, b);
+ const auto b6 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, b);
const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c);
const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c);
@@ -288,13 +299,13 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c);
const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c);
- const auto d0 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, d);
+ const auto d0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, d);
const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d);
- const auto d2 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, d);
+ const auto d2 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, d);
const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d);
- const auto d4 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, d);
+ const auto d4 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, d);
const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d);
- const auto d6 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, d);
+ const auto d6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, d);
const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e);
const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e);
@@ -304,13 +315,13 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e);
const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e);
- const auto f0 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, f);
+ const auto f0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, f);
const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f);
- const auto f2 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, f);
+ const auto f2 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, f);
const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f);
- const auto f4 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, f);
+ const auto f4 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, f);
const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f);
- const auto f6 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, f);
+ const auto f6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, f);
const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g);
const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g);
@@ -320,13 +331,13 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g);
const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g);
- const auto h0 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, h);
+ const auto h0 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, h);
const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h);
- const auto h2 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, h);
+ const auto h2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, h);
const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h);
- const auto h4 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, h);
+ const auto h4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, h);
const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h);
- const auto h6 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, h);
+ const auto h6 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, h);
x1 = reduce_sum_8(a, b, c, d, e, f, g, h);
x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0);
@@ -339,17 +350,12 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
}
template <bool first_stage>
-void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
+void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
- unsigned int Nx2 = 2 * Nx;
- float alpha = 2 * PI / Nx2;
-
- float32x2_t w{ 1, 0 };
- const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
-
+ float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
{
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx2)
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
auto a = float32x2_t{ 0, 0 };
auto b = float32x2_t{ 0, 0 };
@@ -386,19 +392,38 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
}
}
-template <bool first_stage>
-void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
+void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
{
- const unsigned int Nx3 = 3 * Nx;
- const float alpha = 2 * PI / float(Nx3);
- float32x2_t w{ 1, 0 };
- const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
+ float32x2_t w{ 1.0f, 0.0f };
+ for(unsigned int j = 0; j < Nx; j++)
+ {
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ {
+ // Load inputs
+ float32x2_t a = wrapper::vload(x + M * k);
+ float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
+ // Base-case prime transform
+ fft_2(a, b, w);
+
+ // Write outputs
+ wrapper::vstore(X + M * k, a);
+ wrapper::vstore(X + M * (k + 2 * Nx), b);
+ }
+
+ w = c_mul_neon(w, w_m);
+ }
+}
+
+template <bool first_stage>
+void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+{
+ float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
{
const auto w2 = c_mul_neon(w, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx3)
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
// Load inputs
float32x2_t a = { 0, 0 };
@@ -435,21 +460,42 @@ void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
}
}
-template <bool first_stage>
-void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
+void fft_radix_3_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
{
- unsigned int Nx4 = 4 * Nx;
- const float alpha = 2 * PI / float(Nx4);
+ float32x2_t w{ 1.0f, 0.0f };
+ for(unsigned int j = 0; j < Nx; j++)
+ {
+ const auto w2 = c_mul_neon(w, w);
+
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ {
+ // Load inputs
+ float32x2_t a = wrapper::vload(x + M * k);
+ float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
+
+ // Base-case prime transform
+ fft_3(a, b, c, w, w2);
- float32x2_t w{ 1, 0 };
- float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
+ // Store the output
+ wrapper::vstore(X + M * k, a);
+ wrapper::vstore(X + M * (k + 2 * Nx), b);
+ wrapper::vstore(X + M * (k + 4 * Nx), c);
+ }
+ w = c_mul_neon(w, w_m);
+ }
+}
+template <bool first_stage>
+void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+{
+ float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
{
const auto w2 = c_mul_neon(w, w);
const auto w3 = c_mul_neon(w2, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx4)
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
float32x2_t a = { 0, 0 };
float32x2_t b = { 0, 0 };
@@ -494,22 +540,46 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
}
}
-template <bool first_stage>
-void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
+void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
{
- unsigned int Nx5 = 5 * Nx;
- const float alpha = 2 * PI / float(Nx5);
+ float32x2_t w{ 1.0f, 0.0f };
+ for(unsigned int j = 0; j < Nx; j++)
+ {
+ const auto w2 = c_mul_neon(w, w);
+ const auto w3 = c_mul_neon(w2, w);
+
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ {
+ // Load inputs
+ float32x2_t a = wrapper::vload(x + M * k);
+ float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
+
+ // Base-case prime transform
+ fft_4(a, b, c, d, w, w2, w3);
- float32x2_t w{ 1, 0 };
- float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
+ wrapper::vstore(X + M * k, a);
+ wrapper::vstore(X + M * (k + 2 * Nx), b);
+ wrapper::vstore(X + M * (k + 4 * Nx), c);
+ wrapper::vstore(X + M * (k + 6 * Nx), d);
+ }
+
+ w = c_mul_neon(w, w_m);
+ }
+}
+template <bool first_stage>
+void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+{
+ float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
const float32x2_t w3 = c_mul_neon(w2, w);
const float32x2_t w4 = c_mul_neon(w3, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx5)
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
float32x2_t a = { 0, 0 };
float32x2_t b = { 0, 0 };
@@ -560,15 +630,43 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
}
}
-template <bool first_stage>
-void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
+void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
{
- unsigned int Nx7 = 7 * Nx;
- const float alpha = 2 * PI / float(Nx7);
+ float32x2_t w{ 1.0f, 0.0f };
+ for(unsigned int j = 0; j < Nx; j++)
+ {
+ const float32x2_t w2 = c_mul_neon(w, w);
+ const float32x2_t w3 = c_mul_neon(w2, w);
+ const float32x2_t w4 = c_mul_neon(w3, w);
- float32x2_t w{ 1, 0 };
- float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ {
+ // Load inputs
+ float32x2_t a = wrapper::vload(x + M * k);
+ float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
+
+ // Base-case prime transform
+ fft_5(a, b, c, d, e, w, w2, w3, w4);
+
+ // Store outputs
+ wrapper::vstore(X + M * k, a);
+ wrapper::vstore(X + M * (k + 2 * Nx), b);
+ wrapper::vstore(X + M * (k + 4 * Nx), c);
+ wrapper::vstore(X + M * (k + 6 * Nx), d);
+ wrapper::vstore(X + M * (k + 8 * Nx), e);
+ }
+
+ w = c_mul_neon(w, w_m);
+ }
+}
+template <bool first_stage>
+void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+{
+ float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
@@ -577,7 +675,7 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
const float32x2_t w5 = c_mul_neon(w4, w);
const float32x2_t w6 = c_mul_neon(w5, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx7)
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
float32x2_t a = { 0, 0 };
float32x2_t b = { 0, 0 };
@@ -637,15 +735,49 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
}
}
-template <bool first_stage>
-void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
+void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
{
- unsigned int Nx8 = 8 * Nx;
- const float alpha = 2 * PI / float(Nx8);
+ float32x2_t w{ 1.0f, 0.0f };
+ for(unsigned int j = 0; j < Nx; j++)
+ {
+ const float32x2_t w2 = c_mul_neon(w, w);
+ const float32x2_t w3 = c_mul_neon(w2, w);
+ const float32x2_t w4 = c_mul_neon(w3, w);
+ const float32x2_t w5 = c_mul_neon(w4, w);
+ const float32x2_t w6 = c_mul_neon(w5, w);
- float32x2_t w{ 1, 0 };
- const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ {
+ // Load inputs
+ float32x2_t a = wrapper::vload(x + M * k);
+ float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
+ float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
+ float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
+
+ // Base-case prime transform
+ fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
+
+ // Store outputs
+ wrapper::vstore(X + M * k, a);
+ wrapper::vstore(X + M * (k + 2 * Nx), b);
+ wrapper::vstore(X + M * (k + 4 * Nx), c);
+ wrapper::vstore(X + M * (k + 6 * Nx), d);
+ wrapper::vstore(X + M * (k + 8 * Nx), e);
+ wrapper::vstore(X + M * (k + 10 * Nx), f);
+ wrapper::vstore(X + M * (k + 12 * Nx), g);
+ }
+
+ w = c_mul_neon(w, w_m);
+ }
+}
+template <bool first_stage>
+void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+{
+ float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
@@ -655,7 +787,7 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
const float32x2_t w6 = c_mul_neon(w5, w);
const float32x2_t w7 = c_mul_neon(w6, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx8)
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
// Load inputs
float32x2_t a = { 0, 0 };
@@ -724,11 +856,54 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
}
}
+void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+{
+ float32x2_t w{ 1.0f, 0.0f };
+ for(unsigned int j = 0; j < Nx; j++)
+ {
+ const float32x2_t w2 = c_mul_neon(w, w);
+ const float32x2_t w3 = c_mul_neon(w2, w);
+ const float32x2_t w4 = c_mul_neon(w3, w);
+ const float32x2_t w5 = c_mul_neon(w4, w);
+ const float32x2_t w6 = c_mul_neon(w5, w);
+ const float32x2_t w7 = c_mul_neon(w6, w);
+
+ for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ {
+ // Load inputs
+ float32x2_t a = wrapper::vload(x + M * k);
+ float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
+ float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
+ float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
+ float32x2_t h = wrapper::vload(x + M * (k + 14 * Nx));
+
+ // Base-case prime transform
+ fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
+
+ // Store outputs
+ wrapper::vstore(X + M * k, a);
+ wrapper::vstore(X + M * (k + 2 * Nx), b);
+ wrapper::vstore(X + M * (k + 4 * Nx), c);
+ wrapper::vstore(X + M * (k + 6 * Nx), d);
+ wrapper::vstore(X + M * (k + 8 * Nx), e);
+ wrapper::vstore(X + M * (k + 10 * Nx), f);
+ wrapper::vstore(X + M * (k + 12 * Nx), g);
+ wrapper::vstore(X + M * (k + 14 * Nx), h);
+ }
+
+ w = c_mul_neon(w, w_m);
+ }
+}
+
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(config.axis != 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1);
ARM_COMPUTE_RETURN_ERROR_ON(NEFFTRadixStageKernel::supported_radix().count(config.radix) == 0);
+ ARM_COMPUTE_UNUSED(config);
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
@@ -742,12 +917,14 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
{
+ ARM_COMPUTE_UNUSED(config);
+
if(output != nullptr)
{
auto_init_if_empty(*output, *input);
}
- Window win = calculate_max_window(*input, Steps(config.radix));
+ Window win = calculate_max_window(*input, Steps());
if(output != nullptr)
{
output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
@@ -758,36 +935,51 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
} // namespace
NEFFTRadixStageKernel::NEFFTRadixStageKernel()
- : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _func()
+ : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
{
}
-template <bool first_stage>
-void NEFFTRadixStageKernel::set_radix_stage_fun(unsigned int radix)
+void NEFFTRadixStageKernel::set_radix_stage_axis0(const FFTRadixStageKernelInfo &config)
{
- switch(radix)
+ // FFT table axis 0: [radix, first_stage]
+ static std::map<unsigned int, std::map<bool, FFTFunctionPointerAxis0>> fft_table_axis0;
+
+ if(fft_table_axis0.empty())
{
- case 2:
- _func = &fft_radix_2_axes_0<first_stage>;
- break;
- case 3:
- _func = &fft_radix_3_axes_0<first_stage>;
- break;
- case 4:
- _func = &fft_radix_4_axes_0<first_stage>;
- break;
- case 5:
- _func = &fft_radix_5_axes_0<first_stage>;
- break;
- case 7:
- _func = &fft_radix_7_axes_0<first_stage>;
- break;
- case 8:
- _func = &fft_radix_8_axes_0<first_stage>;
- break;
- default:
- ARM_COMPUTE_ERROR("Radix not supported");
+ fft_table_axis0[2][false] = &fft_radix_2_axes_0<false>;
+ fft_table_axis0[3][false] = &fft_radix_3_axes_0<false>;
+ fft_table_axis0[4][false] = &fft_radix_4_axes_0<false>;
+ fft_table_axis0[5][false] = &fft_radix_5_axes_0<false>;
+ fft_table_axis0[7][false] = &fft_radix_7_axes_0<false>;
+ fft_table_axis0[8][false] = &fft_radix_8_axes_0<false>;
+
+ fft_table_axis0[2][true] = &fft_radix_2_axes_0<true>;
+ fft_table_axis0[3][true] = &fft_radix_3_axes_0<true>;
+ fft_table_axis0[4][true] = &fft_radix_4_axes_0<true>;
+ fft_table_axis0[5][true] = &fft_radix_5_axes_0<true>;
+ fft_table_axis0[7][true] = &fft_radix_7_axes_0<true>;
+ fft_table_axis0[8][true] = &fft_radix_8_axes_0<true>;
+ }
+
+ _func_0 = fft_table_axis0[config.radix][config.is_first_stage];
+}
+
+void NEFFTRadixStageKernel::set_radix_stage_axis1(const FFTRadixStageKernelInfo &config)
+{
+ // FFT table axis 1: [radix, first_stage]
+ static std::map<unsigned int, FFTFunctionPointerAxis1> fft_table_axis1;
+
+ if(fft_table_axis1.empty())
+ {
+ fft_table_axis1[2] = &fft_radix_2_axes_1;
+ fft_table_axis1[3] = &fft_radix_3_axes_1;
+ fft_table_axis1[4] = &fft_radix_4_axes_1;
+ fft_table_axis1[5] = &fft_radix_5_axes_1;
+ fft_table_axis1[7] = &fft_radix_7_axes_1;
+ fft_table_axis1[8] = &fft_radix_8_axes_1;
}
+
+ _func_1 = fft_table_axis1[config.radix];
}
void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFTRadixStageKernelInfo &config)
@@ -806,14 +998,20 @@ void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFT
_output = output;
_run_in_place = (output == nullptr) || (output == input);
_Nx = config.Nx;
+ _axis = config.axis;
+ _radix = config.radix;
- if(config.is_first_stage)
- {
- set_radix_stage_fun<true>(config.radix);
- }
- else
+ switch(config.axis)
{
- set_radix_stage_fun<false>(config.radix);
+ case 0:
+ set_radix_stage_axis0(config);
+ break;
+ case 1:
+ set_radix_stage_axis1(config);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Axis not supported");
+ break;
}
// Configure kernel window
@@ -841,23 +1039,40 @@ std::set<unsigned int> NEFFTRadixStageKernel::supported_radix()
void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
{
- ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+ ARM_COMPUTE_UNUSED(info);
Window input_window = window;
- input_window.set(Window::DimX, 0);
-
- unsigned int N = _input->info()->dimension(0);
+ input_window.set(_axis, 0);
Iterator in(_input, input_window);
Iterator out(_run_in_place ? _input : _output, input_window);
- execute_window_loop(input_window, [&](const Coordinates &)
+ // Precompute FFT constants
+ const unsigned int NxRadix = _radix * _Nx;
+ const float alpha = 2.0f * kPi / float(NxRadix);
+ const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
+
+ if(_axis == 0)
+ {
+ const unsigned int N = _input->info()->dimension(0);
+ execute_window_loop(input_window, [&](const Coordinates &)
+ {
+ _func_0(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N);
+ },
+ in, out);
+ }
+ else
{
- _func(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, N);
- },
- in, out);
+ const unsigned int N = _input->info()->dimension(0);
+ const unsigned int M = _input->info()->dimension(1);
+ execute_window_loop(input_window, [&](const Coordinates &)
+ {
+ _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M);
+ },
+ in, out);
+ }
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
diff --git a/src/core/NEON/kernels/NEFFTScaleKernel.cpp b/src/core/NEON/kernels/NEFFTScaleKernel.cpp
new file mode 100644
index 0000000000..6568755e5d
--- /dev/null
+++ b/src/core/NEON/kernels/NEFFTScaleKernel.cpp
@@ -0,0 +1,136 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/NEON/kernels/NEFFTScaleKernel.h"
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
+
+#include <arm_neon.h>
+
+namespace arm_compute
+{
+namespace
+{
+void scale_complex(float *c_in, float *c_out, bool is_conjugate, float scale)
+{
+ const auto a = wrapper::vload(c_in);
+ auto b = wrapper::vdiv(a, float32x2_t{ scale, scale });
+ if(is_conjugate)
+ {
+ const float img_part = wrapper::vgetlane(b, 1);
+ b = wrapper::vsetlane(-img_part, b, 1);
+ }
+
+ wrapper::vstore(c_out, b);
+}
+
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
+
+ // Checks performed when output is configured
+ if((output != nullptr) && (output->total_size() != 0))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(output->num_channels() != 1 && output->num_channels() != 2);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+
+ return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output)
+{
+ // Configure kernel window
+ Window win = calculate_max_window(*input, Steps());
+
+ if(output != nullptr)
+ {
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, *input->clone());
+
+ // NEFFTScaleKernel doesn't need padding so update_window_and_padding() can be skipped
+ Coordinates coord;
+ coord.set_num_dimensions(output->num_dimensions());
+ output->set_valid_region(ValidRegion(coord, output->tensor_shape()));
+ }
+
+ return std::make_pair(Status{}, win);
+}
+} // namespace
+
+NEFFTScaleKernel::NEFFTScaleKernel()
+ : _input(nullptr), _output(nullptr), _scale(), _run_in_place(false), _is_conj(false)
+{
+}
+
+void NEFFTScaleKernel::configure(ITensor *input, ITensor *output, const FFTScaleKernelInfo &config)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr));
+
+ _input = input;
+ _output = output;
+ _run_in_place = (output == nullptr) || (output == input);
+ _is_conj = config.conjugate;
+ _scale = config.scale;
+
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input->info(), _run_in_place ? nullptr : output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ INEKernel::configure(win_config.second);
+}
+
+Status NEFFTScaleKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTScaleKernelInfo &config)
+{
+ ARM_COMPUTE_UNUSED(config);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get()).first);
+
+ return Status{};
+}
+
+void NEFFTScaleKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+ ARM_COMPUTE_UNUSED(info);
+
+ Window input_window = window;
+ input_window.set(Window::DimX, 0);
+
+ Iterator in(_input, input_window);
+ Iterator out(_run_in_place ? _input : _output, input_window);
+
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ scale_complex(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _is_conj, _scale);
+ },
+ in, out);
+}
+} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEFFT1D.cpp b/src/runtime/NEON/functions/NEFFT1D.cpp
index d3ff674a2a..665efeb440 100644
--- a/src/runtime/NEON/functions/NEFFT1D.cpp
+++ b/src/runtime/NEON/functions/NEFFT1D.cpp
@@ -31,7 +31,7 @@
namespace arm_compute
{
NEFFT1D::NEFFT1D(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _digit_reversed_input(), _digit_reverse_indices(), _digit_reverse_kernel(), _fft_kernels(), _n_ffts(0)
+ : _memory_group(std::move(memory_manager)), _digit_reverse_kernel(), _fft_kernels(), _scale_kernel(), _digit_reversed_input(), _digit_reverse_indices(), _num_ffts(0), _axis(0), _run_scale(false)
{
}
@@ -43,6 +43,11 @@ void NEFFT1D::configure(const ITensor *input, ITensor *output, const FFT1DInfo &
const auto decomposed_vector = arm_compute::helpers::fft::decompose_stages(N, supported_radix);
ARM_COMPUTE_ERROR_ON(decomposed_vector.empty());
+ // Flags
+ _run_scale = config.direction == FFTDirection::Inverse;
+ _axis = config.axis;
+ const bool is_c2r = input->info()->num_channels() == 2 && output->info()->num_channels() == 1;
+
// Configure digit reverse
TensorInfo digit_reverse_indices_info(TensorShape(input->info()->tensor_shape()[config.axis]), 1, DataType::U32);
_digit_reverse_indices.allocator()->init(digit_reverse_indices_info);
@@ -51,19 +56,19 @@ void NEFFT1D::configure(const ITensor *input, ITensor *output, const FFT1DInfo &
// Create and configure FFT kernels
unsigned int Nx = 1;
- _n_ffts = decomposed_vector.size();
- _fft_kernels.resize(_n_ffts);
- for(unsigned int i = 0; i < _n_ffts; ++i)
+
+ _num_ffts = decomposed_vector.size();
+ _fft_kernels.resize(_num_ffts);
+ for(unsigned int i = 0; i < _num_ffts; ++i)
{
const unsigned int radix_for_stage = decomposed_vector.at(i);
- FFTRadixStageKernelInfo fft_kernel_desc;
- fft_kernel_desc.axis = config.axis;
- fft_kernel_desc.radix = radix_for_stage;
- fft_kernel_desc.Nx = Nx;
- fft_kernel_desc.is_first_stage = (i == 0);
- _fft_kernels[i].configure(&_digit_reversed_input, i == (_n_ffts - 1) ? output : nullptr, fft_kernel_desc);
-
+ FFTRadixStageKernelInfo fft_kernel_info;
+ fft_kernel_info.axis = config.axis;
+ fft_kernel_info.radix = radix_for_stage;
+ fft_kernel_info.Nx = Nx;
+ fft_kernel_info.is_first_stage = (i == 0);
+ _fft_kernels[i].configure(&_digit_reversed_input, i == (_num_ffts - 1) && !is_c2r ? output : nullptr, fft_kernel_info);
Nx *= radix_for_stage;
}
@@ -80,7 +85,7 @@ Status NEFFT1D::validate(const ITensorInfo *input, const ITensorInfo *output, co
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(config.axis != 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1);
// Check if FFT is decomposable
const auto supported_radix = NEFFTRadixStageKernel::supported_radix();
@@ -102,11 +107,17 @@ void NEFFT1D::run()
{
MemoryGroupResourceScope scope_mg(_memory_group);
- NEScheduler::get().schedule(&_digit_reverse_kernel, Window::DimY);
+ NEScheduler::get().schedule(&_digit_reverse_kernel, (_axis == 0 ? Window::DimY : Window::DimX));
+
+ for(unsigned int i = 0; i < _num_ffts; ++i)
+ {
+ NEScheduler::get().schedule(&_fft_kernels[i], (_axis == 0 ? Window::DimY : Window::DimX));
+ }
- for(unsigned int i = 0; i < _n_ffts; ++i)
+ // Run output scaling
+ if(_run_scale)
{
- NEScheduler::get().schedule(&_fft_kernels[i], Window::DimY);
+ NEScheduler::get().schedule(&_scale_kernel, Window::DimY);
}
}
} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEFFT2D.cpp b/src/runtime/NEON/functions/NEFFT2D.cpp
new file mode 100644
index 0000000000..9210ecfa2e
--- /dev/null
+++ b/src/runtime/NEON/functions/NEFFT2D.cpp
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NEFFT2D.h"
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/Scheduler.h"
+
+namespace arm_compute
+{
+NEFFT2D::NEFFT2D(std::shared_ptr<IMemoryManager> memory_manager)
+ : _memory_group(memory_manager), _first_pass_func(memory_manager), _second_pass_func(memory_manager), _first_pass_tensor()
+{
+}
+
+void NEFFT2D::configure(const ITensor *input, ITensor *output, const FFT2DInfo &config)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(NEFFT2D::validate(input->info(), output->info(), config));
+
+ // Setup first pass
+ FFT1DInfo first_pass_config;
+ first_pass_config.axis = config.axes.first;
+ first_pass_config.direction = config.direction;
+ _memory_group.manage(&_first_pass_tensor);
+ _first_pass_func.configure(input, &_first_pass_tensor, first_pass_config);
+
+ // Setup second pass
+ FFT1DInfo second_pass_config;
+ second_pass_config.axis = config.axes.second;
+ second_pass_config.direction = config.direction;
+ _second_pass_func.configure(&_first_pass_tensor, output, second_pass_config);
+ _first_pass_tensor.allocator()->allocate();
+}
+
+Status NEFFT2D::validate(const ITensorInfo *input, const ITensorInfo *output, const FFT2DInfo &config)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+
+ // Create intermediate tensor info
+ TensorInfo first_pass_tensor(input->clone()->set_is_resizable(true).reset_padding().set_num_channels(2));
+
+ // Validate first pass
+ FFT1DInfo first_pass_config;
+ first_pass_config.axis = config.axes.first;
+ first_pass_config.direction = config.direction;
+ ARM_COMPUTE_RETURN_ON_ERROR(NEFFT1D::validate(input, &first_pass_tensor, first_pass_config));
+
+ // Validate second pass
+ FFT1DInfo second_pass_config;
+ second_pass_config.axis = config.axes.second;
+ second_pass_config.direction = config.direction;
+ ARM_COMPUTE_RETURN_ON_ERROR(NEFFT1D::validate(&first_pass_tensor, output, second_pass_config));
+
+ // Checks performed when output is configured
+ if((output != nullptr) && (output->total_size() != 0))
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+
+ return Status{};
+}
+
+void NEFFT2D::run()
+{
+ _memory_group.acquire();
+
+ _first_pass_func.run();
+ _second_pass_func.run();
+
+ _memory_group.release();
+}
+} // namespace arm_compute