From 201e0fee596dafcf9c869a550fae29779aad2394 Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Wed, 27 Jan 2021 13:14:56 +0000 Subject: Make Softmax kernels on OpenCL stateless * ClSoftmaxKernel and ClSoftmax are created. * ClSoftmaxKernel is now state-less and ClSoftmax handles the internal tensors required for computation. * add_const_tensor() is added to TensorPack not only to have symmetric interface but also to benefit from implicit conversion. Implements: COMPMID-3998 Change-Id: I4f823121777be24260fd12b2cd71a6ff718c4eed Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5087 Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- Android.bp | 3 +- arm_compute/core/ITensorPack.h | 7 + arm_compute/core/KernelDescriptors.h | 1 + arm_compute/core/experimental/Types.h | 1 + arm_compute/runtime/CL/functions/CLSoftmaxLayer.h | 40 +-- docs/00_introduction.dox | 6 +- src/core/CL/CLKernels.h | 1 - src/core/CL/kernels/CLSoftmaxLayerKernel.cpp | 370 ---------------------- src/core/CL/kernels/CLSoftmaxLayerKernel.h | 158 --------- src/core/ITensorPack.cpp | 7 +- src/core/gpu/cl/kernels/ClPermuteKernel.h | 1 + src/core/gpu/cl/kernels/ClSoftmaxKernel.cpp | 355 +++++++++++++++++++++ src/core/gpu/cl/kernels/ClSoftmaxKernel.h | 126 ++++++++ src/runtime/CL/functions/CLSoftmaxLayer.cpp | 172 ++++------ src/runtime/gpu/cl/operators/ClSoftmax.cpp | 276 ++++++++++++++++ src/runtime/gpu/cl/operators/ClSoftmax.h | 119 +++++++ 16 files changed, 965 insertions(+), 678 deletions(-) delete mode 100644 src/core/CL/kernels/CLSoftmaxLayerKernel.cpp delete mode 100644 src/core/CL/kernels/CLSoftmaxLayerKernel.h create mode 100644 src/core/gpu/cl/kernels/ClSoftmaxKernel.cpp create mode 100644 src/core/gpu/cl/kernels/ClSoftmaxKernel.h create mode 100644 src/runtime/gpu/cl/operators/ClSoftmax.cpp create mode 100644 src/runtime/gpu/cl/operators/ClSoftmax.h diff --git a/Android.bp b/Android.bp index 93ce568936..8fd47514ee 100644 --- a/Android.bp +++ b/Android.bp @@ -146,7 +146,6 @@ cc_library_static { "src/core/CL/kernels/CLReverseKernel.cpp", "src/core/CL/kernels/CLScaleKernel.cpp", "src/core/CL/kernels/CLSelectKernel.cpp", - "src/core/CL/kernels/CLSoftmaxLayerKernel.cpp", "src/core/CL/kernels/CLSpaceToBatchLayerKernel.cpp", "src/core/CL/kernels/CLSpaceToDepthLayerKernel.cpp", "src/core/CL/kernels/CLStackLayerKernel.cpp", @@ -387,6 +386,7 @@ cc_library_static { "src/core/gpu/cl/kernels/ClPermuteKernel.cpp", "src/core/gpu/cl/kernels/ClPoolingKernel.cpp", "src/core/gpu/cl/kernels/ClReshapeKernel.cpp", + "src/core/gpu/cl/kernels/ClSoftmaxKernel.cpp", "src/core/gpu/cl/kernels/ClWidthConcatenate2TensorsKernel.cpp", "src/core/gpu/cl/kernels/ClWidthConcatenate4TensorsKernel.cpp", "src/core/gpu/cl/kernels/ClWidthConcatenateKernel.cpp", @@ -681,6 +681,7 @@ cc_library_static { "src/runtime/gpu/cl/operators/ClPermute.cpp", "src/runtime/gpu/cl/operators/ClPooling.cpp", "src/runtime/gpu/cl/operators/ClReshape.cpp", + "src/runtime/gpu/cl/operators/ClSoftmax.cpp", "src/runtime/gpu/cl/operators/ClSub.cpp", "utils/CommonGraphOptions.cpp", "utils/GraphUtils.cpp", diff --git a/arm_compute/core/ITensorPack.h b/arm_compute/core/ITensorPack.h index c06e1d9a73..8aea880bb6 100644 --- a/arm_compute/core/ITensorPack.h +++ b/arm_compute/core/ITensorPack.h @@ -69,6 +69,13 @@ public: * @param[in] tensor Tensor to add */ void add_tensor(int id, const ITensor *tensor); + + /** Add const tensor to the pack + * + * @param[in] id ID/type of the tensor to add + * @param[in] tensor Tensor to add + */ + void add_const_tensor(int id, const ITensor *tensor); /** Get tensor of a given id from the pac * * @param[in] id ID of tensor to extract diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h index e381220695..1f3cee2dd1 100644 --- a/arm_compute/core/KernelDescriptors.h +++ b/arm_compute/core/KernelDescriptors.h @@ -114,6 +114,7 @@ struct SoftmaxKernelInfo float beta{ 1.f }; /**< A scaling factor for the exponent with default value 1.0 */ bool is_log{ false }; /**< Flag used to perform Log Softmax operation */ DataType input_data_type{ DataType::UNKNOWN }; /**< Input tensor data type */ + int32_t axis{ 0 }; /**< The dimension in which to apply softmax. */ }; /** Descriptor used by the direct convolution layer output stage kernels */ diff --git a/arm_compute/core/experimental/Types.h b/arm_compute/core/experimental/Types.h index f615678e31..2a4bd89385 100644 --- a/arm_compute/core/experimental/Types.h +++ b/arm_compute/core/experimental/Types.h @@ -52,6 +52,7 @@ enum TensorType : int32_t ACL_INT_1 = 51, ACL_INT_2 = 52, ACL_INT_3 = 53, + ACL_INT_4 = 54, ACL_SRC_VEC = 256, }; diff --git a/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h b/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h index ab10a64de4..ddb35ae56f 100644 --- a/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h +++ b/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,8 +24,6 @@ #ifndef ARM_COMPUTE_CLSOFTMAXLAYER_H #define ARM_COMPUTE_CLSOFTMAXLAYER_H -#include "arm_compute/runtime/CL/CLTensor.h" -#include "arm_compute/runtime/CL/functions/CLPermute.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" #include "arm_compute/runtime/MemoryGroup.h" @@ -34,11 +32,9 @@ namespace arm_compute { -class CLCompileContext; -class CLLogits1DMaxShiftExpSumKernel; -class CLLogits1DNormKernel; class ICLTensor; class ITensorInfo; +class CLCompileContext; /** Basic function to compute a SoftmaxLayer. * @@ -48,11 +44,11 @@ class ITensorInfo; * Log Softmax is calculated by : * @f[ out = (x - max(x) * beta) - log(\sum{e^{x - max(x) * beta}}) @f] * - * This function runs the following kernels: + * This function runs the following operators/kernels: * -# If axis is not 0: - * -# @ref CLPermute - * -# @ref CLLogits1DNormKernel - * -# @ref CLLogits1DMaxShiftExpSumKernel + * -# @ref opencl::ClPermute + * -# @ref opencl::kernels::ClLogits1DNormKernel + * -# @ref opencl::kernels::ClLogits1DMaxShiftExpSumKernel */ template class CLSoftmaxLayerGeneric : public IFunction @@ -60,14 +56,6 @@ class CLSoftmaxLayerGeneric : public IFunction public: /** Constructor */ CLSoftmaxLayerGeneric(std::shared_ptr memory_manager = nullptr); - /** Prevent instances of this class from being copied */ - CLSoftmaxLayerGeneric(const CLSoftmaxLayerGeneric &) = delete; - /** Prevent instances of this class from being copied */ - CLSoftmaxLayerGeneric &operator=(const CLSoftmaxLayerGeneric &) = delete; - /** Prevent instances of this class to be moved */ - CLSoftmaxLayerGeneric(CLSoftmaxLayerGeneric &&) = delete; - /** Prevent instances of this class to be moved */ - CLSoftmaxLayerGeneric &operator=(CLSoftmaxLayerGeneric &&) = delete; /** Default destructor */ ~CLSoftmaxLayerGeneric(); /** Set the input and output tensors. @@ -105,17 +93,11 @@ public: void run() override; private: - MemoryGroup _memory_group; - CLPermute _permute_input; - CLPermute _permute_output; - std::unique_ptr _max_shift_exp_sum_kernel; - std::unique_ptr _norm_kernel; - CLTensor _max; - CLTensor _sum; - CLTensor _tmp; - CLTensor _input_permuted; - CLTensor _output_permuted; - bool _needs_permute; + struct Impl; + std::unique_ptr _impl; + + /** Allocate workspace required by the operator */ + void allocate_workspace(); }; using CLSoftmaxLayer = CLSoftmaxLayerGeneric; diff --git a/docs/00_introduction.dox b/docs/00_introduction.dox index 8616cb6d13..eb8256f61d 100644 --- a/docs/00_introduction.dox +++ b/docs/00_introduction.dox @@ -295,8 +295,8 @@ v20.11 Public major release - CLWidthConcatenateLayerKernel - CLWidthConcatenate4TensorsKernel - CLWidthConcatenate2TensorsKernel - - @ref CLLogits1DMaxShiftExpSumKernel - - @ref CLLogits1DNormKernel + - CLLogits1DMaxShiftExpSumKernel + - CLLogits1DNormKernel - CLHeightConcatenateLayerKernel - @ref CLGEMMMatrixMultiplyKernel - @ref CLGEMMLowpQuantizeDownInt32ScaleKernel @@ -1400,7 +1400,7 @@ v17.03 Sources preview v17.02.1 Sources preview - New OpenCL kernels / functions: - - CLLogits1DMaxKernel, CLLogits1DShiftExpSumKernel, @ref CLLogits1DNormKernel / @ref CLSoftmaxLayer + - CLLogits1DMaxKernel, CLLogits1DShiftExpSumKernel, CLLogits1DNormKernel / @ref CLSoftmaxLayer - CLPoolingLayerKernel / @ref CLPoolingLayer - @ref CLIm2ColKernel, @ref CLCol2ImKernel, CLConvolutionLayerWeightsReshapeKernel / CLConvolutionLayer - @ref CLRemapKernel / @ref CLRemap diff --git a/src/core/CL/CLKernels.h b/src/core/CL/CLKernels.h index 22c9cd9c0c..45e27f2b1b 100644 --- a/src/core/CL/CLKernels.h +++ b/src/core/CL/CLKernels.h @@ -89,7 +89,6 @@ #include "src/core/CL/kernels/CLReverseKernel.h" #include "src/core/CL/kernels/CLScaleKernel.h" #include "src/core/CL/kernels/CLSelectKernel.h" -#include "src/core/CL/kernels/CLSoftmaxLayerKernel.h" #include "src/core/CL/kernels/CLSpaceToBatchLayerKernel.h" #include "src/core/CL/kernels/CLSpaceToDepthLayerKernel.h" #include "src/core/CL/kernels/CLStackLayerKernel.h" diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp deleted file mode 100644 index 526d9e187d..0000000000 --- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp +++ /dev/null @@ -1,370 +0,0 @@ -/* - * Copyright (c) 2017-2020 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "src/core/CL/kernels/CLSoftmaxLayerKernel.h" - -#include "arm_compute/core/utils/quantization/AsymmHelpers.h" -#include "src/core/CL/CLValidate.h" -#include "src/core/helpers/AutoConfiguration.h" -#include "src/core/helpers/WindowHelpers.h" -#include "support/StringSupport.h" - -namespace arm_compute -{ -namespace -{ -/** Calculates softmax parameters from the quantized input scale and scaling factor for the exponent and places them as build options. - * - * Prepares these build options: - * -INPUT_BETA_MULTIPLIER, INPUT_BETA_LEFT_SHIFT - quantized representation of beta multiplier. - * -DIFF_MIN - threshold difference between maximum value of input data and current processed value, - * it defines whether the value will be taken into account or not. - * - * @param[in] build_opts Build options to extend - * @param[in] input_scale Input scaling factor - * @param[in] beta Exponent scaling factor beta - */ -CLBuildOptions prepare_quantized_softmax_build_options(float input_scale, float beta) -{ - // Number of integer bits in temporary fixed-point representation of current-to-max difference - static const int scaled_diff_int_bits = 5; - // Number of integer bits used in temporary fixed-point representation of exponent accumulator - static const int exp_accumulation_in_bits = 12; - - const double beta_multiplier = std::min( - 1.0 * beta * input_scale * (1 << (31 - scaled_diff_int_bits)), - (1LL << 31) - 1.0); - int input_beta_multiplier; - int input_beta_left_shift; - quantization::calculate_quantized_multiplier_greater_than_one(beta_multiplier, &input_beta_multiplier, &input_beta_left_shift); - - const double max_input_rescaled = 1.0 * ((1 << scaled_diff_int_bits) - 1) * (1LL << (31 - scaled_diff_int_bits)) / (1LL << input_beta_left_shift); - const int diff_min = -1.f * std::floor(max_input_rescaled); - - CLBuildOptions build_opts; - build_opts.add_option("-DSCALED_DIFF_INT_BITS=" + support::cpp11::to_string(scaled_diff_int_bits)); - build_opts.add_option("-DEXP_ACCUMULATION_INT_BITS=" + support::cpp11::to_string(exp_accumulation_in_bits)); - build_opts.add_option("-DINPUT_BETA_MULTIPLIER=" + support::cpp11::to_string(input_beta_multiplier)); - build_opts.add_option("-DINPUT_BETA_LEFT_SHIFT=" + support::cpp11::to_string(input_beta_left_shift)); - build_opts.add_option("-DDIFF_MIN=" + support::cpp11::to_string(diff_min)); - - return build_opts; -} - -Status validate_arguments_1DMaxShiftExpSum(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum) -{ - ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(max, sum, output); - - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max); - - const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->data_type()); - - // Checks performed when output is configured - if(output->total_size() != 0) - { - if(is_quantized_asymmetric) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32); - } - else - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - } - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); - } - - // Checks performed when sum is configured - if(sum->total_size() != 0) - { - if(is_quantized_asymmetric) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(sum, 1, DataType::S32); - } - else - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(max, sum); - } - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum); - } - - return Status{}; -} - -Status validate_arguments_1DNorm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::S32, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(sum, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum); - ARM_COMPUTE_RETURN_ERROR_ON(info.is_log && !is_data_type_float(info.input_data_type)); - - // Note: output should always have a scale of 1/256 and offset 0 - const QuantizationInfo allowed_quantization_info = get_softmax_output_quantization_info(info.input_data_type, info.is_log); - const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(info.input_data_type); - - // Checks performed when output is configured - if(output->total_size() != 0) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); - if(!is_quantized_asymmetric) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - } - else - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); - ARM_COMPUTE_RETURN_ERROR_ON(output->quantization_info() != allowed_quantization_info); - } - } - - return Status{}; -} -} // namespace - -/**< Grid size (obtained through auto-tuning) */ -const unsigned int CLLogits1DMaxShiftExpSumKernel::_grid_size = 64; -/**< Vector size in the serial case (obtained through auto-tuning) */ -const unsigned int CLLogits1DMaxShiftExpSumKernel::_serial_vector_size = 8; -/**< Vector size in the parallel case (obtained through auto-tuning, enables the best memory access pattern for Bifrost) .*/ -const unsigned int CLLogits1DMaxShiftExpSumKernel::_parallel_vector_size = 4; - -CLLogits1DMaxShiftExpSumKernel::CLLogits1DMaxShiftExpSumKernel() - : _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr) -{ -} - -void CLLogits1DMaxShiftExpSumKernel::configure(const ICLTensor *input, ICLTensor *max, ICLTensor *output, ICLTensor *sum, const SoftmaxKernelInfo &info) -{ - configure(CLKernelLibrary::get().get_compile_context(), input, max, output, sum, info); -} - -void CLLogits1DMaxShiftExpSumKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *max, ICLTensor *output, ICLTensor *sum, const SoftmaxKernelInfo &info) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output); - - auto padding_info = get_padding_info({ input, max, output, sum }); - - // Output auto initialization if not yet initialized - auto_init_if_empty(*sum->info(), input->info()->clone()->set_tensor_shape(max->info()->tensor_shape())); - auto_init_if_empty(*output->info(), *input->info()->clone()); - - // Perform validation step - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_1DMaxShiftExpSum(input->info(), max->info(), output->info(), sum->info())); - - _input = input; - _max = max; - _output = output; - _sum = sum; - - const DataType dt = input->info()->data_type(); - const UniformQuantizationInfo qinfo = input->info()->quantization_info().uniform(); - const size_t reduction_dim_size = input->info()->dimension(0); - const float beta = info.beta; - const auto is_signed_qasymm8 = is_data_type_quantized_asymmetric_signed(info.input_data_type); - const int min_value = is_signed_qasymm8 ? CL_SCHAR_MIN : 0; - - ParallelReductionInfo parallel_reduction_info = is_parallel_reduction(reduction_dim_size); - const unsigned int vector_size = adjust_vec_size(std::get<1>(parallel_reduction_info), reduction_dim_size); - - // Set build options - CLBuildOptions build_opts; - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dt)); - build_opts.add_option("-DMIN_VALUE=" + support::cpp11::to_string(min_value)); - build_opts.add_option("-DVECTOR_SIZE=" + support::cpp11::to_string(vector_size)); - build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(reduction_dim_size)); - build_opts.add_option("-DVECTOR_SIZE_LEFTOVER=" + support::cpp11::to_string(reduction_dim_size % vector_size)); - build_opts.add_option("-DLOG_VECTOR_SIZE=" + support::cpp11::to_string(lround(log2(vector_size)))); - build_opts.add_option_if((reduction_dim_size % vector_size) != 0, "-DNON_MULTIPLE_OF_VECTOR_SIZE"); - build_opts.add_option_if(is_signed_qasymm8, "-DQASYMM8_SIGNED"); - build_opts.add_option_if(is_data_type_float(dt) && (beta != 1.0f), "-DBETA=" + float_to_string_with_full_precision(beta)); - build_opts.add_option_if(is_data_type_float(dt) && info.is_log, "-DLOG_SOFTMAX"); - build_opts.add_option_if(is_data_type_float(dt), "-DMINVAL=" + ((dt == DataType::F16) ? std::string("-HALF_MAX") : std::string("-FLT_MAX"))); - build_opts.add_options_if(is_data_type_quantized_asymmetric(dt), prepare_quantized_softmax_build_options(qinfo.scale, beta).options()); - - cl::NDRange lws_hint(cl::NullRange); - std::string kernel_name = std::string("softmax_layer_max_shift_exp_sum_") + (is_data_type_quantized_asymmetric(dt) ? "quantized_" : ""); - - // Configure parallel kernel if needed - if(std::get<0>(parallel_reduction_info)) - { - kernel_name += "parallel"; - bool is_grid_size_pow2 = (_grid_size != 0) && ((_grid_size & (_grid_size - 1)) == 0); - build_opts.add_option_if(is_grid_size_pow2 && _grid_size <= 256, "-DGRID_SIZE=" + support::cpp11::to_string(_grid_size)); - - // Handle boundary conditions. - const unsigned int multiple_grid_size = (reduction_dim_size / vector_size) % _grid_size; - build_opts.add_option_if((multiple_grid_size != 0) || ((reduction_dim_size % vector_size) != 0), "-DNON_MULTIPLE_OF_GRID_SIZE"); - // Setting _lws_hint in this way can also communicate grid_size to CLLogits1DMaxShiftExpSumKernel::run(). - // A single workgroup performs reduction in dimension 0 in the parallel case, hence lws[0]==gws[0]. - lws_hint = cl::NDRange(_grid_size); - } - else - { - kernel_name += "serial"; - } - - // Create kernel. - _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); - - // Configure window - Window win = calculate_max_window(*(input->info()), Steps(reduction_dim_size)); - ICLKernel::configure_internal(win, lws_hint); - - ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); -} - -Status CLLogits1DMaxShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum) -{ - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_1DMaxShiftExpSum(input, max, output, sum)); - return Status{}; -} - -CLLogits1DMaxShiftExpSumKernel::ParallelReductionInfo CLLogits1DMaxShiftExpSumKernel::is_parallel_reduction(size_t size) -{ - bool is_parallel_reduction = (size >= (_grid_size * _serial_vector_size)) && (_grid_size > 1); - unsigned int vector_size = is_parallel_reduction ? _parallel_vector_size : _serial_vector_size; - return std::make_tuple(is_parallel_reduction, vector_size); -} - -void CLLogits1DMaxShiftExpSumKernel::run(const Window &window, cl::CommandQueue &queue) -{ - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); - - // Collapse window in Z dimension - Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); - - // Reconfigure window in case of parallel reduction - ParallelReductionInfo parallel_reduction_info = is_parallel_reduction(_input->info()->dimension(0)); - if(std::get<0>(parallel_reduction_info)) - { - // Launch grid_size parallel work items - window_collapsed.set(Window::DimX, Window::Dimension(0, _grid_size, 1)); - } - - // Get slices - Window slice = window_collapsed.first_slice_window_3D(); - do - { - unsigned int idx = 0; - // Set inputs - add_3D_tensor_argument(idx, _input, slice); - add_3D_tensor_argument(idx, _max, slice); - add_3D_tensor_argument(idx, _output, slice); - add_3D_tensor_argument(idx, _sum, slice); - enqueue(queue, *this, slice, lws_hint()); - } - while(window_collapsed.slide_window_slice_3D(slice)); -} - -CLLogits1DNormKernel::CLLogits1DNormKernel() - : _input(nullptr), _sum(nullptr), _output(nullptr) -{ -} - -void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *sum, ICLTensor *output, const SoftmaxKernelInfo &info) -{ - configure(CLKernelLibrary::get().get_compile_context(), input, sum, output, info); -} - -void CLLogits1DNormKernel::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *sum, ICLTensor *output, const SoftmaxKernelInfo &info) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output); - - auto padding_info = get_padding_info({ input, output, sum }); - - // Note: output should always have a scale of 1/256 and offset 0 - const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(info.input_data_type); - const DataType output_data_type = info.input_data_type; - const QuantizationInfo allowed_quantization_info = get_softmax_output_quantization_info(info.input_data_type, info.is_log); - const UniformQuantizationInfo qinfo = input->info()->quantization_info().uniform(); - - // Output auto initialization if not yet initialized - auto_init_if_empty(*output->info(), - input->info()->clone()->set_data_type(output_data_type).set_quantization_info(allowed_quantization_info)); - - // Perform validation step - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_1DNorm(input->info(), sum->info(), output->info(), info)); - - _input = input; - _sum = sum; - _output = output; - - const auto is_signed_qasymm8 = is_data_type_quantized_asymmetric_signed(info.input_data_type); - const int min_value = is_signed_qasymm8 ? CL_SCHAR_MIN : 0; - const unsigned int vector_size = adjust_vec_size(16, input->info()->dimension(0)); - - // Set build options - CLBuildOptions build_opts; - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(info.input_data_type)); - build_opts.add_option("-DMIN_VALUE=" + support::cpp11::to_string(min_value)); - build_opts.add_option("-DVECTOR_SIZE=" + support::cpp11::to_string(vector_size)); - build_opts.add_option("-DVECTOR_SIZE_LEFTOVER=" + support::cpp11::to_string(input->info()->dimension(0) % vector_size)); - build_opts.add_option_if(is_data_type_quantized_asymmetric_signed(info.input_data_type), "-DQASYMM8_SIGNED"); - build_opts.add_options_if(is_quantized_asymmetric, - prepare_quantized_softmax_build_options(qinfo.scale, info.beta).options()); - build_opts.add_option_if(info.is_log, "-DLOG_SOFTMAX"); - - // Create kernel - std::string kernel_name = std::string("softmax_layer_norm") + (is_quantized_asymmetric ? "_quantized" : ""); - _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); - - // Configure window - auto win = calculate_max_window(*(input->info()), Steps(vector_size)); - ICLKernel::configure_internal(win); - - ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); -} - -Status CLLogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info) -{ - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_1DNorm(input, sum, output, info)); - - return Status{}; -} - -void CLLogits1DNormKernel::run(const Window &window, cl::CommandQueue &queue) -{ - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); - - Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); - Window slice = window_collapsed.first_slice_window_3D(); - - do - { - Window sum_slice = slice; - sum_slice.set(Window::DimX, Window::Dimension(0, 1, 1)); - - unsigned int idx = 0; - // Set inputs - add_3D_tensor_argument(idx, _input, slice); - add_3D_tensor_argument(idx, _sum, sum_slice); - add_3D_tensor_argument(idx, _output, slice); - enqueue(queue, *this, slice, lws_hint()); - } - while(window_collapsed.slide_window_slice_3D(slice)); -} -} // namespace arm_compute \ No newline at end of file diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.h b/src/core/CL/kernels/CLSoftmaxLayerKernel.h deleted file mode 100644 index 29e0f63e46..0000000000 --- a/src/core/CL/kernels/CLSoftmaxLayerKernel.h +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Copyright (c) 2017-2020 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#ifndef ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H -#define ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H - -#include "arm_compute/core/KernelDescriptors.h" -#include "src/core/CL/ICLSimple3DKernel.h" - -namespace arm_compute -{ -class ICLTensor; - -/** Interface for max, shifting, exponentiating and summing the logits */ -class CLLogits1DMaxShiftExpSumKernel : public ICLKernel -{ -public: - /** Info for whether a parallel reduction will be run and the vector size of the execution. */ - using ParallelReductionInfo = std::tuple; - -public: - /** Default constructor */ - CLLogits1DMaxShiftExpSumKernel(); - /** Prevent instances of this class from being copied (As this class contains pointers) */ - CLLogits1DMaxShiftExpSumKernel(const CLLogits1DMaxShiftExpSumKernel &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - CLLogits1DMaxShiftExpSumKernel &operator=(const CLLogits1DMaxShiftExpSumKernel &) = delete; - /** Allow instances of this class to be moved */ - CLLogits1DMaxShiftExpSumKernel(CLLogits1DMaxShiftExpSumKernel &&) = default; - /** Allow instances of this class to be moved */ - CLLogits1DMaxShiftExpSumKernel &operator=(CLLogits1DMaxShiftExpSumKernel &&) = default; - /** Set the input and output tensors. - * - * @param[in] input Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 - * @param[in,out] max Max values tensor. Data types supported: same as @p input - * @param[out] output Destination tensor. Data types supported: same as @p input - * @param[out] sum Sum of 1D logits tensor. Data types supported: same as @p input - * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. - */ - void configure(const ICLTensor *input, ICLTensor *max, ICLTensor *output, ICLTensor *sum, const SoftmaxKernelInfo &info); - /** Set the input and output tensors. - * - * @param[in] compile_context The compile context to be used. - * @param[in] input Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 - * @param[in,out] max Max values tensor. Data types supported: same as @p input - * @param[out] output Destination tensor. Data types supported: same as @p input - * @param[out] sum Sum of 1D logits tensor. Data types supported: same as @p input - * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. - */ - void configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *max, ICLTensor *output, ICLTensor *sum, const SoftmaxKernelInfo &info); - /** Static function to check if given info will lead to a valid configuration of @ref CLLogits1DMaxShiftExpSumKernel - * - * @param[in] input Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 - * @param[in] max Max values tensor. Data types supported: same as @p input - * @param[in] output Destination tensor. Data types supported: same as @p input - * @param[in] sum Sum of 1D logits tensor. Data types supported: same as @p input - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum); - /** Checks if the given size is eligible for parallel reduction - * - * @note Serial reduction is launched for width < (_grid_size * _serial_vector_size). - * @note Parallel reduction is launched for width >= (_grid_size * _serial_vector_size) and vector_size is forced to 4. - * - * @param[in] size Size to check - * - * @return A two-element tuple where the first element is a boolean specifying if a parallel reduction will be run, - * while the second element is the vector size of the execution. - */ - static ParallelReductionInfo is_parallel_reduction(size_t size); - - // Inherited methods overridden: - void run(const Window &window, cl::CommandQueue &queue) override; - -private: - const ICLTensor *_input; - ICLTensor *_max; - ICLTensor *_output; - ICLTensor *_sum; - -private: - static const unsigned int _grid_size; - static const unsigned int _serial_vector_size; - static const unsigned int _parallel_vector_size; -}; -/** Interface for calculating the final step of the Softmax Layer where each logit value is multiplied by the inverse of the sum of the logits. */ -class CLLogits1DNormKernel : public ICLKernel -{ -public: - /** Default constructor */ - CLLogits1DNormKernel(); - /** Prevent instances of this class from being copied (As this class contains pointers) */ - CLLogits1DNormKernel(const CLLogits1DNormKernel &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - CLLogits1DNormKernel &operator=(const CLLogits1DNormKernel &) = delete; - /** Allow instances of this class to be moved */ - CLLogits1DNormKernel(CLLogits1DNormKernel &&) = default; - /** Allow instances of this class to be moved */ - CLLogits1DNormKernel &operator=(CLLogits1DNormKernel &&) = default; - /** Set the input and output tensors. - * - * @param[in] input Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported. - * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input - * @param[out] output Destination tensor. Data types supported: QASYMM8/QASYMM8_SIGNED for S32 @p input, or same as @p input - * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. - */ - void configure(const ICLTensor *input, const ICLTensor *sum, ICLTensor *output, const SoftmaxKernelInfo &info); - /** Set the input and output tensors. - * - * @param[in] compile_context The compile context to be used. - * @param[in] input Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported. - * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input - * @param[out] output Destination tensor. Data types supported: QASYMM8/QASYMM8_SIGNED for S32 @p input, or same as @p input - * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. - */ - void configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *sum, ICLTensor *output, const SoftmaxKernelInfo &info); - /** Static function to check if given info will lead to a valid configuration of @ref CLLogits1DNormKernel - * - * @param[in] input Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported. - * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input - * @param[in] output Destination tensor. Data types supported: QASYMM8 for S32 @p input, or same as @p input - * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info); - - // Inherited methods overridden: - void run(const Window &window, cl::CommandQueue &queue) override; - -private: - const ICLTensor *_input; - const ICLTensor *_sum; - ICLTensor *_output; -}; -} // namespace arm_compute -#endif /*ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H */ diff --git a/src/core/ITensorPack.cpp b/src/core/ITensorPack.cpp index 7a54a8bc6b..546f669985 100644 --- a/src/core/ITensorPack.cpp +++ b/src/core/ITensorPack.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2020-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -37,6 +37,11 @@ void ITensorPack::add_tensor(int id, const ITensor *tensor) _pack[id] = PackElement(tensor); } +void ITensorPack::add_const_tensor(int id, const ITensor *tensor) +{ + add_tensor(id, tensor); +} + const ITensor *ITensorPack::get_const_tensor(int id) const { auto it = _pack.find(id); diff --git a/src/core/gpu/cl/kernels/ClPermuteKernel.h b/src/core/gpu/cl/kernels/ClPermuteKernel.h index 4cc72491bd..ae3996fca1 100644 --- a/src/core/gpu/cl/kernels/ClPermuteKernel.h +++ b/src/core/gpu/cl/kernels/ClPermuteKernel.h @@ -41,6 +41,7 @@ namespace kernels class ClPermuteKernel : public ICLKernel { public: + /** Default constructor */ ClPermuteKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClPermuteKernel); /** Set the src and dst of the kernel. diff --git a/src/core/gpu/cl/kernels/ClSoftmaxKernel.cpp b/src/core/gpu/cl/kernels/ClSoftmaxKernel.cpp new file mode 100644 index 0000000000..000c9ad04d --- /dev/null +++ b/src/core/gpu/cl/kernels/ClSoftmaxKernel.cpp @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/gpu/cl/kernels/ClSoftmaxKernel.h" +#include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/experimental/Types.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" +#include "src/core/CL/CLValidate.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" +#include "support/Cast.h" +#include "support/StringSupport.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace +{ +/** Calculates softmax parameters from the quantized input scale and scaling factor for the exponent and places them as build options. + * + * Prepares these build options: + * -INPUT_BETA_MULTIPLIER, INPUT_BETA_LEFT_SHIFT - quantized representation of beta multiplier. + * -DIFF_MIN - threshold difference between maximum value of input data and current processed value, + * it defines whether the value will be taken into account or not. + * + * @param[in] build_opts Build options to extend + * @param[in] input_scale Input scaling factor + * @param[in] beta Exponent scaling factor beta + */ +CLBuildOptions prepare_quantized_softmax_build_options(float input_scale, float beta) +{ + // Number of integer bits in temporary fixed-point representation of current-to-max difference + static const int scaled_diff_int_bits = 5; + // Number of integer bits used in temporary fixed-point representation of exponent accumulator + static const int exp_accumulation_in_bits = 12; + + const double beta_multiplier = std::min( + 1.0 * beta * input_scale * (1 << (31 - scaled_diff_int_bits)), + (1LL << 31) - 1.0); + int input_beta_multiplier; + int input_beta_left_shift; + quantization::calculate_quantized_multiplier_greater_than_one(beta_multiplier, &input_beta_multiplier, &input_beta_left_shift); + + const double max_input_rescaled = 1.0 * ((1 << scaled_diff_int_bits) - 1) * (1LL << (31 - scaled_diff_int_bits)) / (1LL << input_beta_left_shift); + const int diff_min = -1.f * std::floor(max_input_rescaled); + + CLBuildOptions build_opts; + build_opts.add_option("-DSCALED_DIFF_INT_BITS=" + support::cpp11::to_string(scaled_diff_int_bits)); + build_opts.add_option("-DEXP_ACCUMULATION_INT_BITS=" + support::cpp11::to_string(exp_accumulation_in_bits)); + build_opts.add_option("-DINPUT_BETA_MULTIPLIER=" + support::cpp11::to_string(input_beta_multiplier)); + build_opts.add_option("-DINPUT_BETA_LEFT_SHIFT=" + support::cpp11::to_string(input_beta_left_shift)); + build_opts.add_option("-DDIFF_MIN=" + support::cpp11::to_string(diff_min)); + + return build_opts; +} + +Status validate_arguments_1DMaxShiftExpSum(const ITensorInfo &src, const ITensorInfo &max, const ITensorInfo &dst, const ITensorInfo &sum) +{ + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&src); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &max); + + const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type()); + + // Checks performed when output is configured + if(dst.total_size() != 0) + { + if(is_quantized_asymmetric) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &dst); + } + + // Checks performed when sum is configured + if(sum.total_size() != 0) + { + if(is_quantized_asymmetric) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&sum, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&max, &sum); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&max, &sum); + } + + return Status{}; +} + +Status validate_arguments_1DNorm(const ITensorInfo &src, const ITensorInfo &sum, const ITensorInfo &dst, const SoftmaxKernelInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&src); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::S32, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &sum); + ARM_COMPUTE_RETURN_ERROR_ON(info.is_log && !is_data_type_float(info.input_data_type)); + + // Note: output should always have a scale of 1/256 and offset 0 + const QuantizationInfo allowed_quantization_info = get_softmax_output_quantization_info(info.input_data_type, info.is_log); + const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(info.input_data_type); + + // Checks performed when output is configured + if(dst.total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&src, &dst); + if(!is_quantized_asymmetric) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src, &dst); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); + ARM_COMPUTE_RETURN_ERROR_ON(dst.quantization_info() != allowed_quantization_info); + } + } + + return Status{}; +} +} // namespace + +/**< Grid size (obtained through auto-tuning) */ +const unsigned int ClLogits1DMaxShiftExpSumKernel::_grid_size = 64; +/**< Vector size in the serial case (obtained through auto-tuning) */ +const unsigned int ClLogits1DMaxShiftExpSumKernel::_serial_vector_size = 8; +/**< Vector size in the parallel case (obtained through auto-tuning, enables the best memory access pattern for Bifrost) .*/ +const unsigned int ClLogits1DMaxShiftExpSumKernel::_parallel_vector_size = 4; + +void ClLogits1DMaxShiftExpSumKernel::configure(const CLCompileContext &compile_context, const ITensorInfo &src, ITensorInfo &max, ITensorInfo &dst, ITensorInfo &sum, const SoftmaxKernelInfo &info) +{ + auto padding_info = get_padding_info({ &src, &max, &dst, &sum }); + + // Output auto initialization if not yet initialized + auto_init_if_empty(sum, src.clone()->set_tensor_shape(max.tensor_shape())); + auto_init_if_empty(dst, *src.clone()); + + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_1DMaxShiftExpSum(src, max, dst, sum)); + + const DataType dt = src.data_type(); + const UniformQuantizationInfo qinfo = src.quantization_info().uniform(); + const size_t reduction_dim_size = src.dimension(0); + const float beta = info.beta; + const auto is_signed_qasymm8 = is_data_type_quantized_asymmetric_signed(info.input_data_type); + const int min_value = is_signed_qasymm8 ? CL_SCHAR_MIN : 0; + + ParallelReductionInfo parallel_reduction_info = is_parallel_reduction(reduction_dim_size); + const unsigned int vector_size = adjust_vec_size(std::get<1>(parallel_reduction_info), reduction_dim_size); + + // Set build options + CLBuildOptions build_opts; + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dt)); + build_opts.add_option("-DMIN_VALUE=" + support::cpp11::to_string(min_value)); + build_opts.add_option("-DVECTOR_SIZE=" + support::cpp11::to_string(vector_size)); + build_opts.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(reduction_dim_size)); + build_opts.add_option("-DVECTOR_SIZE_LEFTOVER=" + support::cpp11::to_string(reduction_dim_size % vector_size)); + build_opts.add_option("-DLOG_VECTOR_SIZE=" + support::cpp11::to_string(lround(log2(vector_size)))); + build_opts.add_option_if((reduction_dim_size % vector_size) != 0, "-DNON_MULTIPLE_OF_VECTOR_SIZE"); + build_opts.add_option_if(is_signed_qasymm8, "-DQASYMM8_SIGNED"); + build_opts.add_option_if(is_data_type_float(dt) && (beta != 1.0f), "-DBETA=" + float_to_string_with_full_precision(beta)); + build_opts.add_option_if(is_data_type_float(dt) && info.is_log, "-DLOG_SOFTMAX"); + build_opts.add_option_if(is_data_type_float(dt), "-DMINVAL=" + ((dt == DataType::F16) ? std::string("-HALF_MAX") : std::string("-FLT_MAX"))); + build_opts.add_options_if(is_data_type_quantized_asymmetric(dt), prepare_quantized_softmax_build_options(qinfo.scale, beta).options()); + + cl::NDRange lws_hint(cl::NullRange); + std::string kernel_name = std::string("softmax_layer_max_shift_exp_sum_") + (is_data_type_quantized_asymmetric(dt) ? "quantized_" : ""); + + // Configure parallel kernel if needed + if(std::get<0>(parallel_reduction_info)) + { + kernel_name += "parallel"; + bool is_grid_size_pow2 = (_grid_size != 0) && ((_grid_size & (_grid_size - 1)) == 0); + build_opts.add_option_if(is_grid_size_pow2 && _grid_size <= 256, "-DGRID_SIZE=" + support::cpp11::to_string(_grid_size)); + + // Handle boundary conditions. + const unsigned int multiple_grid_size = (reduction_dim_size / vector_size) % _grid_size; + build_opts.add_option_if((multiple_grid_size != 0) || ((reduction_dim_size % vector_size) != 0), "-DNON_MULTIPLE_OF_GRID_SIZE"); + // Setting _lws_hint in this way can also communicate grid_size to ClLogits1DMaxShiftExpSumKernel::run(). + // A single workgroup performs reduction in dimension 0 in the parallel case, hence lws[0]==gws[0]. + lws_hint = cl::NDRange(_grid_size); + } + else + { + kernel_name += "serial"; + } + + // Create kernel. + _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); + + // Configure window + Window win = calculate_max_window(src, Steps(reduction_dim_size)); + IClKernel::configure_internal(win, lws_hint); + + ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); +} + +Status ClLogits1DMaxShiftExpSumKernel::validate(const ITensorInfo &src, const ITensorInfo &max, const ITensorInfo &dst, const ITensorInfo &sum) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_1DMaxShiftExpSum(src, max, dst, sum)); + return Status{}; +} + +ClLogits1DMaxShiftExpSumKernel::ParallelReductionInfo ClLogits1DMaxShiftExpSumKernel::is_parallel_reduction(size_t size) +{ + bool is_parallel_reduction = (size >= (_grid_size * _serial_vector_size)) && (_grid_size > 1); + unsigned int vector_size = is_parallel_reduction ? _parallel_vector_size : _serial_vector_size; + return std::make_tuple(is_parallel_reduction, vector_size); +} + +void ClLogits1DMaxShiftExpSumKernel::run_op(ITensorPack &tensors, const Window &window, ::cl::CommandQueue &queue) +{ + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); + + auto src = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC)); + auto dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); + auto max = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_INT_0)); + auto sum = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_INT_1)); + + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, max, sum); + + // Collapse window in Z dimension + Window window_collapsed = window.collapse_if_possible(IClKernel::window(), Window::DimZ); + + // Reconfigure window in case of parallel reduction + ParallelReductionInfo parallel_reduction_info = is_parallel_reduction(src->info()->dimension(0)); + if(std::get<0>(parallel_reduction_info)) + { + // Launch grid_size parallel work items + window_collapsed.set(Window::DimX, Window::Dimension(0, _grid_size, 1)); + } + + // Get slices + Window slice = window_collapsed.first_slice_window_3D(); + do + { + unsigned int idx = 0; + // Set inputs + add_3D_tensor_argument(idx, src, slice); + add_3D_tensor_argument(idx, max, slice); + add_3D_tensor_argument(idx, dst, slice); + add_3D_tensor_argument(idx, sum, slice); + enqueue(queue, *this, slice, lws_hint()); + } + while(window_collapsed.slide_window_slice_3D(slice)); +} + +void ClLogits1DNormKernel::configure(const CLCompileContext &compile_context, const ITensorInfo &src, const ITensorInfo &sum, ITensorInfo &dst, const SoftmaxKernelInfo &info) +{ + auto padding_info = get_padding_info({ &src, &dst, &sum }); + + // Note: output should always have a scale of 1/256 and offset 0 + const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(info.input_data_type); + const DataType output_data_type = info.input_data_type; + const QuantizationInfo allowed_quantization_info = get_softmax_output_quantization_info(info.input_data_type, info.is_log); + const UniformQuantizationInfo qinfo = src.quantization_info().uniform(); + + // Output auto initialization if not yet initialized + auto_init_if_empty(dst, src.clone()->set_data_type(output_data_type).set_quantization_info(allowed_quantization_info)); + + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_1DNorm(src, sum, dst, info)); + + const auto is_signed_qasymm8 = is_data_type_quantized_asymmetric_signed(info.input_data_type); + const int min_value = is_signed_qasymm8 ? CL_SCHAR_MIN : 0; + const unsigned int vector_size = adjust_vec_size(16, src.dimension(0)); + + // Set build options + CLBuildOptions build_opts; + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(info.input_data_type)); + build_opts.add_option("-DMIN_VALUE=" + support::cpp11::to_string(min_value)); + build_opts.add_option("-DVECTOR_SIZE=" + support::cpp11::to_string(vector_size)); + build_opts.add_option("-DVECTOR_SIZE_LEFTOVER=" + support::cpp11::to_string(src.dimension(0) % vector_size)); + build_opts.add_option_if(is_data_type_quantized_asymmetric_signed(info.input_data_type), "-DQASYMM8_SIGNED"); + build_opts.add_options_if(is_quantized_asymmetric, + prepare_quantized_softmax_build_options(qinfo.scale, info.beta).options()); + build_opts.add_option_if(info.is_log, "-DLOG_SOFTMAX"); + + // Create kernel + std::string kernel_name = std::string("softmax_layer_norm") + (is_quantized_asymmetric ? "_quantized" : ""); + _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); + + // Configure window + auto win = calculate_max_window(src, Steps(vector_size)); + ICLKernel::configure_internal(win); + + ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); +} + +Status ClLogits1DNormKernel::validate(const ITensorInfo &src, const ITensorInfo &sum, const ITensorInfo &dst, const SoftmaxKernelInfo &info) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_1DNorm(src, sum, dst, info)); + + return Status{}; +} + +void ClLogits1DNormKernel::run_op(ITensorPack &tensors, const Window &window, ::cl::CommandQueue &queue) +{ + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); + + auto src = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC)); + auto dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); + auto sum = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_INT_0)); + + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, sum); + + Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); + Window slice = window_collapsed.first_slice_window_3D(); + + do + { + Window sum_slice = slice; + sum_slice.set(Window::DimX, Window::Dimension(0, 1, 1)); + + unsigned int idx = 0; + // Set inputs + add_3D_tensor_argument(idx, src, slice); + add_3D_tensor_argument(idx, sum, sum_slice); + add_3D_tensor_argument(idx, dst, slice); + enqueue(queue, *this, slice, lws_hint()); + } + while(window_collapsed.slide_window_slice_3D(slice)); +} +} // namespace kernels +} // namespace opencl +} // namespace arm_compute \ No newline at end of file diff --git a/src/core/gpu/cl/kernels/ClSoftmaxKernel.h b/src/core/gpu/cl/kernels/ClSoftmaxKernel.h new file mode 100644 index 0000000000..af980eaa8e --- /dev/null +++ b/src/core/gpu/cl/kernels/ClSoftmaxKernel.h @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H +#define ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H + +#include "arm_compute/core/Error.h" +#include "arm_compute/core/KernelDescriptors.h" +#include "src/core/common/Macros.h" +#include "src/core/gpu/cl/ClCompileContext.h" +#include "src/core/gpu/cl/IClKernel.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +/** Interface for max, shifting, exponentiating and summing the logits */ +class ClLogits1DMaxShiftExpSumKernel : public IClKernel +{ + /**< Grid size (obtained through auto-tuning) */ + static const unsigned int _grid_size; + /**< Vector size in the serial case (obtained through auto-tuning) */ + static const unsigned int _serial_vector_size; + /**< Vector size in the parallel case (obtained through auto-tuning, enables the best memory access pattern for Bifrost) .*/ + static const unsigned int _parallel_vector_size; + +public: + /** Info for whether a parallel reduction will be run and the vector size of the execution. */ + using ParallelReductionInfo = std::tuple; + + /** Default constructor */ + ClLogits1DMaxShiftExpSumKernel() = default; + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClLogits1DMaxShiftExpSumKernel); + /** Configure the kernel using the given information about tensors + * + * @param[in] compile_context The compile context to be used. + * @param[in] src Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + * @param[in,out] max Max values tensor. Data types supported: same as @p src + * @param[out] dst Destination tensor. Data types supported: same as @p src + * @param[out] sum Sum of 1D logits tensor. Data types supported: same as @p src + * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo &src, ITensorInfo &max, ITensorInfo &dst, ITensorInfo &sum, const SoftmaxKernelInfo &info); + /** Static function to check if given info will lead to a valid configuration of @ref ClLogits1DMaxShiftExpSumKernel + * + * @param[in] src Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + * @param[in] max Max values tensor. Data types supported: same as @p src + * @param[in] dst Destination tensor. Data types supported: same as @p src + * @param[in] sum Sum of 1D logits tensor. Data types supported: same as @p src + * + * @return a status + */ + static Status validate(const ITensorInfo &src, const ITensorInfo &max, const ITensorInfo &dst, const ITensorInfo &sum); + /** Checks if the given size is eligible for parallel reduction + * + * @note Serial reduction is launched for width < (_grid_size * _serial_vector_size). + * @note Parallel reduction is launched for width >= (_grid_size * _serial_vector_size) and vector_size is forced to 4. + * + * @param[in] size Size to check + * + * @return A two-element tuple where the first element is a boolean specifying if a parallel reduction will be run, + * while the second element is the vector size of the execution. + */ + static ParallelReductionInfo is_parallel_reduction(size_t size); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, ::cl::CommandQueue &queue) override; +}; + +/** Interface for calculating the final step of the Softmax Layer where each logit value is multiplied by the inverse of the sum of the logits. */ +class ClLogits1DNormKernel : public IClKernel +{ +public: + /** Default constructor */ + ClLogits1DNormKernel() = default; + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClLogits1DNormKernel); + + /** Set the input and output tensors. + * + * @param[in] compile_context The compile context to be used. + * @param[in] src Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported. + * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input + * @param[out] dst Destination tensor. Data types supported: QASYMM8/QASYMM8_SIGNED for S32 @p input, or same as @p input + * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo &src, const ITensorInfo &sum, ITensorInfo &dst, const SoftmaxKernelInfo &info); + /** Static function to check if given info will lead to a valid configuration of @ref ClLogits1DNormKernel + * + * @param[in] src Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported. + * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input + * @param[in] dst Destination tensor. Data types supported: QASYMM8 for S32 @p input, or same as @p input + * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. + * + * @return a status + */ + static Status validate(const ITensorInfo &src, const ITensorInfo &sum, const ITensorInfo &dst, const SoftmaxKernelInfo &info); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, ::cl::CommandQueue &queue) override; +}; +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /*ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H */ diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp index 93e63dd779..938a10a7c0 100644 --- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp +++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,34 +22,35 @@ * SOFTWARE. */ #include "arm_compute/runtime/CL/functions/CLSoftmaxLayer.h" - #include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" #include "arm_compute/core/Helpers.h" +#include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" -#include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "arm_compute/runtime/CL/CLScheduler.h" -#include "src/core/CL/ICLKernel.h" -#include "src/core/CL/kernels/CLFillBorderKernel.h" -#include "src/core/CL/kernels/CLSoftmaxLayerKernel.h" -#include "src/core/helpers/SoftmaxHelpers.h" +#include "src/core/gpu/cl/kernels/ClSoftmaxKernel.h" +#include "src/runtime/gpu/cl/operators/ClPermute.h" +#include "src/runtime/gpu/cl/operators/ClSoftmax.h" namespace arm_compute { +using OperatorType = opencl::ClSoftmax; + +template +struct CLSoftmaxLayerGeneric::Impl +{ + const ICLTensor *src{ nullptr }; + ICLTensor *dst{ nullptr }; + std::unique_ptr op{ nullptr }; + MemoryGroup memory_group{}; + std::vector>> workspace_tensors{}; +}; + template CLSoftmaxLayerGeneric::CLSoftmaxLayerGeneric(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), - _permute_input(), - _permute_output(), - _max_shift_exp_sum_kernel(std::make_unique()), - _norm_kernel(std::make_unique()), - _max(), - _sum(), - _tmp(), - _input_permuted(), - _output_permuted(), - _needs_permute() + : _impl(std::make_unique()) { + _impl->memory_group = MemoryGroup(std::move(memory_manager)); } template @@ -64,118 +65,59 @@ void CLSoftmaxLayerGeneric::configure(const ICLTensor *input, ICLTensor template void CLSoftmaxLayerGeneric::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output, float beta, int32_t axis) { - // Perform validation step - ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_ERROR_THROW_ON(CLSoftmaxLayerGeneric::validate(input->info(), output->info(), beta, axis)); - - const size_t actual_axis = static_cast(wrap_around(axis, static_cast(input->info()->num_dimensions()))); + _impl->src = input; + _impl->dst = output; + _impl->op = std::make_unique(); - _needs_permute = actual_axis != 0; - ICLTensor *tmp_output = output; - const ICLTensor *tmp_input = _needs_permute ? &_input_permuted : input; - if(_needs_permute) - { - _memory_group.manage(&_input_permuted); - _memory_group.manage(&_output_permuted); - _permute_input.configure(compile_context, input, &_input_permuted, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis)); - tmp_output = &_output_permuted; - } - - // Create intermediate tensors - DataType tmp_data_type = is_data_type_quantized_asymmetric(tmp_input->info()->data_type()) ? DataType::S32 : tmp_input->info()->data_type(); - TensorInfo tensor_info_tmp(tmp_input->info()->clone()->set_data_type(tmp_data_type)); - _tmp.allocator()->init(tensor_info_tmp); - TensorShape max_sum_shape = tmp_input->info()->tensor_shape(); - max_sum_shape.set(0, 1); - _max.allocator()->init(tmp_input->info()->clone()->set_tensor_shape(max_sum_shape)); - _sum.allocator()->init(tmp_input->info()->clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type)); - - // Set GPU target to kernels - _max_shift_exp_sum_kernel->set_target(CLScheduler::get().target()); - - // Manage intermediate buffers - _memory_group.manage(&_tmp); - _memory_group.manage(&_max); - _memory_group.manage(&_sum); - - SoftmaxKernelInfo softmax_info; - softmax_info.beta = beta; - softmax_info.is_log = IS_LOG; - softmax_info.input_data_type = tmp_input->info()->data_type(); - - // Configure kernels - _max_shift_exp_sum_kernel->configure(compile_context, tmp_input, &_max, &_tmp, &_sum, softmax_info); - _norm_kernel->configure(compile_context, &_tmp, &_sum, tmp_output, softmax_info); - - // Allocate intermediate buffers - _tmp.allocator()->allocate(); - _max.allocator()->allocate(); - _sum.allocator()->allocate(); - if(_needs_permute) - { - _permute_output.configure(compile_context, &_output_permuted, output, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis)); - _input_permuted.allocator()->allocate(); - _output_permuted.allocator()->allocate(); - } + SoftmaxKernelInfo softmax_info{ beta, IS_LOG, input->info()->data_type(), axis }; + _impl->op->configure(compile_context, *input->info(), *output->info(), softmax_info); } template Status CLSoftmaxLayerGeneric::validate(const ITensorInfo *input, const ITensorInfo *output, float beta, int32_t axis) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() > 4, "Only up to 4 dimensions are supported"); - ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_RETURN_ERROR_ON(axis < static_cast(-input->num_dimensions()) || static_cast(input->num_dimensions()) <= axis); - - const size_t actual_axis = static_cast(wrap_around(axis, static_cast(input->num_dimensions()))); - const bool needs_permute = actual_axis != 0; - if(needs_permute) + SoftmaxKernelInfo softmax_info{ beta, IS_LOG, input->data_type(), axis }; + return OperatorType::validate(*input, *output, softmax_info); +} + +template +void CLSoftmaxLayerGeneric::allocate_workspace() +{ + const auto memory_requirements = _impl->op->workspace(); + std::for_each(memory_requirements.begin(), memory_requirements.end(), [this](const experimental::MemoryInfo & memory_info) { - const PermutationVector permutation_vector = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis); - const TensorShape permuted_shape = misc::shape_calculator::compute_permutation_output_shape(*input, permutation_vector); - TensorInfo input_permuted(input->clone()->set_tensor_shape(permuted_shape)); - ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(input, &input_permuted, permutation_vector)); - TensorInfo output_permuted(output->clone()->set_tensor_shape(permuted_shape)); - ARM_COMPUTE_RETURN_ON_ERROR(CLPermute::validate(&output_permuted, output, permutation_vector)); - } - - // Create intermediate tensor info - DataType tmp_data_type = is_data_type_quantized_asymmetric(input->data_type()) ? DataType::S32 : input->data_type(); - TensorInfo tensor_info_tmp(input->clone()->set_data_type(tmp_data_type).set_is_resizable(true)); - - TensorShape max_sum_shape = input->tensor_shape(); - max_sum_shape.set(0, 1); - TensorInfo tensor_info_max(input->clone()->set_tensor_shape(max_sum_shape).set_is_resizable(true)); - TensorInfo tensor_info_sum(input->clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type).set_quantization_info(QuantizationInfo()).set_is_resizable(true)); - - SoftmaxKernelInfo softmax_info; - softmax_info.beta = beta; - softmax_info.is_log = IS_LOG; - softmax_info.input_data_type = input->data_type(); - - ARM_COMPUTE_RETURN_ON_ERROR(CLLogits1DMaxShiftExpSumKernel::validate(input, &tensor_info_max, &tensor_info_tmp, &tensor_info_sum)); - ARM_COMPUTE_RETURN_ON_ERROR(CLLogits1DNormKernel::validate(&tensor_info_tmp, &tensor_info_sum, output, softmax_info)); - - return Status{}; + auto tensor_info = TensorInfo{ TensorShape(memory_info.size), 1, DataType::U8 }; + _impl->workspace_tensors.emplace_back(memory_info.type, std::make_unique()); + auto tensor = _impl->workspace_tensors.back().second.get(); + ARM_COMPUTE_ERROR_ON_NULLPTR(tensor); + _impl->memory_group.manage(tensor); + tensor->allocator()->init(tensor_info); + tensor->allocator()->allocate(); + }); } template void CLSoftmaxLayerGeneric::run() { - MemoryGroupResourceScope scope_mg(_memory_group); + allocate_workspace(); - if(_needs_permute) - { - _permute_input.run(); - } + // Acquire all the temporaries + MemoryGroupResourceScope scope_mg(_impl->memory_group); + + ARM_COMPUTE_ERROR_ON_NULLPTR(_impl->src, _impl->dst); - CLScheduler::get().enqueue(*_max_shift_exp_sum_kernel, false); - CLScheduler::get().enqueue(*_norm_kernel, !_needs_permute); + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, _impl->src); + pack.add_tensor(TensorType::ACL_DST, _impl->dst); - if(_needs_permute) + std::for_each(_impl->workspace_tensors.begin(), _impl->workspace_tensors.end(), [&pack](std::pair> &wt) { - _permute_output.run(); - } + auto tensor = wt.second.get(); + ARM_COMPUTE_ERROR_ON_NULLPTR(tensor); + pack.add_tensor(wt.first, tensor); + }); + + _impl->op->run(pack); } template class CLSoftmaxLayerGeneric; diff --git a/src/runtime/gpu/cl/operators/ClSoftmax.cpp b/src/runtime/gpu/cl/operators/ClSoftmax.cpp new file mode 100644 index 0000000000..c3ec7cc0da --- /dev/null +++ b/src/runtime/gpu/cl/operators/ClSoftmax.cpp @@ -0,0 +1,276 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/gpu/cl/operators/ClSoftmax.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/gpu/cl/kernels/ClSoftmaxKernel.h" +#include "src/core/helpers/SoftmaxHelpers.h" +#include "src/runtime/gpu/cl/operators/ClPermute.h" +#include "support/Cast.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace +{ +void run_permute(ClPermute *op, const ITensor *src, ITensor *dst) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, op); + ITensorPack pack; + pack.add_const_tensor(TensorType::ACL_SRC, src); + pack.add_tensor(TensorType::ACL_DST, dst); + op->run(pack); +} +} // namespace + +ClSoftmax::ClSoftmax() + : _permute_input(std::make_unique()), + _permute_output(std::make_unique()), + _max_shift_exp_sum_kernel(std::make_unique()), + _norm_kernel(std::make_unique()), + _max_info(_internal_info[static_cast(InternalTensorIdx::MAX)]), + _sum_info(_internal_info[static_cast(InternalTensorIdx::SUM)]), + _tmp_info(_internal_info[static_cast(InternalTensorIdx::TMP)]), + _permuted_src_info(_internal_info[static_cast(InternalTensorIdx::PERMUTED_SRC)]), + _permuted_dst_info(_internal_info[static_cast(InternalTensorIdx::PERMUTED_DST)]) +{ +} + +TensorType ClSoftmax::convert_internal_idx_to_tensor_type(InternalTensorIdx idx) const +{ + switch(idx) + { + case InternalTensorIdx::MAX: + return TensorType::ACL_INT_0; + case InternalTensorIdx::SUM: + return TensorType::ACL_INT_1; + case InternalTensorIdx::TMP: + return TensorType::ACL_INT_2; + case InternalTensorIdx::PERMUTED_SRC: + return TensorType::ACL_INT_3; + case InternalTensorIdx::PERMUTED_DST: + return TensorType::ACL_INT_4; + default: + ARM_COMPUTE_ERROR("invalid internal tensor index is given."); + break; + }; + return TensorType::ACL_UNKNOWN; +} + +void ClSoftmax::create_internal_tensor(TensorInfo &info, InternalTensorIdx idx) +{ + const auto tensor_idx = static_cast(idx); + if(!_internal_tensor[tensor_idx]) + { + _internal_tensor[tensor_idx] = std::make_unique(); + } + _internal_tensor[tensor_idx]->allocator()->init(info); +} + +void ClSoftmax::create_internal_tensor() +{ + for(uint32_t i = 0; i < static_cast(InternalTensorIdx::COUNT); i++) + { + const auto tensor_idx = static_cast(i); + + if(!_needs_permute && (tensor_idx == InternalTensorIdx::PERMUTED_DST || tensor_idx == InternalTensorIdx::PERMUTED_SRC)) + { + continue; + } + create_internal_tensor(_internal_info[i], static_cast(i)); + } +} + +void ClSoftmax::configure(const CLCompileContext &compile_context, const ITensorInfo &src, ITensorInfo &dst, const SoftmaxKernelInfo &info) +{ + ARM_COMPUTE_ERROR_THROW_ON(validate(src, dst, info)); + + const size_t actual_axis = static_cast(wrap_around(info.axis, static_cast(src.num_dimensions()))); + + _needs_permute = actual_axis != 0; + + const ITensorInfo &tmp_input_info = _needs_permute ? _permuted_src_info : src; + ITensorInfo &tmp_output_info = _needs_permute ? _permuted_dst_info : dst; + + if(_needs_permute) + { + const auto perm_info = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis); + _permute_input->configure(compile_context, &src, &_permuted_src_info, perm_info); + } + + DataType tmp_data_type = is_data_type_quantized_asymmetric(tmp_input_info.data_type()) ? DataType::S32 : tmp_input_info.data_type(); + _tmp_info = tmp_input_info.clone()->set_data_type(tmp_data_type); + + TensorShape max_sum_shape = tmp_input_info.tensor_shape(); + _max_info = tmp_input_info.clone()->set_tensor_shape(max_sum_shape); + _sum_info = tmp_input_info.clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type); + + // Set GPU target to kernels + _max_shift_exp_sum_kernel->set_target(CLScheduler::get().target()); + + _max_shift_exp_sum_kernel->configure(compile_context, tmp_input_info, _max_info, _tmp_info, _sum_info, info); + _norm_kernel->configure(compile_context, _tmp_info, _sum_info, tmp_output_info, info); + + if(_needs_permute) + { + const auto perm_info = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis); + _permute_output->configure(compile_context, &_permuted_dst_info, &dst, perm_info); + } +} + +Status ClSoftmax::validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src.num_dimensions() > 4, "Only up to 4 dimensions are supported"); + ARM_COMPUTE_UNUSED(info.beta); + ARM_COMPUTE_RETURN_ERROR_ON(info.axis < static_cast(-src.num_dimensions()) || static_cast(src.num_dimensions()) <= info.axis); + + const size_t actual_axis = static_cast(wrap_around(info.axis, static_cast(src.num_dimensions()))); + const bool needs_permute = actual_axis != 0; + if(needs_permute) + { + const PermutationVector permutation_vector = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis); + const TensorShape permuted_shape = misc::shape_calculator::compute_permutation_output_shape(src, permutation_vector); + TensorInfo input_permuted(src.clone()->set_tensor_shape(permuted_shape)); + ARM_COMPUTE_RETURN_ON_ERROR(ClPermute::validate(&src, &input_permuted, permutation_vector)); + TensorInfo output_permuted(dst.clone()->set_tensor_shape(permuted_shape)); + ARM_COMPUTE_RETURN_ON_ERROR(ClPermute::validate(&output_permuted, &dst, permutation_vector)); + } + + // Create intermediate tensor info + DataType tmp_data_type = is_data_type_quantized_asymmetric(src.data_type()) ? DataType::S32 : src.data_type(); + TensorInfo tensor_info_tmp(src.clone()->set_data_type(tmp_data_type).set_is_resizable(true)); + + TensorShape max_sum_shape = src.tensor_shape(); + max_sum_shape.set(0, 1); + TensorInfo tensor_info_max(src.clone()->set_tensor_shape(max_sum_shape).set_is_resizable(true)); + TensorInfo tensor_info_sum(src.clone()->set_tensor_shape(max_sum_shape).set_data_type(tmp_data_type).set_quantization_info(QuantizationInfo()).set_is_resizable(true)); + + ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClLogits1DMaxShiftExpSumKernel::validate(src, tensor_info_max, tensor_info_tmp, tensor_info_sum)); + ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClLogits1DNormKernel::validate(tensor_info_tmp, tensor_info_sum, dst, info)); + + return Status{}; +} + +void ClSoftmax::import_workspace_memory(ITensorPack &tensors) +{ + auto import_workspace_memory = [this, &tensors](InternalTensorIdx idx) + { + const auto workspace_idx = convert_internal_idx_to_tensor_type(idx); + auto imported_tensor = tensors.get_tensor(workspace_idx); + if(imported_tensor) + { + auto imported_memory = utils::cast::polymorphic_downcast(imported_tensor)->cl_buffer(); + _internal_tensor[static_cast(idx)].get()->allocator()->import_memory(imported_memory); + } + }; + + import_workspace_memory(InternalTensorIdx::PERMUTED_SRC); + import_workspace_memory(InternalTensorIdx::PERMUTED_DST); + import_workspace_memory(InternalTensorIdx::MAX); + import_workspace_memory(InternalTensorIdx::SUM); + import_workspace_memory(InternalTensorIdx::TMP); +} + +void ClSoftmax::run_source_permute(const ITensor *src) +{ + if(_needs_permute) + { + auto permuted_src = _internal_tensor[static_cast(InternalTensorIdx::PERMUTED_SRC)].get(); + run_permute(_permute_input.get(), src, permuted_src); + } +} + +void ClSoftmax::run_destination_permute(ITensor *dst) +{ + if(_needs_permute) + { + auto permuted_dst = _internal_tensor[static_cast(InternalTensorIdx::PERMUTED_DST)].get(); + run_permute(_permute_output.get(), permuted_dst, dst); + } +} + +void ClSoftmax::run_max_sum(const ITensor *src) +{ + auto max = _internal_tensor[static_cast(InternalTensorIdx::MAX)].get(); + auto sum = _internal_tensor[static_cast(InternalTensorIdx::SUM)].get(); + auto tmp = _internal_tensor[static_cast(InternalTensorIdx::TMP)].get(); + + ARM_COMPUTE_ERROR_ON_NULLPTR(src, tmp, max, sum); + + ITensorPack sum_pack; + sum_pack.add_const_tensor(TensorType::ACL_SRC, src); + sum_pack.add_tensor(TensorType::ACL_DST, tmp); + sum_pack.add_tensor(TensorType::ACL_INT_0, max); + sum_pack.add_tensor(TensorType::ACL_INT_1, sum); + + CLScheduler::get().enqueue_op(*_max_shift_exp_sum_kernel.get(), sum_pack, false); +} + +void ClSoftmax::run_norm(ITensor *dst) +{ + auto sum = _internal_tensor[static_cast(InternalTensorIdx::SUM)].get(); + auto tmp = _internal_tensor[static_cast(InternalTensorIdx::TMP)].get(); + + ARM_COMPUTE_ERROR_ON_NULLPTR(tmp, sum, dst); + + ITensorPack norm_pack; + norm_pack.add_const_tensor(TensorType::ACL_SRC, tmp); + norm_pack.add_tensor(TensorType::ACL_DST, dst); + norm_pack.add_tensor(TensorType::ACL_INT_0, sum); + + CLScheduler::get().enqueue_op(*_norm_kernel.get(), norm_pack, false); +} + +void ClSoftmax::run(ITensorPack &tensors) +{ + create_internal_tensor(); + + auto src = tensors.get_const_tensor(TensorType::ACL_SRC); + auto dst = tensors.get_tensor(TensorType::ACL_DST); + + import_workspace_memory(tensors); + run_source_permute(src); + run_max_sum(!_needs_permute ? src : _internal_tensor[static_cast(InternalTensorIdx::PERMUTED_SRC)].get()); + run_norm(!_needs_permute ? dst : _internal_tensor[static_cast(InternalTensorIdx::PERMUTED_DST)].get()); + run_destination_permute(dst); +} + +experimental::MemoryRequirements ClSoftmax::workspace() const +{ + experimental::MemoryRequirements req{}; + + req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::SUM), _sum_info.total_size(), 0); + req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::TMP), _tmp_info.total_size(), 0); + req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::MAX), _max_info.total_size(), 0); + + if(_needs_permute) + { + req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::PERMUTED_SRC), _permuted_src_info.total_size(), 0); + req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::PERMUTED_DST), _permuted_dst_info.total_size(), 0); + } + + return req; +} +} // namespace opencl +} // namespace arm_compute \ No newline at end of file diff --git a/src/runtime/gpu/cl/operators/ClSoftmax.h b/src/runtime/gpu/cl/operators/ClSoftmax.h new file mode 100644 index 0000000000..e38b7c595a --- /dev/null +++ b/src/runtime/gpu/cl/operators/ClSoftmax.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_SOFTMAX_H +#define ARM_COMPUTE_CL_SOFTMAX_H + +#include "arm_compute/runtime/CL/CLTensor.h" +#include "src/core/gpu/cl/ClCompileContext.h" +#include "src/runtime/gpu/cl/IClOperator.h" + +namespace arm_compute +{ +struct SoftmaxKernelInfo; + +namespace opencl +{ +class ClPermute; +namespace kernels +{ +class ClLogits1DMaxShiftExpSumKernel; +class ClLogits1DNormKernel; +} // namespace kernels +class ClSoftmax : public IClOperator +{ +public: + /** Constructor */ + ClSoftmax(); + /** Configure the operator + * + * @param[in] compile_context The compile context to be used. + * @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 for Softmax and F16/F32 for Log Softmax + * @param[out] dst Destination tensor info. Data types supported: same as @p src + * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. + * + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo &src, ITensorInfo &dst, const SoftmaxKernelInfo &info); + /** Static function to check if the given info will lead to a valid configuration + * + * @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 for Softmax and F16/F32 for Log Softmax + * @param[out] dst Destination tensor info. Data types supported: same as @p src + * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. + * + */ + static Status validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info); + // Inherited methods overridden: + void run(ITensorPack &tensors) override; + experimental::MemoryRequirements workspace() const override; + +private: + enum class InternalTensorIdx + { + MAX = 0, + SUM, + TMP, + PERMUTED_SRC, + PERMUTED_DST, + COUNT + }; + + /** Create a single internal tensor + * + * @param[in] info The information used to create a tensor + * @param[in] idx The index within the internal array the created tensor will be held + */ + void create_internal_tensor(TensorInfo &info, InternalTensorIdx idx); + /** Create all required internal tensors */ + void create_internal_tensor(); + /** Function to convert from internal tensor index to @ref TensorType used externally */ + TensorType convert_internal_idx_to_tensor_type(InternalTensorIdx idx) const; + /** Function to import workspace memory allocated by the caller into internal tensor instances */ + void import_workspace_memory(ITensorPack &tensors); + /** Function to permute the given source tensor when permutation is required */ + void run_source_permute(const ITensor *src); + /** Function to permute the intemediate tensor to the final destination tensor when permutation is required */ + void run_destination_permute(ITensor *dst); + /** Function to run @ref arm_compute::opencl::kernels::ClLogits1DMaxShiftExpSumKernel */ + void run_max_sum(const ITensor *src); + /** Function to run @ref kernels::ClLogits1DNormKernel */ + void run_norm(ITensor *dst); + + std::unique_ptr _permute_input; + std::unique_ptr _permute_output; + std::unique_ptr _max_shift_exp_sum_kernel; + std::unique_ptr _norm_kernel; + bool _needs_permute{ false }; + + std::array(InternalTensorIdx::COUNT)> _internal_info{}; + std::array, static_cast(InternalTensorIdx::COUNT)> _internal_tensor{}; + + TensorInfo &_max_info; + TensorInfo &_sum_info; + TensorInfo &_tmp_info; + TensorInfo &_permuted_src_info; + TensorInfo &_permuted_dst_info; +}; + +} // opencl +} // arm_compute +#endif /* ARM_COMPUTE_CL_SOFTMAX_H */ \ No newline at end of file -- cgit v1.2.1