aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEFFTRadixStageKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEFFTRadixStageKernel.cpp903
1 files changed, 504 insertions, 399 deletions
diff --git a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
index 148bbe915a..4b58a7b9ac 100644
--- a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
+++ b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h"
+#include "src/core/NEON/kernels/NEFFTRadixStageKernel.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/TensorInfo.h"
@@ -29,14 +29,17 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Window.h"
+#include "src/core/helpers/AutoConfiguration.h"
+#include "src/core/helpers/WindowHelpers.h"
+#include "src/core/NEON/wrapper/traits.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+#include "support/ToolchainSupport.h"
+
#include <arm_neon.h>
#include <cmath>
#include <complex>
#include <map>
-#include "arm_compute/core/NEON/wrapper/traits.h"
-#include "arm_compute/core/NEON/wrapper/wrapper.h"
-
namespace arm_compute
{
namespace
@@ -68,7 +71,7 @@ float32x2_t c_mul_neon(float32x2_t a, float32x2_t b)
{
using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
- const float32x2_t mask = { -1.0, 1.0 };
+ const float32x2_t mask = {-1.0, 1.0};
const float32x2_t tmp0 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
const float32x2_t tmp1 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
@@ -86,7 +89,7 @@ float32x2_t c_mul_neon_img(float32x2_t a, float img_constant)
const float a_r = wrapper::vgetlane(a, 0);
const float a_i = wrapper::vgetlane(a, 1);
- const auto out = wrapper::vmul(float32x2_t{ -a_i, a_r }, float32x2_t{ img_constant, img_constant });
+ const auto out = wrapper::vmul(float32x2_t{-a_i, a_r}, float32x2_t{img_constant, img_constant});
return out;
}
@@ -98,7 +101,8 @@ float32x2_t reduce_sum_5(float32x2_t a, float32x2_t b, float32x2_t c, float32x2_
return wrapper::vadd(t2, e);
}
-float32x2_t reduce_sum_7(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7)
+float32x2_t reduce_sum_7(
+ float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7)
{
const auto t0 = wrapper::vadd(x1, x2);
const auto t1 = wrapper::vadd(x3, x4);
@@ -109,7 +113,14 @@ float32x2_t reduce_sum_7(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32
return wrapper::vadd(t00, t01);
}
-float32x2_t reduce_sum_8(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7, float32x2_t x8)
+float32x2_t reduce_sum_8(float32x2_t x1,
+ float32x2_t x2,
+ float32x2_t x3,
+ float32x2_t x4,
+ float32x2_t x5,
+ float32x2_t x6,
+ float32x2_t x7,
+ float32x2_t x8)
{
const auto t0 = wrapper::vadd(x1, x2);
const auto t1 = wrapper::vadd(x3, x4);
@@ -139,15 +150,21 @@ void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w,
x = wrapper::vadd(a, b);
x = wrapper::vadd(x, c);
- const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c));
- const auto v2 = c_mul_neon(float32x2_t{ 0.f, -kSqrt3Div2 }, wrapper::vsub(b, c));
+ const auto v1 = wrapper::vmul(float32x2_t{0.5f, 0.5}, wrapper::vadd(b, c));
+ const auto v2 = c_mul_neon(float32x2_t{0.f, -kSqrt3Div2}, wrapper::vsub(b, c));
y = z = wrapper::vsub(a, v1);
y = wrapper::vadd(y, v2);
z = wrapper::vsub(z, v2);
}
-void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3)
+void fft_4(float32x2_t &x1,
+ float32x2_t &x2,
+ float32x2_t &x3,
+ float32x2_t &x4,
+ const float32x2_t &w,
+ const float32x2_t &w2,
+ const float32x2_t &w3)
{
float32x2_t a = x1;
float32x2_t b = c_mul_neon(w, x2);
@@ -171,7 +188,15 @@ void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, c
x4 = wrapper::vadd(x41, x42);
}
-void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4)
+void fft_5(float32x2_t &x1,
+ float32x2_t &x2,
+ float32x2_t &x3,
+ float32x2_t &x4,
+ float32x2_t &x5,
+ const float32x2_t &w,
+ const float32x2_t &w2,
+ const float32x2_t &w3,
+ const float32x2_t &w4)
{
const auto a = x1;
const auto b = c_mul_neon(w, x2);
@@ -179,25 +204,25 @@ void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto d = c_mul_neon(w3, x4);
const auto e = c_mul_neon(w4, x5);
- const auto b0 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, b);
- const auto b1 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, b);
- const auto b2 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, b);
- const auto b3 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, b);
+ const auto b0 = c_mul_neon(float32x2_t{kW5_0, -kW5_1}, b);
+ const auto b1 = c_mul_neon(float32x2_t{-kW5_2, -kW5_3}, b);
+ const auto b2 = c_mul_neon(float32x2_t{-kW5_2, kW5_3}, b);
+ const auto b3 = c_mul_neon(float32x2_t{kW5_0, kW5_1}, b);
- const auto c0 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, c);
- const auto c1 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, c);
- const auto c2 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, c);
- const auto c3 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, c);
+ const auto c0 = c_mul_neon(float32x2_t{-kW5_2, -kW5_3}, c);
+ const auto c1 = c_mul_neon(float32x2_t{kW5_0, kW5_1}, c);
+ const auto c2 = c_mul_neon(float32x2_t{kW5_0, -kW5_1}, c);
+ const auto c3 = c_mul_neon(float32x2_t{-kW5_2, kW5_3}, c);
- const auto d0 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, d);
- const auto d1 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, d);
- const auto d2 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, d);
- const auto d3 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, d);
+ const auto d0 = c_mul_neon(float32x2_t{-kW5_2, kW5_3}, d);
+ const auto d1 = c_mul_neon(float32x2_t{kW5_0, -kW5_1}, d);
+ const auto d2 = c_mul_neon(float32x2_t{kW5_0, kW5_1}, d);
+ const auto d3 = c_mul_neon(float32x2_t{-kW5_2, -kW5_3}, d);
- const auto e0 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, e);
- const auto e1 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, e);
- const auto e2 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, e);
- const auto e3 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, e);
+ const auto e0 = c_mul_neon(float32x2_t{kW5_0, kW5_1}, e);
+ const auto e1 = c_mul_neon(float32x2_t{-kW5_2, kW5_3}, e);
+ const auto e2 = c_mul_neon(float32x2_t{-kW5_2, -kW5_3}, e);
+ const auto e3 = c_mul_neon(float32x2_t{kW5_0, -kW5_1}, e);
x1 = reduce_sum_5(a, b, c, d, e);
x2 = reduce_sum_5(a, b0, c0, d0, e0);
@@ -206,9 +231,19 @@ void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
x5 = reduce_sum_5(a, b3, c3, d3, e3);
}
-void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3,
+void fft_7(float32x2_t &x1,
+ float32x2_t &x2,
+ float32x2_t &x3,
+ float32x2_t &x4,
+ float32x2_t &x5,
+ float32x2_t &x6,
+ float32x2_t &x7,
+ const float32x2_t &w,
+ const float32x2_t &w2,
+ const float32x2_t &w3,
const float32x2_t &w4,
- const float32x2_t &w5, const float32x2_t &w6)
+ const float32x2_t &w5,
+ const float32x2_t &w6)
{
const auto a = x1;
const auto b = c_mul_neon(w, x2);
@@ -218,47 +253,47 @@ void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto f = c_mul_neon(w5, x6);
const auto g = c_mul_neon(w6, x7);
- const auto b0 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, b);
- const auto b1 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, b);
- const auto b2 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, b);
- const auto b3 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, b);
- const auto b4 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, b);
- const auto b5 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, b);
-
- const auto c0 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, c);
- const auto c1 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, c);
- const auto c2 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, c);
- const auto c3 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, c);
- const auto c4 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, c);
- const auto c5 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, c);
-
- const auto d0 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, d);
- const auto d1 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, d);
- const auto d2 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, d);
- const auto d3 = c_mul_neon(float32x2_t{ -kW7_2, +kW7_3 }, d);
- const auto d4 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, d);
- const auto d5 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, d);
-
- const auto e0 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, e);
- const auto e1 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, e);
- const auto e2 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, e);
- const auto e3 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, e);
- const auto e4 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, e);
- const auto e5 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, e);
-
- const auto f0 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, f);
- const auto f1 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, f);
- const auto f2 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, f);
- const auto f3 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, f);
- const auto f4 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, f);
- const auto f5 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, f);
-
- const auto g0 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, g);
- const auto g1 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, g);
- const auto g2 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, g);
- const auto g3 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, g);
- const auto g4 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, g);
- const auto g5 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, g);
+ const auto b0 = c_mul_neon(float32x2_t{kW7_0, -kW7_1}, b);
+ const auto b1 = c_mul_neon(float32x2_t{-kW7_2, -kW7_3}, b);
+ const auto b2 = c_mul_neon(float32x2_t{-kW7_4, -kW7_5}, b);
+ const auto b3 = c_mul_neon(float32x2_t{-kW7_4, kW7_5}, b);
+ const auto b4 = c_mul_neon(float32x2_t{-kW7_2, kW7_3}, b);
+ const auto b5 = c_mul_neon(float32x2_t{kW7_0, kW7_1}, b);
+
+ const auto c0 = c_mul_neon(float32x2_t{-kW7_2, -kW7_3}, c);
+ const auto c1 = c_mul_neon(float32x2_t{-kW7_4, kW7_5}, c);
+ const auto c2 = c_mul_neon(float32x2_t{kW7_0, kW7_1}, c);
+ const auto c3 = c_mul_neon(float32x2_t{kW7_0, -kW7_1}, c);
+ const auto c4 = c_mul_neon(float32x2_t{-kW7_4, -kW7_5}, c);
+ const auto c5 = c_mul_neon(float32x2_t{-kW7_2, kW7_3}, c);
+
+ const auto d0 = c_mul_neon(float32x2_t{-kW7_4, -kW7_5}, d);
+ const auto d1 = c_mul_neon(float32x2_t{kW7_0, kW7_1}, d);
+ const auto d2 = c_mul_neon(float32x2_t{-kW7_2, -kW7_3}, d);
+ const auto d3 = c_mul_neon(float32x2_t{-kW7_2, +kW7_3}, d);
+ const auto d4 = c_mul_neon(float32x2_t{kW7_0, -kW7_1}, d);
+ const auto d5 = c_mul_neon(float32x2_t{-kW7_4, kW7_5}, d);
+
+ const auto e0 = c_mul_neon(float32x2_t{-kW7_4, kW7_5}, e);
+ const auto e1 = c_mul_neon(float32x2_t{kW7_0, -kW7_1}, e);
+ const auto e2 = c_mul_neon(float32x2_t{-kW7_2, kW7_3}, e);
+ const auto e3 = c_mul_neon(float32x2_t{-kW7_2, -kW7_3}, e);
+ const auto e4 = c_mul_neon(float32x2_t{kW7_0, kW7_1}, e);
+ const auto e5 = c_mul_neon(float32x2_t{-kW7_4, -kW7_5}, e);
+
+ const auto f0 = c_mul_neon(float32x2_t{-kW7_2, kW7_3}, f);
+ const auto f1 = c_mul_neon(float32x2_t{-kW7_4, -kW7_5}, f);
+ const auto f2 = c_mul_neon(float32x2_t{kW7_0, -kW7_1}, f);
+ const auto f3 = c_mul_neon(float32x2_t{kW7_0, kW7_1}, f);
+ const auto f4 = c_mul_neon(float32x2_t{-kW7_4, kW7_5}, f);
+ const auto f5 = c_mul_neon(float32x2_t{-kW7_2, -kW7_3}, f);
+
+ const auto g0 = c_mul_neon(float32x2_t{kW7_0, kW7_1}, g);
+ const auto g1 = c_mul_neon(float32x2_t{-kW7_2, kW7_3}, g);
+ const auto g2 = c_mul_neon(float32x2_t{-kW7_4, kW7_5}, g);
+ const auto g3 = c_mul_neon(float32x2_t{-kW7_4, -kW7_5}, g);
+ const auto g4 = c_mul_neon(float32x2_t{-kW7_2, -kW7_3}, g);
+ const auto g5 = c_mul_neon(float32x2_t{kW7_0, -kW7_1}, g);
x1 = reduce_sum_7(a, b, c, d, e, f, g);
x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0);
@@ -269,9 +304,20 @@ void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5);
}
-void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2,
+void fft_8(float32x2_t &x1,
+ float32x2_t &x2,
+ float32x2_t &x3,
+ float32x2_t &x4,
+ float32x2_t &x5,
+ float32x2_t &x6,
+ float32x2_t &x7,
+ float32x2_t &x8,
+ const float32x2_t &w,
+ const float32x2_t &w2,
const float32x2_t &w3,
- const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6,
+ const float32x2_t &w4,
+ const float32x2_t &w5,
+ const float32x2_t &w6,
const float32x2_t &w7)
{
const auto a = x1;
@@ -283,61 +329,61 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
const auto g = c_mul_neon(w6, x7);
const auto h = c_mul_neon(w7, x8);
- const auto b0 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, b);
- const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b);
- const auto b2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, b);
- const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b);
- const auto b4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, b);
- const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b);
- const auto b6 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, b);
-
- const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c);
- const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c);
- const auto c2 = c_mul_neon(float32x2_t{ 0, 1 }, c);
- const auto c3 = c_mul_neon(float32x2_t{ 1, 0 }, c);
- const auto c4 = c_mul_neon(float32x2_t{ 0, -1 }, c);
- const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c);
- const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c);
-
- const auto d0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, d);
- const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d);
- const auto d2 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, d);
- const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d);
- const auto d4 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, d);
- const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d);
- const auto d6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, d);
-
- const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e);
- const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e);
- const auto e2 = c_mul_neon(float32x2_t{ -1, 0 }, e);
- const auto e3 = c_mul_neon(float32x2_t{ 1, 0 }, e);
- const auto e4 = c_mul_neon(float32x2_t{ -1, 0 }, e);
- const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e);
- const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e);
-
- const auto f0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, f);
- const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f);
- const auto f2 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, f);
- const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f);
- const auto f4 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, f);
- const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f);
- const auto f6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, f);
-
- const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g);
- const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g);
- const auto g2 = c_mul_neon(float32x2_t{ 0, -1 }, g);
- const auto g3 = c_mul_neon(float32x2_t{ 1, 0 }, g);
- const auto g4 = c_mul_neon(float32x2_t{ 0, 1 }, g);
- const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g);
- const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g);
-
- const auto h0 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, h);
- const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h);
- const auto h2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, h);
- const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h);
- const auto h4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, h);
- const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h);
- const auto h6 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, h);
+ const auto b0 = c_mul_neon(float32x2_t{kSqrt2Div2, -kSqrt2Div2}, b);
+ const auto b1 = c_mul_neon(float32x2_t{0, -1}, b);
+ const auto b2 = c_mul_neon(float32x2_t{-kSqrt2Div2, -kSqrt2Div2}, b);
+ const auto b3 = c_mul_neon(float32x2_t{-1, 0}, b);
+ const auto b4 = c_mul_neon(float32x2_t{-kSqrt2Div2, kSqrt2Div2}, b);
+ const auto b5 = c_mul_neon(float32x2_t{0, 1}, b);
+ const auto b6 = c_mul_neon(float32x2_t{kSqrt2Div2, kSqrt2Div2}, b);
+
+ const auto c0 = c_mul_neon(float32x2_t{0, -1}, c);
+ const auto c1 = c_mul_neon(float32x2_t{-1, 0}, c);
+ const auto c2 = c_mul_neon(float32x2_t{0, 1}, c);
+ const auto c3 = c_mul_neon(float32x2_t{1, 0}, c);
+ const auto c4 = c_mul_neon(float32x2_t{0, -1}, c);
+ const auto c5 = c_mul_neon(float32x2_t{-1, 0}, c);
+ const auto c6 = c_mul_neon(float32x2_t{0, 1}, c);
+
+ const auto d0 = c_mul_neon(float32x2_t{-kSqrt2Div2, -kSqrt2Div2}, d);
+ const auto d1 = c_mul_neon(float32x2_t{0, 1}, d);
+ const auto d2 = c_mul_neon(float32x2_t{kSqrt2Div2, -kSqrt2Div2}, d);
+ const auto d3 = c_mul_neon(float32x2_t{-1, 0}, d);
+ const auto d4 = c_mul_neon(float32x2_t{kSqrt2Div2, kSqrt2Div2}, d);
+ const auto d5 = c_mul_neon(float32x2_t{0, -1}, d);
+ const auto d6 = c_mul_neon(float32x2_t{-kSqrt2Div2, kSqrt2Div2}, d);
+
+ const auto e0 = c_mul_neon(float32x2_t{-1, 0}, e);
+ const auto e1 = c_mul_neon(float32x2_t{1, 0}, e);
+ const auto e2 = c_mul_neon(float32x2_t{-1, 0}, e);
+ const auto e3 = c_mul_neon(float32x2_t{1, 0}, e);
+ const auto e4 = c_mul_neon(float32x2_t{-1, 0}, e);
+ const auto e5 = c_mul_neon(float32x2_t{1, 0}, e);
+ const auto e6 = c_mul_neon(float32x2_t{-1, 0}, e);
+
+ const auto f0 = c_mul_neon(float32x2_t{-kSqrt2Div2, kSqrt2Div2}, f);
+ const auto f1 = c_mul_neon(float32x2_t{0, -1}, f);
+ const auto f2 = c_mul_neon(float32x2_t{kSqrt2Div2, kSqrt2Div2}, f);
+ const auto f3 = c_mul_neon(float32x2_t{-1, 0}, f);
+ const auto f4 = c_mul_neon(float32x2_t{kSqrt2Div2, -kSqrt2Div2}, f);
+ const auto f5 = c_mul_neon(float32x2_t{0, 1}, f);
+ const auto f6 = c_mul_neon(float32x2_t{-kSqrt2Div2, -kSqrt2Div2}, f);
+
+ const auto g0 = c_mul_neon(float32x2_t{0, 1}, g);
+ const auto g1 = c_mul_neon(float32x2_t{-1, 0}, g);
+ const auto g2 = c_mul_neon(float32x2_t{0, -1}, g);
+ const auto g3 = c_mul_neon(float32x2_t{1, 0}, g);
+ const auto g4 = c_mul_neon(float32x2_t{0, 1}, g);
+ const auto g5 = c_mul_neon(float32x2_t{-1, 0}, g);
+ const auto g6 = c_mul_neon(float32x2_t{0, -1}, g);
+
+ const auto h0 = c_mul_neon(float32x2_t{kSqrt2Div2, kSqrt2Div2}, h);
+ const auto h1 = c_mul_neon(float32x2_t{0, 1}, h);
+ const auto h2 = c_mul_neon(float32x2_t{-kSqrt2Div2, kSqrt2Div2}, h);
+ const auto h3 = c_mul_neon(float32x2_t{-1, 0}, h);
+ const auto h4 = c_mul_neon(float32x2_t{-kSqrt2Div2, -kSqrt2Div2}, h);
+ const auto h5 = c_mul_neon(float32x2_t{0, -1}, h);
+ const auto h6 = c_mul_neon(float32x2_t{kSqrt2Div2, -kSqrt2Div2}, h);
x1 = reduce_sum_8(a, b, c, d, e, f, g, h);
x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0);
@@ -350,41 +396,42 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
}
template <bool first_stage>
-void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_2_axes_0(
+ float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
- auto a = float32x2_t{ 0, 0 };
- auto b = float32x2_t{ 0, 0 };
+ auto a = float32x2_t{0, 0};
+ auto b = float32x2_t{0, 0};
// Load inputs
- if(first_stage)
+ if (first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
+ const auto ab = wrapper::vloadq(in + k);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
}
// Base-case prime transform
fft_2(a, b, w);
// Write outputs
- if(first_stage)
+ if (first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
}
}
@@ -392,23 +439,31 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
}
-void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_2_axes_1(float *out,
+ float *in,
+ unsigned int Nx,
+ unsigned int NxRadix,
+ const float32x2_t &w_m,
+ unsigned int N,
+ unsigned int M,
+ unsigned int in_pad_x,
+ unsigned int out_pad_x)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
// Base-case prime transform
fft_2(a, b, w);
// Write outputs
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
}
w = c_mul_neon(w, w_m);
@@ -416,95 +471,105 @@ void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
template <bool first_stage>
-void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_3_axes_0(
+ float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const auto w2 = c_mul_neon(w, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = { 0, 0 };
- float32x2_t b = { 0, 0 };
- float32x2_t c = { 0, 0 };
- if(first_stage)
+ float32x2_t a = {0, 0};
+ float32x2_t b = {0, 0};
+ float32x2_t c = {0, 0};
+ if (first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
+ const auto ab = wrapper::vloadq(in + k);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
}
- c = wrapper::vload(x + k + 4 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
// Base-case prime transform
fft_3(a, b, c, w, w2);
- if(first_stage)
+ if (first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
}
- wrapper::vstore(X + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 4 * Nx, c);
}
w = c_mul_neon(w, w_m);
}
}
-void fft_radix_3_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_3_axes_1(float *out,
+ float *in,
+ unsigned int Nx,
+ unsigned int NxRadix,
+ const float32x2_t &w_m,
+ unsigned int N,
+ unsigned int M,
+ unsigned int in_pad_x,
+ unsigned int out_pad_x)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const auto w2 = c_mul_neon(w, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
// Base-case prime transform
fft_3(a, b, c, w, w2);
// Store the output
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
}
w = c_mul_neon(w, w_m);
}
}
template <bool first_stage>
-void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_4_axes_0(
+ float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const auto w2 = c_mul_neon(w, w);
const auto w3 = c_mul_neon(w2, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
- float32x2_t a = { 0, 0 };
- float32x2_t b = { 0, 0 };
- float32x2_t c = { 0, 0 };
- float32x2_t d = { 0, 0 };
- if(first_stage)
+ float32x2_t a = {0, 0};
+ float32x2_t b = {0, 0};
+ float32x2_t c = {0, 0};
+ float32x2_t d = {0, 0};
+ if (first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
- const auto cd = wrapper::vloadq(x + k + 4 * Nx);
+ const auto ab = wrapper::vloadq(in + k);
+ const auto cd = wrapper::vloadq(in + k + 4 * Nx);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
c = wrapper::vgetlow(cd);
@@ -513,26 +578,26 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
else
{
// Load inputs
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
- c = wrapper::vload(x + k + 4 * Nx);
- d = wrapper::vload(x + k + 6 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
+ d = wrapper::vload(in + k + 6 * Nx);
}
// Base-case prime transform
fft_4(a, b, c, d, w, w2, w3);
- if(first_stage)
+ if (first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
- wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
- wrapper::vstore(X + k + 4 * Nx, c);
- wrapper::vstore(X + k + 6 * Nx, d);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
+ wrapper::vstore(out + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 6 * Nx, d);
}
}
@@ -540,29 +605,37 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
}
-void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_4_axes_1(float *out,
+ float *in,
+ unsigned int Nx,
+ unsigned int NxRadix,
+ const float32x2_t &w_m,
+ unsigned int N,
+ unsigned int M,
+ unsigned int in_pad_x,
+ unsigned int out_pad_x)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const auto w2 = c_mul_neon(w, w);
const auto w3 = c_mul_neon(w2, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
- float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
// Base-case prime transform
fft_4(a, b, c, d, w, w2, w3);
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
- wrapper::vstore(X + M * (k + 6 * Nx), d);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
}
w = c_mul_neon(w, w_m);
@@ -570,28 +643,29 @@ void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
template <bool first_stage>
-void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_5_axes_0(
+ float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
const float32x2_t w3 = c_mul_neon(w2, w);
const float32x2_t w4 = c_mul_neon(w3, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
- float32x2_t a = { 0, 0 };
- float32x2_t b = { 0, 0 };
- float32x2_t c = { 0, 0 };
- float32x2_t d = { 0, 0 };
- float32x2_t e = { 0, 0 };
+ float32x2_t a = {0, 0};
+ float32x2_t b = {0, 0};
+ float32x2_t c = {0, 0};
+ float32x2_t d = {0, 0};
+ float32x2_t e = {0, 0};
// Load inputs
- if(first_stage)
+ if (first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
- const auto cd = wrapper::vloadq(x + k + 4 * Nx);
+ const auto ab = wrapper::vloadq(in + k);
+ const auto cd = wrapper::vloadq(in + k + 4 * Nx);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
@@ -600,63 +674,71 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
- c = wrapper::vload(x + k + 4 * Nx);
- d = wrapper::vload(x + k + 6 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
+ d = wrapper::vload(in + k + 6 * Nx);
}
- e = wrapper::vload(x + k + 8 * Nx);
+ e = wrapper::vload(in + k + 8 * Nx);
// Base-case prime transform
fft_5(a, b, c, d, e, w, w2, w3, w4);
// Store outputs
- if(first_stage)
+ if (first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
- wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
- wrapper::vstore(X + k + 4 * Nx, c);
- wrapper::vstore(X + k + 6 * Nx, d);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
+ wrapper::vstore(out + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 6 * Nx, d);
}
- wrapper::vstore(X + k + 8 * Nx, e);
+ wrapper::vstore(out + k + 8 * Nx, e);
}
w = c_mul_neon(w, w_m);
}
}
-void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_5_axes_1(float *out,
+ float *in,
+ unsigned int Nx,
+ unsigned int NxRadix,
+ const float32x2_t &w_m,
+ unsigned int N,
+ unsigned int M,
+ unsigned int in_pad_x,
+ unsigned int out_pad_x)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
const float32x2_t w3 = c_mul_neon(w2, w);
const float32x2_t w4 = c_mul_neon(w3, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
- float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
- float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
// Base-case prime transform
fft_5(a, b, c, d, e, w, w2, w3, w4);
// Store outputs
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
- wrapper::vstore(X + M * (k + 6 * Nx), d);
- wrapper::vstore(X + M * (k + 8 * Nx), e);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
}
w = c_mul_neon(w, w_m);
@@ -664,10 +746,11 @@ void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
template <bool first_stage>
-void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_7_axes_0(
+ float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
const float32x2_t w3 = c_mul_neon(w2, w);
@@ -675,22 +758,22 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
const float32x2_t w5 = c_mul_neon(w4, w);
const float32x2_t w6 = c_mul_neon(w5, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
- float32x2_t a = { 0, 0 };
- float32x2_t b = { 0, 0 };
- float32x2_t c = { 0, 0 };
- float32x2_t d = { 0, 0 };
- float32x2_t e = { 0, 0 };
- float32x2_t f = { 0, 0 };
- float32x2_t g = { 0, 0 };
+ float32x2_t a = {0, 0};
+ float32x2_t b = {0, 0};
+ float32x2_t c = {0, 0};
+ float32x2_t d = {0, 0};
+ float32x2_t e = {0, 0};
+ float32x2_t f = {0, 0};
+ float32x2_t g = {0, 0};
// Load inputs
- if(first_stage)
+ if (first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
- const auto cd = wrapper::vloadq(x + k + 4 * Nx);
- const auto ef = wrapper::vloadq(x + k + 8 * Nx);
+ const auto ab = wrapper::vloadq(in + k);
+ const auto cd = wrapper::vloadq(in + k + 4 * Nx);
+ const auto ef = wrapper::vloadq(in + k + 8 * Nx);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
@@ -701,44 +784,52 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
- c = wrapper::vload(x + k + 4 * Nx);
- d = wrapper::vload(x + k + 6 * Nx);
- e = wrapper::vload(x + k + 8 * Nx);
- f = wrapper::vload(x + k + 10 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
+ d = wrapper::vload(in + k + 6 * Nx);
+ e = wrapper::vload(in + k + 8 * Nx);
+ f = wrapper::vload(in + k + 10 * Nx);
}
- g = wrapper::vload(x + k + 12 * Nx);
+ g = wrapper::vload(in + k + 12 * Nx);
// Base-case prime transform
fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
- if(first_stage)
+ if (first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
- wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
- wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
+ wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
- wrapper::vstore(X + k + 4 * Nx, c);
- wrapper::vstore(X + k + 6 * Nx, d);
- wrapper::vstore(X + k + 8 * Nx, e);
- wrapper::vstore(X + k + 10 * Nx, f);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
+ wrapper::vstore(out + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 6 * Nx, d);
+ wrapper::vstore(out + k + 8 * Nx, e);
+ wrapper::vstore(out + k + 10 * Nx, f);
}
- wrapper::vstore(X + k + 12 * Nx, g);
+ wrapper::vstore(out + k + 12 * Nx, g);
}
w = c_mul_neon(w, w_m);
}
}
-void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_7_axes_1(float *out,
+ float *in,
+ unsigned int Nx,
+ unsigned int NxRadix,
+ const float32x2_t &w_m,
+ unsigned int N,
+ unsigned int M,
+ unsigned int in_pad_x,
+ unsigned int out_pad_x)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
const float32x2_t w3 = c_mul_neon(w2, w);
@@ -746,28 +837,28 @@ void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
const float32x2_t w5 = c_mul_neon(w4, w);
const float32x2_t w6 = c_mul_neon(w5, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
- float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
- float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
- float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
- float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
+ float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
+ float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
// Base-case prime transform
fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
// Store outputs
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
- wrapper::vstore(X + M * (k + 6 * Nx), d);
- wrapper::vstore(X + M * (k + 8 * Nx), e);
- wrapper::vstore(X + M * (k + 10 * Nx), f);
- wrapper::vstore(X + M * (k + 12 * Nx), g);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
}
w = c_mul_neon(w, w_m);
@@ -775,10 +866,11 @@ void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
template <bool first_stage>
-void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_8_axes_0(
+ float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
const float32x2_t w3 = c_mul_neon(w2, w);
@@ -787,25 +879,25 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
const float32x2_t w6 = c_mul_neon(w5, w);
const float32x2_t w7 = c_mul_neon(w6, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = { 0, 0 };
- float32x2_t b = { 0, 0 };
- float32x2_t c = { 0, 0 };
- float32x2_t d = { 0, 0 };
- float32x2_t e = { 0, 0 };
- float32x2_t f = { 0, 0 };
- float32x2_t g = { 0, 0 };
- float32x2_t h = { 0, 0 };
+ float32x2_t a = {0, 0};
+ float32x2_t b = {0, 0};
+ float32x2_t c = {0, 0};
+ float32x2_t d = {0, 0};
+ float32x2_t e = {0, 0};
+ float32x2_t f = {0, 0};
+ float32x2_t g = {0, 0};
+ float32x2_t h = {0, 0};
// Base-case prime transform
- if(first_stage)
+ if (first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
- const auto cd = wrapper::vloadq(x + k + 4 * Nx);
- const auto ef = wrapper::vloadq(x + k + 8 * Nx);
- const auto gh = wrapper::vloadq(x + k + 12 * Nx);
+ const auto ab = wrapper::vloadq(in + k);
+ const auto cd = wrapper::vloadq(in + k + 4 * Nx);
+ const auto ef = wrapper::vloadq(in + k + 8 * Nx);
+ const auto gh = wrapper::vloadq(in + k + 12 * Nx);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
@@ -818,37 +910,37 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
- c = wrapper::vload(x + k + 4 * Nx);
- d = wrapper::vload(x + k + 6 * Nx);
- e = wrapper::vload(x + k + 8 * Nx);
- f = wrapper::vload(x + k + 10 * Nx);
- g = wrapper::vload(x + k + 12 * Nx);
- h = wrapper::vload(x + k + 14 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
+ d = wrapper::vload(in + k + 6 * Nx);
+ e = wrapper::vload(in + k + 8 * Nx);
+ f = wrapper::vload(in + k + 10 * Nx);
+ g = wrapper::vload(in + k + 12 * Nx);
+ h = wrapper::vload(in + k + 14 * Nx);
}
// Apply twiddle factors
fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
// Store outputs
- if(first_stage)
+ if (first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
- wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
- wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
- wrapper::vstore(X + k + 12 * Nx, wrapper::vcombine(g, h));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
+ wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
+ wrapper::vstore(out + k + 12 * Nx, wrapper::vcombine(g, h));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
- wrapper::vstore(X + k + 4 * Nx, c);
- wrapper::vstore(X + k + 6 * Nx, d);
- wrapper::vstore(X + k + 8 * Nx, e);
- wrapper::vstore(X + k + 10 * Nx, f);
- wrapper::vstore(X + k + 12 * Nx, g);
- wrapper::vstore(X + k + 14 * Nx, h);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
+ wrapper::vstore(out + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 6 * Nx, d);
+ wrapper::vstore(out + k + 8 * Nx, e);
+ wrapper::vstore(out + k + 10 * Nx, f);
+ wrapper::vstore(out + k + 12 * Nx, g);
+ wrapper::vstore(out + k + 14 * Nx, h);
}
}
@@ -856,10 +948,18 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
}
-void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_8_axes_1(float *out,
+ float *in,
+ unsigned int Nx,
+ unsigned int NxRadix,
+ const float32x2_t &w_m,
+ unsigned int N,
+ unsigned int M,
+ unsigned int in_pad_x,
+ unsigned int out_pad_x)
{
- float32x2_t w{ 1.0f, 0.0f };
- for(unsigned int j = 0; j < Nx; j++)
+ float32x2_t w{1.0f, 0.0f};
+ for (unsigned int j = 0; j < Nx; j++)
{
const float32x2_t w2 = c_mul_neon(w, w);
const float32x2_t w3 = c_mul_neon(w2, w);
@@ -868,30 +968,30 @@ void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
const float32x2_t w6 = c_mul_neon(w5, w);
const float32x2_t w7 = c_mul_neon(w6, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for (unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
- float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
- float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
- float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
- float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
- float32x2_t h = wrapper::vload(x + M * (k + 14 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
+ float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
+ float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
+ float32x2_t h = wrapper::vload(in + (N + in_pad_x) * (k + 14 * Nx));
// Base-case prime transform
fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
// Store outputs
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
- wrapper::vstore(X + M * (k + 6 * Nx), d);
- wrapper::vstore(X + M * (k + 8 * Nx), e);
- wrapper::vstore(X + M * (k + 10 * Nx), f);
- wrapper::vstore(X + M * (k + 12 * Nx), g);
- wrapper::vstore(X + M * (k + 14 * Nx), h);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 14 * Nx), h);
}
w = c_mul_neon(w, w_m);
@@ -906,7 +1006,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
ARM_COMPUTE_UNUSED(config);
// Checks performed when output is configured
- if((output != nullptr) && (output->total_size() != 0))
+ if ((output != nullptr) && (output->total_size() != 0))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
@@ -915,27 +1015,24 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
+std::pair<Status, Window>
+validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
{
ARM_COMPUTE_UNUSED(config);
- if(output != nullptr)
+ if (output != nullptr)
{
auto_init_if_empty(*output, *input);
}
Window win = calculate_max_window(*input, Steps());
- if(output != nullptr)
- {
- output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
- }
return std::make_pair(Status{}, win);
}
} // namespace
NEFFTRadixStageKernel::NEFFTRadixStageKernel()
- : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
+ : _input(nullptr), _output(nullptr), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
{
}
@@ -944,7 +1041,7 @@ void NEFFTRadixStageKernel::set_radix_stage_axis0(const FFTRadixStageKernelInfo
// FFT table axis 0: [radix, first_stage]
static std::map<unsigned int, std::map<bool, FFTFunctionPointerAxis0>> fft_table_axis0;
- if(fft_table_axis0.empty())
+ if (fft_table_axis0.empty())
{
fft_table_axis0[2][false] = &fft_radix_2_axes_0<false>;
fft_table_axis0[3][false] = &fft_radix_3_axes_0<false>;
@@ -969,7 +1066,7 @@ void NEFFTRadixStageKernel::set_radix_stage_axis1(const FFTRadixStageKernelInfo
// FFT table axis 1: [radix, first_stage]
static std::map<unsigned int, FFTFunctionPointerAxis1> fft_table_axis1;
- if(fft_table_axis1.empty())
+ if (fft_table_axis1.empty())
{
fft_table_axis1[2] = &fft_radix_2_axes_1;
fft_table_axis1[3] = &fft_radix_3_axes_1;
@@ -987,21 +1084,21 @@ void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFT
ARM_COMPUTE_ERROR_ON_NULLPTR(input);
// Output auto inizialitation if not yet initialized
- if(output != nullptr)
+ if (output != nullptr)
{
auto_init_if_empty(*output->info(), *input->info()->clone());
}
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
+ ARM_COMPUTE_ERROR_THROW_ON(
+ validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
- _input = input;
- _output = output;
- _run_in_place = (output == nullptr) || (output == input);
- _Nx = config.Nx;
- _axis = config.axis;
- _radix = config.radix;
+ _input = input;
+ _output = (output == nullptr) ? input : output;
+ _Nx = config.Nx;
+ _axis = config.axis;
+ _radix = config.radix;
- switch(config.axis)
+ switch (config.axis)
{
case 0:
set_radix_stage_axis0(config);
@@ -1015,26 +1112,28 @@ void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFT
}
// Configure kernel window
- auto win_config = validate_and_configure_window(input->info(), (_run_in_place) ? nullptr : output->info(), config);
+ auto win_config =
+ validate_and_configure_window(input->info(), (output != nullptr) ? output->info() : nullptr, config);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
}
-Status NEFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
+Status NEFFTRadixStageKernel::validate(const ITensorInfo *input,
+ const ITensorInfo *output,
+ const FFTRadixStageKernelInfo &config)
{
const bool run_in_place = (output == nullptr) || (output == input);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
- (run_in_place) ? nullptr : output->clone().get(),
- config)
- .first);
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ validate_and_configure_window(input->clone().get(), (run_in_place) ? nullptr : output->clone().get(), config)
+ .first);
return Status{};
}
std::set<unsigned int> NEFFTRadixStageKernel::supported_radix()
{
- return std::set<unsigned int> { 2, 3, 4, 5, 7, 8 };
+ return std::set<unsigned int>{2, 3, 4, 5, 7, 8};
}
void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
@@ -1047,31 +1146,37 @@ void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
input_window.set(_axis, 0);
Iterator in(_input, input_window);
- Iterator out(_run_in_place ? _input : _output, input_window);
+ Iterator out(_output, input_window);
// Precompute FFT constants
const unsigned int NxRadix = _radix * _Nx;
const float alpha = 2.0f * kPi / float(NxRadix);
- const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
+ const float32x2_t w_m{cosf(alpha), -sinf(alpha)};
- if(_axis == 0)
+ if (_axis == 0)
{
const unsigned int N = _input->info()->dimension(0);
- execute_window_loop(input_window, [&](const Coordinates &)
- {
- _func_0(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N);
- },
- in, out);
+ execute_window_loop(
+ input_window,
+ [&](const Coordinates &) {
+ _func_0(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m,
+ N);
+ },
+ in, out);
}
else
{
const unsigned int N = _input->info()->dimension(0);
const unsigned int M = _input->info()->dimension(1);
- execute_window_loop(input_window, [&](const Coordinates &)
- {
- _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M);
- },
- in, out);
+ execute_window_loop(
+ input_window,
+ [&](const Coordinates &)
+ {
+ _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N,
+ M, _input->info()->padding().right + _input->info()->padding().left,
+ _output->info()->padding().right + _output->info()->padding().left);
+ },
+ in, out);
}
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);