aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLFFTRadixStageKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLFFTRadixStageKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLFFTRadixStageKernel.cpp19
1 files changed, 12 insertions, 7 deletions
diff --git a/src/core/CL/kernels/CLFFTRadixStageKernel.cpp b/src/core/CL/kernels/CLFFTRadixStageKernel.cpp
index 87a12b9da9..83d55b7092 100644
--- a/src/core/CL/kernels/CLFFTRadixStageKernel.cpp
+++ b/src/core/CL/kernels/CLFFTRadixStageKernel.cpp
@@ -38,12 +38,13 @@ namespace arm_compute
{
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelDescriptor &config)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
{
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(config.axis != 0);
ARM_COMPUTE_RETURN_ERROR_ON(CLFFTRadixStageKernel::supported_radix().count(config.radix) == 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(std::set<unsigned int>({ 0, 1 }).count(config.axis) == 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[config.axis] % config.radix);
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
@@ -55,14 +56,18 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelDescriptor &config)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
{
if(output != nullptr)
{
auto_init_if_empty(*output, *input);
}
- Window win = calculate_max_window(*input, Steps(config.radix));
+ // Setup window steps
+ Steps steps;
+ steps.set(config.axis, config.radix);
+
+ Window win = calculate_max_window(*input, steps);
if(output != nullptr)
{
output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
@@ -77,7 +82,7 @@ CLFFTRadixStageKernel::CLFFTRadixStageKernel()
{
}
-void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const FFTRadixStageKernelDescriptor &config)
+void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const FFTRadixStageKernelInfo &config)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
@@ -105,7 +110,7 @@ void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const
unsigned int idx = (1 + (_run_in_place ? 0 : 1)) * num_arguments_per_3D_tensor(); // Skip the input and output parameters
_kernel.setArg<cl_uint>(idx++, config.Nx);
_kernel.setArg<cl_uint>(idx++, Ni);
- _kernel.setArg<cl_float>(idx++, exp_const);
+ _kernel.setArg<cl_float>(idx, exp_const);
}
// Configure kernel window
@@ -123,7 +128,7 @@ void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const
_config_id += support::cpp11::to_string(input->info()->dimension(1));
}
-Status CLFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelDescriptor &config)
+Status CLFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
{
const bool run_in_place = (output == nullptr) || (output == input);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config));