aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/common/softmax_layer.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/common/softmax_layer.cl')
-rw-r--r--src/core/CL/cl_kernels/common/softmax_layer.cl371
1 files changed, 371 insertions, 0 deletions
diff --git a/src/core/CL/cl_kernels/common/softmax_layer.cl b/src/core/CL/cl_kernels/common/softmax_layer.cl
new file mode 100644
index 0000000000..bfc0995bb8
--- /dev/null
+++ b/src/core/CL/cl_kernels/common/softmax_layer.cl
@@ -0,0 +1,371 @@
+/*
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "helpers.h"
+
+#define MIN_VALUE_float -FLT_MAX
+#define MIN_VALUE_half -HALF_MAX
+#define MIN_VALUE_char CHAR_MIN
+#define MIN_VALUE_uchar 0
+
+#define MIN_VALUE_TYPE_STR(data_type) MIN_VALUE_##data_type
+#define MIN_VALUE_TYPE(data_type) MIN_VALUE_TYPE_STR(data_type)
+#define MIN_VALUE MIN_VALUE_TYPE(DATA_TYPE)
+
+#ifdef SOFTMAX_X
+
+/** 3-pass softmax in the x dimension.
+ *
+ * List of preprocessors:
+ * - DATA_TYPE: the input/output data type.
+ * - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage.
+ * If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE.
+ * - IS_LOG (optional): indicating whether this is log softmax.
+ * - LENGTH: the number of elements in softmax axis in the input/output tensors.
+ * - BETA: the beta coefficient.
+ * - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data.
+ * - VEC_SIZE: the size of the vector.
+ *
+ * Additional preprocessors in case IS_QUANTIZED is present:
+ * - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor.
+ * - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor.
+ *
+ * @param[in] src_ptr Pointer to the source tensor.
+ * @param[in] src_stride_0 Stride in bytes of the source tensor in the dimension corresponding to global ID 0.
+ * @param[in] src_stride_1 Stride in bytes of the source tensor in the dimension corresponding to global ID 1.
+ * @param[in] src_stride_2 Stride in bytes of the source tensor in the dimension corresponding to global ID 2.
+ * @param[in] src_offset_first_element Offset of the first element in the source tensor.
+ * @param[in] dst_ptr Pointer to the destination tensor.
+ * @param[in] dst_stride_0 Stride in bytes of the destination tensor in the dimension corresponding to global ID 0.
+ * @param[in] dst_stride_1 Stride in bytes of the destination tensor in the dimension corresponding to global ID 1.
+ * @param[in] dst_stride_2 Stride in bytes of the destination tensor in the dimension corresponding to global ID 2.
+ * @param[in] dst_offset_first_element Offset of the first element in the destination tensor.
+ * @param[in] tmp_ptr Pointer to the temporary tensor.
+ * @param[in] tmp_stride_0 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0.
+ * @param[in] tmp_stride_1 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1.
+ * @param[in] tmp_stride_2 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2.
+ * @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor.
+ */
+__kernel void softmax_x(
+ __global uchar *src_ptr,
+ uint src_stride_0,
+ uint src_stride_1,
+ uint src_stride_2,
+ uint src_offset_first_element,
+
+ __global uchar *dst_ptr,
+ uint dst_stride_0,
+ uint dst_stride_1,
+ uint dst_stride_2,
+ uint dst_offset_first_element
+
+#ifdef IS_QUANTIZED
+ ,
+ __global uchar *tmp_ptr,
+ uint tmp_stride_0,
+ uint tmp_stride_1,
+ uint tmp_stride_2,
+ uint tmp_offset_first_element
+#endif // IS_QUANTIZED
+)
+{
+ const int dim_0 = get_global_id(0);
+ const int dim_1 = get_global_id(1);
+ const int dim_2 = get_global_id(2);
+
+ src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
+ dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
+
+#ifdef IS_QUANTIZED
+ tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
+#else // IS_QUANTIZED
+ __global uchar *tmp_ptr = dst_ptr;
+#endif // IS_QUANTIZED
+
+ // Calculate max value.
+ DATA_TYPE max_value = MIN_VALUE;
+ int i = 0;
+
+ for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
+ {
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)));
+
+ max_value = max(max_value, MAX_REDUCE(data, VEC_SIZE));
+ }
+
+ for (; i < LENGTH; ++i)
+ {
+ DATA_TYPE data = *(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE));
+
+ max_value = max(max_value, data);
+ }
+
+ // Regularize the data.
+ TMP_DATA_TYPE sum_value = 0;
+
+#ifdef IS_QUANTIZED
+ TMP_DATA_TYPE max_value_f = (CONVERT(max_value, TMP_DATA_TYPE) - SRC_OFFSET) * SRC_SCALE;
+ TMP_DATA_TYPE regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
+# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
+#else // IS_QUANTIZED
+# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
+#endif // IS_QUANTIZED
+
+ for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
+ {
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE));
+
+ data = REGULARIZE(data);
+
+#ifdef IS_LOG
+ sum_value += SUM_REDUCE(exp(data), VEC_SIZE);
+#else // IS_LOG
+ data = exp(data);
+ sum_value += SUM_REDUCE(data, VEC_SIZE);
+#endif // IS_LOG
+
+ VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)));
+ }
+
+ for (; i < LENGTH; ++i)
+ {
+ TMP_DATA_TYPE data = CONVERT(*(__global DATA_TYPE *)(src_ptr + i * sizeof(DATA_TYPE)), TMP_DATA_TYPE);
+
+ data = REGULARIZE(data);
+
+#ifdef IS_LOG
+ sum_value += exp(data);
+#else // IS_LOG
+ data = exp(data);
+ sum_value += data;
+#endif // IS_LOG
+
+ *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)) = data;
+ }
+
+#undef REGULARIZE
+
+ // Normalize the data.
+#ifdef IS_QUANTIZED
+# if IS_LOG
+ TMP_DATA_TYPE norm_offset = -log(sum_value) + DST_OFFSET;
+# define NORMALIZE(SIZE, x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, SIZE), rte)
+# else // IS_LOG
+ TMP_DATA_TYPE norm_div = sum_value * DST_SCALE;
+# define NORMALIZE(SIZE, x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, SIZE))
+# endif // IS_LOG
+#else // IS_QUANTIZED
+# if IS_LOG
+# define NORMALIZE(SIZE, x) ((x) - log(sum_value))
+# else // IS_LOG
+# define NORMALIZE(SIZE, x) ((x) / sum_value)
+# endif // IS_LOG
+#endif // IS_QUANTIZED
+
+ for (i = 0; i < LENGTH - VEC_SIZE; i += VEC_SIZE)
+ {
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE)));
+
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result = NORMALIZE(VEC_SIZE, data);
+
+ VSTORE(VEC_SIZE)(result, 0, (__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)));
+ }
+
+ for (; i < LENGTH; ++i)
+ {
+ TMP_DATA_TYPE data = *(__global TMP_DATA_TYPE *)(tmp_ptr + i * sizeof(TMP_DATA_TYPE));
+
+ DATA_TYPE result = NORMALIZE(1, data);
+
+ *(__global DATA_TYPE *)(dst_ptr + i * sizeof(DATA_TYPE)) = result;
+ }
+
+#undef NORMALIZE
+}
+
+#endif // SOFTMAX_X
+
+#ifdef SOFTMAX_NON_X
+
+/** 3-pass softmax in any dimension higher than the x dimension.
+ *
+ * List of preprocessors:
+ * - DATA_TYPE: the input/output data type.
+ * - TMP_DATA_TYPE: the data type used for computing and temporary tensor storage.
+ * If DATA_TYPE is quantized, TMP_DATA_TYPE is floating-point, otherwise TMP_DATA_TYPE is the same as DATA_TYPE.
+ * - IS_LOG (optional): indicating whether this is log softmax.
+ * - LENGTH: the number of elements in softmax axis in the input/output tensors.
+ * - BETA: the beta coefficient.
+ * - IS_QUANTIZED (optional): indicating whether the input/output data type is quantized data.
+ * - VEC_SIZE: the size of the vector.
+ * - VEC_SIZE_LEFTOVER: the size of the leftover part.
+ *
+ * Additional preprocessors in case IS_QUANTIZED is present:
+ * - SRC_SCALE and SRC_OFFSET: the quantization information of the source tensor.
+ * - DST_SCALE and DST_OFFSET: the quantization information of the destination tensor.
+ *
+ * @param[in] src_ptr Pointer to the source tensor.
+ * @param[in] src_stride_0 Stride in bytes of the source tensor in the dimension corresponding to global ID 0.
+ * @param[in] src_stride_1 Stride in bytes of the source tensor in the dimension corresponding to global ID 1.
+ * @param[in] src_stride_2 Stride in bytes of the source tensor in the dimension corresponding to global ID 2.
+ * @param[in] src_offset_first_element Offset of the first element in the source tensor.
+ * @param[in] dst_ptr Pointer to the destination tensor.
+ * @param[in] dst_stride_0 Stride in bytes of the destination tensor in the dimension corresponding to global ID 0.
+ * @param[in] dst_stride_1 Stride in bytes of the destination tensor in the dimension corresponding to global ID 1.
+ * @param[in] dst_stride_2 Stride in bytes of the destination tensor in the dimension corresponding to global ID 2.
+ * @param[in] dst_offset_first_element Offset of the first element in the destination tensor.
+ * @param[in] tmp_ptr Pointer to the temporary tensor.
+ * @param[in] tmp_stride_0 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 0.
+ * @param[in] tmp_stride_1 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 1.
+ * @param[in] tmp_stride_2 Stride in bytes of the temporary tensor in the dimension corresponding to global ID 2.
+ * @param[in] tmp_offset_first_element Offset of the first element in the temporary tensor.
+ */
+__kernel void softmax_non_x(
+ __global uchar *src_ptr,
+ uint src_stride_0,
+ uint src_stride_1,
+ uint src_stride_2,
+ uint src_offset_first_element,
+
+ __global uchar *dst_ptr,
+ uint dst_stride_0,
+ uint dst_stride_1,
+ uint dst_stride_2,
+ uint dst_offset_first_element,
+
+ __global uchar *tmp_ptr,
+ uint tmp_stride_0,
+ uint tmp_stride_1,
+ uint tmp_stride_2,
+ uint tmp_offset_first_element,
+
+ uint src_stride_axis,
+ uint dst_stride_axis
+)
+{
+ const int dim_0 = max((int)get_global_id(0) * VEC_SIZE - (VEC_SIZE - VEC_SIZE_LEFTOVER) % VEC_SIZE, 0);
+ const int dim_1 = get_global_id(1);
+ const int dim_2 = get_global_id(2);
+
+ src_ptr += src_offset_first_element + dim_2 * src_stride_2 + dim_1 * src_stride_1 + dim_0 * src_stride_0;
+ dst_ptr += dst_offset_first_element + dim_2 * dst_stride_2 + dim_1 * dst_stride_1 + dim_0 * dst_stride_0;
+ tmp_ptr += tmp_offset_first_element + dim_2 * tmp_stride_2 + dim_1 * tmp_stride_1 + dim_0 * tmp_stride_0;
+
+ // In case of processing quantized data, i.e. DATA_TYPE is smaller than TMP_DATA_TYPE:
+ //
+ // In the first pass (finding max), the quantized data is copied from the input tensor to the temporary tensor.
+ // Dequantization is not needed to find the max value and since dequantization widens the data, we defer it
+ // to the second pass pass to reduce memory bandwidth of the first pass.
+ //
+ // In the second pass, it reads the quantized data from the temporary tensor and writes the dequantized data
+ // back to the temporary tensor.
+ //
+ // To avoid dequantized data overwritting the unprocessed quantized data in the temporary tensor,
+ // this extra offset is introduced to store the quantized data at the end of the temporary tensor.
+ //
+ // Note: Another approach is to perform the second pass in reverse order, but for unexplanable reason
+ // it doesn't work in some devices.
+ uint tmp_extra_offset = LENGTH * VEC_SIZE * (sizeof(TMP_DATA_TYPE) - sizeof(DATA_TYPE));
+
+ // Calculate max value and store the input data to the temporary tensor in suitable format.
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) max_value = MIN_VALUE;
+ int i = 0;
+
+ for (i = 0; i < LENGTH; ++i)
+ {
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(src_ptr + i * src_stride_axis));
+
+ max_value = max(max_value, data);
+
+ VSTORE(VEC_SIZE)(data, 0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i * VEC_SIZE * sizeof(DATA_TYPE)));
+ }
+
+ // Regularize the data.
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) sum_value = 0;
+
+#ifdef IS_QUANTIZED
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) max_value_f = (CONVERT(max_value, VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE)) - SRC_OFFSET) * SRC_SCALE;
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) regularize_offset = -SRC_OFFSET * SRC_SCALE * (TMP_DATA_TYPE)BETA - max_value_f * (TMP_DATA_TYPE)BETA;
+# define REGULARIZE(x) ((x) * SRC_SCALE * (TMP_DATA_TYPE)BETA + regularize_offset)
+#else // IS_QUANTIZED
+# define REGULARIZE(x) (((x) - max_value) * (TMP_DATA_TYPE)BETA)
+#endif // IS_QUANTIZED
+
+ for (i = 0; i < LENGTH; ++i)
+ {
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)(tmp_ptr + tmp_extra_offset + i * VEC_SIZE * sizeof(DATA_TYPE))), VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE));
+
+ data = REGULARIZE(data);
+
+#ifdef IS_LOG
+ sum_value += exp(data);
+#else // IS_LOG
+ data = exp(data);
+ sum_value += data;
+#endif // IS_LOG
+
+ VSTORE(VEC_SIZE)(data, 0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE)));
+ }
+
+#undef REGULARIZE
+
+ // Normalize the data.
+#ifdef IS_QUANTIZED
+# if IS_LOG
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_offset = -log(sum_value) + DST_OFFSET;
+# define NORMALIZE(x) CONVERT_SAT_ROUND((x) / DST_SCALE + norm_offset, VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE), rte)
+# else // IS_LOG
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) norm_div = sum_value * DST_SCALE;
+# define NORMALIZE(x) CONVERT_SAT(add_sat(CONVERT_SAT_ROUND((x) / norm_div, VEC_DATA_TYPE(int, VEC_SIZE), rte), DST_OFFSET), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))
+# endif // IS_LOG
+#else // IS_QUANTIZED
+# if IS_LOG
+# define NORMALIZE(x) ((x) - log(sum_value))
+# else // IS_LOG
+# define NORMALIZE(x) ((x) / sum_value)
+# endif // IS_LOG
+#endif // IS_QUANTIZED
+
+ for (i = 0; i < LENGTH; ++i)
+ {
+ VEC_DATA_TYPE(TMP_DATA_TYPE, VEC_SIZE) data = VLOAD(VEC_SIZE)(0, (__global TMP_DATA_TYPE *)(tmp_ptr + i * VEC_SIZE * sizeof(TMP_DATA_TYPE)));
+
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) result0 = NORMALIZE(data);
+
+ STORE_VECTOR_SELECT(result, DATA_TYPE, dst_ptr + i * dst_stride_axis, VEC_SIZE, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
+ }
+
+#undef NORMALIZE
+}
+
+#endif // SOFTMAX_NON_X
+
+#undef MIN_VALUE
+#undef MIN_VALUE_TYPE
+#undef MIN_VALUE_TYPE_STR
+
+#undef MIN_VALUE_float
+#undef MIN_VALUE_half
+#undef MIN_VALUE_char
+#undef MIN_VALUE_uchar