aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2021-04-15 17:44:55 +0100
committerManuel Bottini <manuel.bottini@arm.com>2021-04-16 16:34:00 +0000
commit9a81cd82a8102ee0bd69bfe4939d5c867aed15e9 (patch)
tree1844c9e1da5d76ade1188ba1db8279c35b36bae4
parentd05f67f825fc2b4f65b95a1d81476b03c095e266 (diff)
downloadComputeLibrary-9a81cd82a8102ee0bd69bfe4939d5c867aed15e9.tar.gz
Fix bug on Implicit Padding for NEON FFT2D
Include paddings in address computation for input and output Resolves: COMPMID-4362 Change-Id: I1b34cf47e3b80b98d55fc8fbdeecbfd850d33197 Signed-off-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5439 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
-rw-r--r--src/core/NEON/kernels/NEFFTRadixStageKernel.cpp341
-rw-r--r--src/core/NEON/kernels/NEFFTRadixStageKernel.h5
-rw-r--r--tests/validation/fixtures/FFTFixture.h3
3 files changed, 174 insertions, 175 deletions
diff --git a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
index d71707a4fc..44c841f626 100644
--- a/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
+++ b/src/core/NEON/kernels/NEFFTRadixStageKernel.cpp
@@ -352,7 +352,7 @@ void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, f
}
template <bool first_stage>
-void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_2_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -365,14 +365,14 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
// Load inputs
if(first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
+ const auto ab = wrapper::vloadq(in + k);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
}
// Base-case prime transform
@@ -381,12 +381,12 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
// Write outputs
if(first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
}
}
@@ -394,23 +394,23 @@ void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
}
-void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_2_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
{
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
// Base-case prime transform
fft_2(a, b, w);
// Write outputs
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
}
w = c_mul_neon(w, w_m);
@@ -418,7 +418,7 @@ void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
template <bool first_stage>
-void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_3_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -433,63 +433,63 @@ void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
float32x2_t c = { 0, 0 };
if(first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
+ const auto ab = wrapper::vloadq(in + k);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
}
- c = wrapper::vload(x + k + 4 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
// Base-case prime transform
fft_3(a, b, c, w, w2);
if(first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
}
- wrapper::vstore(X + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 4 * Nx, c);
}
w = c_mul_neon(w, w_m);
}
}
-void fft_radix_3_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_3_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
{
const auto w2 = c_mul_neon(w, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
// Base-case prime transform
fft_3(a, b, c, w, w2);
// Store the output
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
}
w = c_mul_neon(w, w_m);
}
}
template <bool first_stage>
-void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_4_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -505,8 +505,8 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
float32x2_t d = { 0, 0 };
if(first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
- const auto cd = wrapper::vloadq(x + k + 4 * Nx);
+ const auto ab = wrapper::vloadq(in + k);
+ const auto cd = wrapper::vloadq(in + k + 4 * Nx);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
c = wrapper::vgetlow(cd);
@@ -515,10 +515,10 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
else
{
// Load inputs
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
- c = wrapper::vload(x + k + 4 * Nx);
- d = wrapper::vload(x + k + 6 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
+ d = wrapper::vload(in + k + 6 * Nx);
}
// Base-case prime transform
@@ -526,15 +526,15 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
if(first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
- wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
- wrapper::vstore(X + k + 4 * Nx, c);
- wrapper::vstore(X + k + 6 * Nx, d);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
+ wrapper::vstore(out + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 6 * Nx, d);
}
}
@@ -542,7 +542,7 @@ void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
}
-void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_4_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -550,21 +550,21 @@ void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
const auto w2 = c_mul_neon(w, w);
const auto w3 = c_mul_neon(w2, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
- float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
// Base-case prime transform
fft_4(a, b, c, d, w, w2, w3);
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
- wrapper::vstore(X + M * (k + 6 * Nx), d);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
}
w = c_mul_neon(w, w_m);
@@ -572,7 +572,7 @@ void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
template <bool first_stage>
-void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_5_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -592,8 +592,8 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
// Load inputs
if(first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
- const auto cd = wrapper::vloadq(x + k + 4 * Nx);
+ const auto ab = wrapper::vloadq(in + k);
+ const auto cd = wrapper::vloadq(in + k + 4 * Nx);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
@@ -602,12 +602,12 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
- c = wrapper::vload(x + k + 4 * Nx);
- d = wrapper::vload(x + k + 6 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
+ d = wrapper::vload(in + k + 6 * Nx);
}
- e = wrapper::vload(x + k + 8 * Nx);
+ e = wrapper::vload(in + k + 8 * Nx);
// Base-case prime transform
fft_5(a, b, c, d, e, w, w2, w3, w4);
@@ -615,24 +615,24 @@ void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
// Store outputs
if(first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
- wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
- wrapper::vstore(X + k + 4 * Nx, c);
- wrapper::vstore(X + k + 6 * Nx, d);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
+ wrapper::vstore(out + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 6 * Nx, d);
}
- wrapper::vstore(X + k + 8 * Nx, e);
+ wrapper::vstore(out + k + 8 * Nx, e);
}
w = c_mul_neon(w, w_m);
}
}
-void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_5_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -641,24 +641,24 @@ void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
const float32x2_t w3 = c_mul_neon(w2, w);
const float32x2_t w4 = c_mul_neon(w3, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
- float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
- float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
// Base-case prime transform
fft_5(a, b, c, d, e, w, w2, w3, w4);
// Store outputs
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
- wrapper::vstore(X + M * (k + 6 * Nx), d);
- wrapper::vstore(X + M * (k + 8 * Nx), e);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
}
w = c_mul_neon(w, w_m);
@@ -666,7 +666,7 @@ void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
template <bool first_stage>
-void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_7_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -690,9 +690,9 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
// Load inputs
if(first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
- const auto cd = wrapper::vloadq(x + k + 4 * Nx);
- const auto ef = wrapper::vloadq(x + k + 8 * Nx);
+ const auto ab = wrapper::vloadq(in + k);
+ const auto cd = wrapper::vloadq(in + k + 4 * Nx);
+ const auto ef = wrapper::vloadq(in + k + 8 * Nx);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
@@ -703,41 +703,41 @@ void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
- c = wrapper::vload(x + k + 4 * Nx);
- d = wrapper::vload(x + k + 6 * Nx);
- e = wrapper::vload(x + k + 8 * Nx);
- f = wrapper::vload(x + k + 10 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
+ d = wrapper::vload(in + k + 6 * Nx);
+ e = wrapper::vload(in + k + 8 * Nx);
+ f = wrapper::vload(in + k + 10 * Nx);
}
- g = wrapper::vload(x + k + 12 * Nx);
+ g = wrapper::vload(in + k + 12 * Nx);
// Base-case prime transform
fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
if(first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
- wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
- wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
+ wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
- wrapper::vstore(X + k + 4 * Nx, c);
- wrapper::vstore(X + k + 6 * Nx, d);
- wrapper::vstore(X + k + 8 * Nx, e);
- wrapper::vstore(X + k + 10 * Nx, f);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
+ wrapper::vstore(out + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 6 * Nx, d);
+ wrapper::vstore(out + k + 8 * Nx, e);
+ wrapper::vstore(out + k + 10 * Nx, f);
}
- wrapper::vstore(X + k + 12 * Nx, g);
+ wrapper::vstore(out + k + 12 * Nx, g);
}
w = c_mul_neon(w, w_m);
}
}
-void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_7_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -748,28 +748,28 @@ void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
const float32x2_t w5 = c_mul_neon(w4, w);
const float32x2_t w6 = c_mul_neon(w5, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
- float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
- float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
- float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
- float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
+ float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
+ float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
// Base-case prime transform
fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
// Store outputs
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
- wrapper::vstore(X + M * (k + 6 * Nx), d);
- wrapper::vstore(X + M * (k + 8 * Nx), e);
- wrapper::vstore(X + M * (k + 10 * Nx), f);
- wrapper::vstore(X + M * (k + 12 * Nx), g);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
}
w = c_mul_neon(w, w_m);
@@ -777,7 +777,7 @@ void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
template <bool first_stage>
-void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
+void fft_radix_8_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -804,10 +804,10 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
// Base-case prime transform
if(first_stage)
{
- const auto ab = wrapper::vloadq(x + k);
- const auto cd = wrapper::vloadq(x + k + 4 * Nx);
- const auto ef = wrapper::vloadq(x + k + 8 * Nx);
- const auto gh = wrapper::vloadq(x + k + 12 * Nx);
+ const auto ab = wrapper::vloadq(in + k);
+ const auto cd = wrapper::vloadq(in + k + 4 * Nx);
+ const auto ef = wrapper::vloadq(in + k + 8 * Nx);
+ const auto gh = wrapper::vloadq(in + k + 12 * Nx);
a = wrapper::vgetlow(ab);
b = wrapper::vgethigh(ab);
@@ -820,14 +820,14 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
else
{
- a = wrapper::vload(x + k);
- b = wrapper::vload(x + k + 2 * Nx);
- c = wrapper::vload(x + k + 4 * Nx);
- d = wrapper::vload(x + k + 6 * Nx);
- e = wrapper::vload(x + k + 8 * Nx);
- f = wrapper::vload(x + k + 10 * Nx);
- g = wrapper::vload(x + k + 12 * Nx);
- h = wrapper::vload(x + k + 14 * Nx);
+ a = wrapper::vload(in + k);
+ b = wrapper::vload(in + k + 2 * Nx);
+ c = wrapper::vload(in + k + 4 * Nx);
+ d = wrapper::vload(in + k + 6 * Nx);
+ e = wrapper::vload(in + k + 8 * Nx);
+ f = wrapper::vload(in + k + 10 * Nx);
+ g = wrapper::vload(in + k + 12 * Nx);
+ h = wrapper::vload(in + k + 14 * Nx);
}
// Apply twiddle factors
@@ -836,21 +836,21 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
// Store outputs
if(first_stage)
{
- wrapper::vstore(X + k, wrapper::vcombine(a, b));
- wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
- wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
- wrapper::vstore(X + k + 12 * Nx, wrapper::vcombine(g, h));
+ wrapper::vstore(out + k, wrapper::vcombine(a, b));
+ wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
+ wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
+ wrapper::vstore(out + k + 12 * Nx, wrapper::vcombine(g, h));
}
else
{
- wrapper::vstore(X + k, a);
- wrapper::vstore(X + k + 2 * Nx, b);
- wrapper::vstore(X + k + 4 * Nx, c);
- wrapper::vstore(X + k + 6 * Nx, d);
- wrapper::vstore(X + k + 8 * Nx, e);
- wrapper::vstore(X + k + 10 * Nx, f);
- wrapper::vstore(X + k + 12 * Nx, g);
- wrapper::vstore(X + k + 14 * Nx, h);
+ wrapper::vstore(out + k, a);
+ wrapper::vstore(out + k + 2 * Nx, b);
+ wrapper::vstore(out + k + 4 * Nx, c);
+ wrapper::vstore(out + k + 6 * Nx, d);
+ wrapper::vstore(out + k + 8 * Nx, e);
+ wrapper::vstore(out + k + 10 * Nx, f);
+ wrapper::vstore(out + k + 12 * Nx, g);
+ wrapper::vstore(out + k + 14 * Nx, h);
}
}
@@ -858,7 +858,7 @@ void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadi
}
}
-void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
+void fft_radix_8_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
{
float32x2_t w{ 1.0f, 0.0f };
for(unsigned int j = 0; j < Nx; j++)
@@ -870,30 +870,30 @@ void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadi
const float32x2_t w6 = c_mul_neon(w5, w);
const float32x2_t w7 = c_mul_neon(w6, w);
- for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
+ for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
{
// Load inputs
- float32x2_t a = wrapper::vload(x + M * k);
- float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
- float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
- float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
- float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
- float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
- float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
- float32x2_t h = wrapper::vload(x + M * (k + 14 * Nx));
+ float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
+ float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
+ float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
+ float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
+ float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
+ float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
+ float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
+ float32x2_t h = wrapper::vload(in + (N + in_pad_x) * (k + 14 * Nx));
// Base-case prime transform
fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
// Store outputs
- wrapper::vstore(X + M * k, a);
- wrapper::vstore(X + M * (k + 2 * Nx), b);
- wrapper::vstore(X + M * (k + 4 * Nx), c);
- wrapper::vstore(X + M * (k + 6 * Nx), d);
- wrapper::vstore(X + M * (k + 8 * Nx), e);
- wrapper::vstore(X + M * (k + 10 * Nx), f);
- wrapper::vstore(X + M * (k + 12 * Nx), g);
- wrapper::vstore(X + M * (k + 14 * Nx), h);
+ wrapper::vstore(out + (N + out_pad_x) * k, a);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
+ wrapper::vstore(out + (N + out_pad_x) * (k + 14 * Nx), h);
}
w = c_mul_neon(w, w_m);
@@ -933,7 +933,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
} // namespace
NEFFTRadixStageKernel::NEFFTRadixStageKernel()
- : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
+ : _input(nullptr), _output(nullptr), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
{
}
@@ -992,12 +992,11 @@ void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFT
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
- _input = input;
- _output = output;
- _run_in_place = (output == nullptr) || (output == input);
- _Nx = config.Nx;
- _axis = config.axis;
- _radix = config.radix;
+ _input = input;
+ _output = (output == nullptr) ? input : output;
+ _Nx = config.Nx;
+ _axis = config.axis;
+ _radix = config.radix;
switch(config.axis)
{
@@ -1013,7 +1012,7 @@ void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFT
}
// Configure kernel window
- auto win_config = validate_and_configure_window(input->info(), (_run_in_place) ? nullptr : output->info(), config);
+ auto win_config = validate_and_configure_window(input->info(), (output != nullptr) ? output->info() : nullptr, config);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
}
@@ -1045,7 +1044,7 @@ void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
input_window.set(_axis, 0);
Iterator in(_input, input_window);
- Iterator out(_run_in_place ? _input : _output, input_window);
+ Iterator out(_output, input_window);
// Precompute FFT constants
const unsigned int NxRadix = _radix * _Nx;
@@ -1067,7 +1066,9 @@ void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
const unsigned int M = _input->info()->dimension(1);
execute_window_loop(input_window, [&](const Coordinates &)
{
- _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M);
+ _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M,
+ _input->info()->padding().right + _input->info()->padding().left,
+ _output->info()->padding().right + _output->info()->padding().left);
},
in, out);
}
diff --git a/src/core/NEON/kernels/NEFFTRadixStageKernel.h b/src/core/NEON/kernels/NEFFTRadixStageKernel.h
index 8a695b790f..2291a1068c 100644
--- a/src/core/NEON/kernels/NEFFTRadixStageKernel.h
+++ b/src/core/NEON/kernels/NEFFTRadixStageKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -85,7 +85,6 @@ public:
private:
ITensor *_input;
ITensor *_output;
- bool _run_in_place;
unsigned int _Nx;
unsigned int _axis;
unsigned int _radix;
@@ -94,7 +93,7 @@ private:
void set_radix_stage_axis1(const FFTRadixStageKernelInfo &config);
using FFTFunctionPointerAxis0 = std::function<void(float *, float *, unsigned int, unsigned int, const float32x2_t &, unsigned int)>;
- using FFTFunctionPointerAxis1 = std::function<void(float *, float *, unsigned int, unsigned int, const float32x2_t &, unsigned int, unsigned int)>;
+ using FFTFunctionPointerAxis1 = std::function<void(float *, float *, unsigned int, unsigned int, const float32x2_t &, unsigned int, unsigned int, unsigned int, unsigned int)>;
FFTFunctionPointerAxis0 _func_0;
FFTFunctionPointerAxis1 _func_1;
diff --git a/tests/validation/fixtures/FFTFixture.h b/tests/validation/fixtures/FFTFixture.h
index 3a75135718..fc6b9df8de 100644
--- a/tests/validation/fixtures/FFTFixture.h
+++ b/tests/validation/fixtures/FFTFixture.h
@@ -91,8 +91,7 @@ protected:
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
- // TODO: uncomment after COMPMID-4362
- // add_padding_x({ &src, &dst });
+ add_padding_x({ &src, &dst });
// Allocate tensors
src.allocator()->allocate();