From 77bbe2e08b0376edfd3f504950be7f4b5720eeb0 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Wed, 6 Dec 2023 11:01:15 +0000 Subject: Add SME2 implementation of softmax for FP32 Signed-off-by: Viet-Hoa Do Change-Id: I8a63610cfb9ccff89dec6045d023439fc19b027a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11357 Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- docs/user_guide/release_version_and_change_log.dox | 1 + filelist.json | 3 +- scripts/generate_android_bp.py | 4 +- src/BUILD.bazel | 1 + src/CMakeLists.txt | 1 + src/cpu/kernels/CpuSoftmaxKernel.cpp | 6 + src/cpu/kernels/softmax/generic/sme2/fp32.cpp | 578 +++++++++++++++++++++ src/cpu/kernels/softmax/list.h | 7 + 8 files changed, 599 insertions(+), 2 deletions(-) create mode 100644 src/cpu/kernels/softmax/generic/sme2/fp32.cpp diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox index 31b756070d..aa27c2b44c 100644 --- a/docs/user_guide/release_version_and_change_log.dox +++ b/docs/user_guide/release_version_and_change_log.dox @@ -45,6 +45,7 @@ v24.04 Public major release - Add Bfloat16 data type support for @ref NEMatMul. - Optimize start-up time of @ref NEConvolutionLayer for some input configurations where GeMM is selected as the convolution algorithm - Optimize @ref NEConvolutionLayer for input tensor size > 1e7 bytes and weight tensor height > 7 + - Add support for SoftMax in SME2 for FP32. - Performance optimizations: - Optimize @ref NESoftmaxLayer for axis != 0 by natively supporting higher axes up to axis 3. diff --git a/filelist.json b/filelist.json index 06a09139db..f6e85473c2 100644 --- a/filelist.json +++ b/filelist.json @@ -2237,7 +2237,8 @@ "common": [ "src/cpu/kernels/softmax/generic/sve/impl.cpp" ] }, "sve2":{ - "common" :["src/cpu/kernels/softmax/generic/sve2/impl.cpp"] + "common" :["src/cpu/kernels/softmax/generic/sve2/impl.cpp"], + "fp32" :["src/cpu/kernels/softmax/generic/sme2/fp32.cpp"] } } }, diff --git a/scripts/generate_android_bp.py b/scripts/generate_android_bp.py index 6efd072acd..d5b268f522 100755 --- a/scripts/generate_android_bp.py +++ b/scripts/generate_android_bp.py @@ -45,7 +45,9 @@ excluded_paths = ["build", "/sve/", "/SVE/", "/sve2/", - "/SVE2/" + "/SVE2/", + "/sme/", + "/sme2/", ] excluded_files = ["TracePoint.cpp"] diff --git a/src/BUILD.bazel b/src/BUILD.bazel index dd19f38d6d..be6337a2c0 100644 --- a/src/BUILD.bazel +++ b/src/BUILD.bazel @@ -117,6 +117,7 @@ filegroup( "cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp", "cpu/kernels/elementwise_unary/generic/sve2/q8.cpp", "cpu/kernels/lut/generic/sve2/u8.cpp", + "cpu/kernels/softmax/generic/sme2/fp32.cpp", "cpu/kernels/softmax/generic/sve2/impl.cpp"] + glob(["**/*.h", "**/*.hpp", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1a085a0011..a1cba792b7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -335,6 +335,7 @@ target_sources( cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp cpu/kernels/elementwise_unary/generic/sve2/q8.cpp cpu/kernels/lut/generic/sve2/u8.cpp + cpu/kernels/softmax/generic/sme2/fp32.cpp cpu/kernels/softmax/generic/sve2/impl.cpp ) diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp index 54ff858eeb..a088fb6660 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.cpp +++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp @@ -50,6 +50,12 @@ namespace { /* Softmax */ static const std::vector available_kernels = { +#ifdef ARM_COMPUTE_ENABLE_SME2 + {"sme2_fp32_softmax", + [](const SoftmaxKernelDataTypeISASelectorData &data) + { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2); }, + REGISTER_FP32_NEON(sme2_fp32_softmax)}, +#endif // ARM_COMPUTE_ENABLE_SME2 {"neon_fp32_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); }, REGISTER_FP32_NEON(neon_fp32_softmax)}, diff --git a/src/cpu/kernels/softmax/generic/sme2/fp32.cpp b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp new file mode 100644 index 0000000000..e80041c812 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp @@ -0,0 +1,578 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_f32_softmax_kernel( // + const float *src, + float *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4]) +{ + // Precondition: + // * src_strides[0] == sizeof(float) + // * dst_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // Registers + // + // * x9: temporary, index + // * x10: temporary, -inf + // * x11: temporary, 0 + // * x12: temporary, 1.0f + // * x13: temporary, body_length + // + // * x20: index_3 + // * x21: src_3 + // * x22: dst_3 + // * x23: index_2 + // * x24: src_2 + // * x25: dst_2 + // * x26: index_1 + // * x27: src_1 + // * x28: dst_1 + // + // * z0: c1 + // * z1: c2 + // * z2: c3 + // * z3: c4 + // * z4: c5 + // * z5: shift + // * z6: inv_ln2 + // * z7: neg_ln2_hi + // * z8: neg_ln2_lo + // * z9: min_input + // * z10: 23, 0 + // * z11: max_value + // * z12-z15: x, r_hi, r, r2 + // * z16-z19: max_value, shift, z, scale, poly + // * z20-z21: n, p1, p12345 + // * z22-z23: n, p23, p2345 + // * z24-z25: p45 + // * z26: beta + // * z28-z31: sum_value + // + // * za0-za3: sum_value + // + // * p0: all-true + // * p1: left-over predicate + // * p4-p7: underflow + // * pn9: all-true + + // Prepares all constant values + + ptrue p0.b + .inst 0x25207811 // ptrue pn9.b + + mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb + mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010 + + movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb + movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010 + + dup z0.s, w9 // c1. + dup z1.s, w10 // c2. + dup z2.s, w11 // c3. + dup z3.s, w12 // c4. + dup z4.s, w13 // c5. + + mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + dup z5.s, w9 // shift + dup z6.s, w10 // inv_ln2 + dup z7.s, w11 // neg_ln2_hi + dup z8.s, w12 // neg_ln2_lo + dup z9.s, w13 // min_input + + dup z26.s, %w[beta] // beta + + mov w10, #0x0000 // -inf: 0xff800000 + movk w10, #0xff80 // -inf: 0xff800000 + + mov w11, #0 // 0 + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntw x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + + // ---------------------------------------------------------------- z16-z19: max_value = -inf + mov z16.d, z11.d + mov z17.d, z11.d + mov z18.d, z11.d + mov z19.d, z11.d + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + dup z11.s, w10 // z11: max_value = -inf + +find_max_body_start%=: + cmp x9, x13 + b.eq find_max_body_end%= + + .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x + .inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x) + + incw x9, ALL, MUL #4 + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none find_max_leftover_end%= + + ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x + fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x) + + incw x9 + b find_max_leftover_start%= +find_max_leftover_end%=: + + // ---------------------------------------------------------------- z16: max_value + .inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s} + fmax z16.s, p0/m, z16.s, z17.s + fmaxv s16, p0, z16.s + + // ---------------------------------------------------------------- z11: max_value + dup z11.s, z16.s[0] + + // ================================================== + // Step 2: Regularize + // ================================================== + + .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // ---------------------------------------------------- x9: index + +regularize_body_start%=: + cmp x9, x13 + b.eq regularize_body_end%= + + // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data + .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] + + // ---------------------------------------------------------------- z12-z15: x = input_data - max_value + fsub z12.s, z12.s, z11.s + fsub z13.s, z13.s, z11.s + fsub z14.s, z14.s, z11.s + fsub z15.s, z15.s, z11.s + + // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta + fmul z12.s, z12.s, z26.s + fmul z13.s, z13.s, z26.s + fmul z14.s, z14.s, z26.s + fmul z15.s, z15.s, z26.s + + // ---------------------------------------------------------------- z16-z19: shift + mov z16.d, z5.d + mov z17.d, z5.d + mov z18.d, z5.d + mov z19.d, z5.d + + // ---------------------------------------------------------------- p4-p7: underflow = x < min_input + fcmlt p4.s, p0/z, z12.s, z9.s + fcmlt p5.s, p0/z, z13.s, z9.s + fcmlt p6.s, p0/z, z14.s, z9.s + fcmlt p7.s, p0/z, z15.s, z9.s + + // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2 + fmla z16.s, p0/m, z12.s, z6.s + fmla z17.s, p0/m, z13.s, z6.s + fmla z18.s, p0/m, z14.s, z6.s + fmla z19.s, p0/m, z15.s, z6.s + + // ---------------------------------------------------------------- z20-z23: n = z - shift + fsub z20.s, z16.s, z5.s + fsub z21.s, z17.s, z5.s + fsub z22.s, z18.s, z5.s + fsub z23.s, z19.s, z5.s + + // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi + fmla z12.s, p0/m, z20.s, z7.s + fmla z13.s, p0/m, z21.s, z7.s + fmla z14.s, p0/m, z22.s, z7.s + fmla z15.s, p0/m, z23.s, z7.s + + // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo + fmla z12.s, p0/m, z20.s, z8.s + fmla z13.s, p0/m, z21.s, z8.s + fmla z14.s, p0/m, z22.s, z8.s + fmla z15.s, p0/m, z23.s, z8.s + + // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n) + dup z10.s, #23 + urshl z16.s, p0/m, z16.s, z10.s + urshl z17.s, p0/m, z17.s, z10.s + urshl z18.s, p0/m, z18.s, z10.s + urshl z19.s, p0/m, z19.s, z10.s + + // Processes the first 2 vectors. + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z12.s, z0.s + fmul z21.s, z13.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z12.s, z2.s + fmla z23.s, p0/m, z13.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z12.s, z4.s + fmla z25.s, p0/m, z13.s, z4.s + + // ---------------------------------------------------------------- z12-z13: r2 = r * r + fmul z12.s, z12.s, z12.s + fmul z13.s, z13.s, z13.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z12.s, z24.s + fmla z23.s, p0/m, z13.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z12.s, z22.s + fmla z21.s, p0/m, z13.s, z23.s + + // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale + fmla z16.s, p0/m, z20.s, z16.s + fmla z17.s, p0/m, z21.s, z17.s + + // Processes the last 2 vectors + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z14.s, z0.s + fmul z21.s, z15.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z14.s, z2.s + fmla z23.s, p0/m, z15.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z14.s, z4.s + fmla z25.s, p0/m, z15.s, z4.s + + // ---------------------------------------------------------------- z14-z15: r2 = r * r + fmul z14.s, z14.s, z14.s + fmul z15.s, z15.s, z15.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z14.s, z24.s + fmla z23.s, p0/m, z15.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z14.s, z22.s + fmla z21.s, p0/m, z15.s, z23.s + + // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale + fmla z18.s, p0/m, z20.s, z18.s + fmla z19.s, p0/m, z21.s, z19.s + + // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly + dup z10.s, #0 + sel z16.s, p4, z10.s, z16.s + sel z17.s, p5, z10.s, z17.s + sel z18.s, p6, z10.s, z18.s + sel z19.s, p7, z10.s, z19.s + + // Stores 4 consecutive registers to the output + .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2] + + .inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly + + incw x9, ALL, MUL #4 + b regularize_body_start%= +regularize_body_end%=: + + // ---------------------------------------------------------------- z28: sum_value + .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4] + fadd z28.s, z28.s, z29.s + fadd z30.s, z30.s, z31.s + fadd z28.s, z28.s, z30.s + + // Loop for processing the leftover part. +regularize_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none regularize_leftover_end%= + + ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data + + fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value + fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta + + mov z16.d, z5.d // z16: shift + fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input + fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2 + fsub z20.s, z16.s, z5.s // z20: n = z - shift + fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi + fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo + dup z10.s, #23 // z10: 23 + urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n) + fmul z20.s, z12.s, z0.s // z20: p1 = r * c1 + mov z22.d, z1.d // z22: p23 = c2 + fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3 + mov z24.d, z3.d // z24: c4 + fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5 + fmul z12.s, z12.s, z12.s // z12: r2 = r * r + fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45 + fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345 + fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale + dup z10.s, #0 // z10: 0 + sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly + + st1w z16.s, p1, [x28, x9, LSL #2] + + fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly + + incw x9 + b regularize_leftover_start%= +regularize_leftover_end%=: + + // ================================================== + // Step 3: Normalize + // ================================================== + + // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value + fmov s29, #1.0 // 1.0f + faddv s28, p0, z28.s + fdiv s28, s29, s28 + dup z28.s, z28.s[0] + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + +normalize_body_start%=: + cmp x9, x13 + b.eq normalize_body_end%= + + .inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x + + // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value + fmul z12.s, z12.s, z28.s + fmul z13.s, z13.s, z28.s + fmul z14.s, z14.s, z28.s + fmul z15.s, z15.s, z28.s + + .inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2] + + incw x9, ALL, MUL #4 + b normalize_body_start%= +normalize_body_end%=: + + // Loop for processing the leftover part. +normalize_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none normalize_leftover_end%= + + ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x + fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value + + st1w z12.s, p1, [x28, x9, LSL #2] + + incw x9 + b normalize_leftover_start%= +normalize_leftover_end%=: + + // ================================================== + // 3D loop closing + // ================================================== + + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p4", "p5", "p6", "p7", "p9", // + "x9", "x10", "x11", "x12", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_fp32_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window) +{ + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const auto *k_src = reinterpret_cast(in->buffer() + k_src_offset); + auto *k_dst = reinterpret_cast(out->buffer() + k_dst_offset); + + sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h index f9295ebbcc..16fbd31a19 100644 --- a/src/cpu/kernels/softmax/list.h +++ b/src/cpu/kernels/softmax/list.h @@ -37,6 +37,13 @@ DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax); DECLARE_SOFTMAX_KERNEL(neon_qasymm8_softmax); DECLARE_SOFTMAX_KERNEL(neon_qasymm8_signed_softmax); +#ifdef ARM_COMPUTE_ENABLE_SME2 + +void sme2_fp32_softmax( + const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); + +#endif // ARM_COMPUTE_ENABLE_SME2 + #undef DECLARE_SOFTMAX_KERNEL } // namespace cpu } // namespace arm_compute -- cgit v1.2.1