diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp')
-rw-r--r-- | src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp | 41 |
1 files changed, 28 insertions, 13 deletions
diff --git a/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp b/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp index 1f09515b86..40b1eaca1f 100644 --- a/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp +++ b/src/gpu/cl/kernels/ClElementwiseUnaryKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -38,6 +38,8 @@ namespace kernels { namespace { +constexpr unsigned int vector_size_byte_opencl = 16; + Status validate_arguments(const ITensorInfo &src, const ITensorInfo &dst, const ElementWiseUnary op) { ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&src); @@ -49,6 +51,10 @@ Status validate_arguments(const ITensorInfo &src, const ITensorInfo &dst, const { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32, DataType::S32); } + else if(op == ElementWiseUnary::RSQRT) // Allow quantized types for only RSQRT. + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED); + } else { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::F16, DataType::F32); @@ -78,17 +84,29 @@ void ClElementWiseUnaryKernel::configure(const CLCompileContext &compile_context auto padding_info = get_padding_info({ src, dst }); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src, *dst, op)); + const unsigned int num_elems_processed_per_iteration = adjust_vec_size(vector_size_byte_opencl / dst->element_size(), dst->dimension(0)); - const std::string kernel_name = "elementwise_unary"; - const int vec_size_x = 16 / dst->element_size(); - const int dst_width_x = dst->tensor_shape().x(); - const bool multi_access_x = (dst_width_x / vec_size_x > 0); - + std::string kernel_name = "elementwise_unary"; + const int vec_size_x = num_elems_processed_per_iteration; + const int dst_width_x = dst->dimension(0); + if(is_data_type_quantized(src->data_type())) + { + kernel_name += "_quantized"; + } // Set kernel build options CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src->data_type())); - build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); - build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(dst_width_x - vec_size_x, 0))); + build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x)); + build_opts.add_option("-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(dst_width_x - vec_size_x, 0))); + if(is_data_type_quantized(src->data_type())) + { + const UniformQuantizationInfo iqinfo = src->quantization_info().uniform(); + const UniformQuantizationInfo oqinfo = dst->quantization_info().uniform(); + build_opts.add_option("-DOFFSET_IN=" + support::cpp11::to_string(iqinfo.offset)); + build_opts.add_option("-DOFFSET_OUT=" + support::cpp11::to_string(oqinfo.offset)); + build_opts.add_option("-DSCALE_IN=" + float_to_string_with_full_precision(iqinfo.scale)); + build_opts.add_option("-DSCALE_OUT=" + float_to_string_with_full_precision(oqinfo.scale)); + } switch(op) { case ElementWiseUnary::RSQRT: @@ -124,11 +142,8 @@ void ClElementWiseUnaryKernel::configure(const CLCompileContext &compile_context // Configure kernel window Window win = calculate_max_window(*dst); - if(multi_access_x) - { - win.set(Window::DimX, - Window::Dimension(win.x().start(), ceil_to_multiple(win.x().end(), vec_size_x), vec_size_x)); - } + win.set(Window::DimX, Window::Dimension(win.x().start(), ceil_to_multiple(win.x().end(), vec_size_x), vec_size_x)); + ICLKernel::configure_internal(win); ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); |