From 9a81cd82a8102ee0bd69bfe4939d5c867aed15e9 Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Thu, 15 Apr 2021 17:44:55 +0100 Subject: Fix bug on Implicit Padding for NEON FFT2D Include paddings in address computation for input and output Resolves: COMPMID-4362 Change-Id: I1b34cf47e3b80b98d55fc8fbdeecbfd850d33197 Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5439 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas --- src/core/NEON/kernels/NEFFTRadixStageKernel.cpp | 341 ++++++++++++------------ src/core/NEON/kernels/NEFFTRadixStageKernel.h | 5 +- tests/validation/fixtures/FFTFixture.h | 3 +- 3 files changed, 174 insertions(+), 175 deletions(-) diff --git a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp index d71707a4fc..44c841f626 100644 --- a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp +++ b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp @@ -352,7 +352,7 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f } template -void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +void fft_radix_2_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -365,14 +365,14 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi // Load inputs if(first_stage) { - const auto ab = wrapper::vloadq(x + k); + const auto ab = wrapper::vloadq(in + k); a = wrapper::vgetlow(ab); b = wrapper::vgethigh(ab); } else { - a = wrapper::vload(x + k); - b = wrapper::vload(x + k + 2 * Nx); + a = wrapper::vload(in + k); + b = wrapper::vload(in + k + 2 * Nx); } // Base-case prime transform @@ -381,12 +381,12 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi // Write outputs if(first_stage) { - wrapper::vstore(X + k, wrapper::vcombine(a, b)); + wrapper::vstore(out + k, wrapper::vcombine(a, b)); } else { - wrapper::vstore(X + k, a); - wrapper::vstore(X + k + 2 * Nx, b); + wrapper::vstore(out + k, a); + wrapper::vstore(out + k + 2 * Nx, b); } } @@ -394,23 +394,23 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi } } -void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) +void fft_radix_2_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) { - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix) { // Load inputs - float32x2_t a = wrapper::vload(x + M * k); - float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); + float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k); + float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx)); // Base-case prime transform fft_2(a, b, w); // Write outputs - wrapper::vstore(X + M * k, a); - wrapper::vstore(X + M * (k + 2 * Nx), b); + wrapper::vstore(out + (N + out_pad_x) * k, a); + wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b); } w = c_mul_neon(w, w_m); @@ -418,7 +418,7 @@ void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi } template -void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +void fft_radix_3_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -433,63 +433,63 @@ void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi float32x2_t c = { 0, 0 }; if(first_stage) { - const auto ab = wrapper::vloadq(x + k); + const auto ab = wrapper::vloadq(in + k); a = wrapper::vgetlow(ab); b = wrapper::vgethigh(ab); } else { - a = wrapper::vload(x + k); - b = wrapper::vload(x + k + 2 * Nx); + a = wrapper::vload(in + k); + b = wrapper::vload(in + k + 2 * Nx); } - c = wrapper::vload(x + k + 4 * Nx); + c = wrapper::vload(in + k + 4 * Nx); // Base-case prime transform fft_3(a, b, c, w, w2); if(first_stage) { - wrapper::vstore(X + k, wrapper::vcombine(a, b)); + wrapper::vstore(out + k, wrapper::vcombine(a, b)); } else { - wrapper::vstore(X + k, a); - wrapper::vstore(X + k + 2 * Nx, b); + wrapper::vstore(out + k, a); + wrapper::vstore(out + k + 2 * Nx, b); } - wrapper::vstore(X + k + 4 * Nx, c); + wrapper::vstore(out + k + 4 * Nx, c); } w = c_mul_neon(w, w_m); } } -void fft_radix_3_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) +void fft_radix_3_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) { const auto w2 = c_mul_neon(w, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix) { // Load inputs - float32x2_t a = wrapper::vload(x + M * k); - float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); - float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); + float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k); + float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx)); // Base-case prime transform fft_3(a, b, c, w, w2); // Store the output - wrapper::vstore(X + M * k, a); - wrapper::vstore(X + M * (k + 2 * Nx), b); - wrapper::vstore(X + M * (k + 4 * Nx), c); + wrapper::vstore(out + (N + out_pad_x) * k, a); + wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b); + wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c); } w = c_mul_neon(w, w_m); } } template -void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +void fft_radix_4_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -505,8 +505,8 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi float32x2_t d = { 0, 0 }; if(first_stage) { - const auto ab = wrapper::vloadq(x + k); - const auto cd = wrapper::vloadq(x + k + 4 * Nx); + const auto ab = wrapper::vloadq(in + k); + const auto cd = wrapper::vloadq(in + k + 4 * Nx); a = wrapper::vgetlow(ab); b = wrapper::vgethigh(ab); c = wrapper::vgetlow(cd); @@ -515,10 +515,10 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi else { // Load inputs - a = wrapper::vload(x + k); - b = wrapper::vload(x + k + 2 * Nx); - c = wrapper::vload(x + k + 4 * Nx); - d = wrapper::vload(x + k + 6 * Nx); + a = wrapper::vload(in + k); + b = wrapper::vload(in + k + 2 * Nx); + c = wrapper::vload(in + k + 4 * Nx); + d = wrapper::vload(in + k + 6 * Nx); } // Base-case prime transform @@ -526,15 +526,15 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi if(first_stage) { - wrapper::vstore(X + k, wrapper::vcombine(a, b)); - wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d)); + wrapper::vstore(out + k, wrapper::vcombine(a, b)); + wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d)); } else { - wrapper::vstore(X + k, a); - wrapper::vstore(X + k + 2 * Nx, b); - wrapper::vstore(X + k + 4 * Nx, c); - wrapper::vstore(X + k + 6 * Nx, d); + wrapper::vstore(out + k, a); + wrapper::vstore(out + k + 2 * Nx, b); + wrapper::vstore(out + k + 4 * Nx, c); + wrapper::vstore(out + k + 6 * Nx, d); } } @@ -542,7 +542,7 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi } } -void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) +void fft_radix_4_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -550,21 +550,21 @@ void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi const auto w2 = c_mul_neon(w, w); const auto w3 = c_mul_neon(w2, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix) { // Load inputs - float32x2_t a = wrapper::vload(x + M * k); - float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); - float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); - float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx)); + float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k); + float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx)); + float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx)); // Base-case prime transform fft_4(a, b, c, d, w, w2, w3); - wrapper::vstore(X + M * k, a); - wrapper::vstore(X + M * (k + 2 * Nx), b); - wrapper::vstore(X + M * (k + 4 * Nx), c); - wrapper::vstore(X + M * (k + 6 * Nx), d); + wrapper::vstore(out + (N + out_pad_x) * k, a); + wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b); + wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c); + wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d); } w = c_mul_neon(w, w_m); @@ -572,7 +572,7 @@ void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi } template -void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +void fft_radix_5_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -592,8 +592,8 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi // Load inputs if(first_stage) { - const auto ab = wrapper::vloadq(x + k); - const auto cd = wrapper::vloadq(x + k + 4 * Nx); + const auto ab = wrapper::vloadq(in + k); + const auto cd = wrapper::vloadq(in + k + 4 * Nx); a = wrapper::vgetlow(ab); b = wrapper::vgethigh(ab); @@ -602,12 +602,12 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi } else { - a = wrapper::vload(x + k); - b = wrapper::vload(x + k + 2 * Nx); - c = wrapper::vload(x + k + 4 * Nx); - d = wrapper::vload(x + k + 6 * Nx); + a = wrapper::vload(in + k); + b = wrapper::vload(in + k + 2 * Nx); + c = wrapper::vload(in + k + 4 * Nx); + d = wrapper::vload(in + k + 6 * Nx); } - e = wrapper::vload(x + k + 8 * Nx); + e = wrapper::vload(in + k + 8 * Nx); // Base-case prime transform fft_5(a, b, c, d, e, w, w2, w3, w4); @@ -615,24 +615,24 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi // Store outputs if(first_stage) { - wrapper::vstore(X + k, wrapper::vcombine(a, b)); - wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d)); + wrapper::vstore(out + k, wrapper::vcombine(a, b)); + wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d)); } else { - wrapper::vstore(X + k, a); - wrapper::vstore(X + k + 2 * Nx, b); - wrapper::vstore(X + k + 4 * Nx, c); - wrapper::vstore(X + k + 6 * Nx, d); + wrapper::vstore(out + k, a); + wrapper::vstore(out + k + 2 * Nx, b); + wrapper::vstore(out + k + 4 * Nx, c); + wrapper::vstore(out + k + 6 * Nx, d); } - wrapper::vstore(X + k + 8 * Nx, e); + wrapper::vstore(out + k + 8 * Nx, e); } w = c_mul_neon(w, w_m); } } -void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) +void fft_radix_5_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -641,24 +641,24 @@ void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi const float32x2_t w3 = c_mul_neon(w2, w); const float32x2_t w4 = c_mul_neon(w3, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix) { // Load inputs - float32x2_t a = wrapper::vload(x + M * k); - float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); - float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); - float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx)); - float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx)); + float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k); + float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx)); + float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx)); + float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx)); // Base-case prime transform fft_5(a, b, c, d, e, w, w2, w3, w4); // Store outputs - wrapper::vstore(X + M * k, a); - wrapper::vstore(X + M * (k + 2 * Nx), b); - wrapper::vstore(X + M * (k + 4 * Nx), c); - wrapper::vstore(X + M * (k + 6 * Nx), d); - wrapper::vstore(X + M * (k + 8 * Nx), e); + wrapper::vstore(out + (N + out_pad_x) * k, a); + wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b); + wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c); + wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d); + wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e); } w = c_mul_neon(w, w_m); @@ -666,7 +666,7 @@ void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi } template -void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +void fft_radix_7_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -690,9 +690,9 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi // Load inputs if(first_stage) { - const auto ab = wrapper::vloadq(x + k); - const auto cd = wrapper::vloadq(x + k + 4 * Nx); - const auto ef = wrapper::vloadq(x + k + 8 * Nx); + const auto ab = wrapper::vloadq(in + k); + const auto cd = wrapper::vloadq(in + k + 4 * Nx); + const auto ef = wrapper::vloadq(in + k + 8 * Nx); a = wrapper::vgetlow(ab); b = wrapper::vgethigh(ab); @@ -703,41 +703,41 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi } else { - a = wrapper::vload(x + k); - b = wrapper::vload(x + k + 2 * Nx); - c = wrapper::vload(x + k + 4 * Nx); - d = wrapper::vload(x + k + 6 * Nx); - e = wrapper::vload(x + k + 8 * Nx); - f = wrapper::vload(x + k + 10 * Nx); + a = wrapper::vload(in + k); + b = wrapper::vload(in + k + 2 * Nx); + c = wrapper::vload(in + k + 4 * Nx); + d = wrapper::vload(in + k + 6 * Nx); + e = wrapper::vload(in + k + 8 * Nx); + f = wrapper::vload(in + k + 10 * Nx); } - g = wrapper::vload(x + k + 12 * Nx); + g = wrapper::vload(in + k + 12 * Nx); // Base-case prime transform fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6); if(first_stage) { - wrapper::vstore(X + k, wrapper::vcombine(a, b)); - wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d)); - wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f)); + wrapper::vstore(out + k, wrapper::vcombine(a, b)); + wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d)); + wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f)); } else { - wrapper::vstore(X + k, a); - wrapper::vstore(X + k + 2 * Nx, b); - wrapper::vstore(X + k + 4 * Nx, c); - wrapper::vstore(X + k + 6 * Nx, d); - wrapper::vstore(X + k + 8 * Nx, e); - wrapper::vstore(X + k + 10 * Nx, f); + wrapper::vstore(out + k, a); + wrapper::vstore(out + k + 2 * Nx, b); + wrapper::vstore(out + k + 4 * Nx, c); + wrapper::vstore(out + k + 6 * Nx, d); + wrapper::vstore(out + k + 8 * Nx, e); + wrapper::vstore(out + k + 10 * Nx, f); } - wrapper::vstore(X + k + 12 * Nx, g); + wrapper::vstore(out + k + 12 * Nx, g); } w = c_mul_neon(w, w_m); } } -void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) +void fft_radix_7_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -748,28 +748,28 @@ void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi const float32x2_t w5 = c_mul_neon(w4, w); const float32x2_t w6 = c_mul_neon(w5, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix) { // Load inputs - float32x2_t a = wrapper::vload(x + M * k); - float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); - float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); - float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx)); - float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx)); - float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx)); - float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx)); + float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k); + float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx)); + float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx)); + float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx)); + float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx)); + float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx)); // Base-case prime transform fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6); // Store outputs - wrapper::vstore(X + M * k, a); - wrapper::vstore(X + M * (k + 2 * Nx), b); - wrapper::vstore(X + M * (k + 4 * Nx), c); - wrapper::vstore(X + M * (k + 6 * Nx), d); - wrapper::vstore(X + M * (k + 8 * Nx), e); - wrapper::vstore(X + M * (k + 10 * Nx), f); - wrapper::vstore(X + M * (k + 12 * Nx), g); + wrapper::vstore(out + (N + out_pad_x) * k, a); + wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b); + wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c); + wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d); + wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e); + wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f); + wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g); } w = c_mul_neon(w, w_m); @@ -777,7 +777,7 @@ void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi } template -void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) +void fft_radix_8_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -804,10 +804,10 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi // Base-case prime transform if(first_stage) { - const auto ab = wrapper::vloadq(x + k); - const auto cd = wrapper::vloadq(x + k + 4 * Nx); - const auto ef = wrapper::vloadq(x + k + 8 * Nx); - const auto gh = wrapper::vloadq(x + k + 12 * Nx); + const auto ab = wrapper::vloadq(in + k); + const auto cd = wrapper::vloadq(in + k + 4 * Nx); + const auto ef = wrapper::vloadq(in + k + 8 * Nx); + const auto gh = wrapper::vloadq(in + k + 12 * Nx); a = wrapper::vgetlow(ab); b = wrapper::vgethigh(ab); @@ -820,14 +820,14 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi } else { - a = wrapper::vload(x + k); - b = wrapper::vload(x + k + 2 * Nx); - c = wrapper::vload(x + k + 4 * Nx); - d = wrapper::vload(x + k + 6 * Nx); - e = wrapper::vload(x + k + 8 * Nx); - f = wrapper::vload(x + k + 10 * Nx); - g = wrapper::vload(x + k + 12 * Nx); - h = wrapper::vload(x + k + 14 * Nx); + a = wrapper::vload(in + k); + b = wrapper::vload(in + k + 2 * Nx); + c = wrapper::vload(in + k + 4 * Nx); + d = wrapper::vload(in + k + 6 * Nx); + e = wrapper::vload(in + k + 8 * Nx); + f = wrapper::vload(in + k + 10 * Nx); + g = wrapper::vload(in + k + 12 * Nx); + h = wrapper::vload(in + k + 14 * Nx); } // Apply twiddle factors @@ -836,21 +836,21 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi // Store outputs if(first_stage) { - wrapper::vstore(X + k, wrapper::vcombine(a, b)); - wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d)); - wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f)); - wrapper::vstore(X + k + 12 * Nx, wrapper::vcombine(g, h)); + wrapper::vstore(out + k, wrapper::vcombine(a, b)); + wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d)); + wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f)); + wrapper::vstore(out + k + 12 * Nx, wrapper::vcombine(g, h)); } else { - wrapper::vstore(X + k, a); - wrapper::vstore(X + k + 2 * Nx, b); - wrapper::vstore(X + k + 4 * Nx, c); - wrapper::vstore(X + k + 6 * Nx, d); - wrapper::vstore(X + k + 8 * Nx, e); - wrapper::vstore(X + k + 10 * Nx, f); - wrapper::vstore(X + k + 12 * Nx, g); - wrapper::vstore(X + k + 14 * Nx, h); + wrapper::vstore(out + k, a); + wrapper::vstore(out + k + 2 * Nx, b); + wrapper::vstore(out + k + 4 * Nx, c); + wrapper::vstore(out + k + 6 * Nx, d); + wrapper::vstore(out + k + 8 * Nx, e); + wrapper::vstore(out + k + 10 * Nx, f); + wrapper::vstore(out + k + 12 * Nx, g); + wrapper::vstore(out + k + 14 * Nx, h); } } @@ -858,7 +858,7 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi } } -void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N) +void fft_radix_8_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x) { float32x2_t w{ 1.0f, 0.0f }; for(unsigned int j = 0; j < Nx; j++) @@ -870,30 +870,30 @@ void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi const float32x2_t w6 = c_mul_neon(w5, w); const float32x2_t w7 = c_mul_neon(w6, w); - for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix) + for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix) { // Load inputs - float32x2_t a = wrapper::vload(x + M * k); - float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx)); - float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx)); - float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx)); - float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx)); - float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx)); - float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx)); - float32x2_t h = wrapper::vload(x + M * (k + 14 * Nx)); + float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k); + float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx)); + float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx)); + float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx)); + float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx)); + float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx)); + float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx)); + float32x2_t h = wrapper::vload(in + (N + in_pad_x) * (k + 14 * Nx)); // Base-case prime transform fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7); // Store outputs - wrapper::vstore(X + M * k, a); - wrapper::vstore(X + M * (k + 2 * Nx), b); - wrapper::vstore(X + M * (k + 4 * Nx), c); - wrapper::vstore(X + M * (k + 6 * Nx), d); - wrapper::vstore(X + M * (k + 8 * Nx), e); - wrapper::vstore(X + M * (k + 10 * Nx), f); - wrapper::vstore(X + M * (k + 12 * Nx), g); - wrapper::vstore(X + M * (k + 14 * Nx), h); + wrapper::vstore(out + (N + out_pad_x) * k, a); + wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b); + wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c); + wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d); + wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e); + wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f); + wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g); + wrapper::vstore(out + (N + out_pad_x) * (k + 14 * Nx), h); } w = c_mul_neon(w, w_m); @@ -933,7 +933,7 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen } // namespace NEFFTRadixStageKernel::NEFFTRadixStageKernel() - : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1() + : _input(nullptr), _output(nullptr), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1() { } @@ -992,12 +992,11 @@ void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFT ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config)); - _input = input; - _output = output; - _run_in_place = (output == nullptr) || (output == input); - _Nx = config.Nx; - _axis = config.axis; - _radix = config.radix; + _input = input; + _output = (output == nullptr) ? input : output; + _Nx = config.Nx; + _axis = config.axis; + _radix = config.radix; switch(config.axis) { @@ -1013,7 +1012,7 @@ void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFT } // Configure kernel window - auto win_config = validate_and_configure_window(input->info(), (_run_in_place) ? nullptr : output->info(), config); + auto win_config = validate_and_configure_window(input->info(), (output != nullptr) ? output->info() : nullptr, config); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); INEKernel::configure(win_config.second); } @@ -1045,7 +1044,7 @@ void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info) input_window.set(_axis, 0); Iterator in(_input, input_window); - Iterator out(_run_in_place ? _input : _output, input_window); + Iterator out(_output, input_window); // Precompute FFT constants const unsigned int NxRadix = _radix * _Nx; @@ -1067,7 +1066,9 @@ void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info) const unsigned int M = _input->info()->dimension(1); execute_window_loop(input_window, [&](const Coordinates &) { - _func_1(reinterpret_cast(out.ptr()), reinterpret_cast(in.ptr()), _Nx, NxRadix, w_m, N, M); + _func_1(reinterpret_cast(out.ptr()), reinterpret_cast(in.ptr()), _Nx, NxRadix, w_m, N, M, + _input->info()->padding().right + _input->info()->padding().left, + _output->info()->padding().right + _output->info()->padding().left); }, in, out); } diff --git a/src/core/NEON/kernels/NEFFTRadixStageKernel.h b/src/core/NEON/kernels/NEFFTRadixStageKernel.h index 8a695b790f..2291a1068c 100644 --- a/src/core/NEON/kernels/NEFFTRadixStageKernel.h +++ b/src/core/NEON/kernels/NEFFTRadixStageKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -85,7 +85,6 @@ public: private: ITensor *_input; ITensor *_output; - bool _run_in_place; unsigned int _Nx; unsigned int _axis; unsigned int _radix; @@ -94,7 +93,7 @@ private: void set_radix_stage_axis1(const FFTRadixStageKernelInfo &config); using FFTFunctionPointerAxis0 = std::function; - using FFTFunctionPointerAxis1 = std::function; + using FFTFunctionPointerAxis1 = std::function; FFTFunctionPointerAxis0 _func_0; FFTFunctionPointerAxis1 _func_1; diff --git a/tests/validation/fixtures/FFTFixture.h b/tests/validation/fixtures/FFTFixture.h index 3a75135718..fc6b9df8de 100644 --- a/tests/validation/fixtures/FFTFixture.h +++ b/tests/validation/fixtures/FFTFixture.h @@ -91,8 +91,7 @@ protected: ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); - // TODO: uncomment after COMPMID-4362 - // add_padding_x({ &src, &dst }); + add_padding_x({ &src, &dst }); // Allocate tensors src.allocator()->allocate(); -- cgit v1.2.1