diff options
Diffstat (limited to 'src/core/CL/kernels/CLDequantizationLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLDequantizationLayerKernel.cpp | 46 |
1 files changed, 20 insertions, 26 deletions
diff --git a/src/core/CL/kernels/CLDequantizationLayerKernel.cpp b/src/core/CL/kernels/CLDequantizationLayerKernel.cpp index e653c59550..e2c49fbf66 100644 --- a/src/core/CL/kernels/CLDequantizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLDequantizationLayerKernel.cpp @@ -53,22 +53,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output) return Status{}; } - -std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) -{ - // Configure kernel window - Window win = calculate_max_window(*input, Steps()); - - // Output tensor auto initialization if not yet initialized - auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::F32); - - // CLDequantizationLayerKernel doesn't need padding so update_window_and_padding() can be skipped - Coordinates coord; - coord.set_num_dimensions(output->num_dimensions()); - output->set_valid_region(ValidRegion(coord, output->tensor_shape())); - - return std::make_tuple(Status{}, win); -} } // namespace CLDequantizationLayerKernel::CLDequantizationLayerKernel() @@ -84,6 +68,12 @@ void CLDequantizationLayerKernel::configure(const ICLTensor *input, ICLTensor *o void CLDequantizationLayerKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + + // Output tensor auto initialization if not yet initialized + auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, DataType::F32); + + auto padding_info = get_padding_info({ input, output }); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info())); _input = input; @@ -93,15 +83,6 @@ void CLDequantizationLayerKernel::configure(const CLCompileContext &compile_cont const int output_width_x = output->info()->tensor_shape().x(); const bool multi_access_x = (output_width_x / vec_size_x > 0); - // Create and update the window (if needed) - Window win = calculate_max_window(*output->info()); - if(multi_access_x) - { - win.set(Window::DimX, - Window::Dimension(win.x().start(), ceil_to_multiple(win.x().end(), vec_size_x), vec_size_x)); - } - ICLKernel::configure_internal(win); - const bool is_quantized_per_channel = is_data_type_quantized_per_channel(input->info()->data_type()); std::string kernel_name = "dequantization_layer"; @@ -127,12 +108,25 @@ void CLDequantizationLayerKernel::configure(const CLCompileContext &compile_cont // Create kernel name _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); + + // Configure kernel window + Window win = calculate_max_window(*output->info()); + if(multi_access_x) + { + win.set(Window::DimX, + Window::Dimension(win.x().start(), ceil_to_multiple(win.x().end(), vec_size_x), vec_size_x)); + } + ICLKernel::configure_internal(win); + + // Set output valid region + output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape())); + + ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); } Status CLDequantizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output)); - ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get()))); return Status{}; } |