diff options
Diffstat (limited to 'src/core/CL/kernels/CLFFTRadixStageKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLFFTRadixStageKernel.cpp | 54 |
1 files changed, 29 insertions, 25 deletions
diff --git a/src/core/CL/kernels/CLFFTRadixStageKernel.cpp b/src/core/CL/kernels/CLFFTRadixStageKernel.cpp index 5db3cb6bf2..3729e6b77d 100644 --- a/src/core/CL/kernels/CLFFTRadixStageKernel.cpp +++ b/src/core/CL/kernels/CLFFTRadixStageKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,8 @@ #include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/StringUtils.h" + #include "src/core/CL/CLValidate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -45,11 +47,11 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON(CLFFTRadixStageKernel::supported_radix().count(config.radix) == 0); - ARM_COMPUTE_RETURN_ERROR_ON(std::set<unsigned int>({ 0, 1 }).count(config.axis) == 0); + ARM_COMPUTE_RETURN_ERROR_ON(std::set<unsigned int>({0, 1}).count(config.axis) == 0); ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape()[config.axis] % config.radix); // Checks performed when output is configured - if((output != nullptr) && (output->total_size() != 0)) + if ((output != nullptr) && (output->total_size() != 0)) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); @@ -58,9 +60,10 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c return Status{}; } -std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config) +std::pair<Status, Window> +validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config) { - if(output != nullptr) + if (output != nullptr) { auto_init_if_empty(*output, *input); } @@ -70,18 +73,14 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen steps.set(config.axis, config.radix); Window win = calculate_max_window(*input, steps); - if(output != nullptr) - { - output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape())); - } return std::make_pair(Status{}, win); } } // namespace -CLFFTRadixStageKernel::CLFFTRadixStageKernel() - : _input(nullptr), _output(nullptr), _run_in_place(false) +CLFFTRadixStageKernel::CLFFTRadixStageKernel() : _input(nullptr), _output(nullptr), _run_in_place(false) { + _type = CLKernelType::ELEMENTWISE; } void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const FFTRadixStageKernelInfo &config) @@ -89,11 +88,15 @@ void CLFFTRadixStageKernel::configure(ICLTensor *input, ICLTensor *output, const configure(CLKernelLibrary::get().get_compile_context(), input, output, config); } -void CLFFTRadixStageKernel::configure(const CLCompileContext &compile_context, ICLTensor *input, ICLTensor *output, const FFTRadixStageKernelInfo &config) +void CLFFTRadixStageKernel::configure(const CLCompileContext &compile_context, + ICLTensor *input, + ICLTensor *output, + const FFTRadixStageKernelInfo &config) { ARM_COMPUTE_ERROR_ON_NULLPTR(input); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config)); - auto padding_info = get_padding_info({ input, output }); + ARM_COMPUTE_ERROR_THROW_ON( + validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config)); + auto padding_info = get_padding_info({input, output}); _input = input; _output = output; @@ -112,11 +115,12 @@ void CLFFTRadixStageKernel::configure(const CLCompileContext &compile_context, I _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); // Set static arguments if not the first stage - if(!config.is_first_stage) + if (!config.is_first_stage) { const unsigned int Ni = config.Nx * config.radix; const float exp_const = (-2.0 * M_PI) / static_cast<float>(Ni); - unsigned int idx = (1 + (_run_in_place ? 0 : 1)) * num_arguments_per_3D_tensor(); // Skip the input and output parameters + unsigned int idx = + (1 + (_run_in_place ? 0 : 1)) * num_arguments_per_3D_tensor(); // Skip the input and output parameters _kernel.setArg<cl_uint>(idx++, config.Nx); _kernel.setArg<cl_uint>(idx++, Ni); _kernel.setArg<cl_float>(idx, exp_const); @@ -138,21 +142,22 @@ void CLFFTRadixStageKernel::configure(const CLCompileContext &compile_context, I ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); } -Status CLFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config) +Status CLFFTRadixStageKernel::validate(const ITensorInfo *input, + const ITensorInfo *output, + const FFTRadixStageKernelInfo &config) { const bool run_in_place = (output == nullptr) || (output == input); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), - (run_in_place) ? nullptr : output->clone().get(), - config) - .first); + ARM_COMPUTE_RETURN_ON_ERROR( + validate_and_configure_window(input->clone().get(), (run_in_place) ? nullptr : output->clone().get(), config) + .first); return Status{}; } std::set<unsigned int> CLFFTRadixStageKernel::supported_radix() { - return std::set<unsigned int> { 2, 3, 4, 5, 7, 8 }; + return std::set<unsigned int>{2, 3, 4, 5, 7, 8}; } void CLFFTRadixStageKernel::run(const Window &window, cl::CommandQueue &queue) @@ -167,12 +172,11 @@ void CLFFTRadixStageKernel::run(const Window &window, cl::CommandQueue &queue) { unsigned int idx = 0; add_3D_tensor_argument(idx, _input, slice); - if(!_run_in_place) + if (!_run_in_place) { add_3D_tensor_argument(idx, _output, slice); } enqueue(queue, *this, slice, lws_hint()); - } - while(collapsed.slide_window_slice_3D(slice)); + } while (collapsed.slide_window_slice_3D(slice)); } } // namespace arm_compute |