diff options
Diffstat (limited to 'src/core/CL/cl_kernels/common/softmax_layer.cl')
-rw-r--r-- | src/core/CL/cl_kernels/common/softmax_layer.cl | 371 |
1 files changed, 371 insertions, 0 deletions
diff --git a/src/core/CL/cl_kernels/common/softmax_layer.cl b/src/core/CL/cl_kernels/common/softmax_layer.cl new file mode 100644 index 0000000000..bfc0995bb8 --- /dev/null +++ b/src/core/CL/cl_kernels/common/softmax_layer.cl @@ -0,0 +1,371 @@ +/* + * Copyright (c) 2017-2021, 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "helpers.h" + +#define MIN_VALUE_float -FLT_MAX +#define MIN_VALUE_half -HALF_MAX +#define MIN_VALUE_char CHAR_MIN +#define MIN_VALUE_uchar 0 + +#define MIN_VALUE_TYPE_STR(data_type) MIN_VALUE_##data_type +#define MIN_VALUE_TYPE(data_type) MIN_VALUE_TYPE_STR(data_type) +#define MIN_VALUE MIN_VALUE_TYPE(DATA_TYPE) + +#ifdef SOFTMAX_X + +/** 3-pass softmax in the x dimension. + * + * List of preprocessors: + * - DATA_TYPE: the input/output data type. + * - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage. + * If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE. + * - IS_LOG (optional): indicating whether this is log softmax. + * - LENGTH: the number of elements in softmax axis in the input/output tensors. + * - BETA: the beta coefficient. + * - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data. + * - VEC_SIZE: the size of the vector. + * + * Additional preprocessors in case IS_QUANTIZED is present: + * - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor. + * - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor. + * + * @param[in] src_ptr Pointer to the source tensor. + * @param[in] src_stride_0 Stride in bytes of the source tensor in the dimension corresponding to global ID 0. + * @param[in] src_stride_1 Stride in bytes of the source tensor in the dimension corresponding to global ID 1. + * @param[in] src_stride_2 Stride in bytes of the source tensor in the dimension corresponding to global ID 2. + * @param[in] src_offset_first_element Offset of the first element in the source tensor. + * @param[in] dst_ptr Pointer to the destination tensor. + * @param[in] dst_stride_0 Stride in bytes of the destination tensor in the dimension corresponding to global ID 0. + * @param[in] dst_stride_1 Stride in bytes of the destination tensor in the dimension corresponding to global ID 1. + * @param[in] dst_stride_2 Stride in bytes of the destination tensor in the dimension corresponding to global ID 2. + * @param[in] dst_offset_first_element Offset of the first element in the destination tensor. + * @param[in] tmp_ptr Pointer to the temporary tensor. + * @param[in] tmp_stride_0 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0. + * @param[in] tmp_stride_1 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1. + * @param[in] tmp_stride_2 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2. + * @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor. + */ +__kernel void softmax_x( + __global uchar *src_ptr, + uint src_stride_0, + uint src_stride_1, + uint src_stride_2, + uint src_offset_first_element, + + __global uchar *dst_ptr, + uint dst_stride_0, + uint dst_stride_1, + uint dst_stride_2, + uint dst_offset_first_element + +#ifdef IS_QUANTIZED + , + __global uchar *tmp_ptr, + uint tmp_stride_0, + uint tmp_stride_1, + uint tmp_stride_2, + uint tmp_offset_first_element +#endif // IS_QUANTIZED +) +{ + const int dim_0 = get_global_id(0); + const int dim_1 = get_global_id(1); + const int dim_2 = get_global_id(2); + + src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0; + dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0; + +#ifdef IS_QUANTIZED + tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0; +#else // IS_QUANTIZED + __global uchar *tmp_ptr = dst_ptr; +#endif // IS_QUANTIZED + + // Calculate max value. + DATA_TYPE max_value = MIN_VALUE; + int i = 0; + + for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE) + { + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE))); + + max_value = max(max_value, MAX_REDUCE(data, VEC_SIZE)); + } + + for (; i < LENGTH; ++i) + { + DATA_TYPE data = *(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)); + + max_value = max(max_value, data); + } + + // Regularize the data. + TMP_DATA_TYPE sum_value = 0; + +#ifdef IS_QUANTIZED + TMP_DATA_TYPE max_value_f = (CONVERT(max_value, TMP_DATA_TYPE) - SRC_OFFSET) * SRC_SCALE; + TMP_DATA_TYPE regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA; +# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset) +#else // IS_QUANTIZED +# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA) +#endif // IS_QUANTIZED + + for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE) + { + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE)); + + data = REGULARIZE(data); + +#ifdef IS_LOG + sum_value += SUM_REDUCE(exp(data), VEC_SIZE); +#else // IS_LOG + data = exp(data); + sum_value += SUM_REDUCE(data, VEC_SIZE); +#endif // IS_LOG + + VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE))); + } + + for (; i < LENGTH; ++i) + { + TMP_DATA_TYPE data = CONVERT(*(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)), TMP_DATA_TYPE); + + data = REGULARIZE(data); + +#ifdef IS_LOG + sum_value += exp(data); +#else // IS_LOG + data = exp(data); + sum_value += data; +#endif // IS_LOG + + *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)) = data; + } + +#undef REGULARIZE + + // Normalize the data. +#ifdef IS_QUANTIZED +# if IS_LOG + TMP_DATA_TYPE norm_offset = -log(sum_value) + DST_OFFSET; +# define NORMALIZE(SIZE, x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, SIZE), rte) +# else // IS_LOG + TMP_DATA_TYPE norm_div = sum_value * DST_SCALE; +# define NORMALIZE(SIZE, x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, SIZE)) +# endif // IS_LOG +#else // IS_QUANTIZED +# if IS_LOG +# define NORMALIZE(SIZE, x) ((x) - log(sum_value)) +# else // IS_LOG +# define NORMALIZE(SIZE, x) ((x) / sum_value) +# endif // IS_LOG +#endif // IS_QUANTIZED + + for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE) + { + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE))); + + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result = NORMALIZE(VEC_SIZE, data); + + VSTORE(VEC_SIZE)(result, 0, (__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE))); + } + + for (; i < LENGTH; ++i) + { + TMP_DATA_TYPE data = *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)); + + DATA_TYPE result = NORMALIZE(1, data); + + *(__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)) = result; + } + +#undef NORMALIZE +} + +#endif // SOFTMAX_X + +#ifdef SOFTMAX_NON_X + +/** 3-pass softmax in any dimension higher than the x dimension. + * + * List of preprocessors: + * - DATA_TYPE: the input/output data type. + * - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage. + * If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE. + * - IS_LOG (optional): indicating whether this is log softmax. + * - LENGTH: the number of elements in softmax axis in the input/output tensors. + * - BETA: the beta coefficient. + * - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data. + * - VEC_SIZE: the size of the vector. + * - VEC_SIZE_LEFTOVER: the size of the leftover part. + * + * Additional preprocessors in case IS_QUANTIZED is present: + * - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor. + * - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor. + * + * @param[in] src_ptr Pointer to the source tensor. + * @param[in] src_stride_0 Stride in bytes of the source tensor in the dimension corresponding to global ID 0. + * @param[in] src_stride_1 Stride in bytes of the source tensor in the dimension corresponding to global ID 1. + * @param[in] src_stride_2 Stride in bytes of the source tensor in the dimension corresponding to global ID 2. + * @param[in] src_offset_first_element Offset of the first element in the source tensor. + * @param[in] dst_ptr Pointer to the destination tensor. + * @param[in] dst_stride_0 Stride in bytes of the destination tensor in the dimension corresponding to global ID 0. + * @param[in] dst_stride_1 Stride in bytes of the destination tensor in the dimension corresponding to global ID 1. + * @param[in] dst_stride_2 Stride in bytes of the destination tensor in the dimension corresponding to global ID 2. + * @param[in] dst_offset_first_element Offset of the first element in the destination tensor. + * @param[in] tmp_ptr Pointer to the temporary tensor. + * @param[in] tmp_stride_0 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0. + * @param[in] tmp_stride_1 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1. + * @param[in] tmp_stride_2 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2. + * @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor. + */ +__kernel void softmax_non_x( + __global uchar *src_ptr, + uint src_stride_0, + uint src_stride_1, + uint src_stride_2, + uint src_offset_first_element, + + __global uchar *dst_ptr, + uint dst_stride_0, + uint dst_stride_1, + uint dst_stride_2, + uint dst_offset_first_element, + + __global uchar *tmp_ptr, + uint tmp_stride_0, + uint tmp_stride_1, + uint tmp_stride_2, + uint tmp_offset_first_element, + + uint src_stride_axis, + uint dst_stride_axis +) +{ + const int dim_0 = max((int)get_global_id(0) * VEC_SIZE - (VEC_SIZE - VEC_SIZE_LEFTOVER) % VEC_SIZE, 0); + const int dim_1 = get_global_id(1); + const int dim_2 = get_global_id(2); + + src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0; + dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0; + tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0; + + // In case of processing quantized data, i.e. DATA_TYPE is smaller than TMP_DATA_TYPE: + // + // In the first pass (finding max), the quantized data is copied from the input tensor to the temporary tensor. + // Dequantization is not needed to find the max value and since dequantization widens the data, we defer it + // to the second pass pass to reduce memory bandwidth of the first pass. + // + // In the second pass, it reads the quantized data from the temporary tensor and writes the dequantized data + // back to the temporary tensor. + // + // To avoid dequantized data overwritting the unprocessed quantized data in the temporary tensor, + // this extra offset is introduced to store the quantized data at the end of the temporary tensor. + // + // Note: Another approach is to perform the second pass in reverse order, but for unexplanable reason + // it doesn't work in some devices. + uint tmp_extra_offset = LENGTH * VEC_SIZE * (sizeof(TMP_DATA_TYPE) - sizeof(DATA_TYPE)); + + // Calculate max value and store the input data to the temporary tensor in suitable format. + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) max_value = MIN_VALUE; + int i = 0; + + for (i = 0; i < LENGTH; ++i) + { + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * src_stride_axis)); + + max_value = max(max_value, data); + + VSTORE(VEC_SIZE)(data, 0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i * VEC_SIZE * sizeof(DATA_TYPE))); + } + + // Regularize the data. + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) sum_value = 0; + +#ifdef IS_QUANTIZED + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) max_value_f = (CONVERT(max_value, VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE)) - SRC_OFFSET) * SRC_SCALE; + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA; +# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset) +#else // IS_QUANTIZED +# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA) +#endif // IS_QUANTIZED + + for (i = 0; i < LENGTH; ++i) + { + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i * VEC_SIZE * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE)); + + data = REGULARIZE(data); + +#ifdef IS_LOG + sum_value += exp(data); +#else // IS_LOG + data = exp(data); + sum_value += data; +#endif // IS_LOG + + VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE))); + } + +#undef REGULARIZE + + // Normalize the data. +#ifdef IS_QUANTIZED +# if IS_LOG + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_offset = -log(sum_value) + DST_OFFSET; +# define NORMALIZE(x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE), rte) +# else // IS_LOG + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_div = sum_value * DST_SCALE; +# define NORMALIZE(x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, VEC_SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)) +# endif // IS_LOG +#else // IS_QUANTIZED +# if IS_LOG +# define NORMALIZE(x) ((x) - log(sum_value)) +# else // IS_LOG +# define NORMALIZE(x) ((x) / sum_value) +# endif // IS_LOG +#endif // IS_QUANTIZED + + for (i = 0; i < LENGTH; ++i) + { + VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE))); + + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result0 = NORMALIZE(data); + + STORE_VECTOR_SELECT(result, DATA_TYPE, dst_ptr + i * dst_stride_axis, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0) + } + +#undef NORMALIZE +} + +#endif // SOFTMAX_NON_X + +#undef MIN_VALUE +#undef MIN_VALUE_TYPE +#undef MIN_VALUE_TYPE_STR + +#undef MIN_VALUE_float +#undef MIN_VALUE_half +#undef MIN_VALUE_char +#undef MIN_VALUE_uchar |