aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp149
1 files changed, 83 insertions, 66 deletions
diff --git a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
index 261437f07d..a5969cd497 100644
--- a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
+++ b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
@@ -28,6 +28,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
@@ -37,16 +38,19 @@ namespace arm_compute
{
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *idx, const FFTDigitReverseKernelInfo &config)
+Status validate_arguments(const ITensorInfo *input,
+ const ITensorInfo *output,
+ const ITensorInfo *idx,
+ const FFTDigitReverseKernelInfo &config)
{
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON(input->num_channels() > 2);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(idx, 1, DataType::U32);
- ARM_COMPUTE_RETURN_ERROR_ON(std::set<unsigned int>({ 0, 1 }).count(config.axis) == 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(std::set<unsigned int>({0, 1}).count(config.axis) == 0);
ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[config.axis] != idx->tensor_shape().x());
// Checks performed when output is configured
- if((output != nullptr) && (output->total_size() != 0))
+ if ((output != nullptr) && (output->total_size() != 0))
{
ARM_COMPUTE_RETURN_ERROR_ON(output->num_channels() != 2);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
@@ -56,7 +60,10 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *idx, const FFTDigitReverseKernelInfo &config)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input,
+ ITensorInfo *output,
+ ITensorInfo *idx,
+ const FFTDigitReverseKernelInfo &config)
{
ARM_COMPUTE_UNUSED(idx, config);
@@ -68,12 +75,14 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
}
} // namespace
-NEFFTDigitReverseKernel::NEFFTDigitReverseKernel()
- : _func(nullptr), _input(nullptr), _output(nullptr), _idx(nullptr)
+NEFFTDigitReverseKernel::NEFFTDigitReverseKernel() : _func(nullptr), _input(nullptr), _output(nullptr), _idx(nullptr)
{
}
-void NEFFTDigitReverseKernel::configure(const ITensor *input, ITensor *output, const ITensor *idx, const FFTDigitReverseKernelInfo &config)
+void NEFFTDigitReverseKernel::configure(const ITensor *input,
+ ITensor *output,
+ const ITensor *idx,
+ const FFTDigitReverseKernelInfo &config)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, idx);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), idx->info(), config));
@@ -91,11 +100,11 @@ void NEFFTDigitReverseKernel::configure(const ITensor *input, ITensor *output, c
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
- if(axis == 0)
+ if (axis == 0)
{
- if(is_input_complex)
+ if (is_input_complex)
{
- if(is_conj)
+ if (is_conj)
{
_func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0<true, true>;
}
@@ -109,11 +118,11 @@ void NEFFTDigitReverseKernel::configure(const ITensor *input, ITensor *output, c
_func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0<false, false>;
}
}
- else if(axis == 1)
+ else if (axis == 1)
{
- if(is_input_complex)
+ if (is_input_complex)
{
- if(is_conj)
+ if (is_conj)
{
_func = &NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1<true, true>;
}
@@ -133,10 +142,14 @@ void NEFFTDigitReverseKernel::configure(const ITensor *input, ITensor *output, c
}
}
-Status NEFFTDigitReverseKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *idx, const FFTDigitReverseKernelInfo &config)
+Status NEFFTDigitReverseKernel::validate(const ITensorInfo *input,
+ const ITensorInfo *output,
+ const ITensorInfo *idx,
+ const FFTDigitReverseKernelInfo &config)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, idx, config));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), idx->clone().get(), config).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ validate_and_configure_window(input->clone().get(), output->clone().get(), idx->clone().get(), config).first);
return Status{};
}
@@ -159,38 +172,40 @@ void NEFFTDigitReverseKernel::digit_reverse_kernel_axis_0(const Window &window)
std::vector<float> buffer_row_out(2 * N);
std::vector<float> buffer_row_in(2 * N);
- execute_window_loop(slice, [&](const Coordinates &)
- {
- if(is_input_complex)
+ execute_window_loop(
+ slice,
+ [&](const Coordinates &)
{
- // Load
- memcpy(buffer_row_in.data(), reinterpret_cast<float *>(in.ptr()), 2 * N * sizeof(float));
-
- // Shuffle
- for(size_t x = 0; x < 2 * N; x += 2)
+ if (is_input_complex)
{
- size_t idx = buffer_idx[x / 2];
- buffer_row_out[x] = buffer_row_in[2 * idx];
- buffer_row_out[x + 1] = (is_conj ? -buffer_row_in[2 * idx + 1] : buffer_row_in[2 * idx + 1]);
- }
- }
- else
- {
- // Load
- memcpy(buffer_row_in.data(), reinterpret_cast<float *>(in.ptr()), N * sizeof(float));
+ // Load
+ memcpy(buffer_row_in.data(), reinterpret_cast<float *>(in.ptr()), 2 * N * sizeof(float));
- // Shuffle
- for(size_t x = 0; x < N; ++x)
+ // Shuffle
+ for (size_t x = 0; x < 2 * N; x += 2)
+ {
+ size_t idx = buffer_idx[x / 2];
+ buffer_row_out[x] = buffer_row_in[2 * idx];
+ buffer_row_out[x + 1] = (is_conj ? -buffer_row_in[2 * idx + 1] : buffer_row_in[2 * idx + 1]);
+ }
+ }
+ else
{
- size_t idx = buffer_idx[x];
- buffer_row_out[2 * x] = buffer_row_in[idx];
+ // Load
+ memcpy(buffer_row_in.data(), reinterpret_cast<float *>(in.ptr()), N * sizeof(float));
+
+ // Shuffle
+ for (size_t x = 0; x < N; ++x)
+ {
+ size_t idx = buffer_idx[x];
+ buffer_row_out[2 * x] = buffer_row_in[idx];
+ }
}
- }
- // Copy back
- memcpy(reinterpret_cast<float *>(out.ptr()), buffer_row_out.data(), 2 * N * sizeof(float));
- },
- in, out);
+ // Copy back
+ memcpy(reinterpret_cast<float *>(out.ptr()), buffer_row_out.data(), 2 * N * sizeof(float));
+ },
+ in, out);
}
template <bool is_input_complex, bool is_conj>
@@ -215,39 +230,41 @@ void NEFFTDigitReverseKernel::digit_reverse_kernel_axis_1(const Window &window)
const size_t stride_z = _input->info()->strides_in_bytes()[2];
const size_t stride_w = _input->info()->strides_in_bytes()[3];
- execute_window_loop(slice, [&](const Coordinates & id)
- {
- auto *out_ptr = reinterpret_cast<float *>(out.ptr());
- auto *in_ptr = reinterpret_cast<float *>(_input->buffer() + id.z() * stride_z + id[3] * stride_w);
- const size_t y_shuffled = buffer_idx[id.y()];
-
- if(is_input_complex)
+ execute_window_loop(
+ slice,
+ [&](const Coordinates &id)
{
- // Shuffle the entire row into the output
- memcpy(out_ptr, in_ptr + 2 * Nx * y_shuffled, 2 * Nx * sizeof(float));
+ auto *out_ptr = reinterpret_cast<float *>(out.ptr());
+ auto *in_ptr = reinterpret_cast<float *>(_input->buffer() + id.z() * stride_z + id[3] * stride_w);
+ const size_t y_shuffled = buffer_idx[id.y()];
- // Conjugate if necessary
- if(is_conj)
+ if (is_input_complex)
{
- for(size_t x = 0; x < 2 * Nx; x += 2)
+ // Shuffle the entire row into the output
+ memcpy(out_ptr, in_ptr + 2 * Nx * y_shuffled, 2 * Nx * sizeof(float));
+
+ // Conjugate if necessary
+ if (is_conj)
{
- out_ptr[x + 1] = -out_ptr[x + 1];
+ for (size_t x = 0; x < 2 * Nx; x += 2)
+ {
+ out_ptr[x + 1] = -out_ptr[x + 1];
+ }
}
}
- }
- else
- {
- // Shuffle the entire row into the buffer
- memcpy(buffer_row.data(), in_ptr + Nx * y_shuffled, Nx * sizeof(float));
-
- // Copy the buffer to the output, with a zero imaginary part
- for(size_t x = 0; x < 2 * Nx; x += 2)
+ else
{
- out_ptr[x] = buffer_row[x / 2];
+ // Shuffle the entire row into the buffer
+ memcpy(buffer_row.data(), in_ptr + Nx * y_shuffled, Nx * sizeof(float));
+
+ // Copy the buffer to the output, with a zero imaginary part
+ for (size_t x = 0; x < 2 * Nx; x += 2)
+ {
+ out_ptr[x] = buffer_row[x / 2];
+ }
}
- }
- },
- out);
+ },
+ out);
}
void NEFFTDigitReverseKernel::run(const Window &window, const ThreadInfo &info)