diff options
Diffstat (limited to 'src/cpu/kernels/softmax')
-rw-r--r-- | src/cpu/kernels/softmax/generic/neon/fp16.cpp | 72 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/neon/fp32.cpp | 69 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/neon/impl.cpp | 596 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/neon/impl.h | 428 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/neon/qasymm8.cpp | 68 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp | 68 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/sme2/fp16.cpp | 781 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/sme2/fp32.cpp | 585 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp | 634 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp | 655 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/sve/impl.cpp | 179 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/sve/impl.h | 46 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/sve2/impl.cpp | 212 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/generic/sve2/impl.h | 43 | ||||
-rw-r--r-- | src/cpu/kernels/softmax/list.h | 81 |
15 files changed, 4517 insertions, 0 deletions
diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp new file mode 100644 index 0000000000..425fcf7ac6 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#include "arm_compute/core/Helpers.h" + +#include "src/cpu/CpuTypes.h" +#include "src/cpu/kernels/softmax/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ + +template <bool IS_LOG> +void neon_fp16_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(lut_ptr); + if (axis == 0) + { + return neon_softmax_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window); + } + else + { + return neon_softmax_non_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window); + } +} + +template void neon_fp16_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_fp16_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +} // namespace cpu +} // namespace arm_compute +#endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp new file mode 100644 index 0000000000..a64946eb74 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/Helpers.h" + +#include "src/cpu/kernels/softmax/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ + +template <bool IS_LOG> +void neon_fp32_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(lut_ptr); + if (axis == 0) + { + return neon_softmax_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window); + } + else + { + return neon_softmax_non_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window); + } +} + +template void neon_fp32_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_fp32_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/impl.cpp b/src/cpu/kernels/softmax/generic/neon/impl.cpp new file mode 100644 index 0000000000..31baf8a9df --- /dev/null +++ b/src/cpu/kernels/softmax/generic/neon/impl.cpp @@ -0,0 +1,596 @@ +/* + * Copyright (c) 2021-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/cpu/kernels/softmax/generic/neon/impl.h" + +#include "support/SaturateCast.h" + +namespace arm_compute +{ +namespace cpu +{ +template <typename T, bool IS_LOG> +void neon_softmax_x_quantized( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window) +{ + ARM_COMPUTE_UNUSED(axis); + + static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value, + "quantized type should be either qasymm8_t or qasymm8_signed_t."); + + const int input_width = in->info()->valid_region().shape.x(); + + const float scale_beta = -beta * in->info()->quantization_info().uniform().scale; + const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta); + + Iterator in_it(in, window); + Iterator out_it(out, window); + + constexpr int vec_size = 16; + +#ifndef __aarch64__ + const int sum_stages = log2(vec_size >> 1); +#endif // __aarch64__ + + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>; + + execute_window_loop( + window, + [&](const Coordinates &) + { + /* Get pointers */ + const T *in_ptr = reinterpret_cast<const T *>(in_it.ptr()); + T *out_ptr = reinterpret_cast<T *>(out_it.ptr()); + float *tmp_ptr = reinterpret_cast<float *>(tmp); + + T max_val; + + /* Compute Max */ + { + // Init max value + auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{}); + int x = 0; + + for (; x <= (input_width - vec_size); x += vec_size) + { + const auto current_value = wrapper::vloadq(in_ptr + x); + vec_max = wrapper::vmax(vec_max, current_value); + } + +#ifdef __aarch64__ + max_val = wrapper::vmaxv(vec_max); +#else // __aarch64__ + auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max)); + + for (int i = 0; i < sum_stages; ++i) + { + carry_max = wrapper::vpmax(carry_max, carry_max); + } + + max_val = wrapper::vgetlane(carry_max, 0); +#endif // __aarch64__ + + // Compute left-over elements + for (; x < input_width; ++x) + { + max_val = std::max(*(in_ptr + x), max_val); + } + } // Compute Max + + float sum_transformed{}; + + /* Compute exponentials and sum */ + { + /* Get max value */ + const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{}); + + /* Init sum to zero */ + float32x4x4_t vec_sum = { + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + }; + + /* Loop over row and compute exponentials and sum */ + int x = 0; + for (; x <= (input_width - vec_size); x += vec_size) + { + auto vec_elements = wrapper::vloadq(in_ptr + x); + vec_elements = wrapper::vqsub(vec_max, vec_elements); + float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements); + + if (IS_LOG) + { + vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec); + vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec); + vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec); + vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec); + vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0])); + vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1])); + vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2])); + vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3])); + } + else + { + vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec)); + vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec)); + vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec)); + vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec)); + vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]); + vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]); + vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]); + vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]); + } + + vst4q_f32(tmp_ptr + x, vec_elements_flt); + } + + /* Reduce sum */ + const float32x4_t sum_16_byte = + vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3])); + + float sum; + +#ifdef __aarch64__ + sum = wrapper::vaddv(sum_16_byte); +#else // __aarch64__ + auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte)); + sum_res = vpadd_f32(sum_res, sum_res); + sum = wrapper::vgetlane(sum_res, 0); +#endif // __aarch64__ + + /* Run remaining elements */ + for (; x < input_width; ++x) + { + float element{}; + if (IS_LOG) + { + element = (max_val - in_ptr[x]) * scale_beta; + sum += std::exp(element); + } + else + { + element = std::exp((max_val - in_ptr[x]) * scale_beta); + sum += element; + } + + tmp_ptr[x] = element; + } + + if (!IS_LOG) + { + sum_transformed = 256.f / sum; + } + else + { + sum_transformed = std::log(sum); + } + } // Compute exponentials and sum + + /* Normalize exponentials */ + { + constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value; + + const float32x4_t sum_vec = vdupq_n_f32(sum_transformed); + + /* Loop over row and compute softmax */ + int x = 0; + for (; x <= (input_width - vec_size); x += vec_size) + { + using int_vec_type = wrapper::traits::neon_vector_t<T, 16>; + float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x); + int_vec_type normalized_value{}; + if (IS_LOG) + { + const float32x4x4_t sub = { + vsubq_f32(vec_in.val[0], sum_vec), + vsubq_f32(vec_in.val[1], sum_vec), + vsubq_f32(vec_in.val[2], sum_vec), + vsubq_f32(vec_in.val[3], sum_vec), + }; + normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub); + } + else + { + float32x4x4_t mul = { + vmulq_f32(vec_in.val[0], sum_vec), + vmulq_f32(vec_in.val[1], sum_vec), + vmulq_f32(vec_in.val[2], sum_vec), + vmulq_f32(vec_in.val[3], sum_vec), + }; + + if (is_qasymm8_signed) + { + const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{}); + mul.val[0] = wrapper::vsub(mul.val[0], offset_vec); + mul.val[1] = wrapper::vsub(mul.val[1], offset_vec); + mul.val[2] = wrapper::vsub(mul.val[2], offset_vec); + mul.val[3] = wrapper::vsub(mul.val[3], offset_vec); + } + + normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul); + } + wrapper::vstore(out_ptr + x, normalized_value); + } + /* Run remaining elements */ + for (; x < input_width; ++x) + { + if (IS_LOG) + { + out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum_transformed); + } + else + { + out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_transformed) - + (is_qasymm8_signed ? 128.f : 0)); + } + } + } // Normalize exponentials + }, + in_it, out_it); +} + +template <typename T, bool IS_LOG> +void neon_softmax_non_x_quantized( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window) +{ + static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value, + "quantized type should be either qasymm8_t or qasymm8_signed_t."); + + const float scale_beta = -beta * in->info()->quantization_info().uniform().scale; + const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta); + + Iterator in_it(in, window); + Iterator out_it(out, window); + + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>; + + constexpr int vec_size = 16; + const ITensorInfo *in_info = in->info(); + const ITensorInfo *out_info = out->info(); + const int x_width = in_info->valid_region().shape.x(); + const int in_axis_stride = in_info->strides_in_bytes()[axis]; + const int out_axis_stride = out_info->strides_in_bytes()[axis]; + const int tmp_axis_stride = in_axis_stride; + const int axis_width = in_info->dimension(axis); + const int end_actual = std::min(window[0].end(), x_width); + + execute_window_loop( + window, + [&](const Coordinates &winCoords) + { + const bool vector_exceeds_bounds = ((winCoords[0] + vec_size) > end_actual); + + int num_remaining = (end_actual - winCoords[0]); + int num_remaining_full = num_remaining / 4; + int num_remaining_partial = num_remaining % 4; + + /* Get pointers */ + const uint8_t *in_ptr = in_it.ptr(); + uint8_t *out_ptr = out_it.ptr(); + uint8_t *tmp_ptr = reinterpret_cast<uint8_t *>(tmp); + + auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{}); + + /* Compute Max */ + { + if (!vector_exceeds_bounds) + { + int i = 0; + for (; i < axis_width; ++i) + { + const auto current_value = + wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr)); + vec_max = wrapper::vmax(vec_max, current_value); + } + } + else + { + int i = 0; + for (; i < axis_width; ++i) + { + const T *const base_ptr_in = ((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr)); + int j = 0; + for (; j < num_remaining; ++j) + { + const T current_value = *(base_ptr_in + j); + vec_max[j] = std::max(vec_max[j], current_value); + } + } + } + } // Compute Max + + float32x4x4_t vec_sum_transformed = { + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + }; + + /* Compute exponentials and sum */ + { + /* Init sum to zero */ + float32x4x4_t vec_sum = vec_sum_transformed; + + auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{}); + + float32x4x4_t vec_elements_flt; + + if (!vector_exceeds_bounds) + { + int i = 0; + for (; i < axis_width; ++i) + { + vec_elements = wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr)); + vec_elements = wrapper::vqsub(vec_max, vec_elements); + vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements); + + if (IS_LOG) + { + vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec); + vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec); + vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec); + vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec); + vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0])); + vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1])); + vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2])); + vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3])); + } + else + { + vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec)); + vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec)); + vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec)); + vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec)); + vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]); + vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]); + vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]); + vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]); + } + vst4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr), vec_elements_flt); + } + + auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256.f), ExactTagType{}); + if (!IS_LOG) + { + vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]); + vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]); + vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]); + vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]); + } + else + { + vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]); + vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]); + vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]); + vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]); + } + } + else + { + int i = 0; + for (; i < axis_width; ++i) + { + const T *const base_ptr_in = (i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr); + auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{}); + //vec_els is functionally redundant but is needed as a workaround for a toolchain bug. + std::vector<T> vec_els(16); + + for (int k = 0; k < num_remaining_full; ++k) + { + for (int j = 0; j < 4; ++j) + { + vec_els[k * 4 + j] = *(base_ptr_in + (4 * k + j)); + } + } + for (int j = 0; j < num_remaining_partial; ++j) + { + vec_els[num_remaining_full * 4 + j] = *(base_ptr_in + (4 * num_remaining_full + j)); + } + for (int q = 0; q < 16; q++) + { + vec_elements[q] = vec_els[q]; + } + vec_elements = wrapper::vqsub(vec_max, vec_elements); + float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements); + + if (IS_LOG) + { + vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec); + vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec); + vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec); + vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec); + vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0])); + vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1])); + vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2])); + vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3])); + } + else + { + vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec)); + vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec)); + vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec)); + vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec)); + vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]); + vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]); + vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]); + vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]); + } + + float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr); + for (int k = 0; k < num_remaining_full; ++k) + { + for (int j = 0; j < 4; ++j) + { + *(base_ptr_tmp + (4 * k + j)) = vec_elements_flt.val[k][j]; + } + } + + for (int j = 0; j < num_remaining_partial; ++j) + { + *(base_ptr_tmp + (4 * num_remaining_full + j)) = + vec_elements_flt.val[num_remaining_full][j]; + } + } + + auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256), ExactTagType{}); + if (!IS_LOG) + { + vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]); + vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]); + vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]); + vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]); + } + else + { + vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]); + vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]); + vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]); + vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]); + } + } + } // Compute exponentials and sum + + /* Normalize exponentials */ + { + constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value; + if (!vector_exceeds_bounds) + { + int i = 0; + for (; i < axis_width; ++i) + { + using int_vec_type = wrapper::traits::neon_vector_t<T, 16>; + float32x4x4_t vec_in = vld4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr)); + + int_vec_type normalized_value{}; + + if (IS_LOG) + { + const float32x4x4_t sub = { + vsubq_f32(vec_in.val[0], vec_sum_transformed.val[0]), + vsubq_f32(vec_in.val[1], vec_sum_transformed.val[1]), + vsubq_f32(vec_in.val[2], vec_sum_transformed.val[2]), + vsubq_f32(vec_in.val[3], vec_sum_transformed.val[3]), + }; + normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub); + } + else + { + float32x4x4_t mul = { + vmulq_f32(vec_in.val[0], vec_sum_transformed.val[0]), + vmulq_f32(vec_in.val[1], vec_sum_transformed.val[1]), + vmulq_f32(vec_in.val[2], vec_sum_transformed.val[2]), + vmulq_f32(vec_in.val[3], vec_sum_transformed.val[3]), + }; + + if (is_qasymm8_signed) + { + const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{}); + mul.val[0] = wrapper::vsub(mul.val[0], offset_vec); + mul.val[1] = wrapper::vsub(mul.val[1], offset_vec); + mul.val[2] = wrapper::vsub(mul.val[2], offset_vec); + mul.val[3] = wrapper::vsub(mul.val[3], offset_vec); + } + + normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul); + } + wrapper::vstore((i * out_axis_stride) + reinterpret_cast<T *>(out_ptr), normalized_value); + } + } + else + { + int i = 0; + for (; i < axis_width; ++i) + { + T *const base_ptr_out = (i * out_axis_stride) + reinterpret_cast<T *>(out_ptr); + float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr); + if (IS_LOG) + { + for (int k = 0; k < num_remaining_full; ++k) + { + for (int j = 0; j < 4; ++j) + { + *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>( + (*(base_ptr_tmp + (4 * k + j)) - vec_sum_transformed.val[k][j])); + } + } + for (int j = 0; j < num_remaining_partial; ++j) + { + *(base_ptr_out + (4 * num_remaining_full + j)) = + utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) - + vec_sum_transformed.val[num_remaining_full][j]); + } + } + else + { + for (int k = 0; k < num_remaining_full; ++k) + { + for (int j = 0; j < 4; ++j) + { + *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>( + *(base_ptr_tmp + (4 * k + j)) * vec_sum_transformed.val[k][j] - + (is_qasymm8_signed ? 128.f : 0)); + } + } + for (int j = 0; j < num_remaining_partial; ++j) + { + *(base_ptr_out + (4 * num_remaining_full + j)) = + utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) * + vec_sum_transformed.val[num_remaining_full][j] - + (is_qasymm8_signed ? 128.f : 0)); + } + } + } + } + } // Normalize exponentials + }, + in_it, out_it); +} + +template void neon_softmax_x_quantized<qasymm8_signed_t, true>( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); + +template void neon_softmax_x_quantized<qasymm8_signed_t, false>( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); + +template void neon_softmax_x_quantized<qasymm8_t, true>( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); + +template void neon_softmax_x_quantized<qasymm8_t, false>( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); + +template void neon_softmax_non_x_quantized<qasymm8_signed_t, true>( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); + +template void neon_softmax_non_x_quantized<qasymm8_signed_t, false>( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); + +template void neon_softmax_non_x_quantized<qasymm8_t, true>( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); + +template void neon_softmax_non_x_quantized<qasymm8_t, false>( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/impl.h b/src/cpu/kernels/softmax/generic/neon/impl.h new file mode 100644 index 0000000000..e417271d0e --- /dev/null +++ b/src/cpu/kernels/softmax/generic/neon/impl.h @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2021-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H +#define ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H + +#include "arm_compute/core/Helpers.h" + +#include "src/core/NEON/NEMath.h" +#include "src/core/NEON/wrapper/wrapper.h" + +namespace arm_compute +{ +namespace cpu +{ + +#ifdef __aarch64__ +namespace +{ +// These helper functions are added because vaddv does not exist for fp16, +// and, therefore, is not part of the wrapper::vaddv interface. +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +inline float16_t wrapper_vaddv(const float16x8_t &a, int sum_stages) +{ + auto sum_res = wrapper::vpadd(wrapper::vgethigh(a), wrapper::vgetlow(a)); + for (int i = 0; i < sum_stages; ++i) + { + sum_res = wrapper::vpadd(sum_res, sum_res); + } + return wrapper::vgetlane(sum_res, 0); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +inline float wrapper_vaddv(const float32x4_t &a, int sum_stages) +{ + ARM_COMPUTE_UNUSED(sum_stages); + return wrapper::vaddv(a); +} +} // namespace +#endif // __aarch64__ + +// The template implementation for float data types is stored in the header file because +// we need all fp16 instantiated code to live in fp16.cpp files. +template <typename T, bool IS_LOG> +void neon_softmax_x_float(const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window) +{ + ARM_COMPUTE_UNUSED(axis); + ARM_COMPUTE_UNUSED(tmp); + + const int input_width = in->info()->valid_region().shape.x(); + + Iterator in_it(in, window); + Iterator out_it(out, window); + + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>; + + constexpr int vec_size = 16 / sizeof(T); + + const int sum_stages = log2(vec_size >> 1); + + const auto beta_vec = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}); + + execute_window_loop( + window, + [&](const Coordinates &) + { + /* Get pointers */ + const T *in_ptr = reinterpret_cast<const T *>(in_it.ptr()); + T *out_ptr = reinterpret_cast<T *>(out_it.ptr()); + + T max_val; + + /* Compute Max */ + { + // Init max value + auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{}); + int x = 0; + + for (; x <= (input_width - vec_size); x += vec_size) + { + const auto current_value = wrapper::vloadq(in_ptr + x); + vec_max = wrapper::vmax(vec_max, current_value); + } + +#ifdef __aarch64__ + max_val = wrapper::vmaxv(vec_max); +#else // __aarch64__ + auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max)); + + for (int i = 0; i < sum_stages; ++i) + { + carry_max = wrapper::vpmax(carry_max, carry_max); + } + + max_val = wrapper::vgetlane(carry_max, 0); +#endif // __aarch64__ + + // Compute left-over elements + for (; x < input_width; ++x) + { + max_val = std::max(*(in_ptr + x), max_val); + } + } // compute max + + T sum_transformed{}; + + /* Compute exponentials and sum */ + { + /* Get max value */ + const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{}); + + /* Init sum to zero */ + auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{}); + + /* Loop over row and compute exponentials and sum */ + int x = 0; + for (; x <= (input_width - vec_size); x += vec_size) + { + auto vec_elements = wrapper::vloadq(in_ptr + x); + vec_elements = wrapper::vsub(vec_elements, vec_max); + if (IS_LOG) + { + vec_elements = wrapper::vmul(vec_elements, beta_vec); + vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements)); + } + else + { + vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec)); + vec_sum = wrapper::vadd(vec_sum, vec_elements); + } + wrapper::vstore(out_ptr + x, vec_elements); + } + + /* Reduce sum */ + T sum{}; +#ifdef __aarch64__ + sum = wrapper_vaddv(vec_sum, sum_stages); +#else // __aarch64__ + auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum)); + for (int i = 0; i < sum_stages; ++i) + { + sum_res = wrapper::vpadd(sum_res, sum_res); + } + sum = wrapper::vgetlane(sum_res, 0); +#endif // __aarch64__ + + /* Run remaining elements */ + for (; x < input_width; ++x) + { + T element{}; + + if (IS_LOG) + { + element = (in_ptr[x] - max_val) * beta; + sum += std::exp(element); + } + else + { + element = std::exp((in_ptr[x] - max_val) * beta); + sum += element; + } + + out_ptr[x] = element; + } + + if (!IS_LOG) + { + sum_transformed = T(1) / sum; + } + else + { + sum_transformed = static_cast<T>(std::log(sum)); + } + } // Compute exponentials and sum + + /* Normalize exponentials */ + { + const auto sum_vec = wrapper::vdup_n(static_cast<T>(sum_transformed), ExactTagType{}); + + /* Loop over row and compute softmax */ + int x = 0; + for (; x <= (input_width - vec_size); x += vec_size) + { + const auto vec_in = wrapper::vloadq(out_ptr + x); + if (IS_LOG) + { + wrapper::vstore(out_ptr + x, wrapper::vsub(vec_in, sum_vec)); + } + else + { + wrapper::vstore(out_ptr + x, wrapper::vmul(vec_in, sum_vec)); + } + } + + /* Run remaining elements */ + for (; x < input_width; ++x) + { + if (IS_LOG) + { + out_ptr[x] = out_ptr[x] - sum_transformed; + } + else + { + out_ptr[x] = out_ptr[x] * sum_transformed; + } + } + } // Normalize exponentials + }, + in_it, out_it); +} +template <typename T, bool IS_LOG> +void neon_softmax_non_x_float( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window) +{ + ARM_COMPUTE_UNUSED(tmp); + + Iterator in_it(in, window); + Iterator out_it(out, window); + + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>; + + const auto beta_vec = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}); + constexpr int vec_size = 16 / sizeof(T); + const ITensorInfo *in_info = in->info(); + const ITensorInfo *out_info = out->info(); + const int x_width = in_info->valid_region().shape.x(); + const unsigned int in_axis_stride = in_info->strides_in_bytes()[axis]; + const unsigned int out_axis_stride = out_info->strides_in_bytes()[axis]; + const int axis_width = in_info->dimension(axis); + + execute_window_loop( + window, + [&](const Coordinates &winCoords) + { + const bool vector_exceeds_bounds = (winCoords[0] + vec_size) > x_width; + + /* Get pointers */ + const uint8_t *in_ptr = in_it.ptr(); + uint8_t *out_ptr = out_it.ptr(); + + // Init max value + auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{}); + + /* Compute Max */ + { + if (!vector_exceeds_bounds) + { + int i = 0; + for (; i < axis_width; ++i) + { + const auto current_value = + wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr)); + vec_max = wrapper::vmax(vec_max, current_value); + } + } + else + { + int i = 0; + for (; i < axis_width; ++i) + { + const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr); + int j = 0; + for (; j < (x_width - winCoords[0]); ++j) + { + const auto current_value = *(base_ptr_in + j); + vec_max[j] = std::max(vec_max[j], current_value); + } + } + } + } // compute max + + auto vec_sum_transformed = wrapper::vdup_n(static_cast<T>(0), ExactTagType{}); + + auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{}); + /* Init sum to zero */ + auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{}); + + /* Compute exponentials and sum */ + { + if (!vector_exceeds_bounds) + { + const auto vec_one = wrapper::vdup_n(static_cast<T>(1), ExactTagType{}); + /* Loop over row and compute exponentials and sum */ + int i = 0; + for (; i < axis_width; ++i) + { + vec_elements = wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr)); + vec_elements = wrapper::vsub(vec_elements, vec_max); + if (IS_LOG) + { + vec_elements = wrapper::vmul(vec_elements, beta_vec); + vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements)); + } + else + { + vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec)); + vec_sum = wrapper::vadd(vec_sum, vec_elements); + } + + wrapper::vstore(reinterpret_cast<T *>((i * out_axis_stride) + out_ptr), vec_elements); + } + + if (!IS_LOG) + { + vec_sum_transformed = wrapper::vdiv(vec_one, vec_sum); + } + else + { + vec_sum_transformed = wrapper::vlog(vec_sum); + } + } + else + { + int i = 0; + for (; i < axis_width; ++i) + { + const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr); + T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr); + int j = 0; + for (; j < (x_width - winCoords[0]); ++j) + { + vec_elements[j] = *(base_ptr_in + j); + vec_elements[j] -= vec_max[j]; + if (IS_LOG) + { + vec_elements[j] *= beta; + vec_sum[j] += std::exp(vec_elements[j]); + } + else + { + vec_elements[j] = std::exp(vec_elements[j] * beta); + vec_sum[j] += vec_elements[j]; + } + *(base_ptr_out + j) = vec_elements[j]; + } + } + int j = 0; + for (; j < (x_width - winCoords[0]); ++j) + { + if (!IS_LOG) + { + vec_sum_transformed[j] = 1 / vec_sum[j]; + } + else + { + vec_sum_transformed[j] = std::log(vec_sum[j]); + } + } + } + } // Compute exponentials and sum + + /* Normalize exponentials */ + { + if (!vector_exceeds_bounds) + { + /* Loop over row and compute softmax */ + int i = 0; + for (; i < axis_width; ++i) + { + T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr); + auto vec_in = wrapper::vloadq(base_ptr_out); + if (IS_LOG) + { + wrapper::vstore(base_ptr_out, wrapper::vsub(vec_in, vec_sum_transformed)); + } + else + { + wrapper::vstore(base_ptr_out, wrapper::vmul(vec_in, vec_sum_transformed)); + } + } + } + else + { + int i = 0; + for (; i < axis_width; ++i) + { + T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr); + int j = 0; + for (; j < (x_width - winCoords[0]); ++j) + { + if (IS_LOG) + { + *(base_ptr_out + j) -= vec_sum_transformed[j]; + } + else + { + *(base_ptr_out + j) *= vec_sum_transformed[j]; + } + } + } + } + } // Normalize exponentials + }, + in_it, out_it); +} +template <typename T, bool IS_LOG> +void neon_softmax_x_quantized( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); + +template <typename T, bool IS_LOG> +void neon_softmax_non_x_quantized( + const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window); +} // namespace cpu +} // namespace arm_compute + +#endif // ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp new file mode 100644 index 0000000000..369f9bb005 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/Helpers.h" + +#include "src/cpu/kernels/softmax/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +template <bool IS_LOG> +void neon_qasymm8_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(lut_ptr); + if (axis == 0) + { + return neon_softmax_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window); + } + else + { + return neon_softmax_non_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window); + } +} + +template void neon_qasymm8_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_qasymm8_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp new file mode 100644 index 0000000000..594ceb7654 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/Helpers.h" + +#include "src/cpu/kernels/softmax/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +template <bool IS_LOG> +void neon_qasymm8_signed_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(lut_ptr); + if (axis == 0) + { + return neon_softmax_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window); + } + else + { + return neon_softmax_non_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window); + } +} + +template void neon_qasymm8_signed_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_qasymm8_signed_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/sme2/fp16.cpp b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp new file mode 100644 index 0000000000..e70c9f4793 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp @@ -0,0 +1,781 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_f16_softmax_kernel( // + const float16_t *src, + float16_t *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4]) +{ + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // Registers + // + // * x9: temporary, index + // * x10: temporary, -inf + // * x11: temporary, 0 + // * x12: temporary, 1.0f + // * x13: temporary, body_length + // + // * x20: index_3 + // * x21: src_3 + // * x22: dst_3 + // * x23: index_2 + // * x24: src_2 + // * x25: dst_2 + // * x26: index_1 + // * x27: src_1 + // * x28: dst_1 + // + // * z0: c1 + // * z1: c2 + // * z2: c3 + // * z3: c4 + // * z4: c5 + // * z5: shift + // * z6: inv_ln2 + // * z7: neg_ln2_hi + // * z8: neg_ln2_lo + // * z9: min_input + // * z10: 23, 0 + // * z11: max_value + // * z12-z15: x, x_fp32_lower_halves, r_hi, r, r2 + // * z16-z19: max_value, shift, z, scale, poly + // * z20-z21: n, p1, p12345 + // * z22-z23: n, p23, p2345 + // * z24-z25: p45 + // * z26: beta + // * z28-z31: sum_value, x_fp32_upper_halves + // + // * za0-za3: sum_value + // + // * p0: all-true + // * p1: left-over predicate for find-max & normalize loops + // * p2-p4: left-over predicates for regularize loop + // * p4-p7: underflow in vector loop + // * p5-p6: underflow in leftover loop + // * + // * pn9: all-true + + // Prepares all constant values + + ptrue p0.b + .inst 0x25207811 // ptrue pn9.b + + mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb + mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010 + + movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb + movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010 + + dup z0.s, w9 // c1. + dup z1.s, w10 // c2. + dup z2.s, w11 // c3. + dup z3.s, w12 // c4. + dup z4.s, w13 // c5. + + mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + dup z5.s, w9 // shift + dup z6.s, w10 // inv_ln2 + dup z7.s, w11 // neg_ln2_hi + dup z8.s, w12 // neg_ln2_lo + dup z9.s, w13 // min_input + + dup z26.s, %w[beta] // beta + fcvt h26, s26 + dup z26.h, z26.h[0] + + mov w10, #0xfc00 // -inf: 0xfc00 for fp16 + + mov w11, #0 // 0 + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cnth x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + + // ---------------------------------------------------------------- z16-z19: max_value = -inf + dup z16.h, w10 + dup z17.h, w10 + dup z18.h, w10 + dup z19.h, w10 + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + dup z11.h, w10 // z11: max_value = -inf + +find_max_body_start%=: + cmp x9, x13 + b.eq find_max_body_end%= + + .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x + .inst 0xc16cb910 // fmax {z16.h-z19.h}, {z16.h-z19.h}, {z12.h-z15.h} // z16-z19: max_value = max(max_value, x) + + inch x9, ALL, MUL #4 + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.h, x9, %x[length] + b.none find_max_leftover_end%= + + ld1h z12.h, p1/z, [x27, x9, LSL #1] // z12: x + fmax z16.h, p1/m, z16.h, z12.h // z16: max_value = max(max_value, x) + + inch x9 + b find_max_leftover_start%= +find_max_leftover_end%=: + + // ---------------------------------------------------------------- z16: max_value + .inst 0xc172b110 // fmax {z16.h-z17.h}, {z16.h-z17.h}, {z18.s-z19.h} + fmax z16.h, p0/m, z16.h, z17.h + fmaxv h16, p0, z16.h + + // ---------------------------------------------------------------- z11: max_value + dup z11.h, z16.h[0] + + // ================================================== + // Step 2: Regularize, i.e. Calculate exp(x - max(x) + // ================================================== + + .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value (in fp32) + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // ---------------------------------------------------- x9: index + +regularize_body_start%=: + cmp x9, x13 + b.eq regularize_body_end%= + + // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data + .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x + + // ---------------------------------------------------------------- z12-z15: x = input_data - max_value + fsub z12.h, z12.h, z11.h + fsub z13.h, z13.h, z11.h + fsub z14.h, z14.h, z11.h + fsub z15.h, z15.h, z11.h + + // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta + fmul z12.h, z12.h, z26.h + fmul z13.h, z13.h, z26.h + fmul z14.h, z14.h, z26.h + fmul z15.h, z15.h, z26.h + + // ---------------------------------------------------------------- + // Convert fp16 values to fp32. This results in four more registers. + // z12 --> z12, z28 + fcvtlt z28.s, p0/m, z12.h + fcvt z12.s, p0/m, z12.h + + // z13 --> z13, z29 + fcvtlt z29.s, p0/m, z13.h + fcvt z13.s, p0/m, z13.h + + // z14 --> z14, z30 + fcvtlt z30.s, p0/m, z14.h + fcvt z14.s, p0/m, z14.h + + // z15 --> z15, z31 + fcvtlt z31.s, p0/m, z15.h + fcvt z15.s, p0/m, z15.h + + // ---------------------------------------------------------------- + // Process z12-z15 + // ---------------------------------------------------------------- + // ---------------------------------------------------------------- z16-z19: shift + mov z16.d, z5.d + mov z17.d, z5.d + mov z18.d, z5.d + mov z19.d, z5.d + + // ---------------------------------------------------------------- p4-p7: underflow = x < min_input + fcmlt p4.s, p0/z, z12.s, z9.s + fcmlt p5.s, p0/z, z13.s, z9.s + fcmlt p6.s, p0/z, z14.s, z9.s + fcmlt p7.s, p0/z, z15.s, z9.s + + // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2 + fmla z16.s, p0/m, z12.s, z6.s + fmla z17.s, p0/m, z13.s, z6.s + fmla z18.s, p0/m, z14.s, z6.s + fmla z19.s, p0/m, z15.s, z6.s + + // ---------------------------------------------------------------- z20-z23: n = z - shift + fsub z20.s, z16.s, z5.s + fsub z21.s, z17.s, z5.s + fsub z22.s, z18.s, z5.s + fsub z23.s, z19.s, z5.s + + // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi + fmla z12.s, p0/m, z20.s, z7.s + fmla z13.s, p0/m, z21.s, z7.s + fmla z14.s, p0/m, z22.s, z7.s + fmla z15.s, p0/m, z23.s, z7.s + + // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo + fmla z12.s, p0/m, z20.s, z8.s + fmla z13.s, p0/m, z21.s, z8.s + fmla z14.s, p0/m, z22.s, z8.s + fmla z15.s, p0/m, z23.s, z8.s + + // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n) + dup z10.s, #23 + urshl z16.s, p0/m, z16.s, z10.s + urshl z17.s, p0/m, z17.s, z10.s + urshl z18.s, p0/m, z18.s, z10.s + urshl z19.s, p0/m, z19.s, z10.s + + // Processes the first 2 vectors. (z12-z13) + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z12.s, z0.s + fmul z21.s, z13.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z12.s, z2.s + fmla z23.s, p0/m, z13.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z12.s, z4.s + fmla z25.s, p0/m, z13.s, z4.s + + // ---------------------------------------------------------------- z12-z13: r2 = r * r + fmul z12.s, z12.s, z12.s + fmul z13.s, z13.s, z13.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z12.s, z24.s + fmla z23.s, p0/m, z13.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z12.s, z22.s + fmla z21.s, p0/m, z13.s, z23.s + + // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale + fmla z16.s, p0/m, z20.s, z16.s + fmla z17.s, p0/m, z21.s, z17.s + + // Processes the last 2 vectors (z14-z15) + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z14.s, z0.s + fmul z21.s, z15.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z14.s, z2.s + fmla z23.s, p0/m, z15.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z14.s, z4.s + fmla z25.s, p0/m, z15.s, z4.s + + // ---------------------------------------------------------------- z14-z15: r2 = r * r + fmul z14.s, z14.s, z14.s + fmul z15.s, z15.s, z15.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z14.s, z24.s + fmla z23.s, p0/m, z15.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z14.s, z22.s + fmla z21.s, p0/m, z15.s, z23.s + + // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale + fmla z18.s, p0/m, z20.s, z18.s + fmla z19.s, p0/m, z21.s, z19.s + + // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly + dup z10.s, #0 + sel z12.s, p4, z10.s, z16.s + sel z13.s, p5, z10.s, z17.s + sel z14.s, p6, z10.s, z18.s + sel z15.s, p7, z10.s, z19.s + + // ---------------------------------------------------------------- sum in fp32 + .inst 0xc1a17d80 // fadd za.s[w11, #0, VGx4], {z12.s-z15.s} za0-za3: sum_value = sum_value + poly + + // ---------------------------------------------------------------- + // Process z28-z31 + // ---------------------------------------------------------------- + // ---------------------------------------------------------------- z16-z19: shift + mov z16.d, z5.d + mov z17.d, z5.d + mov z18.d, z5.d + mov z19.d, z5.d + + // ---------------------------------------------------------------- p4-p7: underflow = x < min_input + fcmlt p4.s, p0/z, z28.s, z9.s + fcmlt p5.s, p0/z, z29.s, z9.s + fcmlt p6.s, p0/z, z30.s, z9.s + fcmlt p7.s, p0/z, z31.s, z9.s + + // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2 + fmla z16.s, p0/m, z28.s, z6.s + fmla z17.s, p0/m, z29.s, z6.s + fmla z18.s, p0/m, z30.s, z6.s + fmla z19.s, p0/m, z31.s, z6.s + + // ---------------------------------------------------------------- z20-z23: n = z - shift + fsub z20.s, z16.s, z5.s + fsub z21.s, z17.s, z5.s + fsub z22.s, z18.s, z5.s + fsub z23.s, z19.s, z5.s + + // ---------------------------------------------------------------- z24-z27: r_hi = x + n * neg_ln2_hi + fmla z28.s, p0/m, z20.s, z7.s + fmla z29.s, p0/m, z21.s, z7.s + fmla z30.s, p0/m, z22.s, z7.s + fmla z31.s, p0/m, z23.s, z7.s + + // ---------------------------------------------------------------- z27-z30: r = r_hi + n * neg_ln2_lo + fmla z28.s, p0/m, z20.s, z8.s + fmla z29.s, p0/m, z21.s, z8.s + fmla z30.s, p0/m, z22.s, z8.s + fmla z31.s, p0/m, z23.s, z8.s + + // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n) + dup z10.s, #23 + urshl z16.s, p0/m, z16.s, z10.s + urshl z17.s, p0/m, z17.s, z10.s + urshl z18.s, p0/m, z18.s, z10.s + urshl z19.s, p0/m, z19.s, z10.s + + // Processes the first 2 vectors. (z28-z29) + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z28.s, z0.s + fmul z21.s, z29.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z28.s, z2.s + fmla z23.s, p0/m, z29.s, z2.s + + // ---------------------------------------------------------------- z24-z25: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z28.s, z4.s + fmla z25.s, p0/m, z29.s, z4.s + + // ---------------------------------------------------------------- z28-z29: r2 = r * r + fmul z28.s, z28.s, z28.s + fmul z29.s, z29.s, z29.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z28.s, z24.s + fmla z23.s, p0/m, z29.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z28.s, z22.s + fmla z21.s, p0/m, z29.s, z23.s + + // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale + fmla z16.s, p0/m, z20.s, z16.s + fmla z17.s, p0/m, z21.s, z17.s + + // Processes the last 2 vectors (z30-z31) + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z30.s, z0.s + fmul z21.s, z31.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z30.s, z2.s + fmla z23.s, p0/m, z31.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z30.s, z4.s + fmla z25.s, p0/m, z31.s, z4.s + + // ---------------------------------------------------------------- z30-z31: r2 = r * r + fmul z30.s, z30.s, z30.s + fmul z31.s, z31.s, z31.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z30.s, z24.s + fmla z23.s, p0/m, z31.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z30.s, z22.s + fmla z21.s, p0/m, z31.s, z23.s + + // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale + fmla z18.s, p0/m, z20.s, z18.s + fmla z19.s, p0/m, z21.s, z19.s + + // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly + dup z10.s, #0 + sel z28.s, p4, z10.s, z16.s + sel z29.s, p5, z10.s, z17.s + sel z30.s, p6, z10.s, z18.s + sel z31.s, p7, z10.s, z19.s + + // ---------------------------------------------------------------- sum in fp32 + .inst 0xc1a17f80 // fadd za.s[w11, #0, VGx4], {z28.s-z31.s} za0-za3: sum_value = sum_value + poly + + fcvt z12.h, p0/m, z12.s + fcvtnt z12.h, p0/m, z28.s + + fcvt z13.h, p0/m, z13.s + fcvtnt z13.h, p0/m, z29.s + + fcvt z14.h, p0/m, z14.s + fcvtnt z14.h, p0/m, z30.s + + fcvt z15.h, p0/m, z15.s + fcvtnt z15.h, p0/m, z31.s + + // Stores 4 consecutive registers to the output + .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1] + + inch x9, ALL, MUL #4 + b regularize_body_start%= +regularize_body_end%=: + + // ---------------------------------------------------------------- z28: sum_value + .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4] + fadd z28.s, z28.s, z29.s + fadd z30.s, z30.s, z31.s + fadd z28.s, z28.s, z30.s + + // Loop for processing the leftover part. +regularize_leftover_start%=: + whilelo p2.h, x9, %x[length] + b.none regularize_leftover_end%= + + ld1h z12.h, p2/z, [x27, x9, LSL #1] // x12: input_data + + fsub z12.h, z12.h, z11.h // z12: x = input_data - max_value + fmul z12.h, z12.h, z26.h // z12: x = (input_data - max_value) * beta + + // ---------------------------------------------------------------- z12.h --> z12.s, z13.s + fcvtlt z13.s, p2/m, z12.h + fcvt z12.s, p2/m, z12.h + + // ---------------------------------------------------------------- p3, p4: predicates for z12, z14 + pfalse p1.b + trn1 p3.h, p2.h, p1.h // for z12 + trn2 p4.h, p2.h, p1.h // for z13 + + mov z16.d, z5.d // z16: shift + mov z17.d, z5.d // z17: shift + fcmlt p5.s, p3/z, z12.s, z9.s // p5: underflow = x < min_input + fcmlt p6.s, p4/z, z13.s, z9.s // p6: underflow = x < min_input + fmla z16.s, p3/m, z12.s, z6.s // z16: z = shift + x * inv_ln2 + fmla z17.s, p4/m, z13.s, z6.s // z17: z = shift + x * inv_ln2 + fsub z20.s, z16.s, z5.s // z20: n = z - shift + fsub z21.s, z17.s, z5.s // z21: n = z - shift + fmla z12.s, p3/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi + fmla z13.s, p4/m, z21.s, z7.s // z13: r_hi = x + n * neg_ln2_hi + fmla z12.s, p3/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo + fmla z13.s, p4/m, z21.s, z8.s // z13: r = r_hi + n * neg_ln2_lo + dup z10.s, #23 // z10: 23 + urshl z16.s, p3/m, z16.s, z10.s // z16: scale = z << 23 (2^n) + urshl z17.s, p4/m, z17.s, z10.s // z17: scale = z << 23 (2^n) + fmul z20.s, z12.s, z0.s // z20: p1 = r * c1 + fmul z21.s, z13.s, z0.s // z21: p1 = r * c1 + mov z22.d, z1.d // z22: p23 = c2 + mov z23.d, z1.d // z23: p23 = c2 + fmla z22.s, p3/m, z12.s, z2.s // z22: p23 = c2 + r * c3 + fmla z23.s, p4/m, z13.s, z2.s // z23: p23 = c2 + r * c3 + mov z24.d, z3.d // z24: c4 + mov z25.d, z3.d // z25: c4 + fmla z24.s, p3/m, z12.s, z4.s // z24: p45 = c4 + r * c5 + fmla z25.s, p4/m, z13.s, z4.s // z25: p45 = c4 + r * c5 + fmul z12.s, z12.s, z12.s // z12: r2 = r * r + fmul z13.s, z13.s, z13.s // z13: r2 = r * r + fmla z22.s, p3/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45 + fmla z23.s, p4/m, z13.s, z25.s // z23: p2345 = p23 + r2 * p45 + fmla z20.s, p3/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345 + fmla z21.s, p4/m, z13.s, z23.s // z21: p12345 = p1 + r2 * p2345 + fmla z16.s, p3/m, z20.s, z16.s // z16: poly = scale + p12345 * scale + fmla z17.s, p4/m, z21.s, z17.s // z17: poly = scale + p12345 * scale + dup z10.s, #0 // z10: 0 + sel z16.s, p5, z10.s, z16.s // z16: poly = underflow ? 0 : poly + sel z17.s, p6, z10.s, z17.s // z17: poly = underflow ? 0 : poly + fadd z28.s, p3/m, z28.s, z16.s // z28: sum_value = sum_value + poly + fadd z28.s, p4/m, z28.s, z17.s // z28: sum_value = sum_value + poly + + fcvt z16.h, p3/m, z16.s + fcvtnt z16.h, p4/m, z17.s + st1h z16.h, p2, [x28, x9, LSL #1] + + inch x9 + b regularize_leftover_start%= +regularize_leftover_end%=: + + // ================================================== + // Step 3: Normalize + // ================================================== + + // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value + faddv s28, p0, z28.s + fmov s29, #1.0 // 1.0f + fdiv s28, s29, s28 + fcvt h28, s28 + + dup z28.h, z28.h[0] + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + +normalize_body_start%=: + cmp x9, x13 + b.eq normalize_body_end%= + + .inst 0xa009a78c // ld1h {z12.h-z15.h}, pn9/z, [x28, x9, LSL #1] + + // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value + fmul z12.h, z12.h, z28.h + fmul z13.h, z13.h, z28.h + fmul z14.h, z14.h, z28.h + fmul z15.h, z15.h, z28.h + + .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1] + + inch x9, ALL, MUL #4 + b normalize_body_start%= +normalize_body_end%=: + + // Loop for processing the leftover part. +normalize_leftover_start%=: + whilelo p1.h, x9, %x[length] + b.none normalize_leftover_end%= + + ld1h z12.h, p1/z, [x28, x9, LSL #1] // z12: x + fmul z12.h, z12.h, z28.h // z12: result = x * inv_sum_value + + st1h z12.h, p1, [x28, x9, LSL #1] + + inch x9 + b normalize_leftover_start%= +normalize_leftover_end%=: + + // ================================================== + // 3D loop closing + // ================================================== + + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p9", // + "x9", "x10", "x11", "x12", "x13", "x14", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_fp16_softmax(const ITensor *in, + void *const, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(lut_ptr); + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const auto *k_src = reinterpret_cast<const float16_t *>(in->buffer() + k_src_offset); + auto *k_dst = reinterpret_cast<float16_t *>(out->buffer() + k_dst_offset); + + sme2_f16_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/generic/sme2/fp32.cpp b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp new file mode 100644 index 0000000000..5e29d51746 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp @@ -0,0 +1,585 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_f32_softmax_kernel( // + const float *src, + float *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4]) +{ + // Precondition: + // * src_strides[0] == sizeof(float) + // * dst_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // Registers + // + // * x9: temporary, index + // * x10: temporary, -inf + // * x11: temporary, 0 + // * x12: temporary, 1.0f + // * x13: temporary, body_length + // + // * x20: index_3 + // * x21: src_3 + // * x22: dst_3 + // * x23: index_2 + // * x24: src_2 + // * x25: dst_2 + // * x26: index_1 + // * x27: src_1 + // * x28: dst_1 + // + // * z0: c1 + // * z1: c2 + // * z2: c3 + // * z3: c4 + // * z4: c5 + // * z5: shift + // * z6: inv_ln2 + // * z7: neg_ln2_hi + // * z8: neg_ln2_lo + // * z9: min_input + // * z10: 23, 0 + // * z11: max_value + // * z12-z15: x, r_hi, r, r2 + // * z16-z19: max_value, shift, z, scale, poly + // * z20-z21: n, p1, p12345 + // * z22-z23: n, p23, p2345 + // * z24-z25: p45 + // * z26: beta + // * z28-z31: sum_value + // + // * za0-za3: sum_value + // + // * p0: all-true + // * p1: left-over predicate + // * p4-p7: underflow + // * pn9: all-true + + // Prepares all constant values + + ptrue p0.b + .inst 0x25207811 // ptrue pn9.b + + mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb + mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010 + + movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb + movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010 + + dup z0.s, w9 // c1. + dup z1.s, w10 // c2. + dup z2.s, w11 // c3. + dup z3.s, w12 // c4. + dup z4.s, w13 // c5. + + mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + dup z5.s, w9 // shift + dup z6.s, w10 // inv_ln2 + dup z7.s, w11 // neg_ln2_hi + dup z8.s, w12 // neg_ln2_lo + dup z9.s, w13 // min_input + + dup z26.s, %w[beta] // beta + + mov w10, #0x0000 // -inf: 0xff800000 + movk w10, #0xff80 // -inf: 0xff800000 + + mov w11, #0 // 0 + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntw x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + dup z11.s, w10 // z11: max_value = -inf + + // ---------------------------------------------------------------- z16-z19: max_value = -inf + mov z16.d, z11.d + mov z17.d, z11.d + mov z18.d, z11.d + mov z19.d, z11.d + +find_max_body_start%=: + cmp x9, x13 + b.eq find_max_body_end%= + + .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x + .inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x) + + incw x9, ALL, MUL #4 + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none find_max_leftover_end%= + + ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x + fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x) + + incw x9 + b find_max_leftover_start%= +find_max_leftover_end%=: + + // ---------------------------------------------------------------- z16: max_value + .inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s} + fmax z16.s, p0/m, z16.s, z17.s + fmaxv s16, p0, z16.s + + // ---------------------------------------------------------------- z11: max_value + dup z11.s, z16.s[0] + + // ================================================== + // Step 2: Regularize + // ================================================== + + .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // ---------------------------------------------------- x9: index + +regularize_body_start%=: + cmp x9, x13 + b.eq regularize_body_end%= + + // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data + .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] + + // ---------------------------------------------------------------- z12-z15: x = input_data - max_value + fsub z12.s, z12.s, z11.s + fsub z13.s, z13.s, z11.s + fsub z14.s, z14.s, z11.s + fsub z15.s, z15.s, z11.s + + // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta + fmul z12.s, z12.s, z26.s + fmul z13.s, z13.s, z26.s + fmul z14.s, z14.s, z26.s + fmul z15.s, z15.s, z26.s + + // ---------------------------------------------------------------- z16-z19: shift + mov z16.d, z5.d + mov z17.d, z5.d + mov z18.d, z5.d + mov z19.d, z5.d + + // ---------------------------------------------------------------- p4-p7: underflow = x < min_input + fcmlt p4.s, p0/z, z12.s, z9.s + fcmlt p5.s, p0/z, z13.s, z9.s + fcmlt p6.s, p0/z, z14.s, z9.s + fcmlt p7.s, p0/z, z15.s, z9.s + + // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2 + fmla z16.s, p0/m, z12.s, z6.s + fmla z17.s, p0/m, z13.s, z6.s + fmla z18.s, p0/m, z14.s, z6.s + fmla z19.s, p0/m, z15.s, z6.s + + // ---------------------------------------------------------------- z20-z23: n = z - shift + fsub z20.s, z16.s, z5.s + fsub z21.s, z17.s, z5.s + fsub z22.s, z18.s, z5.s + fsub z23.s, z19.s, z5.s + + // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi + fmla z12.s, p0/m, z20.s, z7.s + fmla z13.s, p0/m, z21.s, z7.s + fmla z14.s, p0/m, z22.s, z7.s + fmla z15.s, p0/m, z23.s, z7.s + + // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo + fmla z12.s, p0/m, z20.s, z8.s + fmla z13.s, p0/m, z21.s, z8.s + fmla z14.s, p0/m, z22.s, z8.s + fmla z15.s, p0/m, z23.s, z8.s + + // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n) + dup z10.s, #23 + urshl z16.s, p0/m, z16.s, z10.s + urshl z17.s, p0/m, z17.s, z10.s + urshl z18.s, p0/m, z18.s, z10.s + urshl z19.s, p0/m, z19.s, z10.s + + // Processes the first 2 vectors. + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z12.s, z0.s + fmul z21.s, z13.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z12.s, z2.s + fmla z23.s, p0/m, z13.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z12.s, z4.s + fmla z25.s, p0/m, z13.s, z4.s + + // ---------------------------------------------------------------- z12-z13: r2 = r * r + fmul z12.s, z12.s, z12.s + fmul z13.s, z13.s, z13.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z12.s, z24.s + fmla z23.s, p0/m, z13.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z12.s, z22.s + fmla z21.s, p0/m, z13.s, z23.s + + // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale + fmla z16.s, p0/m, z20.s, z16.s + fmla z17.s, p0/m, z21.s, z17.s + + // Processes the last 2 vectors + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z14.s, z0.s + fmul z21.s, z15.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z14.s, z2.s + fmla z23.s, p0/m, z15.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z14.s, z4.s + fmla z25.s, p0/m, z15.s, z4.s + + // ---------------------------------------------------------------- z14-z15: r2 = r * r + fmul z14.s, z14.s, z14.s + fmul z15.s, z15.s, z15.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z14.s, z24.s + fmla z23.s, p0/m, z15.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z14.s, z22.s + fmla z21.s, p0/m, z15.s, z23.s + + // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale + fmla z18.s, p0/m, z20.s, z18.s + fmla z19.s, p0/m, z21.s, z19.s + + // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly + dup z10.s, #0 + sel z16.s, p4, z10.s, z16.s + sel z17.s, p5, z10.s, z17.s + sel z18.s, p6, z10.s, z18.s + sel z19.s, p7, z10.s, z19.s + + // Stores 4 consecutive registers to the output + .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2] + + .inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly + + incw x9, ALL, MUL #4 + b regularize_body_start%= +regularize_body_end%=: + + // ---------------------------------------------------------------- z28: sum_value + .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4] + fadd z28.s, z28.s, z29.s + fadd z30.s, z30.s, z31.s + fadd z28.s, z28.s, z30.s + + // Loop for processing the leftover part. +regularize_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none regularize_leftover_end%= + + ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data + + fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value + fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta + + mov z16.d, z5.d // z16: shift + fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input + fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2 + fsub z20.s, z16.s, z5.s // z20: n = z - shift + fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi + fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo + dup z10.s, #23 // z10: 23 + urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n) + fmul z20.s, z12.s, z0.s // z20: p1 = r * c1 + mov z22.d, z1.d // z22: p23 = c2 + fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3 + mov z24.d, z3.d // z24: c4 + fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5 + fmul z12.s, z12.s, z12.s // z12: r2 = r * r + fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45 + fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345 + fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale + dup z10.s, #0 // z10: 0 + sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly + + st1w z16.s, p1, [x28, x9, LSL #2] + + fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly + + incw x9 + b regularize_leftover_start%= +regularize_leftover_end%=: + + // ================================================== + // Step 3: Normalize + // ================================================== + + // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value + fmov s29, #1.0 // 1.0f + faddv s28, p0, z28.s + fdiv s28, s29, s28 + dup z28.s, z28.s[0] + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + +normalize_body_start%=: + cmp x9, x13 + b.eq normalize_body_end%= + + .inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x + + // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value + fmul z12.s, z12.s, z28.s + fmul z13.s, z13.s, z28.s + fmul z14.s, z14.s, z28.s + fmul z15.s, z15.s, z28.s + + .inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2] + + incw x9, ALL, MUL #4 + b normalize_body_start%= +normalize_body_end%=: + + // Loop for processing the leftover part. +normalize_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none normalize_leftover_end%= + + ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x + fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value + + st1w z12.s, p1, [x28, x9, LSL #2] + + incw x9 + b normalize_leftover_start%= +normalize_leftover_end%=: + + // ================================================== + // 3D loop closing + // ================================================== + + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p4", "p5", "p6", "p7", "p9", // + "x9", "x10", "x11", "x12", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_fp32_softmax(const ITensor *in, + void *const, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(lut_ptr); + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const auto *k_src = reinterpret_cast<const float *>(in->buffer() + k_src_offset); + auto *k_dst = reinterpret_cast<float *>(out->buffer() + k_dst_offset); + + sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp new file mode 100644 index 0000000000..9feb669f7c --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp @@ -0,0 +1,634 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_qasymm8_softmax_kernel_512VL( // + const uint8_t *src, + uint8_t *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4], + const float *lut, + float *tmp) +{ + // Precondition: + // * src_strides[0] == sizeof(uint8_t) + // * dst_strides[0] == sizeof(uint8_t) + // * tmp_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // Registers + // + // * x1: Loop index + // * x2: LUT index + // * x13: temporary, body_length + // + // * x20: index_3 + // * x21: src_3 + // * x22: dst_3 + // * x23: index_2 + // * x24: src_2 + // * x25: dst_2 + // * x26: index_1 + // * x27: src_1 + // * x28: dst_1 + // * x29 tmp + // + // + // * p0: all-true + // * p1: predicate for QASYMM8 values + // * p2: predicate 0 for FP32 values (first quarter of expanded/unpacked p1) + // * p3: predicate 1 for FP32 values (second quarter of expanded/unpacked p1) + // * p4: predicate 2 for FP32 values (third quarter of expanded/unpacked p1) + // * p5: predicate 3 for FP32 values (fourth quarter of expanded/unpacked p1) + // * pn9: all-true for 32 bit values + // * pn8: all-true for 8-bit values + // + // * z0-z15 the 256 LUT values of exp(-scale*beta*x) for x in QASYMM8, stored as FP32 values + + // Prepares all constant values + + ptrue p0.b + .inst 0x25a07811 // ptrue pn9.s + .inst 0x25207810 // ptrue pn8.b + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntb x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + mov x19, %x[lut] + mov x29, %x[tmp] + + // Load the LUT to the register file. + mov x2, %x[lut] + .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2] + + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + // z16-z19 = minimum QASYMM8 value (0) to allow for it to be used for comparison to find the max. + dup z16.b, #0 + dup z17.b, #0 + dup z18.b, #0 + dup z19.b, #0 + mov x1, #0 // x1: index +find_max_body_start%=: + cmp x1, x13 + b.eq find_max_body_end%= + .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z20-z23: x + .inst 0xc134b811 // umax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x) + add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers. + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none find_max_leftover_end%= + + ld1b z30.b, p1/z, [x27, x1] // z30: x + umax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x) + + add x1, x1, #64 + + b find_max_leftover_start%= +find_max_leftover_end%=: + + .inst 0xc132b011 // umax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b } + umax z16.b, p0/m, z16.b, z17.b + umaxv b16, p0, z16.b // Reduction unsigned max operation to get maximum_value + dup z16.b, z16.b[0] + uunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction + uunpklo z16.s, z16.h + + mov x1, #0 // reset index + dup z25.s, #0 + + mov x1, #0 + +regularize_start%=: + whilelo p1.b, x1, %x[length] + b.none regularize_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + ld1b z17.b, p1/z, [x27, x1] //z17: input data + + uunpklo z18.h, z17.b //Using unpack instructions to align the input QASYMM8 values with the FP32 entries in the LUT for use in the TBX instruction + uunpkhi z19.h, z17.b + + uunpklo z17.s, z18.h // z17 = low low input QASYMM8 values + uunpkhi z18.s, z18.h // z18 = low high input QASYMM8 values + + uunpkhi z20.s, z19.h // z20 = high high input QASYMM8 values + uunpklo z19.s, z19.h // z19 = high low input QASYMM8 values + + sub z17.s, z16.s, z17.s // z12: x = max_value - input_data + sub z18.s, z16.s, z18.s // z13: x = max_value - input_data + sub z19.s, z16.s, z19.s // z14: x = max_value - input_data + sub z20.s, z16.s, z20.s // z15: x = max_value - input_data + + tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT. + tbx z22.s, z0.s, z18.s + tbx z23.s, z0.s, z19.s + tbx z24.s, z0.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT. + tbx z22.s, z1.s, z18.s + tbx z23.s, z1.s, z19.s + tbx z24.s, z1.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT. + tbx z22.s, z2.s, z18.s + tbx z23.s, z2.s, z19.s + tbx z24.s, z2.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT. + tbx z22.s, z3.s, z18.s + tbx z23.s, z3.s, z19.s + tbx z24.s, z3.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT. + tbx z22.s, z4.s, z18.s + tbx z23.s, z4.s, z19.s + tbx z24.s, z4.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT. + tbx z22.s, z5.s, z18.s + tbx z23.s, z5.s, z19.s + tbx z24.s, z5.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT. + tbx z22.s, z6.s, z18.s + tbx z23.s, z6.s, z19.s + tbx z24.s, z6.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT. + tbx z22.s, z7.s, z18.s + tbx z23.s, z7.s, z19.s + tbx z24.s, z7.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT. + tbx z22.s, z8.s, z18.s + tbx z23.s, z8.s, z19.s + tbx z24.s, z8.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT. + tbx z22.s, z9.s, z18.s + tbx z23.s, z9.s, z19.s + tbx z24.s, z9.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT. + tbx z22.s, z10.s, z18.s + tbx z23.s, z10.s, z19.s + tbx z24.s, z10.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT. + tbx z22.s, z11.s, z18.s + tbx z23.s, z11.s, z19.s + tbx z24.s, z11.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT. + tbx z22.s, z12.s, z18.s + tbx z23.s, z12.s, z19.s + tbx z24.s, z12.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT. + tbx z22.s, z13.s, z18.s + tbx z23.s, z13.s, z19.s + tbx z24.s, z13.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT. + tbx z22.s, z14.s, z18.s + tbx z23.s, z14.s, z19.s + tbx z24.s, z14.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT. + tbx z22.s, z15.s, z18.s + tbx z23.s, z15.s, z19.s + tbx z24.s, z15.s, z20.s + + + st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p2/m, z25.s, z21.s + add x1, x1, #16 + + st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p3/m, z25.s, z22.s + add x1, x1, #16 + + st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p4/m, z25.s, z23.s + add x1, x1, #16 + + st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p5/m, z25.s, z24.s + add x1, x1, #16 + + b regularize_start%= +regularize_end%=: + + mov w9, 0x0000 + movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [0,255] integer range of QASYMM8 + dup z29.s, w9 + faddv s25, p0, z25.s + fdiv s25, s29, s25 + dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax. + + // ================================================== + // Step 3: Normalize + // ================================================== + mov x1, #0 +normalize_body_start%=: + cmp x1, x13 + b.eq normalize_body_end%= + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z16.s, z25.s, z16.s + fmul z17.s, z25.s, z17.s + fmul z18.s, z25.s, z18.s + fmul z19.s, z25.s, z19.s + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z16-z23: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z16.s, p0/m, z16.s + fcvtzu z17.s, p0/m, z17.s + fcvtzu z18.s, p0/m, z18.s + fcvtzu z19.s, p0/m, z19.s + fcvtzu z20.s, p0/m, z20.s + fcvtzu z21.s, p0/m, z21.s + fcvtzu z22.s, p0/m, z22.s + fcvtzu z23.s, p0/m, z23.s + + // z16-z17: narrow the uint32 values into uint8 and saturate them. + .inst 0xc133e230 // uqcvt z16.b, { z16.s - z19.s } + .inst 0xc133e2b1 // uqcvt z17.b, { z20.s - z23.s } + + dup z20.s, z25.s[0] // Juggling the value to z20 as z25 will be overwritten by the load below + + .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z24.s, z20.s, z24.s + fmul z25.s, z20.s, z25.s + fmul z26.s, z20.s, z26.s + fmul z27.s, z20.s, z27.s + fmul z28.s, z20.s, z28.s + fmul z29.s, z20.s, z29.s + fmul z30.s, z20.s, z30.s + fmul z31.s, z20.s, z31.s + + // z24-z31: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z24.s, p0/m, z24.s + fcvtzu z25.s, p0/m, z25.s + fcvtzu z26.s, p0/m, z26.s + fcvtzu z27.s, p0/m, z27.s + fcvtzu z28.s, p0/m, z28.s + fcvtzu z29.s, p0/m, z29.s + fcvtzu z30.s, p0/m, z30.s + fcvtzu z31.s, p0/m, z31.s + + // z18-z19: narrow the uint32 values into uint8 and saturate them. + .inst 0xc133e332 // uqcvt z18.b, { z24.s - z27.s } + .inst 0xc133e3b3 // uqcvt z19.b, { z28.s - z31.s } + + .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2] + + dup z25.s, z20.s[0] // Juggling the value back to z25 as z20 will be overwritten by the next iteration or z25 will be used below. + +b normalize_body_start%= +normalize_body_end%=: + +normalize_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none normalize_leftover_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + + // z20-z23: load exp(-scale*beta*x) from the tmp tensor + ld1w z20.s, p2/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z21.s, p3/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z22.s, p4/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z23.s, p5/z, [x29, x1, LSL #2] + add x1, x1, #16 + + // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z20-23: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z20.s, p0/m, z20.s + fcvtzu z21.s, p0/m, z21.s + fcvtzu z22.s, p0/m, z22.s + fcvtzu z23.s, p0/m, z23.s + + .inst 0xc133e2b3 // uqcvt z19.b, { z20.s - z23.s }, narrow the uint32 values into uint8 and saturate them into z19. + + st1b z19.b, p1, [x28, x2] + + b normalize_leftover_start%= +normalize_leftover_end%=: + // ================================================== + // 3D loop closing + // ================================================== + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p1", "p2", "p3", "p4", // + "x2", "x9", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_qasymm8_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + Strides tmp_strides; + + tmp_strides[0] = src_strides[0] * 4; + tmp_strides[1] = src_strides[1] * 4; + tmp_strides[2] = src_strides[2] * 4; + tmp_strides[3] = src_strides[3] * 4; + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + // + window[1].start() * tmp_strides[1] + // + window[2].start() * tmp_strides[2] + // + window[3].start() * tmp_strides[3]; + + const auto *k_src = reinterpret_cast<const uint8_t *>(in->buffer() + k_src_offset); + float *tmp_float_ptr = reinterpret_cast<float *>(tmp); + auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset); + auto *k_dst = reinterpret_cast<uint8_t *>(out->buffer() + k_dst_offset); + + sme2_qasymm8_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp new file mode 100644 index 0000000000..14c0f6c327 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp @@ -0,0 +1,655 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_qasymm8_signed_softmax_kernel_512VL( // + const int8_t *src, + int8_t *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4], + const float *lut, + float *tmp) +{ + // Precondition: + // * src_strides[0] == sizeof(int8_t) + // * dst_strides[0] == sizeof(int8_t) + // * tmp_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // For register list explanation refer to qasymm8.cpp. + + // Prepares all constant values + + ptrue p0.b + .inst 0x25a07811 // ptrue pn9.s + .inst 0x25207810 // ptrue pn8.b + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntb x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + mov x19, %x[lut] + mov x29, %x[tmp] + + // Load the LUT to the register file. + mov x2, %x[lut] + .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2] + + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + // z16-z19 = minimum QASYMM8_SIGNED value (-128) to allow for it to be used for comparison to find the max. + dup z16.b, #0x80 + dup z17.b, #0x80 + dup z18.b, #0x80 + dup z19.b, #0x80 + + mov x1, #0 // x1: index +find_max_body_start%=: + cmp x1, x13 + b.eq find_max_body_end%= + .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z16-z19: x + .inst 0xc134b810 // smax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x) + add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers. + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none find_max_leftover_end%= + + ld1b z30.b, p1/z, [x27, x1] // z30: x + smax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x) + + add x1, x1, #64 + + b find_max_leftover_start%= +find_max_leftover_end%=: + .inst 0xc132b010 // smax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b } + smax z16.b, p0/m, z16.b, z17.b + smaxv b16, p0, z16.b // Reduction signed max operation to get maximum_value + mov z16.b, b16 // z16: duplicated max_value for current row + + sunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction + sunpklo z16.s, z16.h + + mov x1, #0 // reset index + dup z25.s, #0 + + +regularize_start%=: + whilelo p1.b, x1, %x[length] + b.none regularize_end%= + + mov w9, 0xFF80 + movk w9, 0xFFFF, LSL #16 // Moving -127.f into w9 to set the registers below to the minimum QASYMM8_SIGNED value + dup z17.s, w9 + dup z18.s, w9 + dup z19.s, w9 + dup z20.s, w9 + + dup z21.s, #0x0 + dup z22.s, #0x0 + dup z23.s, #0x0 + dup z24.s, #0x0 + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + ld1b z17.b, p1/z, [x27, x1] //z17: input data + + sunpklo z18.h, z17.b // Using unpack instructions to align the input QASYMM8_SIGNED values with the FP32 entries in the LUT for use in the TBX instruction + sunpkhi z19.h, z17.b // + + sunpklo z17.s, z18.h // z17 = low low input QASYMM8_SIGNED values + sunpkhi z18.s, z18.h // z18 = low high input QASYMM8_SIGNED values + + sunpkhi z20.s, z19.h // z20 = high high input QASYMM8_SIGNED values + sunpklo z19.s, z19.h // z19 = high low input QASYMM8_SIGNED values + + sub z17.s, z16.s, z17.s // z12: x = max_value - input_data + sub z18.s, z16.s, z18.s // z13: x = max_value - input_data + sub z19.s, z16.s, z19.s // z14: x = max_value - input_data + sub z20.s, z16.s, z20.s // z15: x = max_value - input_data + + add z17.s, z17.s, #128 + add z18.s, z18.s, #128 + add z19.s, z19.s, #128 + add z20.s, z20.s, #128 + + tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT. + tbx z22.s, z0.s, z18.s + tbx z23.s, z0.s, z19.s + tbx z24.s, z0.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT. + tbx z22.s, z1.s, z18.s + tbx z23.s, z1.s, z19.s + tbx z24.s, z1.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT. + tbx z22.s, z2.s, z18.s + tbx z23.s, z2.s, z19.s + tbx z24.s, z2.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT. + tbx z22.s, z3.s, z18.s + tbx z23.s, z3.s, z19.s + tbx z24.s, z3.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT. + tbx z22.s, z4.s, z18.s + tbx z23.s, z4.s, z19.s + tbx z24.s, z4.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT. + tbx z22.s, z5.s, z18.s + tbx z23.s, z5.s, z19.s + tbx z24.s, z5.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT. + tbx z22.s, z6.s, z18.s + tbx z23.s, z6.s, z19.s + tbx z24.s, z6.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT. + tbx z22.s, z7.s, z18.s + tbx z23.s, z7.s, z19.s + tbx z24.s, z7.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT. + tbx z22.s, z8.s, z18.s + tbx z23.s, z8.s, z19.s + tbx z24.s, z8.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT. + tbx z22.s, z9.s, z18.s + tbx z23.s, z9.s, z19.s + tbx z24.s, z9.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT. + tbx z22.s, z10.s, z18.s + tbx z23.s, z10.s, z19.s + tbx z24.s, z10.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT. + tbx z22.s, z11.s, z18.s + tbx z23.s, z11.s, z19.s + tbx z24.s, z11.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT. + tbx z22.s, z12.s, z18.s + tbx z23.s, z12.s, z19.s + tbx z24.s, z12.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT. + tbx z22.s, z13.s, z18.s + tbx z23.s, z13.s, z19.s + tbx z24.s, z13.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT. + tbx z22.s, z14.s, z18.s + tbx z23.s, z14.s, z19.s + tbx z24.s, z14.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT. + tbx z22.s, z15.s, z18.s + tbx z23.s, z15.s, z19.s + tbx z24.s, z15.s, z20.s + + + st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p2/m, z25.s, z21.s + add x1, x1, #16 + + st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p3/m, z25.s, z22.s + add x1, x1, #16 + + st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p4/m, z25.s, z23.s + add x1, x1, #16 + + st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p5/m, z25.s, z24.s + add x1, x1, #16 + + b regularize_start%= +regularize_end%=: + + mov w9, 0x0000 + movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [-128, 127] integer range of QASYMM8_SIGNED + mov w10, 0x0000 + movk w10, 0x4300, LSL #16 // Moving 128.f into w10 for the subtraction to move the results - via subtraction - from the [0,255] range to the [-128,127] range + dup z29.s, w9 + dup z30.s, w10 + faddv s25, p0, z25.s + fdiv s25, s29, s25 + dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax. + + // ================================================== + // Step 3: Normalize + // ================================================== + mov x1, #0 +normalize_body_start%=: + cmp x1, x13 + b.eq normalize_body_end%= + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z16.s, z25.s, z16.s + fmul z17.s, z25.s, z17.s + fmul z18.s, z25.s, z18.s + fmul z19.s, z25.s, z19.s + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z16-z23: subtract 128.f. + fsub z16.s, z16.s, z30.s // Subtract 128.f + fsub z17.s, z17.s, z30.s // Subtract 128.f + fsub z18.s, z18.s, z30.s // Subtract 128.f + fsub z19.s, z19.s, z30.s // Subtract 128.f + fsub z20.s, z20.s, z30.s // Subtract 128.f + fsub z21.s, z21.s, z30.s // Subtract 128.f + fsub z22.s, z22.s, z30.s // Subtract 128.f + fsub z23.s, z23.s, z30.s // Subtract 128.f + + // z16-z23: convert the FP32 values from the tmp tensor to int32. + fcvtzs z16.s, p0/m, z16.s + fcvtzs z17.s, p0/m, z17.s + fcvtzs z18.s, p0/m, z18.s + fcvtzs z19.s, p0/m, z19.s + fcvtzs z20.s, p0/m, z20.s + fcvtzs z21.s, p0/m, z21.s + fcvtzs z22.s, p0/m, z22.s + fcvtzs z23.s, p0/m, z23.s + + // z16-z17: narrow the int32 values into int8 and saturate them. + .inst 0xc133e210 // sqcvt z16.b, { z16.s - z19.s } + .inst 0xc133e291 // sqcvt z17.b, { z20.s - z23.s } + + // Juggling the value to z20 (resp. 21) as z25 (resp. z30) will be overwritten by the load below. + dup z20.s, z25.s[0] + dup z21.s, z30.s[0] + + .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z24.s, z20.s, z24.s + fmul z25.s, z20.s, z25.s + fmul z26.s, z20.s, z26.s + fmul z27.s, z20.s, z27.s + fmul z28.s, z20.s, z28.s + fmul z29.s, z20.s, z29.s + fmul z30.s, z20.s, z30.s + fmul z31.s, z20.s, z31.s + + // z24-z31: subtract 128.f. + fsub z24.s, z24.s, z21.s + fsub z25.s, z25.s, z21.s + fsub z26.s, z26.s, z21.s + fsub z27.s, z27.s, z21.s + fsub z28.s, z28.s, z21.s + fsub z29.s, z29.s, z21.s + fsub z30.s, z30.s, z21.s + fsub z31.s, z31.s, z21.s + + // z24-z31: convert the FP32 values from the tmp tensor to int32. + fcvtzs z24.s, p0/m, z24.s + fcvtzs z25.s, p0/m, z25.s + fcvtzs z26.s, p0/m, z26.s + fcvtzs z27.s, p0/m, z27.s + fcvtzs z28.s, p0/m, z28.s + fcvtzs z29.s, p0/m, z29.s + fcvtzs z30.s, p0/m, z30.s + fcvtzs z31.s, p0/m, z31.s + + // z18-z19: narrow the int32 values into int8 and saturate them. + .inst 0xc133e312 // sqcvt z18.b, { z24.s - z27.s } + .inst 0xc133e393 // sqcvt z19.b, { z28.s - z31.s } + + .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2] + + // Juggling the values back to z25 (resp. z30) as z20 (resp. z21) will be overwritten by the next iteration or z25 (resp. z30) will be used below. + dup z25.s, z20.s[0] + dup z30.s, z21.s[0] +b normalize_body_start%= +normalize_body_end%=: +normalize_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none normalize_leftover_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + + // z20-z23: load exp(-scale*beta*x) from the tmp tensor + ld1w z20.s, p2/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z21.s, p3/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z22.s, p4/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z23.s, p5/z, [x29, x1, LSL #2] + add x1, x1, #16 + + // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + //z20-z23: Subtract 128.f. + fsub z20.s, z20.s, z30.s + fsub z21.s, z21.s, z30.s + fsub z22.s, z22.s, z30.s + fsub z23.s, z23.s, z30.s + + // z20-23: convert the FP32 values from the tmp tensor to int32. + fcvtzs z20.s, p0/m, z20.s + fcvtzs z21.s, p0/m, z21.s + fcvtzs z22.s, p0/m, z22.s + fcvtzs z23.s, p0/m, z23.s + + .inst 0xc133e293 // sqcvt z19.b, { z20.s - z23.s }, narrow the int32 values into int8 and saturate them into z19. + + st1b z19.b, p1, [x28, x2] + + b normalize_leftover_start%= +normalize_leftover_end%=: + // ================================================== + // 3D loop closing + // ================================================== + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p1", "p2", "p3", "p4", // + "x2", "x9", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + Strides tmp_strides; + + tmp_strides[0] = src_strides[0] * 4; + tmp_strides[1] = src_strides[1] * 4; + tmp_strides[2] = src_strides[2] * 4; + tmp_strides[3] = src_strides[3] * 4; + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + // + window[1].start() * tmp_strides[1] + // + window[2].start() * tmp_strides[2] + // + window[3].start() * tmp_strides[3]; + + const auto *k_src = reinterpret_cast<const int8_t *>(in->buffer() + k_src_offset); + float *tmp_float_ptr = reinterpret_cast<float *>(tmp); + auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset); + auto *k_dst = reinterpret_cast<int8_t *>(out->buffer() + k_dst_offset); + + sme2_qasymm8_signed_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/generic/sve/impl.cpp b/src/cpu/kernels/softmax/generic/sve/impl.cpp new file mode 100644 index 0000000000..0d4b7f4509 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sve/impl.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2021-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/softmax/generic/sve/impl.h" + +#include "src/core/NEON/wrapper/intrinsics/intrinsics.h" + +namespace arm_compute +{ +namespace cpu +{ +/// TODO: (COMPMID-6505) Similar to Neon(TM), this implementation be converted to +/// a single kernel that performs softmax operation. Leaving the SVE code here for +/// future references. Implementation for Neon(TM) is introduced in COMPMID-6500 +template <typename ScalarType> +void sve_logits_1d_max(const ITensor *in, ITensor *out, const Window &window) +{ + const auto all_true_pg = wrapper::svptrue<ScalarType>(); + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + Window win{window}; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + Iterator input(in, win); + Iterator output(out, win); + + execute_window_loop( + win, + [&](const Coordinates &) + { + // Get pointers + const auto in_ptr = reinterpret_cast<const ScalarType *>(input.ptr()); + const auto out_ptr = reinterpret_cast<ScalarType *>(output.ptr()); + + // Init max value + auto vec_max = wrapper::svdup_n(support::cpp11::lowest<ScalarType>()); + + int x = window_start_x; + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); + do + { + const auto current_value = svld1(pg, in_ptr + x); + vec_max = svmax_m(pg, vec_max, current_value); + + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, window_end_x); + } while (svptest_any(all_true_pg, pg)); + + auto max_val = svmaxv(all_true_pg, vec_max); + + *out_ptr = max_val; + }, + input, output); +} + +template <typename ScalarType> +void sve_softmax_logits_1d_float(const ITensor *in, + const ITensor *max, + void *const tmp, + ITensor *out, + const float beta, + bool is_log, + const Window &window) +{ + const int start_x = in->info()->valid_region().anchor.x(); + const int input_width = in->info()->valid_region().shape.x(); + + Iterator in_it(in, window); + Iterator max_it(max, window); + Iterator out_it(out, window); + + const auto all_true_pg = wrapper::svptrue<ScalarType>(); + + execute_window_loop( + window, + [&](const Coordinates &) + { + /* Get pointers */ + const auto in_ptr = reinterpret_cast<const ScalarType *>(in_it.ptr()) + start_x; + const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr()) + start_x; + const auto tmp_ptr = reinterpret_cast<ScalarType *>(tmp); + + ScalarType sum{0}; + + /* Compute exponentials and sum */ + { + /* Get max value */ + const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr()); + const auto vec_max = wrapper::svdup_n(max_val); + const auto vec_beta = wrapper::svdup_n(static_cast<ScalarType>(beta)); + + /* Init sum to zero */ + auto vec_sum = wrapper::svdup_n(static_cast<ScalarType>(0)); + + /* Loop over row and compute exponentials and sum */ + int x = 0; + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width); + do + { + auto vec_elements = svld1(pg, in_ptr + x); + vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta); + if (!is_log) + { + vec_elements = wrapper::svexp_z(pg, vec_elements); + vec_sum = svadd_m(pg, vec_sum, vec_elements); + } + svst1(pg, tmp_ptr + x, vec_elements); + + if (is_log) + { + vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements)); + } + + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, input_width); + } while (svptest_any(all_true_pg, pg)); + + /* Reduce sum */ + sum = svaddv(all_true_pg, vec_sum); + + if (is_log) + { + sum = static_cast<ScalarType>(std::log(sum)); + } + else + { + sum = ScalarType(1) / sum; + } + } + + /* Normalize exponentials */ + { + /* Loop over row and compute softmax */ + int x = 0; + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width); + do + { + auto vec_in = svld1(pg, tmp_ptr + x); + auto normalized_value = wrapper::svdup_n(static_cast<ScalarType>(0)); + if (is_log) + { + normalized_value = svsub_z(pg, vec_in, wrapper::svdup_n(static_cast<ScalarType>(sum))); + } + else + { + normalized_value = svmul_z(pg, vec_in, wrapper::svdup_n(static_cast<ScalarType>(sum))); + } + svst1(pg, out_ptr + x, normalized_value); + + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, input_width); + } while (svptest_any(all_true_pg, pg)); + } + }, + in_it, max_it, out_it); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/sve/impl.h b/src/cpu/kernels/softmax/generic/sve/impl.h new file mode 100644 index 0000000000..89a30d042f --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sve/impl.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_SVE_KERNELS_SOFTMAX_IMPL_H +#define SRC_CORE_SVE_KERNELS_SOFTMAX_IMPL_H + +#include "arm_compute/core/Helpers.h" +namespace arm_compute +{ +namespace cpu +{ +template <typename ScalarType> +void sve_logits_1d_max(const ITensor *in, ITensor *out, const Window &window); + +template <typename ScalarType> +void sve_softmax_logits_1d_float(const ITensor *in, + const ITensor *max, + void *const tmp, + ITensor *out, + const float beta, + bool is_log, + const Window &window); +} // namespace cpu +} // namespace arm_compute + +#endif /* SRC_CORE_SVE_KERNELS_SOFTMAX_IMPL_H */ diff --git a/src/cpu/kernels/softmax/generic/sve2/impl.cpp b/src/cpu/kernels/softmax/generic/sve2/impl.cpp new file mode 100644 index 0000000000..a8fb1d4adf --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sve2/impl.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2021-2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/softmax/generic/sve2/impl.h" + +#include "arm_compute/core/Types.h" + +#include "src/core/NEON/wrapper/wrapper.h" + +namespace arm_compute +{ +namespace cpu +{ +/// TODO: (COMPMID-6505) Similar to Neon(TM), this implementation be converted to +/// a single kernel that performs softmax operation. Leaving the SVE2 code here for +/// future references. Implementation for Neon(TM) is introduced in COMPMID-6500 +template <typename ScalarType> +void sve2_softmax_logits_1d_quantized( + const ITensor *in, const ITensor *max, void *const tmp, ITensor *out, float beta, bool is_log, const Window &window) +{ + const int start_x = in->info()->valid_region().anchor.x(); + const int input_width = in->info()->valid_region().shape.x(); + + const float scale_beta = -beta * in->info()->quantization_info().uniform().scale; + const auto scale_beta_vec = svdup_n_f32(scale_beta); + + Iterator in_it(in, window); + Iterator max_it(max, window); + Iterator out_it(out, window); + const auto all_true_pg = wrapper::svptrue<ScalarType>(); + using SVEType = typename wrapper::traits::sve_vector<ScalarType>::type; + + const int inc_1 = static_cast<int>(svcntw()); + const int inc_2 = static_cast<int>(2 * svcntw()); + const int inc_3 = static_cast<int>(3 * svcntw()); + + execute_window_loop( + window, + [&](const Coordinates &) + { + /* Get pointers */ + const auto in_ptr = reinterpret_cast<const ScalarType *>(in_it.ptr()) + start_x; + const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr()) + start_x; + const auto tmp_ptr = reinterpret_cast<float *>(tmp); + + float sum{}; + + /* Compute exponentials and sum */ + { + /* Get max value */ + const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr()); + const auto vec_max = wrapper::svdup_n(max_val); + + /* Init sum to zero */ + auto vec_sum_0 = svdup_n_f32(0.f); + auto vec_sum_1 = svdup_n_f32(0.f); + auto vec_sum_2 = svdup_n_f32(0.f); + auto vec_sum_3 = svdup_n_f32(0.f); + + /* Loop over row and compute exponentials and sum */ + int x = 0; + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width); + svbool_t pg_0 = svunpklo(svunpklo(pg)); + svbool_t pg_1 = svunpkhi(svunpklo(pg)); + svbool_t pg_2 = svunpklo(svunpkhi(pg)); + svbool_t pg_3 = svunpkhi(svunpkhi(pg)); + do + { + const auto vec_elements = svld1(pg, in_ptr + x); + const auto vec_elements_sub = svreinterpret_u8(svsub_z(pg, vec_max, vec_elements)); + + auto vec_elements_flt_0 = svcvt_f32_z(pg_0, svunpklo(svunpklo(vec_elements_sub))); + auto vec_elements_flt_1 = svcvt_f32_z(pg_1, svunpkhi(svunpklo(vec_elements_sub))); + auto vec_elements_flt_2 = svcvt_f32_z(pg_2, svunpklo(svunpkhi(vec_elements_sub))); + auto vec_elements_flt_3 = svcvt_f32_z(pg_3, svunpkhi(svunpkhi(vec_elements_sub))); + + if (is_log) + { + vec_elements_flt_0 = svmul_f32_z(pg_0, vec_elements_flt_0, scale_beta_vec); + vec_elements_flt_1 = svmul_f32_z(pg_1, vec_elements_flt_1, scale_beta_vec); + vec_elements_flt_2 = svmul_f32_z(pg_2, vec_elements_flt_2, scale_beta_vec); + vec_elements_flt_3 = svmul_f32_z(pg_3, vec_elements_flt_3, scale_beta_vec); + vec_sum_0 = svadd_f32_m(pg_0, vec_sum_0, svexp_f32_z(pg_0, vec_elements_flt_0)); + vec_sum_1 = svadd_f32_m(pg_1, vec_sum_1, svexp_f32_z(pg_1, vec_elements_flt_1)); + vec_sum_2 = svadd_f32_m(pg_2, vec_sum_2, svexp_f32_z(pg_2, vec_elements_flt_2)); + vec_sum_3 = svadd_f32_m(pg_3, vec_sum_3, svexp_f32_z(pg_3, vec_elements_flt_3)); + } + else + { + vec_elements_flt_0 = svexp_f32_z(pg_0, svmul_f32_z(pg_0, vec_elements_flt_0, scale_beta_vec)); + vec_elements_flt_1 = svexp_f32_z(pg_1, svmul_f32_z(pg_1, vec_elements_flt_1, scale_beta_vec)); + vec_elements_flt_2 = svexp_f32_z(pg_2, svmul_f32_z(pg_2, vec_elements_flt_2, scale_beta_vec)); + vec_elements_flt_3 = svexp_f32_z(pg_3, svmul_f32_z(pg_3, vec_elements_flt_3, scale_beta_vec)); + vec_sum_0 = svadd_f32_m(pg_0, vec_sum_0, vec_elements_flt_0); + vec_sum_1 = svadd_f32_m(pg_1, vec_sum_1, vec_elements_flt_1); + vec_sum_2 = svadd_f32_m(pg_2, vec_sum_2, vec_elements_flt_2); + vec_sum_3 = svadd_f32_m(pg_3, vec_sum_3, vec_elements_flt_3); + } + + svst1_f32(pg_0, tmp_ptr + x, vec_elements_flt_0); + svst1_f32(pg_1, tmp_ptr + x + inc_1, vec_elements_flt_1); + svst1_f32(pg_2, tmp_ptr + x + inc_2, vec_elements_flt_2); + svst1_f32(pg_3, tmp_ptr + x + inc_3, vec_elements_flt_3); + + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, input_width); + pg_0 = svunpklo(svunpklo(pg)); + pg_1 = svunpkhi(svunpklo(pg)); + pg_2 = svunpklo(svunpkhi(pg)); + pg_3 = svunpkhi(svunpkhi(pg)); + } while (svptest_any(all_true_pg, pg)); + + /* Reduce sum */ + const auto vec_sum = svadd_f32_z(all_true_pg, svadd_f32_z(all_true_pg, vec_sum_0, vec_sum_1), + svadd_f32_z(all_true_pg, vec_sum_2, vec_sum_3)); + sum = svaddv_f32(all_true_pg, vec_sum); + + /* Run remaining elements */ + x = 0; + if (is_log) + { + sum = std::log(sum); + } + else + { + sum = 256.f / sum; + } + } + + /* Normalize exponentials */ + { + constexpr bool is_qasymm8_signed = std::is_same<ScalarType, qasymm8_signed_t>::value; + /* Loop over row and compute softmax */ + int x = 0; + svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width); + svbool_t pg_0 = svunpklo(svunpklo(pg)); + svbool_t pg_1 = svunpkhi(svunpklo(pg)); + svbool_t pg_2 = svunpklo(svunpkhi(pg)); + svbool_t pg_3 = svunpkhi(svunpkhi(pg)); + do + { + auto vec_in_0 = svld1_f32(pg_0, tmp_ptr + x); + auto vec_in_1 = svld1_f32(pg_1, tmp_ptr + x + inc_1); + auto vec_in_2 = svld1_f32(pg_2, tmp_ptr + x + inc_2); + auto vec_in_3 = svld1_f32(pg_3, tmp_ptr + x + inc_3); + + svfloat32_t res_0{}; + svfloat32_t res_1{}; + svfloat32_t res_2{}; + svfloat32_t res_3{}; + + if (is_log) + { + res_0 = svsub_f32_z(pg_0, vec_in_0, svdup_n_f32(sum)); + res_1 = svsub_f32_z(pg_1, vec_in_1, svdup_n_f32(sum)); + res_2 = svsub_f32_z(pg_2, vec_in_2, svdup_n_f32(sum)); + res_3 = svsub_f32_z(pg_3, vec_in_3, svdup_n_f32(sum)); + } + else + { + res_0 = svmul_f32_z(pg_0, vec_in_0, svdup_n_f32(sum)); + res_1 = svmul_f32_z(pg_1, vec_in_1, svdup_n_f32(sum)); + res_2 = svmul_f32_z(pg_2, vec_in_2, svdup_n_f32(sum)); + res_3 = svmul_f32_z(pg_3, vec_in_3, svdup_n_f32(sum)); + + if (is_qasymm8_signed) + { + const auto offset_vec = svdup_n_f32(128.f); + res_0 = svsub_z(pg_0, res_0, offset_vec); + res_1 = svsub_z(pg_1, res_1, offset_vec); + res_2 = svsub_z(pg_2, res_2, offset_vec); + res_3 = svsub_z(pg_3, res_3, offset_vec); + } + } + + // Store value + const auto out = convert_float_to_int<SVEType>(res_0, res_1, res_2, res_3); + svst1(pg, out_ptr + x, out); + x += wrapper::svcnt<ScalarType>(); + pg = wrapper::svwhilelt<ScalarType>(x, input_width); + pg_0 = svunpklo(svunpklo(pg)); + pg_1 = svunpkhi(svunpklo(pg)); + pg_2 = svunpklo(svunpkhi(pg)); + pg_3 = svunpkhi(svunpkhi(pg)); + } while (svptest_any(all_true_pg, pg)); + } + }, + in_it, max_it, out_it); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/sve2/impl.h b/src/cpu/kernels/softmax/generic/sve2/impl.h new file mode 100644 index 0000000000..33fcc26cda --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sve2/impl.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021-2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_SVE2_KERNELS_SOFTMAX_IMPL_H +#define SRC_CORE_SVE2_KERNELS_SOFTMAX_IMPL_H + +#include "arm_compute/core/Helpers.h" + +namespace arm_compute +{ +namespace cpu +{ +template <typename ScalarType> +void sve2_softmax_logits_1d_quantized(const ITensor *in, + const ITensor *max, + void *const tmp, + ITensor *out, + float beta, + bool is_log, + const Window &window); +} // namespace cpu +} // namespace arm_compute +#endif /* SRC_CORE_SVE2_KERNELS_SOFTMAX_IMPL_H */ diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h new file mode 100644 index 0000000000..7bbb265022 --- /dev/null +++ b/src/cpu/kernels/softmax/list.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2021-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_SOFTMAX_LIST_H +#define ACL_SRC_CPU_KERNELS_SOFTMAX_LIST_H + +namespace arm_compute +{ +namespace cpu +{ +#define DECLARE_SOFTMAX_KERNEL(func_name) \ + template <bool IS_LOG> \ + void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window, \ + const float *lut_ptr) + +DECLARE_SOFTMAX_KERNEL(neon_fp32_softmax); +DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax); +DECLARE_SOFTMAX_KERNEL(neon_qasymm8_softmax); +DECLARE_SOFTMAX_KERNEL(neon_qasymm8_signed_softmax); + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +void sme2_fp32_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +void sme2_fp16_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +void sme2_qasymm8_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +#endif // ARM_COMPUTE_ENABLE_SME2 + +#undef DECLARE_SOFTMAX_KERNEL +} // namespace cpu +} // namespace arm_compute + +#endif // ACL_SRC_CPU_KERNELS_SOFTMAX_LIST_H |