diff options
Diffstat (limited to 'src/core/CL/kernels/CLFFTRadixStageKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLFFTRadixStageKernel.cpp | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/src/core/CL/kernels/CLFFTRadixStageKernel.cpp b/src/core/CL/kernels/CLFFTRadixStageKernel.cpp index 87a12b9da9..83d55b7092 100644 --- a/src/core/CL/kernels/CLFFTRadixStageKernel.cpp +++ b/src/core/CL/kernels/CLFFTRadixStageKernel.cpp @@ -38,12 +38,13 @@ namespace arm_compute { namespace { -Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelDescriptor &config) +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config) { ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON(config.axis != 0); ARM_COMPUTE_RETURN_ERROR_ON(CLFFTRadixStageKernel::supported_radix().count(config.radix) == 0); + ARM_COMPUTE_RETURN_ERROR_ON(std::set<unsigned int>({ 0, 1 }).count(config.axis) == 0); + ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[config.axis] % config.radix); // Checks performed when output is configured if((output != nullptr) && (output->total_size() != 0)) @@ -55,14 +56,18 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c return Status{}; } -std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelDescriptor &config) +std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config) { if(output != nullptr) { auto_init_if_empty(*output, *input); } - Window win = calculate_max_window(*input, Steps(config.radix)); + // Setup window steps + Steps steps; + steps.set(config.axis, config.radix); + + Window win = calculate_max_window(*input, steps); if(output != nullptr) { output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape())); @@ -77,7 +82,7 @@ CLFFTRadixStageKernel::CLFFTRadixStageKernel() { } -void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const FFTRadixStageKernelDescriptor &config) +void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const FFTRadixStageKernelInfo &config) { ARM_COMPUTE_ERROR_ON_NULLPTR(input); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config)); @@ -105,7 +110,7 @@ void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const unsigned int idx = (1 + (_run_in_place ? 0 : 1)) * num_arguments_per_3D_tensor(); // Skip the input and output parameters _kernel.setArg<cl_uint>(idx++, config.Nx); _kernel.setArg<cl_uint>(idx++, Ni); - _kernel.setArg<cl_float>(idx++, exp_const); + _kernel.setArg<cl_float>(idx, exp_const); } // Configure kernel window @@ -123,7 +128,7 @@ void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const _config_id += support::cpp11::to_string(input->info()->dimension(1)); } -Status CLFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelDescriptor &config) +Status CLFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config) { const bool run_in_place = (output == nullptr) || (output == input); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config)); |