aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp')
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp655
1 files changed, 655 insertions, 0 deletions
diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp
new file mode 100644
index 0000000000..14c0f6c327
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp
@@ -0,0 +1,655 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_qasymm8_signed_softmax_kernel_512VL( //
+ const int8_t *src,
+ int8_t *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4],
+ const float *lut,
+ float *tmp)
+{
+ // Precondition:
+ // * src_strides[0] == sizeof(int8_t)
+ // * dst_strides[0] == sizeof(int8_t)
+ // * tmp_strides[0] == sizeof(float)
+
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // For register list explanation refer to qasymm8.cpp.
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25a07811 // ptrue pn9.s
+ .inst 0x25207810 // ptrue pn8.b
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntb x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+ mov x19, %x[lut]
+ mov x29, %x[tmp]
+
+ // Load the LUT to the register file.
+ mov x2, %x[lut]
+ .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2]
+ add x2, x2, #256
+ .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2]
+
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+ // z16-z19 = minimum QASYMM8_SIGNED value (-128) to allow for it to be used for comparison to find the max.
+ dup z16.b, #0x80
+ dup z17.b, #0x80
+ dup z18.b, #0x80
+ dup z19.b, #0x80
+
+ mov x1, #0 // x1: index
+find_max_body_start%=:
+ cmp x1, x13
+ b.eq find_max_body_end%=
+ .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z16-z19: x
+ .inst 0xc134b810 // smax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x)
+ add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers.
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1b z30.b, p1/z, [x27, x1] // z30: x
+ smax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x)
+
+ add x1, x1, #64
+
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+ .inst 0xc132b010 // smax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b }
+ smax z16.b, p0/m, z16.b, z17.b
+ smaxv b16, p0, z16.b // Reduction signed max operation to get maximum_value
+ mov z16.b, b16 // z16: duplicated max_value for current row
+
+ sunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction
+ sunpklo z16.s, z16.h
+
+ mov x1, #0 // reset index
+ dup z25.s, #0
+
+
+regularize_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none regularize_end%=
+
+ mov w9, 0xFF80
+ movk w9, 0xFFFF, LSL #16 // Moving -127.f into w9 to set the registers below to the minimum QASYMM8_SIGNED value
+ dup z17.s, w9
+ dup z18.s, w9
+ dup z19.s, w9
+ dup z20.s, w9
+
+ dup z21.s, #0x0
+ dup z22.s, #0x0
+ dup z23.s, #0x0
+ dup z24.s, #0x0
+
+ // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated
+ punpklo p2.h, p1.b
+ punpkhi p4.h, p1.b
+
+ punpkhi p3.h, p2.b
+ punpklo p2.h, p2.b
+
+ punpkhi p5.h, p4.b
+ punpklo p4.h, p4.b
+
+ ld1b z17.b, p1/z, [x27, x1] //z17: input data
+
+ sunpklo z18.h, z17.b // Using unpack instructions to align the input QASYMM8_SIGNED values with the FP32 entries in the LUT for use in the TBX instruction
+ sunpkhi z19.h, z17.b //
+
+ sunpklo z17.s, z18.h // z17 = low low input QASYMM8_SIGNED values
+ sunpkhi z18.s, z18.h // z18 = low high input QASYMM8_SIGNED values
+
+ sunpkhi z20.s, z19.h // z20 = high high input QASYMM8_SIGNED values
+ sunpklo z19.s, z19.h // z19 = high low input QASYMM8_SIGNED values
+
+ sub z17.s, z16.s, z17.s // z12: x = max_value - input_data
+ sub z18.s, z16.s, z18.s // z13: x = max_value - input_data
+ sub z19.s, z16.s, z19.s // z14: x = max_value - input_data
+ sub z20.s, z16.s, z20.s // z15: x = max_value - input_data
+
+ add z17.s, z17.s, #128
+ add z18.s, z18.s, #128
+ add z19.s, z19.s, #128
+ add z20.s, z20.s, #128
+
+ tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT.
+ tbx z22.s, z0.s, z18.s
+ tbx z23.s, z0.s, z19.s
+ tbx z24.s, z0.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT.
+ tbx z22.s, z1.s, z18.s
+ tbx z23.s, z1.s, z19.s
+ tbx z24.s, z1.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT.
+ tbx z22.s, z2.s, z18.s
+ tbx z23.s, z2.s, z19.s
+ tbx z24.s, z2.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT.
+ tbx z22.s, z3.s, z18.s
+ tbx z23.s, z3.s, z19.s
+ tbx z24.s, z3.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT.
+ tbx z22.s, z4.s, z18.s
+ tbx z23.s, z4.s, z19.s
+ tbx z24.s, z4.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT.
+ tbx z22.s, z5.s, z18.s
+ tbx z23.s, z5.s, z19.s
+ tbx z24.s, z5.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT.
+ tbx z22.s, z6.s, z18.s
+ tbx z23.s, z6.s, z19.s
+ tbx z24.s, z6.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT.
+ tbx z22.s, z7.s, z18.s
+ tbx z23.s, z7.s, z19.s
+ tbx z24.s, z7.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT.
+ tbx z22.s, z8.s, z18.s
+ tbx z23.s, z8.s, z19.s
+ tbx z24.s, z8.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT.
+ tbx z22.s, z9.s, z18.s
+ tbx z23.s, z9.s, z19.s
+ tbx z24.s, z9.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT.
+ tbx z22.s, z10.s, z18.s
+ tbx z23.s, z10.s, z19.s
+ tbx z24.s, z10.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT.
+ tbx z22.s, z11.s, z18.s
+ tbx z23.s, z11.s, z19.s
+ tbx z24.s, z11.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT.
+ tbx z22.s, z12.s, z18.s
+ tbx z23.s, z12.s, z19.s
+ tbx z24.s, z12.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT.
+ tbx z22.s, z13.s, z18.s
+ tbx z23.s, z13.s, z19.s
+ tbx z24.s, z13.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT.
+ tbx z22.s, z14.s, z18.s
+ tbx z23.s, z14.s, z19.s
+ tbx z24.s, z14.s, z20.s
+
+ sub z17.s, z17.s, #16
+ sub z18.s, z18.s, #16
+ sub z19.s, z19.s, #16
+ sub z20.s, z20.s, #16
+
+ tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT.
+ tbx z22.s, z15.s, z18.s
+ tbx z23.s, z15.s, z19.s
+ tbx z24.s, z15.s, z20.s
+
+
+ st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p2/m, z25.s, z21.s
+ add x1, x1, #16
+
+ st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p3/m, z25.s, z22.s
+ add x1, x1, #16
+
+ st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p4/m, z25.s, z23.s
+ add x1, x1, #16
+
+ st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor
+ fadd z25.s, p5/m, z25.s, z24.s
+ add x1, x1, #16
+
+ b regularize_start%=
+regularize_end%=:
+
+ mov w9, 0x0000
+ movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [-128, 127] integer range of QASYMM8_SIGNED
+ mov w10, 0x0000
+ movk w10, 0x4300, LSL #16 // Moving 128.f into w10 for the subtraction to move the results - via subtraction - from the [0,255] range to the [-128,127] range
+ dup z29.s, w9
+ dup z30.s, w10
+ faddv s25, p0, z25.s
+ fdiv s25, s29, s25
+ dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax.
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+ mov x1, #0
+normalize_body_start%=:
+ cmp x1, x13
+ b.eq normalize_body_end%=
+
+ mov x2, x1 // Preserve the index into x2 for the final store to dst.
+ .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+ .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+
+ // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z16.s, z25.s, z16.s
+ fmul z17.s, z25.s, z17.s
+ fmul z18.s, z25.s, z18.s
+ fmul z19.s, z25.s, z19.s
+ fmul z20.s, z25.s, z20.s
+ fmul z21.s, z25.s, z21.s
+ fmul z22.s, z25.s, z22.s
+ fmul z23.s, z25.s, z23.s
+
+ // z16-z23: subtract 128.f.
+ fsub z16.s, z16.s, z30.s // Subtract 128.f
+ fsub z17.s, z17.s, z30.s // Subtract 128.f
+ fsub z18.s, z18.s, z30.s // Subtract 128.f
+ fsub z19.s, z19.s, z30.s // Subtract 128.f
+ fsub z20.s, z20.s, z30.s // Subtract 128.f
+ fsub z21.s, z21.s, z30.s // Subtract 128.f
+ fsub z22.s, z22.s, z30.s // Subtract 128.f
+ fsub z23.s, z23.s, z30.s // Subtract 128.f
+
+ // z16-z23: convert the FP32 values from the tmp tensor to int32.
+ fcvtzs z16.s, p0/m, z16.s
+ fcvtzs z17.s, p0/m, z17.s
+ fcvtzs z18.s, p0/m, z18.s
+ fcvtzs z19.s, p0/m, z19.s
+ fcvtzs z20.s, p0/m, z20.s
+ fcvtzs z21.s, p0/m, z21.s
+ fcvtzs z22.s, p0/m, z22.s
+ fcvtzs z23.s, p0/m, z23.s
+
+ // z16-z17: narrow the int32 values into int8 and saturate them.
+ .inst 0xc133e210 // sqcvt z16.b, { z16.s - z19.s }
+ .inst 0xc133e291 // sqcvt z17.b, { z20.s - z23.s }
+
+ // Juggling the value to z20 (resp. 21) as z25 (resp. z30) will be overwritten by the load below.
+ dup z20.s, z25.s[0]
+ dup z21.s, z30.s[0]
+
+ .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+ .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2]
+ add x1, x1, #64
+
+ // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z24.s, z20.s, z24.s
+ fmul z25.s, z20.s, z25.s
+ fmul z26.s, z20.s, z26.s
+ fmul z27.s, z20.s, z27.s
+ fmul z28.s, z20.s, z28.s
+ fmul z29.s, z20.s, z29.s
+ fmul z30.s, z20.s, z30.s
+ fmul z31.s, z20.s, z31.s
+
+ // z24-z31: subtract 128.f.
+ fsub z24.s, z24.s, z21.s
+ fsub z25.s, z25.s, z21.s
+ fsub z26.s, z26.s, z21.s
+ fsub z27.s, z27.s, z21.s
+ fsub z28.s, z28.s, z21.s
+ fsub z29.s, z29.s, z21.s
+ fsub z30.s, z30.s, z21.s
+ fsub z31.s, z31.s, z21.s
+
+ // z24-z31: convert the FP32 values from the tmp tensor to int32.
+ fcvtzs z24.s, p0/m, z24.s
+ fcvtzs z25.s, p0/m, z25.s
+ fcvtzs z26.s, p0/m, z26.s
+ fcvtzs z27.s, p0/m, z27.s
+ fcvtzs z28.s, p0/m, z28.s
+ fcvtzs z29.s, p0/m, z29.s
+ fcvtzs z30.s, p0/m, z30.s
+ fcvtzs z31.s, p0/m, z31.s
+
+ // z18-z19: narrow the int32 values into int8 and saturate them.
+ .inst 0xc133e312 // sqcvt z18.b, { z24.s - z27.s }
+ .inst 0xc133e393 // sqcvt z19.b, { z28.s - z31.s }
+
+ .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2]
+
+ // Juggling the values back to z25 (resp. z30) as z20 (resp. z21) will be overwritten by the next iteration or z25 (resp. z30) will be used below.
+ dup z25.s, z20.s[0]
+ dup z30.s, z21.s[0]
+b normalize_body_start%=
+normalize_body_end%=:
+normalize_leftover_start%=:
+ whilelo p1.b, x1, %x[length]
+ b.none normalize_leftover_end%=
+
+ // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated
+ punpklo p2.h, p1.b
+ punpkhi p4.h, p1.b
+
+ punpkhi p3.h, p2.b
+ punpklo p2.h, p2.b
+
+ punpkhi p5.h, p4.b
+ punpklo p4.h, p4.b
+
+ mov x2, x1 // Preserve the index into x2 for the final store to dst.
+
+ // z20-z23: load exp(-scale*beta*x) from the tmp tensor
+ ld1w z20.s, p2/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z21.s, p3/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z22.s, p4/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ ld1w z23.s, p5/z, [x29, x1, LSL #2]
+ add x1, x1, #16
+
+ // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256.
+ fmul z20.s, z25.s, z20.s
+ fmul z21.s, z25.s, z21.s
+ fmul z22.s, z25.s, z22.s
+ fmul z23.s, z25.s, z23.s
+
+ //z20-z23: Subtract 128.f.
+ fsub z20.s, z20.s, z30.s
+ fsub z21.s, z21.s, z30.s
+ fsub z22.s, z22.s, z30.s
+ fsub z23.s, z23.s, z30.s
+
+ // z20-23: convert the FP32 values from the tmp tensor to int32.
+ fcvtzs z20.s, p0/m, z20.s
+ fcvtzs z21.s, p0/m, z21.s
+ fcvtzs z22.s, p0/m, z22.s
+ fcvtzs z23.s, p0/m, z23.s
+
+ .inst 0xc133e293 // sqcvt z19.b, { z20.s - z23.s }, narrow the int32 values into int8 and saturate them into z19.
+
+ st1b z19.b, p1, [x28, x2]
+
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p1", "p2", "p3", "p4", //
+ "x2", "x9", "x13", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in,
+ void *const tmp,
+ ITensor *out,
+ const float beta,
+ int axis,
+ const Window &window,
+ const float *lut_ptr)
+{
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+ Strides tmp_strides;
+
+ tmp_strides[0] = src_strides[0] * 4;
+ tmp_strides[1] = src_strides[1] * 4;
+ tmp_strides[2] = src_strides[2] * 4;
+ tmp_strides[3] = src_strides[3] * 4;
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + //
+ window[1].start() * tmp_strides[1] + //
+ window[2].start() * tmp_strides[2] + //
+ window[3].start() * tmp_strides[3];
+
+ const auto *k_src = reinterpret_cast<const int8_t *>(in->buffer() + k_src_offset);
+ float *tmp_float_ptr = reinterpret_cast<float *>(tmp);
+ auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset);
+ auto *k_dst = reinterpret_cast<int8_t *>(out->buffer() + k_dst_offset);
+
+ sme2_qasymm8_signed_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2