diff options
Diffstat (limited to 'src/core')
40 files changed, 5888 insertions, 1781 deletions
diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl new file mode 100644 index 0000000000..e3ec9cc98e --- /dev/null +++ b/src/core/CL/cl_kernels/common/scatter.cl @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "helpers.h" +#include "tile_helpers.h" + +// The below defines the various reduce operations for our purposes. +// Where a corresponds to the existing value, and b the new value. +#define ADD_OP(a, b) ((a) + (b)) +#define SUB_OP(a, b) ((a) - (b)) + +#ifdef IS_FLOAT +#define MAX_OP(a, b) fmax(a, b) +#define MIN_OP(a, b) fmin(a, b) +#else // ifdef IS_FLOAT +#define MAX_OP(a, b) max(a, b) +#define MIN_OP(a, b) min(a, b) +#endif // ifdef IS_FLOAT + +#define UPDATE_OP(a, b) (b) + +#ifdef SCATTER_MP1D_2D_MPND + +/** This kernel performs scatter operation + * + * @note Datatype should be given as a compile-time argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short + * @note Number of indices should be given as a compile-time argument using -DNUM_INDICES, e.g. -DNUM_INDICES=3 + * @note Index length should be given as a compile-time argument using -DINDEX_LENGTH, e.g. -DINDEX_LENGTH=2 + * @note Outermost output shapes should be given as a compile-time argument using -DOUT_SHAPE_N_MINUS_X, where + * X must be 1,2,3,4,5, e.g. -DOUT_SHAPE_N_MINUS_1=3, ... + * @note Number of elements to copy in a row should be given as a compile-time argument using -DN0, e.g. -DN0=4 + * @note Number of partial elements at the edge to copy in a row should be given as a compile-time argument using + * -DPARTIAL_N0, e.g. -DPARTIAL_N0=2 + * @note Scatter function should be given as a compile-time argument using -DSCATTER_FUNCTION, e.g. -DSCATTER_FUNCTION=ADD + * @note If the kernel should skip reading the output tensor, -DSKIP_OUTPUT_READ option should be provided. + * @note Kernel name in uppercase letters should be provided as a compile-time argument, e.g. -DSCATTER_MP1D_2D_MPND + * + * @param[in] updates_ptr Pointer to the updates tensor. Data Types: F32 + * @param[in] updates_stride_x Stride of the updates tensor in X dimension (in bytes) + * @param[in] updates_step_x updates_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] updates_stride_y Stride of the updates tensor in Y dimension (in bytes) + * @param[in] updates_step_y updates_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] updates_offset_first_element_in_bytes The offset of the first element in the updates tensor + * @param[in] indices_ptr Pointer to the indices tensor. Data Types: S32 + * @param[in] indices_stride_x Stride of the indices tensor in X dimension (in bytes) + * @param[in] indices_step_x indices_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] indices_stride_y Stride of the indices tensor in Y dimension (in bytes) + * @param[in] indices_step_y indices_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] indices_offset_first_element_in_bytes The offset of the first element in the indices tensor + * @param[out] output_ptr Pointer to the destination tensor. Same as @p upt_ptr + * @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] output_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] upt_block_stride Update tensor data block stride in bytes + * @param[in] out_block_stride Output tensor data block stride in bytes + */ +__kernel void scatter_mp1d_2d_mpnd( + IMAGE_DECLARATION(updates), + IMAGE_DECLARATION(indices), + IMAGE_DECLARATION(output), + int upt_block_stride, + int out_block_stride + ) +{ + const int out_shape[5] = {OUT_SHAPE_N_MINUS_1, OUT_SHAPE_N_MINUS_2, OUT_SHAPE_N_MINUS_3, + OUT_SHAPE_N_MINUS_4, OUT_SHAPE_N_MINUS_5}; + + const int x = GET_SPATIAL_IDX(0, N0, PARTIAL_N0); // x-coordinate in the tensor + const int y = get_global_id(1); // collapsed y-coordinate (ignoring the outermost dimensions) + + const bool x_cond = (PARTIAL_N0 != 0 && get_global_id(0) == 0); + + uchar *ind_ptr_raw = indices_ptr + indices_offset_first_element_in_bytes; + const uchar *out_ptr_raw = output_ptr + output_offset_first_element_in_bytes + + x * sizeof(DATA_TYPE) + y * output_stride_y; + + const uchar *upt_ptr_raw = updates_ptr + updates_offset_first_element_in_bytes + + x * sizeof(DATA_TYPE) + y * updates_stride_y; + + for(int index_element = 0; index_element < NUM_INDICES; ++index_element) + { + const int *ind_ptr = (const int *) (ind_ptr_raw); + + // Out of bounds check + bool out_of_bounds = false; + LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH, + { + if(ind_ptr[i] >= out_shape[i] || ind_ptr[i] < 0) + { + out_of_bounds = true; + } + }); + + ind_ptr_raw += indices_stride_y; + + if(out_of_bounds) + { + continue; + } + + // Index calculation + int index = 0; + LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH, + { + index = index * out_shape[i] + ind_ptr[i]; + }); + + DATA_TYPE *out_ptr = (DATA_TYPE *) (out_ptr_raw + index * out_block_stride); + + const DATA_TYPE *upt_ptr = (const DATA_TYPE *) (upt_ptr_raw + index_element * upt_block_stride); + + VEC_DATA_TYPE(DATA_TYPE, N0) data_in0 = VLOAD(N0)(0, (__global DATA_TYPE *) upt_ptr); + +#ifdef SKIP_OUTPUT_READ + STORE_VECTOR_SELECT(data_in, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond); +#else // ifdef SKIP_OUTPUT_READ + VEC_DATA_TYPE(DATA_TYPE, N0) data_out0 = VLOAD(N0)(0, (__global DATA_TYPE *) out_ptr); + data_out0 = SCATTER_FUNCTION(data_out0, data_in0); + + STORE_VECTOR_SELECT(data_out, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond); +#endif // ifdef SKIP_OUTPUT_READ + } +} + +#endif // SCATTER_MP1D_2D_MPND + +#ifdef SCATTER1D_PARALLEL + +// NOTE : This code is non-deterministic and can only be excecuted with the "update" ScatterFunction +// This code is currently unusued as it requires changes to the existing test suite. +/** Performs the Scatter1D operation with multiple threads. + * Similar to @ref scatter1D() + */ +__kernel void scatter1D_parallel( + TENSOR4D_DECLARATION(updates), + TENSOR4D_DECLARATION(indices), + TENSOR4D_DECLARATION(output)) +{ + // Currently 1D - only iterate through x dimension of indices. + const int px = get_global_id(0); + const int index_value = *(uchar*)(indices_ptr + indices_offset_first_element_in_bytes + (sizeof(int) * px)); + + if(index_value < OUT_SHAPE_X) + { + const DATA_TYPE update = *(DATA_TYPE *)(updates_ptr + updates_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * px)); + __global uchar *out_addr = output_ptr + indices_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * index_value); + *(__global DATA_TYPE *)(out_addr) = update; + } +} + +#endif // SCATTER1D_PARALLEL diff --git a/src/core/CPP/CPPTypes.cpp b/src/core/CPP/CPPTypes.cpp index 9980db42f3..67fbce490f 100644 --- a/src/core/CPP/CPPTypes.cpp +++ b/src/core/CPP/CPPTypes.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022 Arm Limited. + * Copyright (c) 2018-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ #include "src/common/cpuinfo/CpuInfo.h" #include "src/common/cpuinfo/CpuIsaInfo.h" +#include "src/core/NEON/kernels/arm_gemm/utils.hpp" namespace arm_compute { @@ -135,4 +136,21 @@ unsigned int CPUInfo::get_L2_cache_size() const { return _impl->L2_cache_size; } + +unsigned long CPUInfo::get_sme2_vector_length() const +{ +#ifdef ARM_COMPUTE_ENABLE_SME2 + return arm_gemm::utils::sme::get_vector_length<int8_t>(); +#else // ARM_COMPUTE_ENABLE_SME2 + return 0; +#endif // ARM_COMPUTE_ENABLE_SME2 +} +unsigned int CPUInfo::get_cpu_num_excluding_little() const +{ +#if defined(__ANDROID__) + return _impl->info.not_little_num_cpus(); +#else /* defined(__ANDROID__) */ + return get_cpu_num(); +#endif /* defined(__ANDROID__) */ +} } // namespace arm_compute diff --git a/src/core/NEON/NEAsymm.h b/src/core/NEON/NEAsymm.h index 5f4d08d0f6..b93e64a0ef 100644 --- a/src/core/NEON/NEAsymm.h +++ b/src/core/NEON/NEAsymm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2023 Arm Limited. + * Copyright (c) 2017-2020, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_NEASYMM_H -#define ARM_COMPUTE_NEASYMM_H +#ifndef ACL_SRC_CORE_NEON_NEASYMM_H +#define ACL_SRC_CORE_NEON_NEASYMM_H #include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/intrinsics/intrinsics.h" @@ -637,10 +637,10 @@ inline int32x4x4_t vquantize_internal(const float32x4x4_t &qv, float scale, int3 const float32x4_t vinvscale = vdupq_n_f32(1.f / scale); const int32x4x4_t rf = {{ #ifdef __aarch64__ - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[2], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[3], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[2], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[3], vinvscale)), voffset), #else //__aarch64__ vaddq_s32(vcvtq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), vaddq_s32(vcvtq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), @@ -698,4 +698,4 @@ inline uint16x8x2_t vquantize_qasymm16(const float32x4x4_t &qv, const UniformQua } // namespace arm_compute #include "src/core/NEON/NEAsymm.inl" -#endif // ARM_COMPUTE_NEASYMM_H +#endif // ACL_SRC_CORE_NEON_NEASYMM_H diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index 455d604b3b..5380e6ccce 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -31,1747 +31,221 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/Validate.h" +#include "src/core/common/Registrars.h" #include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "src/core/NEON/INEKernel.h" -#include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/wrapper.h" -#include "support/SaturateCast.h" - -#include <arm_neon.h> +#include "src/cpu/kernels/reduction_layer/generic/neon/list.h" namespace arm_compute { -namespace -{ -// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized -template <typename T> -void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0) -{ - if (std::is_same<T, uint8_t>::value) - { - auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2)); - wrapper::vstore(output.ptr() + offset, res); - } - else - { - auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2)); - wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res); - } -} - -template <typename T> -uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4_t mask{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask = wrapper::vcgt(b, a); - } - else - { - mask = wrapper::vclt(b, a); - } - - uint32x4_t vec_idx = {idx, idx + 1, idx + 2, idx + 3}; - if (axis != 0) - { - vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = {{wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0}}; - - return res; -} - -template <typename T> -uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4x4_t mask{{0}}; - uint8x16_t mask_u8{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask_u8 = wrapper::vcgt(b, a); - } - else - { - mask_u8 = wrapper::vclt(b, a); - } - auto wide_u16_1 = - wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); - auto wide_u16_2 = - wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); - mask.val[0] = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); - mask.val[1] = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); - mask.val[2] = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); - mask.val[3] = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); - - uint32x4x4_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, - {idx + 4, idx + 5, idx + 6, idx + 7}, - {idx + 8, idx + 9, idx + 10, idx + 11}, - {idx + 12, idx + 13, idx + 14, idx + 15}}}; - if (axis != 0) - { - vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = { - {vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]), vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]), - vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]), vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])}}; - - return res; -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, - typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type -calculate_min(T in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - return wrapper::vpmin(pmin, pmin); -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, - typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type -calculate_min(T in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); - return wrapper::vpmin(pmin, pmin); -} - -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, - typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type -calculate_max(T in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - return wrapper::vpmax(pmax, pmax); -} - -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, - typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type -calculate_max(T in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); - return wrapper::vpmax(pmax, pmax); -} - -template <typename T> -uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) -{ - uint32x4_t res_idx_mask{0}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); - } - else - { - auto pmax = calculate_max(vec_res_value); - auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); - } - - res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones); - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask)); - pmin = wrapper::vpmin(pmin, pmin); - uint32_t res = wrapper::vgetlane(pmin, 0); - - return (res - 0xFFFFFFFF); -} - -template <typename T> -uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) -{ - uint32x4x4_t res_idx_mask{{0}}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - uint8x16_t mask_u8{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - } - else - { - auto pmax = calculate_max(vec_res_value); - mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - } - - // Widen vectors - auto wide_u16_1 = - wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); - auto wide_u16_2 = - wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); - auto wide_u32_1 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); - auto wide_u32_2 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); - auto wide_u32_3 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); - auto wide_u32_4 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); - res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); - res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); - res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3); - res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4); - res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); - res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); - res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones); - res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones); - - uint32_t res = 0xFFFFFFFF; - int iter = 0; - do - { - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); - pmin = wrapper::vpmin(pmin, pmin); - res = std::min(wrapper::vgetlane(pmin, 0), res); - iter++; - } while (iter < 4); - - return (res - 0xFFFFFFFF); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -uint32x4x4_t -calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4x2_t mask{0}; - uint16x8_t mask_u16{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask_u16 = wrapper::vcgt(b, a); - } - else - { - mask_u16 = wrapper::vclt(b, a); - } - mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16)); - mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16)); - uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}}; - if (axis != 0) - { - vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]), - wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0}; - - return res; -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -inline float16x4_t calculate_min(float16x8_t in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmin = wrapper::vpmin(pmin, pmin); - return wrapper::vpmin(pmin, pmin); -} -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -inline float16x4_t calculate_max(float16x8_t in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmax = wrapper::vpmax(pmax, pmax); - return wrapper::vpmax(pmax, pmax); -} - -template <> -uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op) -{ - uint32x4x2_t res_idx_mask{0}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - uint16x8_t mask_u16; - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - } - else - { - auto pmax = calculate_max(vec_res_value); - mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - } - - // Widen vectors - auto wide_u32_1 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16))); - auto wide_u32_2 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16))); - res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); - res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); - res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); - res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); - - uint32_t res = 0xFFFFFFFF; - uint32_t iter = 0; - do - { - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); - pmin = wrapper::vpmin(pmin, pmin); - res = std::min(wrapper::vgetlane(pmin, 0), res); - iter++; - } while (iter < 2); - - return (res - 0xFFFFFFFF); -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -template <class F> -class Reducer -{ -public: - static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set out window - Window out_window(window); - out_window.set(Window::DimX, Window::Dimension(0, 1, 1)); - - f(window, out_window, input, output, op); - } - static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in window - Window in_window(window); - Window out_window(window); - - in_window.set(Window::DimY, Window::Dimension(0, 1, 1)); - out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1))); - - f(in_window, out_window, input, output, 1, op); - } - static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in window - Window in_window(window); - Window out_window(window); - - in_window.set(Window::DimZ, Window::Dimension(0, 1, 1)); - out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2))); - - f(in_window, out_window, input, output, 2, op); - } - static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in/out window - Window in_window(window); - Window out_window(window); - - in_window.set(3, Window::Dimension(0, 1, 1)); - out_window.set(3, Window::Dimension(0, 1, 1)); - - f(in_window, out_window, input, output, 3, op); - } -}; - -template <typename T, int S> -struct RedOpX -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) - { - const size_t input_dim_0 = in->info()->dimension(0); - const int window_step_x = 16 / sizeof(T); - const auto window_start_x = static_cast<int>(in_window.x().start()); - const auto window_end_x = static_cast<int>(in_window.x().end()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_window); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<const T *>(input.ptr()); - - auto init_res_value = static_cast<T>(0.f); - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - init_res_value = static_cast<T>(*input_ptr); - break; - } - case ReductionOperation::PROD: - { - init_res_value = static_cast<T>(1.f); - break; - } - default: - break; - } - auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); - uint32x4x4_t vec_res_idx{{0}}; - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vec_elements = wrapper::vloadq(input_ptr + x); - switch (op) - { - case ReductionOperation::SUM_SQUARE: - vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); - break; - case ReductionOperation::MEAN_SUM: - case ReductionOperation::SUM: - vec_res_value = wrapper::vadd(vec_elements, vec_res_value); - break; - case ReductionOperation::PROD: - vec_res_value = wrapper::vmul(vec_elements, vec_res_value); - break; - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, - vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, - vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - case ReductionOperation::SUM_SQUARE: - { -#ifdef ARM_COMPUTE_DEBUG_ENABLED - auto res = static_cast<T>(0.f); - for (int i = 0; i < S; ++i) - { - res += wrapper::vgetlane(vec_res_value, i); - } -#else // ARM_COMPUTE_DEBUG_ENABLED - auto carry_res = - wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - for (int i = 0; i < S / 4; ++i) - { - carry_res = wrapper::vpadd(carry_res, carry_res); - } - auto res = wrapper::vgetlane(carry_res, 0); -#endif // ARM_COMPUTE_DEBUG_ENABLED - if (op == ReductionOperation::SUM_SQUARE) - { - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += (*(input_ptr + x)) * (*(input_ptr + x)); - } - } - else - { - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += *(input_ptr + x); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - res /= input_dim_0; - } - - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::PROD: - { - auto carry_res = - wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - T res = 1; - for (int i = 0; i < S / 2; ++i) - { - res *= wrapper::vgetlane(carry_res, i); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res *= *(input_ptr + x); - } - - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) < res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) > res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::MIN: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) < res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::MAX: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) > res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - }, - input, output); - } -}; - -template <typename T> -struct RedOpX_quantized -{ - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) - { - using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; - - const auto oq_info = out->info()->quantization_info().uniform(); - - const TensorInfo in_info = *(in->info()); - const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); - - const int window_step_x = 16 / sizeof(T); - const auto window_start_x = static_cast<int>(in_window.x().start()); - const auto window_end_x = static_cast<int>(in_window.x().end()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_window); - - const auto in_offset = static_cast<float>(iq_info.offset); - const float in_scale = iq_info.scale; - - const auto out_offset = static_cast<float>(oq_info.offset); - const float out_scale = oq_info.scale; - - const auto num_elements = static_cast<float>(in_info.dimension(0)); - - const float A = in_scale / (out_scale * num_elements); - const float B = out_offset - (in_scale * in_offset) / (out_scale); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - auto vec_res_value1 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value2 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value3 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value4 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - - auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f)); - - typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = {0}; - - if (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || - op == ReductionOperation::MIN || op == ReductionOperation::MAX) - { - vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{}); - } - - uint32x4x4_t vec_res_idx{{0}}; - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vec_elements = wrapper::vloadq(input_ptr + x); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); - vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); - vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); - vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset); - const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale); - - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); - auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); - auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); - auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); - - //de-quantize vec_elements - temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4); - temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4); - temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4); - temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4); - - vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f); - vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f); - vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f); - vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f); - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( - x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( - x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::ARG_IDX_MIN: - { - auto idx = - calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) < res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto idx = - calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) > res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::MIN: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) < res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::MAX: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) > res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::PROD: - { - auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); - carry_res = wrapper::vmul(carry_res, vec_res_value3_f); - carry_res = wrapper::vmul(carry_res, vec_res_value4_f); - - float res = wrapper::vgetlane(carry_res, 0); - res *= wrapper::vgetlane(carry_res, 1); - res *= wrapper::vgetlane(carry_res, 2); - res *= wrapper::vgetlane(carry_res, 3); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - //de-quantize input - if (std::is_same<T, uint8_t>::value) - { - res *= dequantize_qasymm8(*(input_ptr + x), iq_info); - } - else - { - res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info); - } - } - - //re-quantize result - if (std::is_same<T, uint8_t>::value) - { - res = quantize_qasymm8(res, iq_info); - } - else - { - res = quantize_qasymm8_signed(res, iq_info); - } - - *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res); - break; - } - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); - carry_res = wrapper::vadd(carry_res, vec_res_value3); - carry_res = wrapper::vadd(carry_res, vec_res_value4); - - auto carry_paddition = - wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); - carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); - auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += *(input_ptr + x); - } - - if (op == ReductionOperation::MEAN_SUM) - { - const int32_t resFinal = A * (static_cast<float>(res)) + B; - - *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal); - } - else - { - // Subtract accumulated offsets - res -= (in_info.dimension(0) - 1) * iq_info.offset; - *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res); - } - - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - }, - input, output); - } -}; - -template <typename T, int S> -struct RedOpYZW -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; - - inline void operator()(const Window &in_window, - Window &out_window, - const ITensor *in, - ITensor *out, - int axis, - const ReductionOperation op) - { - const TensorInfo in_info = *(in->info()); - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - neon_vector vec_res_value = {0}; - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vloadq(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{}); - break; - } - default: - { - vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - break; - } - } - uint32x4x4_t vec_res_idx{{0}}; - - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); - const auto vec_elements = wrapper::vloadq(in_ptr); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - vec_res_value = wrapper::vadd(vec_elements, vec_res_value); - break; - case ReductionOperation::SUM_SQUARE: - vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); - break; - case ReductionOperation::PROD: - vec_res_value = wrapper::vmul(vec_elements, vec_res_value); - break; - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = - calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = - calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - auto vec_width_inv = - wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{})); - vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv); - } - - if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - if (std::is_same<T, float16_t>::value) - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]); - } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - } - else - { - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value); - } - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto res_value = 0.f; - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - res_value = *(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - res_value = static_cast<T>(1.f); - break; - } - default: - { - res_value = static_cast<T>(0.f); - break; - } - } - - uint32_t res_idx = 0; - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); - - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - res_value += *in_ptr; - break; - case ReductionOperation::SUM_SQUARE: - res_value += *in_ptr * *in_ptr; - break; - case ReductionOperation::PROD: - res_value *= *in_ptr; - break; - case ReductionOperation::ARG_IDX_MIN: - { - if (*in_ptr < res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - if (*in_ptr > res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::MIN: - { - res_value = *in_ptr < res_value ? *in_ptr : res_value; - break; - } - case ReductionOperation::MAX: - { - res_value = *in_ptr > res_value ? *in_ptr : res_value; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - res_value /= in_info.dimension(axis); - } - - if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) - { - *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx; - } - else - { - *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value; - } - } - }, - input, output); - } -}; - -template <typename T, int S, int axis, ReductionOperation op> -struct RedOpYZW_complex -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; - - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation) - { - ARM_COMPUTE_ERROR_ON(axis != 2); - ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM); - - const TensorInfo in_info = *(in->info()); - const size_t stride_z = in_info.strides_in_bytes()[axis]; - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - neon_vector vec_res_value_0 = {0}; - neon_vector vec_res_value_1 = {0}; - - vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - - T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); - T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim); - - const auto vec_elements_0 = wrapper::vloadq(in_ptr_0); - const auto vec_elements_1 = wrapper::vloadq(in_ptr_1); - - vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0); - vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1); - } - - wrapper::vstore(out_ptr, vec_res_value_0); - wrapper::vstore(out_ptr + 4, vec_res_value_1); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto res_value_0 = 0.f; - auto res_value_1 = 0.f; - - T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); - res_value_0 += *in_ptr; - res_value_1 += *(in_ptr + 1); - } - *out_ptr = res_value_0; - *(out_ptr + 1) = res_value_1; - } - }, - input, output); - } -}; - -template <typename T> -struct RedOpYZW_quantized -{ - inline void operator()(const Window &in_window, - Window &out_window, - const ITensor *in, - ITensor *out, - int axis, - const ReductionOperation op) - { - const TensorInfo in_info = *(in->info()); - const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); - using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; - - const auto oq_info = out->info()->quantization_info().uniform(); - - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - using vector_type = - typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type; - using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type; - - vector_type vec_res_value1{}; - vector_type vec_res_value2{}; - vector_type vec_res_value3{}; - vector_type vec_res_value4{}; - - vector_type_f vec_res_value1_f{}; - vector_type_f vec_res_value2_f{}; - vector_type_f vec_res_value3_f{}; - vector_type_f vec_res_value4_f{}; - - const float in_offset = static_cast<float>(iq_info.offset); - const float in_scale = iq_info.scale; - - const float out_offset = static_cast<float>(oq_info.offset); - const float out_scale = oq_info.scale; - - const float num_elements = static_cast<float>(in_info.dimension(axis)); - - const float A = in_scale / (out_scale * num_elements); - const float B = out_offset - (in_scale * in_offset) / (out_scale); - - const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{}); - const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{}); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - uint32x4x4_t vec_res_idx{{0}}; - vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - - vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - - auto vec_res_value = wrapper::vloadq(input_ptr + x); - - for (unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim) - { - const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim; - const auto vec_elements = wrapper::vloadq(in_ptr); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); - vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); - vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); - vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), - wrapper::traits::vector_128_tag{}); - const auto scale32x4f_4 = - wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{}); - - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); - auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); - auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); - auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); - - //de-quantize vec_elements - temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4); - temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4); - temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4); - temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4); - - vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f); - vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f); - vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f); - vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f); - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, - vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, - vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::ARG_IDX_MAX: - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, - vec_res_idx.val[3]); - break; - } - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value); - break; - } - case ReductionOperation::SUM: - { - // Subtract offsets - auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset); - - auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1); - auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2); - auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3); - auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4); - vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets); - vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets); - vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets); - vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets); - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4)); - - combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x); - break; - } - case ReductionOperation::MEAN_SUM: - { - vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A); - vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A); - vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A); - vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A); - -#ifdef __aarch64__ - vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f); -#else // defined(__aarch64__) - vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f); -#endif // __aarch64__ - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); - auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); - - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = - wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{}); - const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale)); - - //re-quantize - vec_res_value1_f = - wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4); - vec_res_value2_f = - wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4); - vec_res_value3_f = - wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4); - vec_res_value4_f = - wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4); - - vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f); - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); - auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); - - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - float res_value = 0.f; - int32_t res_value_q = 0; - - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - res_value = *(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - res_value = static_cast<T>(1.0f); - break; - } - default: - { - res_value = static_cast<T>(0.0f); - break; - } - } - uint32_t res_idx = 0; - - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim); - switch (op) - { - case ReductionOperation::SUM: - { - res_value += *in_ptr; - break; - } - case ReductionOperation::MEAN_SUM: - { - res_value_q += *in_ptr; - break; - } - case ReductionOperation::SUM_SQUARE: - { - res_value += *in_ptr * *in_ptr; - break; - } - case ReductionOperation::PROD: - { - //de-quantize input - if (std::is_same<T, uint8_t>::value) - { - res_value *= dequantize_qasymm8(*in_ptr, iq_info); - } - else - { - res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info); - } - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - if (*in_ptr < res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - if (*in_ptr > res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::MIN: - { - res_value = *in_ptr < res_value ? *in_ptr : res_value; - break; - } - case ReductionOperation::MAX: - { - res_value = *in_ptr > res_value ? *in_ptr : res_value; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::MEAN_SUM: - { - // Apply previously calculated coefficients (with rounding on aarch64) -#ifdef __aarch64__ - const int32_t res = - arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B); -#else // defined(__aarch64__) - const int32_t res = A * (static_cast<float>(res_value_q)) + B; -#endif // __aarch64__ - *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res); - break; - } - case ReductionOperation::SUM: - { - // Subtract accumulated offsets - res_value -= (in_info.dimension(axis) - 1) * iq_info.offset; - *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value); - break; - } - case ReductionOperation::PROD: - { - //re-quantize result - T res = 0; - if (std::is_same<T, uint8_t>::value) - { - res = quantize_qasymm8(res_value, iq_info); - } - else - { - res = quantize_qasymm8_signed(res_value, iq_info); - } - *(reinterpret_cast<T *>(output.ptr() + x)) = res; - break; - } - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::ARG_IDX_MAX: - { - *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx; - break; - } - default: - *(reinterpret_cast<T *>(output.ptr() + x)) = res_value; - } - } - }, - input, output); - } -}; - -void reduce_op( - const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op) +void NEReductionOperationKernel::reduce_op() { - const bool is_complex = (input->info()->num_channels() == 2); + const bool is_complex = (_input->info()->num_channels() == 2); if (is_complex) { - switch (axis) + switch (_reduction_axis) { case 2: - switch (input->info()->data_type()) + switch (_input->info()->data_type()) { case DataType::F32: - switch (op) + { + switch (_op) { case ReductionOperation::SUM: - return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ( - window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), - op); + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM); + break; default: ARM_COMPUTE_ERROR("Not supported"); + break; } + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } return; } - switch (axis) + switch (_reduction_axis) { case 0: { - switch (input->info()->data_type()) + switch (_input->info()->data_type()) { case DataType::QASYMM8: { - return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, - RedOpX_quantized<uint8_t>(), op); + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpX_reduceX_qasymm8); + break; } case DataType::QASYMM8_SIGNED: { - return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), - op); + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpX_reduceX_qasymm8_signed); + break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpX_reduceX_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: { - return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op); + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpX_reduceX_float32_4); + break; } case DataType::S32: { - return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op); + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpX_reduceX_S32_4); + break; } default: { ARM_COMPUTE_ERROR("Not supported"); + break; } } + break; } case 1: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: { - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceY_qasymm8); + break; } case DataType::QASYMM8_SIGNED: { - return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, - RedOpYZW_quantized<int8_t>(), op); + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceY_qasymm8_signed); + break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceY_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceY_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceY_S32_4); + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } case 2: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + { + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceZ_qasymm8); + break; + } case DataType::QASYMM8_SIGNED: - return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, - RedOpYZW_quantized<int8_t>(), op); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceZ_qasymm8_signed); + break; + } +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceZ_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceZ_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceZ_S32_4); + break; + } default: + { + std::cout << int(_input->info()->data_type()) << std::endl; ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } case 3: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + { + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceW_qasymm8); + break; + } case DataType::QASYMM8_SIGNED: - return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, - RedOpYZW_quantized<int8_t>(), op); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceW_qasymm8_signed); + break; + } +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceW_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceW_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceW_S32_4); + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } default: + { ARM_COMPUTE_ERROR("Unsupported reduction axis"); + break; + } } } @@ -1819,10 +293,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u return Status{}; } -} // namespace NEReductionOperationKernel::NEReductionOperationKernel() - : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE) + : _func(nullptr), _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE) { } @@ -1856,6 +329,8 @@ void NEReductionOperationKernel::configure(const ITensor *input, .set_data_type(output_data_type) .reset_padding() .set_is_resizable(true)); + // Determine the reduction function + NEReductionOperationKernel::reduce_op(); } Status NEReductionOperationKernel::validate(const ITensorInfo *input, @@ -1874,6 +349,6 @@ void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &inf ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - reduce_op(window, _input, _output, _reduction_axis, _op); + (*_func)(window, _input, _output, _op); } } // namespace arm_compute diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.h b/src/core/NEON/kernels/NEReductionOperationKernel.h index 78bec62c14..407e5de6d6 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.h +++ b/src/core/NEON/kernels/NEReductionOperationKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H -#define ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H +#ifndef ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H +#define ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H #include "src/core/NEON/INEKernel.h" @@ -80,14 +80,24 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op); +private: // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override; + /** Common signature for all the specialized Reduction functions + * + * @param[in] window Region on which to execute the kernel. + */ + using ReductionFunction = void (*)(const Window &window, const ITensor *in, ITensor *out, ReductionOperation op); -private: + /** Populate the _func with the right reduction operation handler + */ + void reduce_op(); + + ReductionFunction _func; const ITensor *_input; ITensor *_output; unsigned int _reduction_axis; ReductionOperation _op; }; } // namespace arm_compute -#endif /*ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H */ +#endif // ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H diff --git a/src/core/NEON/kernels/NEReorderKernel.cpp b/src/core/NEON/kernels/NEReorderKernel.cpp index f5bea3e163..fe8882f59f 100644 --- a/src/core/NEON/kernels/NEReorderKernel.cpp +++ b/src/core/NEON/kernels/NEReorderKernel.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/Scheduler.h" #include "src/common/utils/Log.h" #include "src/core/NEON/kernels/arm_gemm/transform.hpp" @@ -233,13 +234,20 @@ Status NEReorderKernel::validate(const ITensorInfo *input, } } - int ksize; + int ksize = 0; switch (output_wf) { #if defined(ARM_COMPUTE_ENABLE_SVE) case WeightFormat::OHWIo8: { - ksize = 8; + if (Scheduler::get().cpu_info().has_sve() && arm_gemm::utils::get_vector_length<float>() == 8) + { + ksize = 8; + } + else + { + ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported weight format."); + } break; } #endif /* ARM_COMPUTE_ENABLE_SVE */ diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index 5c08e6137d..0ddca04846 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -86,7 +86,7 @@ static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] = "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, bfloat16, float>(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 3b444ae333..c7adf8e4ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -69,19 +69,19 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 44a7bb894a..0c1d3a387b 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -34,6 +34,7 @@ #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp" #include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp" +#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp" #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" #include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp" #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS @@ -123,14 +124,14 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32bf16fp32_dot_16VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32_mla_16VL", - [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(args); } }, @@ -138,25 +139,25 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); } }, #endif // ARM_COMPUTE_ENABLE_BF16 { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); } }, #ifdef ARM_COMPUTE_ENABLE_BF16 { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(args); } @@ -165,7 +166,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(args); } @@ -174,7 +175,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(args); } }, @@ -182,7 +183,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(args); } }, @@ -198,14 +199,14 @@ GemmImplementation<float, float>::with_estimate( GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_6x4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(args); } ), GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_4x6VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); } ), @@ -292,14 +293,14 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_fp32_mla_8x4", - [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_fp32_mla_6x4", - [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; }, + [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(args); } }, @@ -350,6 +351,14 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>(args); } ), +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_ffhybrid_fp32bf16fp32_mmla_6x16", + KernelWeightFormat::VL256_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16, float, float>(args); } +), #endif // BF16 GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index 89c2d5a23e..0cc4d4f3d9 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -530,7 +530,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } else if (_convolver) { @@ -563,7 +563,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } else { @@ -579,7 +579,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index fd20e53f60..fedda3a47a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2022-2023 Arm Limited. + * Copyright (c) 2017-2020, 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -63,7 +63,7 @@ static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = { "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL, int8_t, int32_t>(args); } }, { @@ -128,14 +128,14 @@ GemmImplementation<int8_t, int32_t>::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_s8s32_dot_8x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_s8s32_dot_6x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(args); } }, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 4f732f7d94..897ec9d05f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -29,7 +29,6 @@ #include "arm_gemm.hpp" #include "bfloat.hpp" #include "convolver.hpp" -#include "kernel_weight_format.hpp" #include "kernel_traits.hpp" #include "kernel_weight_format.hpp" #include "mergeresults.hpp" @@ -191,10 +190,19 @@ void kernel_and_merge<false, false, Requantize32>::run( auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k); #endif + // Offset C pointer in a similar way to non-quantized case above. + Tri *offset_c_ptr; + + if (c_ptr == nullptr) { + offset_c_ptr = nullptr; + } else { + offset_c_ptr = c_ptr + m_0 * ldc + n_0; + } + strat.kernel(// A and B pointers are just the packed panels. a_ptr, b_panel, // Provide relevant part of output array and row stride. - c_ptr + m_0 * ldc + n_0, ldc, + offset_c_ptr, ldc, // M, N, K sizes m_max-m_0, n_max - n_0, kern_k, // Bias, activation, accumulation. Need to offset the bias as needed. @@ -247,6 +255,84 @@ void kernel_and_merge<true, false, Requantize32>::run( } } +// Run a kernel with integrated merge, dequantizing to FP32 +template<> +template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +void kernel_and_merge<false, false, DequantizeFloat>::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *, + Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, + unsigned int n_0, unsigned int n_max, const Tr *bias, + const Activation &act, bool accumulate, const DequantizeFloat &dq, const int32_t *col_bias, + Tab *acc_buff) +{ +#ifdef CYCLE_PROFILING + auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k); +#endif + + const int32_t *offset_col_bias = nullptr; + const Tr *offset_bias = nullptr; + + if (col_bias) { + offset_col_bias = col_bias + n_0; + } + + if (bias) { + offset_bias = bias + n_0; + } + + strat.kernel(// A and B pointers are just the packed panels. + a_ptr, b_panel, + // Provide relevant part of output array and row stride. + c_ptr ? (c_ptr + m_0 * ldc + n_0) : nullptr, ldc, + // M, N, K sizes + m_max-m_0, n_max - n_0, kern_k, + // Bias, activation, accumulation. Need to offset the bias as needed. + offset_col_bias, dq, offset_bias, act, accumulate, acc_buff); +} + +template<> +template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +void kernel_and_merge<true, false, DequantizeFloat>::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel, + Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, + unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias, + const Activation &act, bool accumulate, const DequantizeFloat &qp, const int32_t *, + Tab *) +{ + const int bblocks = iceildiv(n_max - n_0, strategy::out_width()); + + { +#ifdef CYCLE_PROFILING + auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k)); +#endif + + strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k); + } + + { +#ifdef CYCLE_PROFILING + auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, ((m_max-m_0) * bblocks * strategy::out_width() * sizeof(Tr))); +#endif + auto out_area = strategy::out_width() * strategy::out_height(); + for (int i=0; i<bblocks; i++) { + const unsigned int n_start = n_0 + (strategy::out_width() * i); + const unsigned int n_end = std::min(n_start + strategy::out_width(), n_max); + + dequantize_block_32(qp, (n_end - n_start), (m_max - m_0), + c_panel + (i * out_area), strategy::out_width(), + c_ptr + m_0 * ldc + n_start, ldc, + bias != nullptr ? bias + n_start : nullptr, accumulate, act); + + } + } +} + // Integer GEMMs can be used in two contexts - "normal" where the full 32-bit output is required, or in // "requantizing" context where the output will be requantized. // @@ -280,6 +366,12 @@ public: typedef int32_t type; }; +template<typename strategy> +class accumulate_buffer_type<strategy, DequantizeFloat, false> { +public: + typedef int32_t type; +}; + template<typename strategy, typename OutputStage> class accumulate_buffer_type<strategy, OutputStage, true> { public: @@ -350,6 +442,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { const bool _thread_columns; const Activation _act; + const bool _accumulate; const int _maxthreads; int _nthreads; @@ -579,15 +672,27 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return roundup(args._cfg->inner_block_size, strategy::k_unroll()); } - // K blocking not supported if we are requantizing. - if (std::is_same<OutputStage, Requantize32>::value) { + // K blocking not supported if we are requantizing with the merging + // kernels. + if (std::is_same<OutputStage, Requantize32>::value && MergeStep) { return get_ktotal(args); } + const unsigned int L1_size = args._ci->get_L1_cache_size(); + // Special blocking for SME if (is_sme<strategy>::value) { - // Don't bother to block below this size threshold, experimentally determined to be 320 for FP32 - unsigned int scaling_threshold = 1280 / sizeof(Toi); + // Target 512 bytes for 64kB L1, or 1024 bytes for 128kB L1. + unsigned int target_bytes_per_block = L1_size / 128; + + // Default cache size in gemm-linux is 32kB though - so make + // sure minimum is 512 + if (target_bytes_per_block < 512) { + target_bytes_per_block = 512; + } + + // Don't bother to block below this size threshold (1.25X target size) + unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi); if (get_ktotal(args) <= scaling_threshold) { return get_ktotal(args); @@ -595,7 +700,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { // Once we are blocking, this (lower) threshold determines when we should use more blocks // NOTE: Could be that some factor-based solution would work better here. - unsigned int max_block_size = 1024 / sizeof(Toi); + unsigned int max_block_size = target_bytes_per_block / sizeof(Toi); unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size); @@ -604,7 +709,6 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return k_block; } - const unsigned int L1_size = args._ci->get_L1_cache_size(); unsigned int k_block; // k_block: Find out how much of the larger array can be loaded into half the cache. @@ -639,6 +743,17 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return roundup(args._cfg->outer_block_size, strategy::out_width()); } + // Special blocking for SME + if (is_sme<strategy>::value) { + // If total width is less than 4x kernel width, return the entire width. + if (args._Nsize < strategy::out_width()*4) { + return roundup(args._Nsize, strategy::out_width()); + } + + // Otherwise block to single kernel width. + return strategy::out_width(); + } + unsigned int x_block; const unsigned int L2_size = args._ci->get_L2_cache_size(); const unsigned int k_block = get_k_block_size(args); @@ -680,7 +795,7 @@ public: _Ksections(args._Ksections), _Ktotal(get_ktotal(args)), _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())), _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)), - _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())), _os(os) { } @@ -690,7 +805,7 @@ public: _Ksections(args._Ksections), _Ktotal(get_ktotal(args)), _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())), _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)), - _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())), _os() { } @@ -763,6 +878,9 @@ public: const bool first_pass = (k0==0); const bool last_pass = (kmax==_Ktotal); + // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass. + const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass; + // Figure out how many "K" the kernel will actually process. unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll()); @@ -821,9 +939,9 @@ public: // K size, and M/N ranges kern_k, start_row, end_row, start_x, end_x, // Only do bias on the first pass - ((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr), + ((bias_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr), // Only do activation on the last pass, and accumulation on any non-first pass. - (last_pass ? _act : Activation()), !first_pass, + (last_pass ? _act : Activation()), (!first_pass || _accumulate), // Pass in quantization parameters for requantizing kernels (others will ignore) _os, col_bias + (multi * _Nsize), // Accumulation buffer @@ -948,6 +1066,9 @@ public: const bool first_pass = (current.k0() == 0); const bool last_pass = (current.kmax() == _Ktotal); + // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass. + const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass; + // Pointer to appropriate part of result array. Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride); @@ -969,9 +1090,9 @@ public: // K size, and M/N ranges kern_k, y, ymax, current.x0(), current.xmax(), // Only do bias on the first pass - ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr), + ((bias_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr), // Only do activation on the last pass, and accumulation on any non-first pass. - (last_pass ? _act : Activation()), !first_pass, + (last_pass ? _act : Activation()), (!first_pass || _accumulate), // Pass in quantization parameters for requantizing kernels (others will ignore) _os, col_bias + (current.multi() * _Nsize), // Accumulation buffer @@ -1184,6 +1305,13 @@ public: } } + void set_dequantize_scale(const float scale) override { + if(std::is_same<OutputStage, DequantizeFloat>::value) { + DequantizeFloat* df = reinterpret_cast<DequantizeFloat *>(&_os); + df->scale = scale; + } + } + void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override { assert(string_len == _Ksize); _indirect_buf = ptr; @@ -1248,4 +1376,10 @@ using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strat template<typename strategy, typename To, typename Tr> using GemmInterleavedQuantized = GemmInterleaved<strategy, To, Tr, Requantize32>; +template<typename strategy, typename To, typename Tr> +using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat, false>; + +template<typename strategy, typename To, typename Tr> +using GemmInterleavedDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat>; + } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index d1c4e49edb..321c97262f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -82,7 +82,7 @@ static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods "sme2_interleaved_nomerge_s8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_1VLx4VL, int8_t, int8_t>(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index b85b1c4fcf..93eecf991e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -78,7 +78,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp new file mode 100644 index 0000000000..38d9b763f6 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" + +#include "kernels/a64_gemm_s16_8x12.hpp" +#include "kernels/a64_gemm_s8_8x12.hpp" +#include "kernels/a64_gemm_s8_4x4.hpp" +#include "kernels/a64_interleaved_s8s32_mmla_8x12.hpp" + +#ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp" +#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp" +#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SME2 +#include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp" +#include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SVE + +#include "gemm_implementation.hpp" +#include "gemm_interleaved.hpp" +#include "utils.hpp" + +#include <cstdint> +#include <vector> +namespace arm_gemm { + +static const GemmImplementation<int8_t, float, DequantizeFloat> gemm_s8fp32_methods[] = +{ +#ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, + [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>(); + return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL, int8_t, float>(args, dq); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8qfp32_mopa_4Vx1VL.hpp", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, + [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>(); + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL, int8_t, float>(args, dq); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8qfp32_mopa_2Vx2VL.hpp", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, + nullptr, + [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL, int8_t, float>(args, dq); } +}, +#endif // ARM_COMPUTE_ENABLE_SME2 +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_interleaved_s8s32_mmla_8x3VL", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_svei8mm(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, float>(args, qp); } +), +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_interleaved_s8s32_dot_8x3VL", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sve(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, float>(args, qp); } +), +#endif // ARM_COMPUTE_ENABLE_SVE +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_interleaved_s8s32_mmla_8x12", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_i8mm(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, float>(args, qp); } +), +{ + GemmMethod::GEMM_INTERLEAVED, + "a64_gemm_s16_8x12", + nullptr, + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->get_cpu_model() == CPUModel::A53 && ((args._Msize > 28) || ((args._Msize % 8) > 4)); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s16_8x12, int8_t, float>(args, qp); } +}, +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_gemm_s8_8x12", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_dotprod(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_gemm_s8_8x12, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s8_8x12, int8_t, float>(args, qp); } +), +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_gemm_s8_4x4", + nullptr, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_gemm_s8_4x4, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s8_4x4, int8_t, float>(args, qp); } +), +{ + GemmMethod::DEFAULT, + "", + nullptr, + nullptr, + nullptr +} +}; + +template<> +const GemmImplementation<int8_t, float, DequantizeFloat> *gemm_implementation_list<int8_t, float, DequantizeFloat>() { + return gemm_s8fp32_methods; +} + +template UniqueGemmCommon<int8_t, float> gemm<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os); +template KernelDescription get_gemm_method<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os); +template std::vector<KernelDescription> get_compatible_kernels<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os); + +} // namespace arm_gemm + +#endif // __aarch64__
\ No newline at end of file diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index af5cfbbf2b..dfacb687a8 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2022-2023 Arm Limited. + * Copyright (c) 2017-2020, 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -94,14 +94,14 @@ GemmImplementation<uint8_t, uint32_t>::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_u8u32_dot_8x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_u8u32_dot_6x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); } }, diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 92c884ce18..dbada36052 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -180,7 +180,7 @@ public: this->_Cptr + (multi * this->_C_multi_stride) + n, (nmax - n), (kmax-k0), this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr, - _args._act, (k0 != 0), + _args._act, (k0 != 0) || _args._accumulate, _os, col_bias, n + (_args._Nsize * multi)); } } diff --git a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp index 59591935cd..7c09608e3e 100644 --- a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp +++ b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -330,11 +330,11 @@ template void Interleave<8, 2, VLType::None>(float *, const float *, size_t, uns #endif // ARM_COMPUTE_ENABLE_SVE && ARM_COMPUTE_ENABLE_SVEF32MM /* FP16 */ -#if defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16) template void IndirectInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void Interleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); -#endif // FP16_KERNELS ar __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // FP16_KERNELS ar ARM_COMPUTE_ENABLE_FP16 template void IndirectInterleave<8, 1, VLType::None>(float *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(float *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp index 923d008bb1..ac3cbf943f 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,10 @@ public: { if (std::is_same<T, float>::value) { switch (ci->get_cpu_model()) { + case CPUModel::V1: + return { 23.64 }; default: - return { 28.48 }; + return { 16.89 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp new file mode 100644 index 0000000000..98f7fc9403 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg<float>, \ + size_t, size_t, \ + const bfloat16 *, \ + size_t, \ + IndirectOutputArg<float>, \ + const float *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffhybrid_fp32bf16fp32_mmla_6x16( ARGLIST ); + +class cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16 +{ +public: + typedef float lhs_operand_type; + typedef bfloat16 rhs_operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 6; + } + static unsigned int stripe_width() + { + return 4; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL256_BL64_BF16; + } + + static unsigned int out_width() + { + return 16; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsFixed<rhs_operand_type, result_type, 6, 16, 4> transforms = {}; + template<typename T> + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same<T, float>::value) { + switch (ci->get_cpu_model()) { + case CPUModel::V1: + return { 21.05 }; + default: + return { 15.27 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffhybrid_fp32bf16fp32_mmla_6x16; + cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp new file mode 100644 index 0000000000..9ab4aa98f9 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp @@ -0,0 +1,3240 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" +#include "../../utils.hpp" +#include "../../bfloat.hpp" + +#include <cassert> +#include <limits> + +namespace arm_gemm { + +void a64_ffhybrid_fp32bf16fp32_mmla_6x16 ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg<float> A_arg, + size_t M, size_t N, const bfloat16 *B_ptr, size_t B_stride, IndirectOutputArg<float> output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast<float>(std::numeric_limits<float>::infinity()); + float minval = - static_cast<float>(std::numeric_limits<float>::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const bfloat16 *B_ptr = {}; + const bfloat16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + void *output_ptr = nullptr; + const float *bias = nullptr; + } ka; + + unsigned long flags=0; + void *input_ptr; + + if (output_arg.is_indirect) { + ka.output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + ka.output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.bias = bias; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 181f\n" + "cmp %x[M], #0x4\n" + "bgt 145f\n" + "beq 109f\n" + "cmp %x[M], #0x2\n" + "bgt 73f\n" + "beq 37f\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "2:" // Height 1: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 3f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 3f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 3f\n" + "mov x11, x12\n" + "3:" // Height 1: B setup done + "cbz x15, 4f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "b 16f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 15f\n" + "cmp x14, #0x10\n" + "bge 13f\n" + "tbz x14, #3, 8f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "tbz x14, #2, 6f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "tbz x14, #1, 5f\n" + "ldr d16, [x13], #0x8\n" + "mov x20, #0x38\n" + "tbz x14, #0, 12f\n" + "ld1 { v16.s }[2], [x13]\n" + "b 12f\n" + "5:" // Height 1: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 12f\n" + "ldr s16, [x13, #0x0]\n" + "b 12f\n" + "6:" // Height 1: Partial accumulate: partial_2_8 + "tbz x14, #1, 7f\n" + "ldr d11, [x13], #0x8\n" + "mov x20, #0x28\n" + "tbz x14, #0, 12f\n" + "ld1 { v11.s }[2], [x13]\n" + "b 12f\n" + "7:" // Height 1: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 12f\n" + "ldr s11, [x13, #0x0]\n" + "b 12f\n" + "8:" // Height 1: Partial accumulate: partial_4_0 + "tbz x14, #2, 10f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "tbz x14, #1, 9f\n" + "ldr d10, [x13], #0x8\n" + "mov x20, #0x18\n" + "tbz x14, #0, 12f\n" + "ld1 { v10.s }[2], [x13]\n" + "b 12f\n" + "9:" // Height 1: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 12f\n" + "ldr s10, [x13, #0x0]\n" + "b 12f\n" + "10:" // Height 1: Partial accumulate: partial_2_0 + "tbz x14, #1, 11f\n" + "ldr d9, [x13], #0x8\n" + "mov x20, #0x8\n" + "tbz x14, #0, 12f\n" + "ld1 { v9.s }[2], [x13]\n" + "b 12f\n" + "11:" // Height 1: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "mov x20, #0x0\n" + "12:" // Height 1: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 14f\n" + "13:" // Height 1: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "14:" // Height 1: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "b 16f\n" + "15:" // Height 1: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "16:" // Height 1: setup done + "mov x28, #0x0\n" + "17:" // Height 1: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 18f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "cbnz x28, 19f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "b 19f\n" + "18:" // Height 1: setup direct input + "mov x26, %x[input_ptr]\n" + "19:" // Height 1: input setup done + "cmp x27, #0x4\n" + "blt 22f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q7, [x12, #0x10]\n" + "blt 21f\n" + "20:" // Height 1: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + "cmp x27, #0x8\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 20b\n" + "21:" // Height 1: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "22:" // Height 1: Multiply loop: Main loop skip + "cbz x27, 25f\n" + "cbz x27, 25f\n" + "tbz x27, #1, 23f\n" + "ldr d0, [x26], #0x8\n" + "tbz x27, #0, 24f\n" + "ld1 { v0.s }[2], [x26]\n" + "b 24f\n" + "23:" // Height 1: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "24:" // Height 1: Multiply loop: Ragged operand read: Done + "ldr q18, [x12, #0x0]\n" + "ldr q17, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x6e52ec08 // bfmmla v8.4s, v0.8h, v18.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e51ec0c // bfmmla v12.4s, v0.8h, v17.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "25:" // Height 1: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 17b\n" + "uzp1 v8.2d, v8.2d, v12.2d\n" + "uzp1 v9.2d, v9.2d, v13.2d\n" + "uzp1 v10.2d, v10.2d, v14.2d\n" + "uzp1 v11.2d, v11.2d, v15.2d\n" + "tbz %x[flags], #1, 26f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v18.4s }, [x21]\n" + "ld1r { v17.4s }, [x20]\n" + "fmin v8.4s, v8.4s, v18.4s\n" + "fmin v9.4s, v9.4s, v18.4s\n" + "fmin v10.4s, v10.4s, v18.4s\n" + "fmin v11.4s, v11.4s, v18.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v11.4s, v11.4s, v17.4s\n" + "26:" // Height 1: No activation + "cmp x14, #0x10\n" + "bge 35f\n" + "tbz x14, #3, 30f\n" + "st1 { v8.4s }, [x13], #0x10\n" + "st1 { v9.4s }, [x13], #0x10\n" + "tbz x14, #2, 28f\n" + "st1 { v10.4s }, [x13], #0x10\n" + "tbz x14, #1, 27f\n" + "str d11, [x13], #0x8\n" + "tbz x14, #0, 34f\n" + "st1 { v11.s }[2], [x13]\n" + "b 34f\n" + "27:" // Height 1: Partial direct writeback: partial_1_12 + "tbz x14, #0, 34f\n" + "str s11, [x13, #0x0]\n" + "b 34f\n" + "28:" // Height 1: Partial direct writeback: partial_2_8 + "tbz x14, #1, 29f\n" + "str d10, [x13], #0x8\n" + "tbz x14, #0, 34f\n" + "st1 { v10.s }[2], [x13]\n" + "b 34f\n" + "29:" // Height 1: Partial direct writeback: partial_1_8 + "tbz x14, #0, 34f\n" + "str s10, [x13, #0x0]\n" + "b 34f\n" + "30:" // Height 1: Partial direct writeback: partial_4_0 + "tbz x14, #2, 32f\n" + "st1 { v8.4s }, [x13], #0x10\n" + "tbz x14, #1, 31f\n" + "str d9, [x13], #0x8\n" + "tbz x14, #0, 34f\n" + "st1 { v9.s }[2], [x13]\n" + "b 34f\n" + "31:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x14, #0, 34f\n" + "str s9, [x13, #0x0]\n" + "b 34f\n" + "32:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x14, #1, 33f\n" + "str d8, [x13], #0x8\n" + "tbz x14, #0, 34f\n" + "st1 { v8.s }[2], [x13]\n" + "b 34f\n" + "33:" // Height 1: Partial direct writeback: partial_1_0 + "str s8, [x13, #0x0]\n" + "34:" // Height 1: Partial direct writeback: Done + "b 36f\n" + "35:" // Height 1: Full writeback + "str q8, [x13, #0x0]\n" + "str q9, [x13, #0x10]\n" + "str q10, [x13, #0x20]\n" + "str q11, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "36:" // Height 1: Writeback done + "subs x14, x14, #0x10\n" + "bgt 2b\n" + "b 218f\n" + "37:" // Height 2 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "38:" // Height 2: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 39f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 39f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 39f\n" + "mov x11, x12\n" + "39:" // Height 2: B setup done + "cbz x15, 40f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "b 52f\n" + "40:" // Height 2: no bias + "tbz %x[flags], #0, 51f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "bge 49f\n" + "tbz x14, #3, 44f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "tbz x14, #2, 42f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "tbz x14, #1, 41f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "tbz x14, #0, 48f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "b 48f\n" + "41:" // Height 2: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 48f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "b 48f\n" + "42:" // Height 2: Partial accumulate: partial_2_8 + "tbz x14, #1, 43f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "tbz x14, #0, 48f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "b 48f\n" + "43:" // Height 2: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 48f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "b 48f\n" + "44:" // Height 2: Partial accumulate: partial_4_0 + "tbz x14, #2, 46f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "tbz x14, #1, 45f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "tbz x14, #0, 48f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "b 48f\n" + "45:" // Height 2: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 48f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "b 48f\n" + "46:" // Height 2: Partial accumulate: partial_2_0 + "tbz x14, #1, 47f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "tbz x14, #0, 48f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "b 48f\n" + "47:" // Height 2: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "48:" // Height 2: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 50f\n" + "49:" // Height 2: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "50:" // Height 2: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "b 52f\n" + "51:" // Height 2: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "52:" // Height 2: setup done + "mov x28, #0x0\n" + "53:" // Height 2: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 54f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "cbnz x28, 55f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "b 55f\n" + "54:" // Height 2: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "55:" // Height 2: input setup done + "cmp x27, #0x4\n" + "blt 58f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "cmp x27, #0x8\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 57f\n" + "56:" // Height 2: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 56b\n" + "57:" // Height 2: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "58:" // Height 2: Multiply loop: Main loop skip + "cbz x27, 61f\n" + "cbz x27, 61f\n" + "tbz x27, #1, 59f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "tbz x27, #0, 60f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "b 60f\n" + "59:" // Height 2: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "60:" // Height 2: Multiply loop: Ragged operand read: Done + "ldr q18, [x12, #0x0]\n" + "ldr q17, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e52ec08 // bfmmla v8.4s, v0.8h, v18.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e51ec0c // bfmmla v12.4s, v0.8h, v17.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "61:" // Height 2: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 53b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "add x26, x13, x20, LSL #2\n" + "tbz %x[flags], #1, 62f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v18.4s }, [x21]\n" + "ld1r { v17.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v18.4s\n" + "fmin v12.4s, v12.4s, v18.4s\n" + "fmin v13.4s, v13.4s, v18.4s\n" + "fmin v14.4s, v14.4s, v18.4s\n" + "fmin v8.4s, v8.4s, v18.4s\n" + "fmin v9.4s, v9.4s, v18.4s\n" + "fmin v10.4s, v10.4s, v18.4s\n" + "fmin v11.4s, v11.4s, v18.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" + "fmax v13.4s, v13.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v11.4s, v11.4s, v17.4s\n" + "62:" // Height 2: No activation + "cmp x14, #0x10\n" + "bge 71f\n" + "tbz x14, #3, 66f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "tbz x14, #2, 64f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "tbz x14, #1, 63f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "tbz x14, #0, 70f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "b 70f\n" + "63:" // Height 2: Partial direct writeback: partial_1_12 + "tbz x14, #0, 70f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "b 70f\n" + "64:" // Height 2: Partial direct writeback: partial_2_8 + "tbz x14, #1, 65f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "tbz x14, #0, 70f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "b 70f\n" + "65:" // Height 2: Partial direct writeback: partial_1_8 + "tbz x14, #0, 70f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "b 70f\n" + "66:" // Height 2: Partial direct writeback: partial_4_0 + "tbz x14, #2, 68f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "tbz x14, #1, 67f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "tbz x14, #0, 70f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "b 70f\n" + "67:" // Height 2: Partial direct writeback: partial_1_4 + "tbz x14, #0, 70f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "b 70f\n" + "68:" // Height 2: Partial direct writeback: partial_2_0 + "tbz x14, #1, 69f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "tbz x14, #0, 70f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "b 70f\n" + "69:" // Height 2: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "70:" // Height 2: Partial direct writeback: Done + "b 72f\n" + "71:" // Height 2: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "72:" // Height 2: Writeback done + "subs x14, x14, #0x10\n" + "bgt 38b\n" + "b 218f\n" + "73:" // Height 3 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "74:" // Height 3: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 75f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 75f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 75f\n" + "mov x11, x12\n" + "75:" // Height 3: B setup done + "cbz x15, 76f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "b 88f\n" + "76:" // Height 3: no bias + "tbz %x[flags], #0, 87f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "bge 85f\n" + "tbz x14, #3, 80f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "ld1 { v18.4s }, [x25], #0x10\n" + "tbz x14, #2, 78f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "ld1 { v19.4s }, [x25], #0x10\n" + "tbz x14, #1, 77f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x25], #0x8\n" + "tbz x14, #0, 84f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "b 84f\n" + "77:" // Height 3: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 84f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "ldr s24, [x25, #0x0]\n" + "b 84f\n" + "78:" // Height 3: Partial accumulate: partial_2_8 + "tbz x14, #1, 79f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "ldr d19, [x25], #0x8\n" + "tbz x14, #0, 84f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "ld1 { v19.s }[2], [x25]\n" + "b 84f\n" + "79:" // Height 3: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 84f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "ldr s19, [x25, #0x0]\n" + "b 84f\n" + "80:" // Height 3: Partial accumulate: partial_4_0 + "tbz x14, #2, 82f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "tbz x14, #1, 81f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d18, [x25], #0x8\n" + "tbz x14, #0, 84f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "ld1 { v18.s }[2], [x25]\n" + "b 84f\n" + "81:" // Height 3: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 84f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "ldr s18, [x25, #0x0]\n" + "b 84f\n" + "82:" // Height 3: Partial accumulate: partial_2_0 + "tbz x14, #1, 83f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d17, [x25], #0x8\n" + "tbz x14, #0, 84f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "ld1 { v17.s }[2], [x25]\n" + "b 84f\n" + "83:" // Height 3: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s17, [x25, #0x0]\n" + "84:" // Height 3: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 86f\n" + "85:" // Height 3: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "ldr q17, [x25, #0x0]\n" + "ldr q18, [x25, #0x10]\n" + "ldr q19, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "86:" // Height 3: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "b 88f\n" + "87:" // Height 3: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "88:" // Height 3: setup done + "mov x28, #0x0\n" + "89:" // Height 3: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 90f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "cbnz x28, 91f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "b 91f\n" + "90:" // Height 3: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "91:" // Height 3: input setup done + "cmp x27, #0x4\n" + "blt 94f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "cmp x27, #0x8\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 93f\n" + "92:" // Height 3: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 92b\n" + "93:" // Height 3: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "94:" // Height 3: Multiply loop: Main loop skip + "cbz x27, 97f\n" + "cbz x27, 97f\n" + "tbz x27, #1, 95f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "tbz x27, #0, 96f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "ld1 { v2.s }[2], [x24]\n" + "b 96f\n" + "95:" // Height 3: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "ldr s2, [x24, #0x0]\n" + "96:" // Height 3: Multiply loop: Ragged operand read: Done + "ldr q26, [x12, #0x0]\n" + "ldr q25, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e5aec50 // bfmmla v16.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec54 // bfmmla v20.4s, v2.8h, v25.8h\n" + ".inst 0x6e5aec08 // bfmmla v8.4s, v0.8h, v26.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e59ec0c // bfmmla v12.4s, v0.8h, v25.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "97:" // Height 3: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 89b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "add x26, x13, x20, LSL #2\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "add x25, x26, x20, LSL #2\n" + "uzp1 v16.2d, v16.2d, v20.2d\n" + "uzp1 v17.2d, v17.2d, v21.2d\n" + "uzp1 v18.2d, v18.2d, v22.2d\n" + "uzp1 v19.2d, v19.2d, v23.2d\n" + "tbz %x[flags], #1, 98f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v26.4s }, [x21]\n" + "ld1r { v25.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v26.4s\n" + "fmin v12.4s, v12.4s, v26.4s\n" + "fmin v13.4s, v13.4s, v26.4s\n" + "fmin v14.4s, v14.4s, v26.4s\n" + "fmin v8.4s, v8.4s, v26.4s\n" + "fmin v9.4s, v9.4s, v26.4s\n" + "fmin v10.4s, v10.4s, v26.4s\n" + "fmin v11.4s, v11.4s, v26.4s\n" + "fmin v16.4s, v16.4s, v26.4s\n" + "fmin v17.4s, v17.4s, v26.4s\n" + "fmin v18.4s, v18.4s, v26.4s\n" + "fmin v19.4s, v19.4s, v26.4s\n" + "fmax v6.4s, v6.4s, v25.4s\n" + "fmax v12.4s, v12.4s, v25.4s\n" + "fmax v13.4s, v13.4s, v25.4s\n" + "fmax v14.4s, v14.4s, v25.4s\n" + "fmax v8.4s, v8.4s, v25.4s\n" + "fmax v9.4s, v9.4s, v25.4s\n" + "fmax v10.4s, v10.4s, v25.4s\n" + "fmax v11.4s, v11.4s, v25.4s\n" + "fmax v16.4s, v16.4s, v25.4s\n" + "fmax v17.4s, v17.4s, v25.4s\n" + "fmax v18.4s, v18.4s, v25.4s\n" + "fmax v19.4s, v19.4s, v25.4s\n" + "98:" // Height 3: No activation + "cmp x14, #0x10\n" + "bge 107f\n" + "tbz x14, #3, 102f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "st1 { v16.4s }, [x25], #0x10\n" + "st1 { v17.4s }, [x25], #0x10\n" + "tbz x14, #2, 100f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "st1 { v18.4s }, [x25], #0x10\n" + "tbz x14, #1, 99f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "str d19, [x25], #0x8\n" + "tbz x14, #0, 106f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "st1 { v19.s }[2], [x25]\n" + "b 106f\n" + "99:" // Height 3: Partial direct writeback: partial_1_12 + "tbz x14, #0, 106f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "str s19, [x25, #0x0]\n" + "b 106f\n" + "100:" // Height 3: Partial direct writeback: partial_2_8 + "tbz x14, #1, 101f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "str d18, [x25], #0x8\n" + "tbz x14, #0, 106f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v18.s }[2], [x25]\n" + "b 106f\n" + "101:" // Height 3: Partial direct writeback: partial_1_8 + "tbz x14, #0, 106f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s18, [x25, #0x0]\n" + "b 106f\n" + "102:" // Height 3: Partial direct writeback: partial_4_0 + "tbz x14, #2, 104f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v16.4s }, [x25], #0x10\n" + "tbz x14, #1, 103f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "str d17, [x25], #0x8\n" + "tbz x14, #0, 106f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "st1 { v17.s }[2], [x25]\n" + "b 106f\n" + "103:" // Height 3: Partial direct writeback: partial_1_4 + "tbz x14, #0, 106f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "str s17, [x25, #0x0]\n" + "b 106f\n" + "104:" // Height 3: Partial direct writeback: partial_2_0 + "tbz x14, #1, 105f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "str d16, [x25], #0x8\n" + "tbz x14, #0, 106f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v16.s }[2], [x25]\n" + "b 106f\n" + "105:" // Height 3: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "str s16, [x25, #0x0]\n" + "106:" // Height 3: Partial direct writeback: Done + "b 108f\n" + "107:" // Height 3: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "str q16, [x25, #0x0]\n" + "str q17, [x25, #0x10]\n" + "str q18, [x25, #0x20]\n" + "str q19, [x25, #0x30]\n" + "108:" // Height 3: Writeback done + "subs x14, x14, #0x10\n" + "bgt 74b\n" + "b 218f\n" + "109:" // Height 4 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "110:" // Height 4: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 111f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 111f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 111f\n" + "mov x11, x12\n" + "111:" // Height 4: B setup done + "cbz x15, 112f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "b 124f\n" + "112:" // Height 4: no bias + "tbz %x[flags], #0, 123f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "bge 121f\n" + "tbz x14, #3, 116f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "ld1 { v18.4s }, [x25], #0x10\n" + "ld1 { v21.4s }, [x24], #0x10\n" + "tbz x14, #2, 114f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "ld1 { v19.4s }, [x25], #0x10\n" + "ld1 { v22.4s }, [x24], #0x10\n" + "tbz x14, #1, 113f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x25], #0x8\n" + "ldr d23, [x24], #0x8\n" + "tbz x14, #0, 120f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v23.s }[2], [x24]\n" + "b 120f\n" + "113:" // Height 4: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 120f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "ldr s24, [x25, #0x0]\n" + "ldr s23, [x24, #0x0]\n" + "b 120f\n" + "114:" // Height 4: Partial accumulate: partial_2_8 + "tbz x14, #1, 115f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "ldr d19, [x25], #0x8\n" + "ldr d22, [x24], #0x8\n" + "tbz x14, #0, 120f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "ld1 { v19.s }[2], [x25]\n" + "ld1 { v22.s }[2], [x24]\n" + "b 120f\n" + "115:" // Height 4: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 120f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "ldr s19, [x25, #0x0]\n" + "ldr s22, [x24, #0x0]\n" + "b 120f\n" + "116:" // Height 4: Partial accumulate: partial_4_0 + "tbz x14, #2, 118f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "tbz x14, #1, 117f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d18, [x25], #0x8\n" + "ldr d21, [x24], #0x8\n" + "tbz x14, #0, 120f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "ld1 { v18.s }[2], [x25]\n" + "ld1 { v21.s }[2], [x24]\n" + "b 120f\n" + "117:" // Height 4: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 120f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "ldr s18, [x25, #0x0]\n" + "ldr s21, [x24, #0x0]\n" + "b 120f\n" + "118:" // Height 4: Partial accumulate: partial_2_0 + "tbz x14, #1, 119f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d17, [x25], #0x8\n" + "ldr d20, [x24], #0x8\n" + "tbz x14, #0, 120f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "ld1 { v17.s }[2], [x25]\n" + "ld1 { v20.s }[2], [x24]\n" + "b 120f\n" + "119:" // Height 4: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s17, [x25, #0x0]\n" + "ldr s20, [x24, #0x0]\n" + "120:" // Height 4: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 122f\n" + "121:" // Height 4: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "ldr q17, [x25, #0x0]\n" + "ldr q18, [x25, #0x10]\n" + "ldr q19, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "ldr q20, [x24, #0x0]\n" + "ldr q21, [x24, #0x10]\n" + "ldr q22, [x24, #0x20]\n" + "ldr q23, [x24, #0x30]\n" + "122:" // Height 4: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "b 124f\n" + "123:" // Height 4: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "124:" // Height 4: setup done + "mov x28, #0x0\n" + "125:" // Height 4: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 126f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "cbnz x28, 127f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "b 127f\n" + "126:" // Height 4: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "127:" // Height 4: input setup done + "cmp x27, #0x4\n" + "blt 130f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "cmp x27, #0x8\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "ld1 { v3.4s }, [x23], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 129f\n" + "128:" // Height 4: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + "ld1 { v3.4s }, [x23], #0x10\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 128b\n" + "129:" // Height 4: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "130:" // Height 4: Multiply loop: Main loop skip + "cbz x27, 133f\n" + "cbz x27, 133f\n" + "tbz x27, #1, 131f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "ldr d3, [x23], #0x8\n" + "tbz x27, #0, 132f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "ld1 { v2.s }[2], [x24]\n" + "ld1 { v3.s }[2], [x23]\n" + "b 132f\n" + "131:" // Height 4: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "ldr s2, [x24, #0x0]\n" + "ldr s3, [x23, #0x0]\n" + "132:" // Height 4: Multiply loop: Ragged operand read: Done + "ldr q26, [x12, #0x0]\n" + "ldr q25, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e5aec08 // bfmmla v8.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec50 // bfmmla v16.4s, v2.8h, v26.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e59ec0c // bfmmla v12.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec54 // bfmmla v20.4s, v2.8h, v25.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "133:" // Height 4: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 125b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "add x24, x25, x20, LSL #2\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "tbz %x[flags], #1, 134f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v26.4s }, [x21]\n" + "ld1r { v25.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v26.4s\n" + "fmin v12.4s, v12.4s, v26.4s\n" + "fmin v13.4s, v13.4s, v26.4s\n" + "fmin v14.4s, v14.4s, v26.4s\n" + "fmin v8.4s, v8.4s, v26.4s\n" + "fmin v9.4s, v9.4s, v26.4s\n" + "fmin v10.4s, v10.4s, v26.4s\n" + "fmin v11.4s, v11.4s, v26.4s\n" + "fmin v15.4s, v15.4s, v26.4s\n" + "fmin v20.4s, v20.4s, v26.4s\n" + "fmin v21.4s, v21.4s, v26.4s\n" + "fmin v22.4s, v22.4s, v26.4s\n" + "fmin v16.4s, v16.4s, v26.4s\n" + "fmin v17.4s, v17.4s, v26.4s\n" + "fmin v18.4s, v18.4s, v26.4s\n" + "fmin v19.4s, v19.4s, v26.4s\n" + "fmax v6.4s, v6.4s, v25.4s\n" + "fmax v12.4s, v12.4s, v25.4s\n" + "fmax v13.4s, v13.4s, v25.4s\n" + "fmax v14.4s, v14.4s, v25.4s\n" + "fmax v8.4s, v8.4s, v25.4s\n" + "fmax v9.4s, v9.4s, v25.4s\n" + "fmax v10.4s, v10.4s, v25.4s\n" + "fmax v11.4s, v11.4s, v25.4s\n" + "fmax v15.4s, v15.4s, v25.4s\n" + "fmax v20.4s, v20.4s, v25.4s\n" + "fmax v21.4s, v21.4s, v25.4s\n" + "fmax v22.4s, v22.4s, v25.4s\n" + "fmax v16.4s, v16.4s, v25.4s\n" + "fmax v17.4s, v17.4s, v25.4s\n" + "fmax v18.4s, v18.4s, v25.4s\n" + "fmax v19.4s, v19.4s, v25.4s\n" + "134:" // Height 4: No activation + "cmp x14, #0x10\n" + "bge 143f\n" + "tbz x14, #3, 138f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v17.4s }, [x24], #0x10\n" + "tbz x14, #2, 136f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v18.4s }, [x24], #0x10\n" + "tbz x14, #1, 135f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "str d22, [x25], #0x8\n" + "str d19, [x24], #0x8\n" + "tbz x14, #0, 142f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v19.s }[2], [x24]\n" + "b 142f\n" + "135:" // Height 4: Partial direct writeback: partial_1_12 + "tbz x14, #0, 142f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s19, [x24, #0x0]\n" + "b 142f\n" + "136:" // Height 4: Partial direct writeback: partial_2_8 + "tbz x14, #1, 137f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "str d21, [x25], #0x8\n" + "str d18, [x24], #0x8\n" + "tbz x14, #0, 142f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v18.s }[2], [x24]\n" + "b 142f\n" + "137:" // Height 4: Partial direct writeback: partial_1_8 + "tbz x14, #0, 142f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s18, [x24, #0x0]\n" + "b 142f\n" + "138:" // Height 4: Partial direct writeback: partial_4_0 + "tbz x14, #2, 140f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "tbz x14, #1, 139f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "str d20, [x25], #0x8\n" + "str d17, [x24], #0x8\n" + "tbz x14, #0, 142f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v17.s }[2], [x24]\n" + "b 142f\n" + "139:" // Height 4: Partial direct writeback: partial_1_4 + "tbz x14, #0, 142f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s17, [x24, #0x0]\n" + "b 142f\n" + "140:" // Height 4: Partial direct writeback: partial_2_0 + "tbz x14, #1, 141f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "str d15, [x25], #0x8\n" + "str d16, [x24], #0x8\n" + "tbz x14, #0, 142f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v15.s }[2], [x25]\n" + "st1 { v16.s }[2], [x24]\n" + "b 142f\n" + "141:" // Height 4: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "str s15, [x25, #0x0]\n" + "str s16, [x24, #0x0]\n" + "142:" // Height 4: Partial direct writeback: Done + "b 144f\n" + "143:" // Height 4: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "str q15, [x25, #0x0]\n" + "str q20, [x25, #0x10]\n" + "str q21, [x25, #0x20]\n" + "str q22, [x25, #0x30]\n" + "str q16, [x24, #0x0]\n" + "str q17, [x24, #0x10]\n" + "str q18, [x24, #0x20]\n" + "str q19, [x24, #0x30]\n" + "144:" // Height 4: Writeback done + "subs x14, x14, #0x10\n" + "bgt 110b\n" + "b 218f\n" + "145:" // Height 5 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "146:" // Height 5: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 147f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 147f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 147f\n" + "mov x11, x12\n" + "147:" // Height 5: B setup done + "cbz x15, 148f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "mov v24.16b, v8.16b\n" + "mov v28.16b, v12.16b\n" + "mov v25.16b, v9.16b\n" + "mov v29.16b, v13.16b\n" + "mov v26.16b, v10.16b\n" + "mov v30.16b, v14.16b\n" + "mov v27.16b, v11.16b\n" + "mov v31.16b, v15.16b\n" + "b 160f\n" + "148:" // Height 5: no bias + "tbz %x[flags], #0, 159f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "bge 157f\n" + "tbz x14, #3, 152f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "ld1 { v18.4s }, [x25], #0x10\n" + "ld1 { v21.4s }, [x24], #0x10\n" + "ld1 { v26.4s }, [x23], #0x10\n" + "tbz x14, #2, 150f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "ld1 { v19.4s }, [x25], #0x10\n" + "ld1 { v22.4s }, [x24], #0x10\n" + "ld1 { v27.4s }, [x23], #0x10\n" + "tbz x14, #1, 149f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x25], #0x8\n" + "ldr d23, [x24], #0x8\n" + "ldr d6, [x23], #0x8\n" + "tbz x14, #0, 156f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v23.s }[2], [x24]\n" + "ld1 { v6.s }[2], [x23]\n" + "b 156f\n" + "149:" // Height 5: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 156f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "ldr s24, [x25, #0x0]\n" + "ldr s23, [x24, #0x0]\n" + "ldr s6, [x23, #0x0]\n" + "b 156f\n" + "150:" // Height 5: Partial accumulate: partial_2_8 + "tbz x14, #1, 151f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "ldr d19, [x25], #0x8\n" + "ldr d22, [x24], #0x8\n" + "ldr d27, [x23], #0x8\n" + "tbz x14, #0, 156f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "ld1 { v19.s }[2], [x25]\n" + "ld1 { v22.s }[2], [x24]\n" + "ld1 { v27.s }[2], [x23]\n" + "b 156f\n" + "151:" // Height 5: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 156f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "ldr s19, [x25, #0x0]\n" + "ldr s22, [x24, #0x0]\n" + "ldr s27, [x23, #0x0]\n" + "b 156f\n" + "152:" // Height 5: Partial accumulate: partial_4_0 + "tbz x14, #2, 154f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "tbz x14, #1, 153f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d18, [x25], #0x8\n" + "ldr d21, [x24], #0x8\n" + "ldr d26, [x23], #0x8\n" + "tbz x14, #0, 156f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "ld1 { v18.s }[2], [x25]\n" + "ld1 { v21.s }[2], [x24]\n" + "ld1 { v26.s }[2], [x23]\n" + "b 156f\n" + "153:" // Height 5: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 156f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "ldr s18, [x25, #0x0]\n" + "ldr s21, [x24, #0x0]\n" + "ldr s26, [x23, #0x0]\n" + "b 156f\n" + "154:" // Height 5: Partial accumulate: partial_2_0 + "tbz x14, #1, 155f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d17, [x25], #0x8\n" + "ldr d20, [x24], #0x8\n" + "ldr d25, [x23], #0x8\n" + "tbz x14, #0, 156f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "ld1 { v17.s }[2], [x25]\n" + "ld1 { v20.s }[2], [x24]\n" + "ld1 { v25.s }[2], [x23]\n" + "b 156f\n" + "155:" // Height 5: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s17, [x25, #0x0]\n" + "ldr s20, [x24, #0x0]\n" + "ldr s25, [x23, #0x0]\n" + "156:" // Height 5: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 158f\n" + "157:" // Height 5: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "ldr q17, [x25, #0x0]\n" + "ldr q18, [x25, #0x10]\n" + "ldr q19, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "ldr q20, [x24, #0x0]\n" + "ldr q21, [x24, #0x10]\n" + "ldr q22, [x24, #0x20]\n" + "ldr q23, [x24, #0x30]\n" + "ldr q25, [x23, #0x0]\n" + "ldr q26, [x23, #0x10]\n" + "ldr q27, [x23, #0x20]\n" + "ldr q6, [x23, #0x30]\n" + "158:" // Height 5: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "zip1 v24.2d, v25.2d, v28.2d\n" + "zip2 v28.2d, v25.2d, v28.2d\n" + "zip1 v25.2d, v26.2d, v29.2d\n" + "zip2 v29.2d, v26.2d, v29.2d\n" + "zip1 v26.2d, v27.2d, v30.2d\n" + "zip2 v30.2d, v27.2d, v30.2d\n" + "zip1 v27.2d, v6.2d, v31.2d\n" + "zip2 v31.2d, v6.2d, v31.2d\n" + "b 160f\n" + "159:" // Height 5: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "160:" // Height 5: setup done + "mov x28, #0x0\n" + "161:" // Height 5: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 162f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "ldr x22, [x20, #0x20]\n" + "cbnz x28, 163f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "b 163f\n" + "162:" // Height 5: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "add x22, x23, x21, LSL #2\n" + "163:" // Height 5: input setup done + "cmp x27, #0x4\n" + "blt 166f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "cmp x27, #0x8\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "ld1 { v3.4s }, [x23], #0x10\n" + "ld1 { v4.4s }, [x22], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 165f\n" + "164:" // Height 5: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + "ld1 { v3.4s }, [x23], #0x10\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q6, [x11, #0x0]\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q5, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e45ec0d // bfmmla v13.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + ".inst 0x6e45ec9d // bfmmla v29.4s, v4.8h, v5.8h\n" + "ldr q5, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e45ec9e // bfmmla v30.4s, v4.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e45ec0f // bfmmla v15.4s, v0.8h, v5.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + ".inst 0x6e45ec57 // bfmmla v23.4s, v2.8h, v5.8h\n" + "ld1 { v2.4s }, [x24], #0x10\n" + ".inst 0x6e45ec9f // bfmmla v31.4s, v4.8h, v5.8h\n" + "ld1 { v4.4s }, [x22], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 164b\n" + "165:" // Height 5: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q3, [x11, #0x0]\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q1, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n" + "ldr q3, [x10, #0x0]\n" + ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n" + "ldr q1, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n" + "ldr q3, [x9, #0x0]\n" + ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n" + "ldr q1, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n" + ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n" + "166:" // Height 5: Multiply loop: Main loop skip + "cbz x27, 169f\n" + "cbz x27, 169f\n" + "tbz x27, #1, 167f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "ldr d3, [x23], #0x8\n" + "ldr d4, [x22], #0x8\n" + "tbz x27, #0, 168f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "ld1 { v2.s }[2], [x24]\n" + "ld1 { v3.s }[2], [x23]\n" + "ld1 { v4.s }[2], [x22]\n" + "b 168f\n" + "167:" // Height 5: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "ldr s2, [x24, #0x0]\n" + "ldr s3, [x23, #0x0]\n" + "ldr s4, [x22, #0x0]\n" + "168:" // Height 5: Multiply loop: Ragged operand read: Done + "ldr q6, [x12, #0x0]\n" + "ldr q5, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + ".inst 0x6e45ec0c // bfmmla v12.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec9c // bfmmla v28.4s, v4.8h, v5.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q3, [x11, #0x0]\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + "ldr q1, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n" + "ldr q3, [x10, #0x0]\n" + ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n" + "ldr q1, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n" + "ldr q3, [x9, #0x0]\n" + ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n" + "ldr q1, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n" + ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n" + "169:" // Height 5: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 161b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "add x23, x24, x20, LSL #2\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "uzp1 v24.2d, v24.2d, v28.2d\n" + "uzp1 v25.2d, v25.2d, v29.2d\n" + "uzp1 v26.2d, v26.2d, v30.2d\n" + "uzp1 v27.2d, v27.2d, v31.2d\n" + "tbz %x[flags], #1, 170f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v1.4s }, [x21]\n" + "ld1r { v0.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmax v6.4s, v6.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "170:" // Height 5: No activation + "cmp x14, #0x10\n" + "bge 179f\n" + "tbz x14, #3, 174f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v17.4s }, [x24], #0x10\n" + "st1 { v24.4s }, [x23], #0x10\n" + "st1 { v25.4s }, [x23], #0x10\n" + "tbz x14, #2, 172f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v18.4s }, [x24], #0x10\n" + "st1 { v26.4s }, [x23], #0x10\n" + "tbz x14, #1, 171f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "str d22, [x25], #0x8\n" + "str d19, [x24], #0x8\n" + "str d27, [x23], #0x8\n" + "tbz x14, #0, 178f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v19.s }[2], [x24]\n" + "st1 { v27.s }[2], [x23]\n" + "b 178f\n" + "171:" // Height 5: Partial direct writeback: partial_1_12 + "tbz x14, #0, 178f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s19, [x24, #0x0]\n" + "str s27, [x23, #0x0]\n" + "b 178f\n" + "172:" // Height 5: Partial direct writeback: partial_2_8 + "tbz x14, #1, 173f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "str d21, [x25], #0x8\n" + "str d18, [x24], #0x8\n" + "str d26, [x23], #0x8\n" + "tbz x14, #0, 178f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v18.s }[2], [x24]\n" + "st1 { v26.s }[2], [x23]\n" + "b 178f\n" + "173:" // Height 5: Partial direct writeback: partial_1_8 + "tbz x14, #0, 178f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s18, [x24, #0x0]\n" + "str s26, [x23, #0x0]\n" + "b 178f\n" + "174:" // Height 5: Partial direct writeback: partial_4_0 + "tbz x14, #2, 176f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v24.4s }, [x23], #0x10\n" + "tbz x14, #1, 175f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "str d20, [x25], #0x8\n" + "str d17, [x24], #0x8\n" + "str d25, [x23], #0x8\n" + "tbz x14, #0, 178f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v17.s }[2], [x24]\n" + "st1 { v25.s }[2], [x23]\n" + "b 178f\n" + "175:" // Height 5: Partial direct writeback: partial_1_4 + "tbz x14, #0, 178f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s17, [x24, #0x0]\n" + "str s25, [x23, #0x0]\n" + "b 178f\n" + "176:" // Height 5: Partial direct writeback: partial_2_0 + "tbz x14, #1, 177f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "str d15, [x25], #0x8\n" + "str d16, [x24], #0x8\n" + "str d24, [x23], #0x8\n" + "tbz x14, #0, 178f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v15.s }[2], [x25]\n" + "st1 { v16.s }[2], [x24]\n" + "st1 { v24.s }[2], [x23]\n" + "b 178f\n" + "177:" // Height 5: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "str s15, [x25, #0x0]\n" + "str s16, [x24, #0x0]\n" + "str s24, [x23, #0x0]\n" + "178:" // Height 5: Partial direct writeback: Done + "b 180f\n" + "179:" // Height 5: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "str q15, [x25, #0x0]\n" + "str q20, [x25, #0x10]\n" + "str q21, [x25, #0x20]\n" + "str q22, [x25, #0x30]\n" + "str q16, [x24, #0x0]\n" + "str q17, [x24, #0x10]\n" + "str q18, [x24, #0x20]\n" + "str q19, [x24, #0x30]\n" + "str q24, [x23, #0x0]\n" + "str q25, [x23, #0x10]\n" + "str q26, [x23, #0x20]\n" + "str q27, [x23, #0x30]\n" + "180:" // Height 5: Writeback done + "subs x14, x14, #0x10\n" + "bgt 146b\n" + "b 218f\n" + "181:" // Height 6 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "mov x21, #0x18\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "madd x21, x20, x21, x13\n" + "str x21, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "182:" // Height 6: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 183f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 183f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 183f\n" + "mov x11, x12\n" + "183:" // Height 6: B setup done + "cbz x15, 184f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "mov v24.16b, v8.16b\n" + "mov v28.16b, v12.16b\n" + "mov v25.16b, v9.16b\n" + "mov v29.16b, v13.16b\n" + "mov v26.16b, v10.16b\n" + "mov v30.16b, v14.16b\n" + "mov v27.16b, v11.16b\n" + "mov v31.16b, v15.16b\n" + "b 196f\n" + "184:" // Height 6: no bias + "tbz %x[flags], #0, 195f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "add x22, x23, x20, LSL #2\n" + "bge 193f\n" + "tbz x14, #3, 188f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "ld1 { v28.4s }, [x22], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "ld1 { v18.4s }, [x25], #0x10\n" + "ld1 { v21.4s }, [x24], #0x10\n" + "ld1 { v26.4s }, [x23], #0x10\n" + "ld1 { v29.4s }, [x22], #0x10\n" + "tbz x14, #2, 186f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "ld1 { v19.4s }, [x25], #0x10\n" + "ld1 { v22.4s }, [x24], #0x10\n" + "ld1 { v27.4s }, [x23], #0x10\n" + "ld1 { v30.4s }, [x22], #0x10\n" + "tbz x14, #1, 185f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x25], #0x8\n" + "ldr d23, [x24], #0x8\n" + "ldr d6, [x23], #0x8\n" + "ldr d31, [x22], #0x8\n" + "tbz x14, #0, 192f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v23.s }[2], [x24]\n" + "ld1 { v6.s }[2], [x23]\n" + "ld1 { v31.s }[2], [x22]\n" + "b 192f\n" + "185:" // Height 6: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 192f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "ldr s24, [x25, #0x0]\n" + "ldr s23, [x24, #0x0]\n" + "ldr s6, [x23, #0x0]\n" + "ldr s31, [x22, #0x0]\n" + "b 192f\n" + "186:" // Height 6: Partial accumulate: partial_2_8 + "tbz x14, #1, 187f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "ldr d19, [x25], #0x8\n" + "ldr d22, [x24], #0x8\n" + "ldr d27, [x23], #0x8\n" + "ldr d30, [x22], #0x8\n" + "tbz x14, #0, 192f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "ld1 { v19.s }[2], [x25]\n" + "ld1 { v22.s }[2], [x24]\n" + "ld1 { v27.s }[2], [x23]\n" + "ld1 { v30.s }[2], [x22]\n" + "b 192f\n" + "187:" // Height 6: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 192f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "ldr s19, [x25, #0x0]\n" + "ldr s22, [x24, #0x0]\n" + "ldr s27, [x23, #0x0]\n" + "ldr s30, [x22, #0x0]\n" + "b 192f\n" + "188:" // Height 6: Partial accumulate: partial_4_0 + "tbz x14, #2, 190f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "ld1 { v28.4s }, [x22], #0x10\n" + "tbz x14, #1, 189f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d18, [x25], #0x8\n" + "ldr d21, [x24], #0x8\n" + "ldr d26, [x23], #0x8\n" + "ldr d29, [x22], #0x8\n" + "tbz x14, #0, 192f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "ld1 { v18.s }[2], [x25]\n" + "ld1 { v21.s }[2], [x24]\n" + "ld1 { v26.s }[2], [x23]\n" + "ld1 { v29.s }[2], [x22]\n" + "b 192f\n" + "189:" // Height 6: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 192f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "ldr s18, [x25, #0x0]\n" + "ldr s21, [x24, #0x0]\n" + "ldr s26, [x23, #0x0]\n" + "ldr s29, [x22, #0x0]\n" + "b 192f\n" + "190:" // Height 6: Partial accumulate: partial_2_0 + "tbz x14, #1, 191f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d17, [x25], #0x8\n" + "ldr d20, [x24], #0x8\n" + "ldr d25, [x23], #0x8\n" + "ldr d28, [x22], #0x8\n" + "tbz x14, #0, 192f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "ld1 { v17.s }[2], [x25]\n" + "ld1 { v20.s }[2], [x24]\n" + "ld1 { v25.s }[2], [x23]\n" + "ld1 { v28.s }[2], [x22]\n" + "b 192f\n" + "191:" // Height 6: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s17, [x25, #0x0]\n" + "ldr s20, [x24, #0x0]\n" + "ldr s25, [x23, #0x0]\n" + "ldr s28, [x22, #0x0]\n" + "192:" // Height 6: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 194f\n" + "193:" // Height 6: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "ldr q17, [x25, #0x0]\n" + "ldr q18, [x25, #0x10]\n" + "ldr q19, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "ldr q20, [x24, #0x0]\n" + "ldr q21, [x24, #0x10]\n" + "ldr q22, [x24, #0x20]\n" + "ldr q23, [x24, #0x30]\n" + "ldr q25, [x23, #0x0]\n" + "ldr q26, [x23, #0x10]\n" + "ldr q27, [x23, #0x20]\n" + "ldr q6, [x23, #0x30]\n" + "ldr q28, [x22, #0x0]\n" + "ldr q29, [x22, #0x10]\n" + "ldr q30, [x22, #0x20]\n" + "ldr q31, [x22, #0x30]\n" + "194:" // Height 6: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "zip1 v24.2d, v25.2d, v28.2d\n" + "zip2 v28.2d, v25.2d, v28.2d\n" + "zip1 v25.2d, v26.2d, v29.2d\n" + "zip2 v29.2d, v26.2d, v29.2d\n" + "zip1 v26.2d, v27.2d, v30.2d\n" + "zip2 v30.2d, v27.2d, v30.2d\n" + "zip1 v27.2d, v6.2d, v31.2d\n" + "zip2 v31.2d, v6.2d, v31.2d\n" + "b 196f\n" + "195:" // Height 6: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "196:" // Height 6: setup done + "mov x28, #0x0\n" + "197:" // Height 6: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 198f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "ldr x22, [x20, #0x20]\n" + "ldr x21, [x20, #0x28]\n" + "cbnz x28, 199f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "add x21, x21, x20, LSL #2\n" + "b 199f\n" + "198:" // Height 6: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "add x22, x23, x21, LSL #2\n" + "add x21, x22, x21, LSL #2\n" + "199:" // Height 6: input setup done + "cmp x27, #0x4\n" + "blt 202f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "cmp x27, #0x8\n" + "ld1 { v4.4s }, [x22], #0x10\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "ld1 { v3.4s }, [x23], #0x10\n" + "ld1 { v5.4s }, [x21], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 201f\n" + "200:" // Height 6: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + "ld1 { v3.4s }, [x23], #0x10\n" + ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n" + "ld1 { v5.4s }, [x21], #0x10\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + "ldr q6, [x11, #0x0]\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + "ldr q7, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9d // bfmmla v29.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9e // bfmmla v30.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n" + "ld1 { v2.4s }, [x24], #0x10\n" + ".inst 0x6e47ec9f // bfmmla v31.4s, v4.8h, v7.8h\n" + "ld1 { v4.4s }, [x22], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 200b\n" + "201:" // Height 6: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + "ldr q3, [x11, #0x0]\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + "ldr q1, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n" + "ldr q3, [x10, #0x0]\n" + ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n" + "ldr q1, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n" + "ldr q3, [x9, #0x0]\n" + ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n" + "ldr q1, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n" + ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n" + "202:" // Height 6: Multiply loop: Main loop skip + "cbz x27, 205f\n" + "cbz x27, 205f\n" + "tbz x27, #1, 203f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "ldr d3, [x23], #0x8\n" + "ldr d4, [x22], #0x8\n" + "ldr d5, [x21], #0x8\n" + "tbz x27, #0, 204f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "ld1 { v2.s }[2], [x24]\n" + "ld1 { v3.s }[2], [x23]\n" + "ld1 { v4.s }[2], [x22]\n" + "ld1 { v5.s }[2], [x21]\n" + "b 204f\n" + "203:" // Height 6: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "ldr s2, [x24, #0x0]\n" + "ldr s3, [x23, #0x0]\n" + "ldr s4, [x22, #0x0]\n" + "ldr s5, [x21, #0x0]\n" + "204:" // Height 6: Multiply loop: Ragged operand read: Done + "ldr q7, [x12, #0x0]\n" + "ldr q6, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n" + "ldr q3, [x11, #0x0]\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n" + "ldr q1, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n" + "ldr q3, [x10, #0x0]\n" + ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n" + "ldr q1, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n" + "ldr q3, [x9, #0x0]\n" + ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n" + "ldr q1, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n" + ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n" + "205:" // Height 6: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 197b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "add x23, x24, x20, LSL #2\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "add x22, x23, x20, LSL #2\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "uzp1 v23.2d, v24.2d, v28.2d\n" + "uzp2 v24.2d, v24.2d, v28.2d\n" + "uzp1 v28.2d, v25.2d, v29.2d\n" + "uzp2 v25.2d, v25.2d, v29.2d\n" + "uzp1 v29.2d, v26.2d, v30.2d\n" + "uzp2 v26.2d, v26.2d, v30.2d\n" + "uzp1 v30.2d, v27.2d, v31.2d\n" + "uzp2 v27.2d, v27.2d, v31.2d\n" + "tbz %x[flags], #1, 206f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v1.4s }, [x21]\n" + "ld1r { v0.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmax v6.4s, v6.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "206:" // Height 6: No activation + "cmp x14, #0x10\n" + "bge 215f\n" + "tbz x14, #3, 210f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v17.4s }, [x24], #0x10\n" + "st1 { v23.4s }, [x23], #0x10\n" + "st1 { v28.4s }, [x23], #0x10\n" + "st1 { v24.4s }, [x22], #0x10\n" + "st1 { v25.4s }, [x22], #0x10\n" + "tbz x14, #2, 208f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v18.4s }, [x24], #0x10\n" + "st1 { v29.4s }, [x23], #0x10\n" + "st1 { v26.4s }, [x22], #0x10\n" + "tbz x14, #1, 207f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "str d22, [x25], #0x8\n" + "str d19, [x24], #0x8\n" + "str d30, [x23], #0x8\n" + "str d27, [x22], #0x8\n" + "tbz x14, #0, 214f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v19.s }[2], [x24]\n" + "st1 { v30.s }[2], [x23]\n" + "st1 { v27.s }[2], [x22]\n" + "b 214f\n" + "207:" // Height 6: Partial direct writeback: partial_1_12 + "tbz x14, #0, 214f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s19, [x24, #0x0]\n" + "str s30, [x23, #0x0]\n" + "str s27, [x22, #0x0]\n" + "b 214f\n" + "208:" // Height 6: Partial direct writeback: partial_2_8 + "tbz x14, #1, 209f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "str d21, [x25], #0x8\n" + "str d18, [x24], #0x8\n" + "str d29, [x23], #0x8\n" + "str d26, [x22], #0x8\n" + "tbz x14, #0, 214f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v18.s }[2], [x24]\n" + "st1 { v29.s }[2], [x23]\n" + "st1 { v26.s }[2], [x22]\n" + "b 214f\n" + "209:" // Height 6: Partial direct writeback: partial_1_8 + "tbz x14, #0, 214f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s18, [x24, #0x0]\n" + "str s29, [x23, #0x0]\n" + "str s26, [x22, #0x0]\n" + "b 214f\n" + "210:" // Height 6: Partial direct writeback: partial_4_0 + "tbz x14, #2, 212f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v23.4s }, [x23], #0x10\n" + "st1 { v24.4s }, [x22], #0x10\n" + "tbz x14, #1, 211f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "str d20, [x25], #0x8\n" + "str d17, [x24], #0x8\n" + "str d28, [x23], #0x8\n" + "str d25, [x22], #0x8\n" + "tbz x14, #0, 214f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v17.s }[2], [x24]\n" + "st1 { v28.s }[2], [x23]\n" + "st1 { v25.s }[2], [x22]\n" + "b 214f\n" + "211:" // Height 6: Partial direct writeback: partial_1_4 + "tbz x14, #0, 214f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s17, [x24, #0x0]\n" + "str s28, [x23, #0x0]\n" + "str s25, [x22, #0x0]\n" + "b 214f\n" + "212:" // Height 6: Partial direct writeback: partial_2_0 + "tbz x14, #1, 213f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "str d15, [x25], #0x8\n" + "str d16, [x24], #0x8\n" + "str d23, [x23], #0x8\n" + "str d24, [x22], #0x8\n" + "tbz x14, #0, 214f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v15.s }[2], [x25]\n" + "st1 { v16.s }[2], [x24]\n" + "st1 { v23.s }[2], [x23]\n" + "st1 { v24.s }[2], [x22]\n" + "b 214f\n" + "213:" // Height 6: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "str s15, [x25, #0x0]\n" + "str s16, [x24, #0x0]\n" + "str s23, [x23, #0x0]\n" + "str s24, [x22, #0x0]\n" + "214:" // Height 6: Partial direct writeback: Done + "b 216f\n" + "215:" // Height 6: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "str q15, [x25, #0x0]\n" + "str q20, [x25, #0x10]\n" + "str q21, [x25, #0x20]\n" + "str q22, [x25, #0x30]\n" + "str q16, [x24, #0x0]\n" + "str q17, [x24, #0x10]\n" + "str q18, [x24, #0x20]\n" + "str q19, [x24, #0x30]\n" + "str q23, [x23, #0x0]\n" + "str q28, [x23, #0x10]\n" + "str q29, [x23, #0x20]\n" + "str q30, [x23, #0x30]\n" + "str q24, [x22, #0x0]\n" + "str q25, [x22, #0x10]\n" + "str q26, [x22, #0x20]\n" + "str q27, [x22, #0x30]\n" + "216:" // Height 6: Writeback done + "subs x14, x14, #0x10\n" + "bgt 182b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 218f\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 217f\n" + "add x21, x21, #0x6\n" + "str x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "217:" // Update direct input + "mov x20, #0x18\n" + "madd %x[input_ptr], x20, x21, %x[input_ptr]\n" + "b 1b\n" + "218:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr) + : [args_ptr] "r" (&ka), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_output_ptr] "I" (offsetof(KernelArgs, output_ptr)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp index cf4d74266a..1a8b0fd630 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,10 @@ public: if (std::is_same<T, float>::value) { switch (ci->get_cpu_model()) { + case CPUModel::V1: + return { 45.25, 4.29, 4.80 }; default: - return { 38.10, 5.23, 3.15 }; + return { 29.85, 2.60, 5.49 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp index 586d6a64a4..d9668aae02 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) #include "../performance_parameters.hpp" #include "../std_transforms_fixed.hpp" @@ -89,4 +89,4 @@ public: } // namespace arm_gemm -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp new file mode 100644 index 0000000000..7792192856 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include <cstdint> +#include "../std_transforms_sme.hpp" + +namespace arm_gemm +{ + +// Implementations +void sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + +class cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL +{ +public: + typedef int8_t operand_type; + typedef float result_type; + + typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + + /* Kernel blocking parameters */ + static unsigned int out_height() + { + return sme::get_vector_length<int32_t>() * 1; + } + + static unsigned int out_width() + { + return sme::get_vector_length<int32_t>() * 4; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + static constexpr bool supports_bias() + { + return true; + } + + static constexpr bool supports_activation() + { + return true; + } + + static constexpr bool is_sme() + { + return true; + } + + // Default to the generic kernel + kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL; + + StdTransformsSME<operand_type, result_type, 1, 4, 4> transforms = {}; + + cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp new file mode 100644 index 0000000000..4b26a6578c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp @@ -0,0 +1,417 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_gemm.hpp" + +#include <cstdint> +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm { + +void sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer) +{ + struct KernelArgs + { + KernelArgs( + const int8_t *const A, + const int8_t *const B, + float *const C, const int ldc, + const int M, const int N, const int K, + const int32_t *const bias, const float *const late_bias, const Activation act, + bool accumulate, + int32_t *const accumulator_buffer + ) : A(A), + B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)), + C(C), ldcb(ldc * sizeof(float)), + M(M), N(N), K(K), + min(-std::numeric_limits<float>::infinity()), + max(std::numeric_limits<float>::infinity()), + bias(bias), late_bias(late_bias), + accumulator_buffer(accumulator_buffer), + flags(0x0) + { + if (accumulate) + { + flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER + } + if (C == nullptr) + { + flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER + } + + // Initialise the activation values + switch (act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + this->max = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + this->min = static_cast<float>(0); + break; + } + } + + const int8_t *const A; + const int8_t *const B; + const long kstride_bytes; + float *const C; + const long ldcb; + const long M, N, K; + float min = -std::numeric_limits<float>::infinity(); + float max = std::numeric_limits<float>::infinity(); + + const int32_t *const bias; + const float *const late_bias; + + int32_t *const accumulator_buffer; + uint64_t flags; + }; + + // Construct arguments for this kernel + KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer); + + __asm__ __volatile__( + "ldr x13, [%x[args], %[offsetof_flags]]\n" + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p0.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr x11, [%x[args], %[offsetof_accumulator_buffer]]\n" + "ldr x10, [%x[args], %[offsetof_accumulator_buffer]]\n" + "tbz x13, #0, 2f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "1:" // Initial accumulator load from buffer: Loop + ".inst 0xa040c57c // ld1w { z28.s-z31.s }, pn9.b/Z, [x11]\n" + ".inst 0xa041c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n" + ".inst 0xa042c578 // ld1w { z24.s-z27.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n" + ".inst 0xa043c56c // ld1w { z12.s-z15.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n" + ".inst 0xc0840780 // mova za0h.s[x12], { z28.s-z31.s }\n" + "addvl x11, x11, #16\n" + ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840702 // mova za2h.s[x12], { z24.s-z27.s }\n" + ".inst 0xc0840583 // mova za3h.s[x12], { z12.s-z15.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 1b\n" + "2:" // Initial accumulator load from buffer: End + "ldr w9, [%x[args], %[offsetof_M]]\n" + "mov x28, #0x0\n" + "mov x27, #0x0\n" + "ldr w26, [%x[args], %[offsetof_N]]\n" + "ldr x25, [%x[args], %[offsetof_A]]\n" + "3:" // M and N loop + "mov x24, x25\n" + ".inst 0x25ba6770 // whilelt pn8.s, x27, x26, VLx4\n" + "tbnz x13, #0, 4f\n" + "ldr x20, [%x[args], %[offsetof_bias]]\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "cbz x20, 5f\n" + ".inst 0xa01bc288 // ld1w { z8.s-z11.s }, p8/Z, [x20, x27, LSL #2]\n" + ".inst 0xc0900100 // addha za0.s, p0/M, p0/M, z8.s\n" + ".inst 0xc0900121 // addha za1.s, p0/M, p0/M, z9.s\n" + ".inst 0xc0900142 // addha za2.s, p0/M, p0/M, z10.s\n" + ".inst 0xc0900163 // addha za3.s, p0/M, p0/M, z11.s\n" + "4:" // Prepare accumulators: Test for last block + "mov x20, x27\n" + "mov x21, x28\n" + "incw x20, ALL, MUL #4\n" + "incw x21\n" + "cmp x20, x26\n" + "mov x20, x13\n" + "csel x21, x28, x21, LT\n" + "bfm x13, XZR, #0x0, #0x0 // bfc x13, #0x0, #0x1\n" + "cmp x21, x9\n" + "csel x13, x20, x13, LT\n" + "5:" // Prepare accumulators: End + "ldr x20, [%x[args], %[offsetof_K]]\n" + "ldr x23, [%x[args], %[offsetof_B]]\n" + "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "madd x23, x27, x22, x23\n" // bptr = B + n * kstride_bytes + "and x20, x20, #0x3\n" + "cbz x21, 8f\n" + "subs x21, x21, #0x1\n" + "ld1b { z31.b }, p0/Z, [x24]\n" + ".inst 0xa04086e8 // ld1b { z8.b-z11.b }, pn9.b/Z, [x23]\n" + "ld1b { z1.b }, p0/Z, [x24, #1, MUL VL]\n" + ".inst 0xa04186e4 // ld1b { z4.b-z7.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n" + "ld1b { z0.b }, p0/Z, [x24, #2, MUL VL]\n" + ".inst 0xa04286ec // ld1b { z12.b-z15.b }, pn9.b/Z, [x23, #0x8, MUL VL]\n" + "ld1b { z3.b }, p0/Z, [x24, #3, MUL VL]\n" + "addvl x24, x24, #4\n" + ".inst 0xa04386f0 // ld1b { z16.b-z19.b }, pn9.b/Z, [x23, #0xc, MUL VL]\n" + "addvl x23, x23, #16\n" + "ble 7f\n" + "6:" // K loop + ".inst 0xa08803e0 // smopa za0.s, p0/M, p0/M, z31.b, z8.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa08903e1 // smopa za1.s, p0/M, p0/M, z31.b, z9.b\n" + ".inst 0xa08a03e2 // smopa za2.s, p0/M, p0/M, z31.b, z10.b\n" + ".inst 0xa08b03e3 // smopa za3.s, p0/M, p0/M, z31.b, z11.b\n" + "ld1b { z31.b }, p0/Z, [x24]\n" + ".inst 0xa0840020 // smopa za0.s, p0/M, p0/M, z1.b, z4.b\n" + ".inst 0xa04086e8 // ld1b { z8.b-z11.b }, pn9.b/Z, [x23]\n" + ".inst 0xa0850021 // smopa za1.s, p0/M, p0/M, z1.b, z5.b\n" + ".inst 0xa0860022 // smopa za2.s, p0/M, p0/M, z1.b, z6.b\n" + ".inst 0xa0870023 // smopa za3.s, p0/M, p0/M, z1.b, z7.b\n" + "ld1b { z1.b }, p0/Z, [x24, #1, MUL VL]\n" + ".inst 0xa08c0000 // smopa za0.s, p0/M, p0/M, z0.b, z12.b\n" + ".inst 0xa04186e4 // ld1b { z4.b-z7.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n" + ".inst 0xa08d0001 // smopa za1.s, p0/M, p0/M, z0.b, z13.b\n" + ".inst 0xa08e0002 // smopa za2.s, p0/M, p0/M, z0.b, z14.b\n" + ".inst 0xa08f0003 // smopa za3.s, p0/M, p0/M, z0.b, z15.b\n" + "ld1b { z0.b }, p0/Z, [x24, #2, MUL VL]\n" + ".inst 0xa04286ec // ld1b { z12.b-z15.b }, pn9.b/Z, [x23, #0x8, MUL VL]\n" + ".inst 0xa0900060 // smopa za0.s, p0/M, p0/M, z3.b, z16.b\n" + ".inst 0xa0910061 // smopa za1.s, p0/M, p0/M, z3.b, z17.b\n" + ".inst 0xa0920062 // smopa za2.s, p0/M, p0/M, z3.b, z18.b\n" + ".inst 0xa0930063 // smopa za3.s, p0/M, p0/M, z3.b, z19.b\n" + "ld1b { z3.b }, p0/Z, [x24, #3, MUL VL]\n" + "addvl x24, x24, #4\n" + ".inst 0xa04386f0 // ld1b { z16.b-z19.b }, pn9.b/Z, [x23, #0xc, MUL VL]\n" + "addvl x23, x23, #16\n" + "bgt 6b\n" + "7:" // K loop tail + ".inst 0xa08803e0 // smopa za0.s, p0/M, p0/M, z31.b, z8.b\n" + ".inst 0xa08903e1 // smopa za1.s, p0/M, p0/M, z31.b, z9.b\n" + ".inst 0xa08a03e2 // smopa za2.s, p0/M, p0/M, z31.b, z10.b\n" + ".inst 0xa08b03e3 // smopa za3.s, p0/M, p0/M, z31.b, z11.b\n" + ".inst 0xa0840020 // smopa za0.s, p0/M, p0/M, z1.b, z4.b\n" + ".inst 0xa0850021 // smopa za1.s, p0/M, p0/M, z1.b, z5.b\n" + ".inst 0xa0860022 // smopa za2.s, p0/M, p0/M, z1.b, z6.b\n" + ".inst 0xa0870023 // smopa za3.s, p0/M, p0/M, z1.b, z7.b\n" + ".inst 0xa08c0000 // smopa za0.s, p0/M, p0/M, z0.b, z12.b\n" + ".inst 0xa08d0001 // smopa za1.s, p0/M, p0/M, z0.b, z13.b\n" + ".inst 0xa08e0002 // smopa za2.s, p0/M, p0/M, z0.b, z14.b\n" + ".inst 0xa08f0003 // smopa za3.s, p0/M, p0/M, z0.b, z15.b\n" + ".inst 0xa0900060 // smopa za0.s, p0/M, p0/M, z3.b, z16.b\n" + ".inst 0xa0910061 // smopa za1.s, p0/M, p0/M, z3.b, z17.b\n" + ".inst 0xa0920062 // smopa za2.s, p0/M, p0/M, z3.b, z18.b\n" + ".inst 0xa0930063 // smopa za3.s, p0/M, p0/M, z3.b, z19.b\n" + "8:" // K oddments + "cbz x20, 10f\n" + "9:" // K oddments: Loop + "ld1b { z18.b }, p0/Z, [x24]\n" + "subs x20, x20, #0x1\n" + "addvl x24, x24, #1\n" + ".inst 0xa04086fc // ld1b { z28.b-z31.b }, pn9.b/Z, [x23]\n" + "addvl x23, x23, #4\n" + ".inst 0xa09c0240 // smopa za0.s, p0/M, p0/M, z18.b, z28.b\n" + ".inst 0xa09d0241 // smopa za1.s, p0/M, p0/M, z18.b, z29.b\n" + ".inst 0xa09e0242 // smopa za2.s, p0/M, p0/M, z18.b, z30.b\n" + ".inst 0xa09f0243 // smopa za3.s, p0/M, p0/M, z18.b, z31.b\n" + "bgt 9b\n" + "10:" // K oddments: End + "tbz x13, #1, 14f\n" + "tbz x13, #0, 12f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "11:" // Store to partial result buffer: Store and refill: Loop + ".inst 0xa040c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11]\n" + ".inst 0xc0860408 // mova { z8.s-z11.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xa041c57c // ld1w { z28.s-z31.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n" + ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xa042c578 // ld1w { z24.s-z27.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n" + ".inst 0xa043c574 // ld1w { z20.s-z23.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n" + ".inst 0xc0840400 // mova za0h.s[x12], { z0.s-z3.s }\n" + "addvl x11, x11, #16\n" + ".inst 0xc0840781 // mova za1h.s[x12], { z28.s-z31.s }\n" + ".inst 0xa060c548 // st1w { z8.s-z11.s }, pn9.b, [x10]\n" + ".inst 0xc0840702 // mova za2h.s[x12], { z24.s-z27.s }\n" + ".inst 0xa061c54c // st1w { z12.s-z15.s }, pn9.b, [x10, #0x4, MUL VL]\n" + ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n" + "add x12, x12, #0x4\n" + ".inst 0xa062c544 // st1w { z4.s-z7.s }, pn9.b, [x10, #0x8, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa063c550 // st1w { z16.s-z19.s }, pn9.b, [x10, #0xc, MUL VL]\n" + "addvl x10, x10, #16\n" + "blt 11b\n" + "b 21f\n" + "12:" // Store to partial result buffer: Store only + "mov x12, #0x0\n" + "cntw x20\n" + "13:" // Store to partial result buffer: Store only: Loop + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n" + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xa060c544 // st1w { z4.s-z7.s }, pn9.b, [x10]\n" + "add x12, x12, #0x4\n" + ".inst 0xa061c550 // st1w { z16.s-z19.s }, pn9.b, [x10, #0x4, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa062c548 // st1w { z8.s-z11.s }, pn9.b, [x10, #0x8, MUL VL]\n" + ".inst 0xa063c54c // st1w { z12.s-z15.s }, pn9.b, [x10, #0xc, MUL VL]\n" + "addvl x10, x10, #16\n" + "blt 13b\n" + "b 21f\n" + "14:" // Store to output array + "ldr x23, [%x[args], %[offsetof_C]]\n" + "sub x21, x9, x28\n" + "ld1rw { z18.s }, p0/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n" + "fmov z20.s, #0x0\n" + "ldr x22, [%x[args], %[offsetof_ldcb]]\n" + "fmov z21.s, #0x0\n" + "fmov z22.s, #0x0\n" + "ldr x20, [%x[args], %[offsetof_late_bias]]\n" + "fmov z23.s, #0x0\n" + "add x23, x23, x27, LSL #2\n" // C += n + "madd x23, x28, x22, x23\n" // C += m * ldc + "cbz x20, 15f\n" + "add x20, x20, x27, LSL #2\n" + ".inst 0xa040c294 // ld1w { z20.s-z23.s }, p8/Z, [x20]\n" + "15:" // Store to output array: no late bias + "cntw x20\n" + "ld1rw { z17.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "mov x12, #0x0\n" + "cmp x21, x20\n" + "ld1rw { z16.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x20, x21, x20, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 17f\n" + "16:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n" + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmad z0.s, p0/M, z18.s, z20.s\n" + "fmad z1.s, p0/M, z18.s, z20.s\n" + "fmad z2.s, p0/M, z18.s, z20.s\n" + "fmad z3.s, p0/M, z18.s, z20.s\n" + "add x12, x12, #0x4\n" + "fmad z4.s, p0/M, z18.s, z21.s\n" + "fmad z5.s, p0/M, z18.s, z21.s\n" + "cmp x12, x21, LSL #2\n" + "fmad z6.s, p0/M, z18.s, z21.s\n" + "fmad z7.s, p0/M, z18.s, z21.s\n" + "fmad z8.s, p0/M, z18.s, z22.s\n" + "fmad z9.s, p0/M, z18.s, z22.s\n" + "fmad z10.s, p0/M, z18.s, z22.s\n" + "fmad z11.s, p0/M, z18.s, z22.s\n" + "fmad z12.s, p0/M, z18.s, z23.s\n" + "fmad z13.s, p0/M, z18.s, z23.s\n" + "fmad z14.s, p0/M, z18.s, z23.s\n" + "fmad z15.s, p0/M, z18.s, z23.s\n" + ".inst 0xc1b0ca20 // fclamp { z0.s-z3.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca24 // fclamp { z4.s-z7.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca28 // fclamp { z8.s-z11.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca2c // fclamp { z12.s-z15.s }, z17.s, z16.s\n" + ".inst 0xa160c2e0 // st1w { z0.s, z4.s, z8.s, z12.s }, p8, [x23]\n" + "add x23, x23, x22\n" + ".inst 0xa160c2e1 // st1w { z1.s, z5.s, z9.s, z13.s }, p8, [x23]\n" + "add x23, x23, x22\n" + ".inst 0xa160c2e2 // st1w { z2.s, z6.s, z10.s, z14.s }, p8, [x23]\n" + "add x23, x23, x22\n" + ".inst 0xa160c2e3 // st1w { z3.s, z7.s, z11.s, z15.s }, p8, [x23]\n" + "add x23, x23, x22\n" + "blt 16b\n" + "17:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 18f\n" + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n" + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmad z0.s, p0/M, z18.s, z20.s\n" + "fmad z1.s, p0/M, z18.s, z20.s\n" + "fmad z2.s, p0/M, z18.s, z20.s\n" + "fmad z3.s, p0/M, z18.s, z20.s\n" + "subs x20, x20, #0x1\n" + "fmad z4.s, p0/M, z18.s, z21.s\n" + "fmad z5.s, p0/M, z18.s, z21.s\n" + "fmad z6.s, p0/M, z18.s, z21.s\n" + "fmad z7.s, p0/M, z18.s, z21.s\n" + "fmad z8.s, p0/M, z18.s, z22.s\n" + "fmad z9.s, p0/M, z18.s, z22.s\n" + "fmad z10.s, p0/M, z18.s, z22.s\n" + "fmad z11.s, p0/M, z18.s, z22.s\n" + "fmad z12.s, p0/M, z18.s, z23.s\n" + "fmad z13.s, p0/M, z18.s, z23.s\n" + "fmad z14.s, p0/M, z18.s, z23.s\n" + "fmad z15.s, p0/M, z18.s, z23.s\n" + ".inst 0xc1b0ca20 // fclamp { z0.s-z3.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca24 // fclamp { z4.s-z7.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca28 // fclamp { z8.s-z11.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca2c // fclamp { z12.s-z15.s }, z17.s, z16.s\n" + ".inst 0xa160c2e0 // st1w { z0.s, z4.s, z8.s, z12.s }, p8, [x23]\n" + "add x23, x23, x22\n" + "beq 18f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa160c2e1 // st1w { z1.s, z5.s, z9.s, z13.s }, p8, [x23]\n" + "add x23, x23, x22\n" + "beq 18f\n" + ".inst 0xa160c2e2 // st1w { z2.s, z6.s, z10.s, z14.s }, p8, [x23]\n" + "18:" // Store to output array: Accumulator row 0 oddments: End + "19:" // Store to output array: End + "tbz x13, #0, 21f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "20:" // Store to output array: Refill accumulators: Loop + ".inst 0xa040c574 // ld1w { z20.s-z23.s }, pn9.b/Z, [x11]\n" + ".inst 0xa041c56c // ld1w { z12.s-z15.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n" + ".inst 0xa042c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n" + ".inst 0xa043c568 // ld1w { z8.s-z11.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n" + ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n" + "addvl x11, x11, #16\n" + ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n" + ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840503 // mova za3h.s[x12], { z8.s-z11.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 20b\n" + "21:" // End block + "incw x27, ALL, MUL #4\n" + "cmp x27, x26\n" + "blt 3b\n" + "incw x28\n" + "mov x27, #0x0\n" + "cmp x28, x9\n" + "mov x25, x24\n" + "blt 3b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp new file mode 100644 index 0000000000..df2c9c0ca3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include <cstdint> +#include "../std_transforms_sme.hpp" + +namespace arm_gemm +{ + +// Implementations +void sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + +class cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL +{ +public: + typedef int8_t operand_type; + typedef float result_type; + + typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + + /* Kernel blocking parameters */ + static unsigned int out_height() + { + return sme::get_vector_length<int32_t>() * 2; + } + + static unsigned int out_width() + { + return sme::get_vector_length<int32_t>() * 2; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + static constexpr bool supports_bias() + { + return true; + } + + static constexpr bool supports_activation() + { + return true; + } + + static constexpr bool is_sme() + { + return true; + } + + // Default to the generic kernel + kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL; + + StdTransformsSME<operand_type, result_type, 2, 2, 4> transforms = {}; + + cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp new file mode 100644 index 0000000000..1631fae8e9 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_gemm.hpp" + +#include <cstdint> +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm { + +void sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer) +{ + struct KernelArgs + { + KernelArgs( + const int8_t *const A, + const int8_t *const B, + float *const C, const int ldc, + const int M, const int N, const int K, + const int32_t *const bias, const float *const late_bias, const Activation act, + bool accumulate, + int32_t *const accumulator_buffer + ) : A(A), + B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)), + C(C), ldcb(ldc * sizeof(float)), + M(M), N(N), K(K), + min(-std::numeric_limits<float>::infinity()), + max(std::numeric_limits<float>::infinity()), + bias(bias), late_bias(late_bias), + accumulator_buffer(accumulator_buffer), + flags(0x0) + { + if (accumulate) + { + flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER + } + if (C == nullptr) + { + flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER + } + + // Initialise the activation values + switch (act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + this->max = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + this->min = static_cast<float>(0); + break; + } + } + + const int8_t *const A; + const int8_t *const B; + const long kstride_bytes; + float *const C; + const long ldcb; + const long M, N, K; + float min = -std::numeric_limits<float>::infinity(); + float max = std::numeric_limits<float>::infinity(); + + const int32_t *const bias; + const float *const late_bias; + + int32_t *const accumulator_buffer; + uint64_t flags; + }; + + // Construct arguments for this kernel + KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer); + + __asm__ __volatile__( + "ldr x16, [%x[args], %[offsetof_flags]]\n" + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p0.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n" + "ldr x14, [%x[args], %[offsetof_accumulator_buffer]]\n" + "tbz x16, #0, 2f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "1:" // Initial accumulator load from buffer: Loop + ".inst 0xa040c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15]\n" + ".inst 0xa041c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xa042c5e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c5f8 // ld1w { z24.s-z27.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840580 // mova za0h.s[x12], { z12.s-z15.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840681 // mova za1h.s[x12], { z20.s-z23.s }\n" + ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840703 // mova za3h.s[x12], { z24.s-z27.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 1b\n" + "2:" // Initial accumulator load from buffer: End + "ldr w13, [%x[args], %[offsetof_M]]\n" + "mov x11, #0x0\n" + "mov x10, #0x0\n" + "ldr w9, [%x[args], %[offsetof_N]]\n" + "ldr x28, [%x[args], %[offsetof_A]]\n" + "3:" // M and N loop + "mov x27, x28\n" + ".inst 0x25a94550 // whilelt pn8.s, x10, x9, VLx2\n" + "tbnz x16, #0, 4f\n" + "ldr x20, [%x[args], %[offsetof_bias]]\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "cbz x20, 5f\n" + ".inst 0xa10a4286 // ld1w { z6.s, z14.s }, p8/Z, [x20, x10, LSL #2]\n" + ".inst 0xc09000c0 // addha za0.s, p0/M, p0/M, z6.s\n" + ".inst 0xc09001c1 // addha za1.s, p0/M, p0/M, z14.s\n" + ".inst 0xc09000c2 // addha za2.s, p0/M, p0/M, z6.s\n" + ".inst 0xc09001c3 // addha za3.s, p0/M, p0/M, z14.s\n" + "4:" // Prepare accumulators: Test for last block + "mov x20, x10\n" + "mov x21, x11\n" + "incw x20, ALL, MUL #2\n" + "incw x21, ALL, MUL #2\n" + "cmp x20, x9\n" + "mov x20, x16\n" + "csel x21, x11, x21, LT\n" + "bfm x16, XZR, #0x0, #0x0 // bfc x16, #0x0, #0x1\n" + "cmp x21, x13\n" + "csel x16, x20, x16, LT\n" + "5:" // Prepare accumulators: End + "ldr x20, [%x[args], %[offsetof_K]]\n" + "ldr x23, [%x[args], %[offsetof_B]]\n" + "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "madd x23, x10, x22, x23\n" // bptr = B + n * kstride_bytes + "and x20, x20, #0x3\n" + "cbz x21, 8f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa1400775 // ld1b { z21.b, z29.b }, pn9.b/Z, [x27]\n" + ".inst 0xa04006f2 // ld1b { z18.b-z19.b }, pn9.b/Z, [x23]\n" + ".inst 0xa041076a // ld1b { z10.b-z11.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa14106e5 // ld1b { z5.b, z13.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n" + ".inst 0xa1420767 // ld1b { z7.b, z15.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa14206f0 // ld1b { z16.b, z24.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n" + ".inst 0xa1430774 // ld1b { z20.b, z28.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa14306f7 // ld1b { z23.b, z31.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n" + "addvl x23, x23, #8\n" + "ble 7f\n" + "6:" // K loop + ".inst 0xa09202a0 // smopa za0.s, p0/M, p0/M, z21.b, z18.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa09302a1 // smopa za1.s, p0/M, p0/M, z21.b, z19.b\n" + ".inst 0xa09203a2 // smopa za2.s, p0/M, p0/M, z29.b, z18.b\n" + ".inst 0xa09303a3 // smopa za3.s, p0/M, p0/M, z29.b, z19.b\n" + ".inst 0xa1400775 // ld1b { z21.b, z29.b }, pn9.b/Z, [x27]\n" + ".inst 0xa0850140 // smopa za0.s, p0/M, p0/M, z10.b, z5.b\n" + ".inst 0xa04006f2 // ld1b { z18.b-z19.b }, pn9.b/Z, [x23]\n" + ".inst 0xa08d0141 // smopa za1.s, p0/M, p0/M, z10.b, z13.b\n" + ".inst 0xa0850162 // smopa za2.s, p0/M, p0/M, z11.b, z5.b\n" + ".inst 0xa08d0163 // smopa za3.s, p0/M, p0/M, z11.b, z13.b\n" + ".inst 0xa041076a // ld1b { z10.b-z11.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa09000e0 // smopa za0.s, p0/M, p0/M, z7.b, z16.b\n" + ".inst 0xa14106e5 // ld1b { z5.b, z13.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n" + ".inst 0xa09800e1 // smopa za1.s, p0/M, p0/M, z7.b, z24.b\n" + ".inst 0xa09001e2 // smopa za2.s, p0/M, p0/M, z15.b, z16.b\n" + ".inst 0xa09801e3 // smopa za3.s, p0/M, p0/M, z15.b, z24.b\n" + ".inst 0xa1420767 // ld1b { z7.b, z15.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa14206f0 // ld1b { z16.b, z24.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n" + ".inst 0xa0970280 // smopa za0.s, p0/M, p0/M, z20.b, z23.b\n" + ".inst 0xa09f0281 // smopa za1.s, p0/M, p0/M, z20.b, z31.b\n" + ".inst 0xa0970382 // smopa za2.s, p0/M, p0/M, z28.b, z23.b\n" + ".inst 0xa09f0383 // smopa za3.s, p0/M, p0/M, z28.b, z31.b\n" + ".inst 0xa1430774 // ld1b { z20.b, z28.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa14306f7 // ld1b { z23.b, z31.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n" + "addvl x23, x23, #8\n" + "bgt 6b\n" + "7:" // K loop tail + ".inst 0xa09202a0 // smopa za0.s, p0/M, p0/M, z21.b, z18.b\n" + ".inst 0xa09302a1 // smopa za1.s, p0/M, p0/M, z21.b, z19.b\n" + ".inst 0xa09203a2 // smopa za2.s, p0/M, p0/M, z29.b, z18.b\n" + ".inst 0xa09303a3 // smopa za3.s, p0/M, p0/M, z29.b, z19.b\n" + ".inst 0xa0850140 // smopa za0.s, p0/M, p0/M, z10.b, z5.b\n" + ".inst 0xa08d0141 // smopa za1.s, p0/M, p0/M, z10.b, z13.b\n" + ".inst 0xa0850162 // smopa za2.s, p0/M, p0/M, z11.b, z5.b\n" + ".inst 0xa08d0163 // smopa za3.s, p0/M, p0/M, z11.b, z13.b\n" + ".inst 0xa09000e0 // smopa za0.s, p0/M, p0/M, z7.b, z16.b\n" + ".inst 0xa09800e1 // smopa za1.s, p0/M, p0/M, z7.b, z24.b\n" + ".inst 0xa09001e2 // smopa za2.s, p0/M, p0/M, z15.b, z16.b\n" + ".inst 0xa09801e3 // smopa za3.s, p0/M, p0/M, z15.b, z24.b\n" + ".inst 0xa0970280 // smopa za0.s, p0/M, p0/M, z20.b, z23.b\n" + ".inst 0xa09f0281 // smopa za1.s, p0/M, p0/M, z20.b, z31.b\n" + ".inst 0xa0970382 // smopa za2.s, p0/M, p0/M, z28.b, z23.b\n" + ".inst 0xa09f0383 // smopa za3.s, p0/M, p0/M, z28.b, z31.b\n" + "8:" // K oddments + "cbz x20, 10f\n" + "9:" // K oddments: Loop + ".inst 0xa040077e // ld1b { z30.b-z31.b }, pn9.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #2\n" + ".inst 0xa14006e7 // ld1b { z7.b, z15.b }, pn9.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa08703c0 // smopa za0.s, p0/M, p0/M, z30.b, z7.b\n" + ".inst 0xa08f03c1 // smopa za1.s, p0/M, p0/M, z30.b, z15.b\n" + ".inst 0xa08703e2 // smopa za2.s, p0/M, p0/M, z31.b, z7.b\n" + ".inst 0xa08f03e3 // smopa za3.s, p0/M, p0/M, z31.b, z15.b\n" + "bgt 9b\n" + "10:" // K oddments: End + "tbz x16, #1, 14f\n" + "tbz x16, #0, 12f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "11:" // Store to partial result buffer: Store and refill: Loop + ".inst 0xa040c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15]\n" + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n" + ".inst 0xa041c5f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xc0860440 // mova { z0.s-z3.s }, za2h.s[x12]\n" + ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n" + ".inst 0xa042c5fc // ld1w { z28.s-z31.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840580 // mova za0h.s[x12], { z12.s-z15.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840601 // mova za1h.s[x12], { z16.s-z19.s }\n" + ".inst 0xa060c5c4 // st1w { z4.s-z7.s }, pn9.b, [x14]\n" + ".inst 0xc0840782 // mova za2h.s[x12], { z28.s-z31.s }\n" + ".inst 0xa061c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14, #0x4, MUL VL]\n" + ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n" + "add x12, x12, #0x4\n" + ".inst 0xa062c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14, #0x8, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa063c5d8 // st1w { z24.s-z27.s }, pn9.b, [x14, #0xc, MUL VL]\n" + "addvl x14, x14, #16\n" + "blt 11b\n" + "b 24f\n" + "12:" // Store to partial result buffer: Store only + "mov x12, #0x0\n" + "cntw x20\n" + "13:" // Store to partial result buffer: Store only: Loop + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n" + ".inst 0xc0860468 // mova { z8.s-z11.s }, za3h.s[x12]\n" + ".inst 0xa060c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14]\n" + "add x12, x12, #0x4\n" + ".inst 0xa061c5cc // st1w { z12.s-z15.s }, pn9.b, [x14, #0x4, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa062c5d0 // st1w { z16.s-z19.s }, pn9.b, [x14, #0x8, MUL VL]\n" + ".inst 0xa063c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14, #0xc, MUL VL]\n" + "addvl x14, x14, #16\n" + "blt 13b\n" + "b 24f\n" + "14:" // Store to output array + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x13, x11\n" + "ld1rw { z3.s }, p0/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n" + "fmov z2.s, #0x0\n" + "ldr x24, [%x[args], %[offsetof_ldcb]]\n" + "fmov z10.s, #0x0\n" + "ldr x20, [%x[args], %[offsetof_late_bias]]\n" + "add x26, x26, x10, LSL #2\n" // C += n + "madd x26, x11, x24, x26\n" // C += m * ldc + "cbz x20, 15f\n" + "add x20, x20, x10, LSL #2\n" + ".inst 0xa1404282 // ld1w { z2.s, z10.s }, p8/Z, [x20]\n" + "15:" // Store to output array: no late bias + "cntw x23\n" + "ld1rw { z1.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "mov x12, #0x0\n" + "cmp x25, x23\n" + "ld1rw { z0.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 17f\n" + "16:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmad z4.s, p0/M, z3.s, z2.s\n" + "fmad z5.s, p0/M, z3.s, z2.s\n" + "add x12, x12, #0x4\n" + "fmad z6.s, p0/M, z3.s, z2.s\n" + "fmad z7.s, p0/M, z3.s, z2.s\n" + "cmp x12, x21, LSL #2\n" + "fmad z12.s, p0/M, z3.s, z10.s\n" + "fmad z13.s, p0/M, z3.s, z10.s\n" + "fmad z14.s, p0/M, z3.s, z10.s\n" + "fmad z15.s, p0/M, z3.s, z10.s\n" + ".inst 0xc1a0c824 // fclamp { z4.s-z7.s }, z1.s, z0.s\n" + ".inst 0xc1a0c82c // fclamp { z12.s-z15.s }, z1.s, z0.s\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "blt 16b\n" + "17:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 18f\n" + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + ".inst 0xc0860438 // mova { z24.s-z27.s }, za1h.s[x12]\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n" + "fmad z16.s, p0/M, z3.s, z2.s\n" + "fmad z17.s, p0/M, z3.s, z2.s\n" + "subs x20, x20, #0x1\n" + "fmad z18.s, p0/M, z3.s, z2.s\n" + "fmad z19.s, p0/M, z3.s, z2.s\n" + "fmad z24.s, p0/M, z3.s, z10.s\n" + "fmad z25.s, p0/M, z3.s, z10.s\n" + "fmad z26.s, p0/M, z3.s, z10.s\n" + "fmad z27.s, p0/M, z3.s, z10.s\n" + ".inst 0xc1a0c830 // fclamp { z16.s-z19.s }, z1.s, z0.s\n" + ".inst 0xc1a0c838 // fclamp { z24.s-z27.s }, z1.s, z0.s\n" + ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "18:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 22f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x20, x25, x23, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 20f\n" + "19:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n" + ".inst 0xc086047c // mova { z28.s-z31.s }, za3h.s[x12]\n" + ".inst 0xc132e294 // scvtf { z20.s-z23.s }, { z20.s-z23.s }\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmad z20.s, p0/M, z3.s, z2.s\n" + "fmad z21.s, p0/M, z3.s, z2.s\n" + "add x12, x12, #0x4\n" + "fmad z22.s, p0/M, z3.s, z2.s\n" + "fmad z23.s, p0/M, z3.s, z2.s\n" + "cmp x12, x21, LSL #2\n" + "fmad z28.s, p0/M, z3.s, z10.s\n" + "fmad z29.s, p0/M, z3.s, z10.s\n" + "fmad z30.s, p0/M, z3.s, z10.s\n" + "fmad z31.s, p0/M, z3.s, z10.s\n" + ".inst 0xc1a0c834 // fclamp { z20.s-z23.s }, z1.s, z0.s\n" + ".inst 0xc1a0c83c // fclamp { z28.s-z31.s }, z1.s, z0.s\n" + ".inst 0xa1604354 // st1w { z20.s, z28.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604355 // st1w { z21.s, z29.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604356 // st1w { z22.s, z30.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604357 // st1w { z23.s, z31.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "blt 19b\n" + "20:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 21f\n" + ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmad z4.s, p0/M, z3.s, z2.s\n" + "fmad z5.s, p0/M, z3.s, z2.s\n" + "subs x20, x20, #0x1\n" + "fmad z6.s, p0/M, z3.s, z2.s\n" + "fmad z7.s, p0/M, z3.s, z2.s\n" + "fmad z12.s, p0/M, z3.s, z10.s\n" + "fmad z13.s, p0/M, z3.s, z10.s\n" + "fmad z14.s, p0/M, z3.s, z10.s\n" + "fmad z15.s, p0/M, z3.s, z10.s\n" + ".inst 0xc1a0c824 // fclamp { z4.s-z7.s }, z1.s, z0.s\n" + ".inst 0xc1a0c82c // fclamp { z12.s-z15.s }, z1.s, z0.s\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "21:" // Store to output array: Accumulator row 1 oddments: End + "22:" // Store to output array: End + "tbz x16, #0, 24f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "23:" // Store to output array: Refill accumulators: Loop + ".inst 0xa040c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15]\n" + ".inst 0xa041c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xa042c5e4 // ld1w { z4.s-z7.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c5e8 // ld1w { z8.s-z11.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n" + ".inst 0xc0840482 // mova za2h.s[x12], { z4.s-z7.s }\n" + ".inst 0xc0840503 // mova za3h.s[x12], { z8.s-z11.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 23b\n" + "24:" // End block + "incw x10, ALL, MUL #2\n" + "cmp x10, x9\n" + "blt 3b\n" + "incw x11, ALL, MUL #2\n" + "mov x10, #0x0\n" + "cmp x11, x13\n" + "mov x28, x27\n" + "blt 3b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp new file mode 100644 index 0000000000..70952f4f03 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include <cstdint> +#include "../std_transforms_sme.hpp" + +namespace arm_gemm +{ + +// Implementations +void sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + +class cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL +{ +public: + typedef int8_t operand_type; + typedef float result_type; + + typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + + /* Kernel blocking parameters */ + static unsigned int out_height() + { + return sme::get_vector_length<int32_t>() * 4; + } + + static unsigned int out_width() + { + return sme::get_vector_length<int32_t>() * 1; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + static constexpr bool supports_bias() + { + return true; + } + + static constexpr bool supports_activation() + { + return true; + } + + static constexpr bool is_sme() + { + return true; + } + + // Default to the generic kernel + kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL; + + StdTransformsSME<operand_type, result_type, 4, 1, 4> transforms = {}; + + cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp new file mode 100644 index 0000000000..bafb16bca8 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp @@ -0,0 +1,513 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_gemm.hpp" + +#include <cstdint> +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm { + +void sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer) +{ + struct KernelArgs + { + KernelArgs( + const int8_t *const A, + const int8_t *const B, + float *const C, const int ldc, + const int M, const int N, const int K, + const int32_t *const bias, const float *const late_bias, const Activation act, + bool accumulate, + int32_t *const accumulator_buffer + ) : A(A), + B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)), + C(C), ldcb(ldc * sizeof(float)), + M(M), N(N), K(K), + min(-std::numeric_limits<float>::infinity()), + max(std::numeric_limits<float>::infinity()), + bias(bias), late_bias(late_bias), + accumulator_buffer(accumulator_buffer), + flags(0x0) + { + if (accumulate) + { + flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER + } + if (C == nullptr) + { + flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER + } + + // Initialise the activation values + switch (act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + this->max = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + this->min = static_cast<float>(0); + break; + } + } + + const int8_t *const A; + const int8_t *const B; + const long kstride_bytes; + float *const C; + const long ldcb; + const long M, N, K; + float min = -std::numeric_limits<float>::infinity(); + float max = std::numeric_limits<float>::infinity(); + + const int32_t *const bias; + const float *const late_bias; + + int32_t *const accumulator_buffer; + uint64_t flags; + }; + + // Construct arguments for this kernel + KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer); + + __asm__ __volatile__( + "ldr x16, [%x[args], %[offsetof_flags]]\n" + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p1.b\n" + ".inst 0x25207810 // ptrue pn8.b\n" + "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n" + "ldr x14, [%x[args], %[offsetof_accumulator_buffer]]\n" + "tbz x16, #0, 2f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "1:" // Initial accumulator load from buffer: Loop + ".inst 0xa040c1f4 // ld1w { z20.s-z23.s }, pn8.b/Z, [x15]\n" + ".inst 0xa041c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xa042c1e8 // ld1w { z8.s-z11.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c1f0 // ld1w { z16.s-z19.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840781 // mova za1h.s[x12], { z28.s-z31.s }\n" + ".inst 0xc0840502 // mova za2h.s[x12], { z8.s-z11.s }\n" + ".inst 0xc0840603 // mova za3h.s[x12], { z16.s-z19.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 1b\n" + "2:" // Initial accumulator load from buffer: End + "ldr w13, [%x[args], %[offsetof_M]]\n" + "mov x11, #0x0\n" + "mov x10, #0x0\n" + "ldr w9, [%x[args], %[offsetof_N]]\n" + "ldr x28, [%x[args], %[offsetof_A]]\n" + "3:" // M and N loop + "mov x27, x28\n" + "whilelt p0.s, x10, x9\n" + "tbnz x16, #0, 4f\n" + "ldr x20, [%x[args], %[offsetof_bias]]\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "cbz x20, 5f\n" + "ld1w { z23.s }, p0/Z, [x20, x10, LSL #2]\n" + ".inst 0xc09026e0 // addha za0.s, p1/M, p1/M, z23.s\n" + ".inst 0xc09026e1 // addha za1.s, p1/M, p1/M, z23.s\n" + ".inst 0xc09026e2 // addha za2.s, p1/M, p1/M, z23.s\n" + ".inst 0xc09026e3 // addha za3.s, p1/M, p1/M, z23.s\n" + "4:" // Prepare accumulators: Test for last block + "mov x20, x10\n" + "mov x21, x11\n" + "incw x20\n" + "incw x21, ALL, MUL #4\n" + "cmp x20, x9\n" + "mov x20, x16\n" + "csel x21, x11, x21, LT\n" + "bfm x16, XZR, #0x0, #0x0 // bfc x16, #0x0, #0x1\n" + "cmp x21, x13\n" + "csel x16, x20, x16, LT\n" + "5:" // Prepare accumulators: End + "ldr x20, [%x[args], %[offsetof_K]]\n" + "ldr x23, [%x[args], %[offsetof_B]]\n" + "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "madd x23, x10, x22, x23\n" // bptr = B + n * kstride_bytes + "and x20, x20, #0x3\n" + "cbz x21, 8f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0408378 // ld1b { z24.b-z27.b }, pn8.b/Z, [x27]\n" + "ld1b { z4.b }, p1/Z, [x23]\n" + ".inst 0xa0418374 // ld1b { z20.b-z23.b }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + "ld1b { z2.b }, p1/Z, [x23, #1, MUL VL]\n" + ".inst 0xa042836c // ld1b { z12.b-z15.b }, pn8.b/Z, [x27, #0x8, MUL VL]\n" + "ld1b { z11.b }, p1/Z, [x23, #2, MUL VL]\n" + ".inst 0xa0438370 // ld1b { z16.b-z19.b }, pn8.b/Z, [x27, #0xc, MUL VL]\n" + "addvl x27, x27, #16\n" + "ld1b { z28.b }, p1/Z, [x23, #3, MUL VL]\n" + "addvl x23, x23, #4\n" + "ble 7f\n" + "6:" // K loop + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0842721 // smopa za1.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0842742 // smopa za2.s, p1/M, p1/M, z26.b, z4.b\n" + ".inst 0xa0842763 // smopa za3.s, p1/M, p1/M, z27.b, z4.b\n" + ".inst 0xa0408378 // ld1b { z24.b-z27.b }, pn8.b/Z, [x27]\n" + ".inst 0xa0822680 // smopa za0.s, p1/M, p1/M, z20.b, z2.b\n" + "ld1b { z4.b }, p1/Z, [x23]\n" + ".inst 0xa08226a1 // smopa za1.s, p1/M, p1/M, z21.b, z2.b\n" + ".inst 0xa08226c2 // smopa za2.s, p1/M, p1/M, z22.b, z2.b\n" + ".inst 0xa08226e3 // smopa za3.s, p1/M, p1/M, z23.b, z2.b\n" + ".inst 0xa0418374 // ld1b { z20.b-z23.b }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa08b2580 // smopa za0.s, p1/M, p1/M, z12.b, z11.b\n" + "ld1b { z2.b }, p1/Z, [x23, #1, MUL VL]\n" + ".inst 0xa08b25a1 // smopa za1.s, p1/M, p1/M, z13.b, z11.b\n" + ".inst 0xa08b25c2 // smopa za2.s, p1/M, p1/M, z14.b, z11.b\n" + ".inst 0xa08b25e3 // smopa za3.s, p1/M, p1/M, z15.b, z11.b\n" + ".inst 0xa042836c // ld1b { z12.b-z15.b }, pn8.b/Z, [x27, #0x8, MUL VL]\n" + "ld1b { z11.b }, p1/Z, [x23, #2, MUL VL]\n" + ".inst 0xa09c2600 // smopa za0.s, p1/M, p1/M, z16.b, z28.b\n" + ".inst 0xa09c2621 // smopa za1.s, p1/M, p1/M, z17.b, z28.b\n" + ".inst 0xa09c2642 // smopa za2.s, p1/M, p1/M, z18.b, z28.b\n" + ".inst 0xa09c2663 // smopa za3.s, p1/M, p1/M, z19.b, z28.b\n" + ".inst 0xa0438370 // ld1b { z16.b-z19.b }, pn8.b/Z, [x27, #0xc, MUL VL]\n" + "addvl x27, x27, #16\n" + "ld1b { z28.b }, p1/Z, [x23, #3, MUL VL]\n" + "addvl x23, x23, #4\n" + "bgt 6b\n" + "7:" // K loop tail + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + ".inst 0xa0842721 // smopa za1.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0842742 // smopa za2.s, p1/M, p1/M, z26.b, z4.b\n" + ".inst 0xa0842763 // smopa za3.s, p1/M, p1/M, z27.b, z4.b\n" + ".inst 0xa0822680 // smopa za0.s, p1/M, p1/M, z20.b, z2.b\n" + ".inst 0xa08226a1 // smopa za1.s, p1/M, p1/M, z21.b, z2.b\n" + ".inst 0xa08226c2 // smopa za2.s, p1/M, p1/M, z22.b, z2.b\n" + ".inst 0xa08226e3 // smopa za3.s, p1/M, p1/M, z23.b, z2.b\n" + ".inst 0xa08b2580 // smopa za0.s, p1/M, p1/M, z12.b, z11.b\n" + ".inst 0xa08b25a1 // smopa za1.s, p1/M, p1/M, z13.b, z11.b\n" + ".inst 0xa08b25c2 // smopa za2.s, p1/M, p1/M, z14.b, z11.b\n" + ".inst 0xa08b25e3 // smopa za3.s, p1/M, p1/M, z15.b, z11.b\n" + ".inst 0xa09c2600 // smopa za0.s, p1/M, p1/M, z16.b, z28.b\n" + ".inst 0xa09c2621 // smopa za1.s, p1/M, p1/M, z17.b, z28.b\n" + ".inst 0xa09c2642 // smopa za2.s, p1/M, p1/M, z18.b, z28.b\n" + ".inst 0xa09c2663 // smopa za3.s, p1/M, p1/M, z19.b, z28.b\n" + "8:" // K oddments + "cbz x20, 10f\n" + "9:" // K oddments: Loop + ".inst 0xa1408373 // ld1b { z19.b, z23.b, z27.b, z31.b }, pn8.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #4\n" + "ld1b { z16.b }, p1/Z, [x23]\n" + "addvl x23, x23, #1\n" + ".inst 0xa0902660 // smopa za0.s, p1/M, p1/M, z19.b, z16.b\n" + ".inst 0xa09026e1 // smopa za1.s, p1/M, p1/M, z23.b, z16.b\n" + ".inst 0xa0902762 // smopa za2.s, p1/M, p1/M, z27.b, z16.b\n" + ".inst 0xa09027e3 // smopa za3.s, p1/M, p1/M, z31.b, z16.b\n" + "bgt 9b\n" + "10:" // K oddments: End + "tbz x16, #1, 14f\n" + "tbz x16, #0, 12f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "11:" // Store to partial result buffer: Store and refill: Loop + ".inst 0xa040c1e8 // ld1w { z8.s-z11.s }, pn8.b/Z, [x15]\n" + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n" + ".inst 0xa041c1ec // ld1w { z12.s-z15.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xc0860458 // mova { z24.s-z27.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xa042c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c1f4 // ld1w { z20.s-z23.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840500 // mova za0h.s[x12], { z8.s-z11.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n" + ".inst 0xa060c1c0 // st1w { z0.s-z3.s }, pn8.b, [x14]\n" + ".inst 0xc0840782 // mova za2h.s[x12], { z28.s-z31.s }\n" + ".inst 0xa061c1c4 // st1w { z4.s-z7.s }, pn8.b, [x14, #0x4, MUL VL]\n" + ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n" + "add x12, x12, #0x4\n" + ".inst 0xa062c1d8 // st1w { z24.s-z27.s }, pn8.b, [x14, #0x8, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, pn8.b, [x14, #0xc, MUL VL]\n" + "addvl x14, x14, #16\n" + "blt 11b\n" + "b 30f\n" + "12:" // Store to partial result buffer: Store only + "mov x12, #0x0\n" + "cntw x20\n" + "13:" // Store to partial result buffer: Store only: Loop + ".inst 0xc0860408 // mova { z8.s-z11.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xa060c1c8 // st1w { z8.s-z11.s }, pn8.b, [x14]\n" + "add x12, x12, #0x4\n" + ".inst 0xa061c1cc // st1w { z12.s-z15.s }, pn8.b, [x14, #0x4, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa062c1d4 // st1w { z20.s-z23.s }, pn8.b, [x14, #0x8, MUL VL]\n" + ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, pn8.b, [x14, #0xc, MUL VL]\n" + "addvl x14, x14, #16\n" + "blt 13b\n" + "b 30f\n" + "14:" // Store to output array + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x13, x11\n" + "ld1rw { z23.s }, p1/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n" + "fmov z22.s, #0x0\n" + "ldr x24, [%x[args], %[offsetof_ldcb]]\n" + "ldr x20, [%x[args], %[offsetof_late_bias]]\n" + "add x26, x26, x10, LSL #2\n" // C += n + "madd x26, x11, x24, x26\n" // C += m * ldc + "cbz x20, 15f\n" + "add x20, x20, x10, LSL #2\n" + "ld1w { z22.s }, p0/Z, [x20]\n" + "15:" // Store to output array: no late bias + "cntw x23\n" + "ld1rw { z21.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "mov x12, #0x0\n" + "cmp x25, x23\n" + "ld1rw { z20.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 17f\n" + "16:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + "add x12, x12, #0x4\n" + ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + "cmp x12, x21, LSL #2\n" + "fmad z0.s, p1/M, z23.s, z22.s\n" + "fmad z1.s, p1/M, z23.s, z22.s\n" + "fmad z2.s, p1/M, z23.s, z22.s\n" + "fmad z3.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4caa0 // fclamp { z0.s-z3.s }, z21.s, z20.s\n" + "st1w { z0.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z1.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z2.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z3.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "blt 16b\n" + "17:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 18f\n" + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmad z16.s, p1/M, z23.s, z22.s\n" + "fmad z17.s, p1/M, z23.s, z22.s\n" + "fmad z18.s, p1/M, z23.s, z22.s\n" + "fmad z19.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + "st1w { z16.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + "subs x20, x20, #0x1\n" + "st1w { z17.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + "st1w { z18.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "18:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 28f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 20f\n" + "19:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n" + "add x12, x12, #0x4\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "cmp x12, x21, LSL #2\n" + "fmad z16.s, p1/M, z23.s, z22.s\n" + "fmad z17.s, p1/M, z23.s, z22.s\n" + "fmad z18.s, p1/M, z23.s, z22.s\n" + "fmad z19.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + "st1w { z16.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z17.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z18.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z19.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "blt 19b\n" + "20:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 21f\n" + ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmad z28.s, p1/M, z23.s, z22.s\n" + "fmad z29.s, p1/M, z23.s, z22.s\n" + "fmad z30.s, p1/M, z23.s, z22.s\n" + "fmad z31.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cabc // fclamp { z28.s-z31.s }, z21.s, z20.s\n" + "st1w { z28.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + "subs x20, x20, #0x1\n" + "st1w { z29.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + "st1w { z30.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "21:" // Store to output array: Accumulator row 1 oddments: End + "subs x25, x25, x22\n" + "beq 28f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 23f\n" + "22:" // Store to output array: Accumulator row 2 loop + ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n" + "add x12, x12, #0x4\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "cmp x12, x21, LSL #2\n" + "fmad z12.s, p1/M, z23.s, z22.s\n" + "fmad z13.s, p1/M, z23.s, z22.s\n" + "fmad z14.s, p1/M, z23.s, z22.s\n" + "fmad z15.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4caac // fclamp { z12.s-z15.s }, z21.s, z20.s\n" + "st1w { z12.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z13.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z14.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z15.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "blt 22b\n" + "23:" // Store to output array: Accumulator row 2 oddments + "cbz x20, 24f\n" + ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmad z16.s, p1/M, z23.s, z22.s\n" + "fmad z17.s, p1/M, z23.s, z22.s\n" + "fmad z18.s, p1/M, z23.s, z22.s\n" + "fmad z19.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + "st1w { z16.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 24f\n" + "subs x20, x20, #0x1\n" + "st1w { z17.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 24f\n" + "st1w { z18.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "24:" // Store to output array: Accumulator row 2 oddments: End + "subs x25, x25, x22\n" + "beq 28f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x20, x25, x23, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 26f\n" + "25:" // Store to output array: Accumulator row 3 loop + ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n" + "add x12, x12, #0x4\n" + ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n" + "cmp x12, x21, LSL #2\n" + "fmad z24.s, p1/M, z23.s, z22.s\n" + "fmad z25.s, p1/M, z23.s, z22.s\n" + "fmad z26.s, p1/M, z23.s, z22.s\n" + "fmad z27.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n" + "st1w { z24.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z25.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z26.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z27.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "blt 25b\n" + "26:" // Store to output array: Accumulator row 3 oddments + "cbz x20, 27f\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmad z16.s, p1/M, z23.s, z22.s\n" + "fmad z17.s, p1/M, z23.s, z22.s\n" + "fmad z18.s, p1/M, z23.s, z22.s\n" + "fmad z19.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + "st1w { z16.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 27f\n" + "subs x20, x20, #0x1\n" + "st1w { z17.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 27f\n" + "st1w { z18.s }, p0, [x26]\n" + "27:" // Store to output array: Accumulator row 3 oddments: End + "28:" // Store to output array: End + "tbz x16, #0, 30f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "29:" // Store to output array: Refill accumulators: Loop + ".inst 0xa040c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15]\n" + ".inst 0xa041c1e0 // ld1w { z0.s-z3.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xa042c1ec // ld1w { z12.s-z15.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c1e4 // ld1w { z4.s-z7.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840780 // mova za0h.s[x12], { z28.s-z31.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840582 // mova za2h.s[x12], { z12.s-z15.s }\n" + ".inst 0xc0840483 // mova za3h.s[x12], { z4.s-z7.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 29b\n" + "30:" // End block + "incw x10\n" + "cmp x10, x9\n" + "blt 3b\n" + "incw x11, ALL, MUL #4\n" + "mov x10, #0x0\n" + "cmp x11, x13\n" + "mov x28, x27\n" + "blt 3b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp index 887d78e1de..23f686a902 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,10 @@ public: { if (std::is_same<T, float>::value) { switch (ci->get_cpu_model()) { + case CPUModel::V1: + return { 28.74 }; default: - return { 32.35 }; + return { 15.27 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp index d0ef531c33..1fe5f48da6 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,10 @@ public: if (std::is_same<T, float>::value) { switch (ci->get_cpu_model()) { - default: - return { 39.66, 5.18, 4.37 }; + case CPUModel::V1: + return { 53.48, 4.23, 6.53 }; + default: + return { 29.07, 2.76, 5.39 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp index a81d4504ae..ba47e0aa54 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2020, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) template<> void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const __fp16 *bias, Activation act, bool append) @@ -86,7 +86,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -140,7 +140,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -217,7 +217,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -317,7 +317,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -439,7 +439,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -584,7 +584,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -752,7 +752,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -944,7 +944,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1150,7 +1150,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1204,7 +1204,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1278,7 +1278,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1372,7 +1372,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1485,7 +1485,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1618,7 +1618,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1771,7 +1771,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1945,7 +1945,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -2112,4 +2112,4 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } } -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index 111d01ed3a..6da9f4be0e 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2019, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -1142,6 +1142,64 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); +void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height, + const int32_t* in_ptr, unsigned int in_stride, float *out_ptr, unsigned int out_stride, + const float* bias_ptr, bool accumulate, const Activation &act) +{ + const float32x4_t vscale = vdupq_n_f32(qp.scale); + float maxval = std::numeric_limits<float>::infinity(); + float minval = -std::numeric_limits<float>::infinity(); + + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + maxval = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + minval = 0; + break; + } + + const float32x4_t vmin = vdupq_n_f32(minval); + const float32x4_t vmax = vdupq_n_f32(maxval); + + for(unsigned int row=0; row<height; row++) { + auto row_in_ptr = in_ptr + (row * in_stride); + auto row_out_ptr = out_ptr + (row * out_stride); + unsigned int col=0; + if (width >= 4) { + for(; col <= (width - 4); col+= 4) { + const int32x4_t vin = vld1q_s32(row_in_ptr + col); + float32x4_t vdeq = vmulq_f32(vcvtq_f32_s32(vin), vscale); + if(bias_ptr) { + const float32x4_t bin = vld1q_f32(bias_ptr + col); + vdeq = vaddq_f32(vdeq, bin); + } + if(accumulate) { + vdeq = vaddq_f32(vdeq, vld1q_f32(row_out_ptr + col)); + } + vdeq = vminq_f32(vmaxq_f32(vdeq, vmin), vmax); + vst1q_f32(reinterpret_cast<float *>(row_out_ptr + col), vdeq); + } + } + // left-over elements + for(; col < width; ++col) { + const int32_t val = *(row_in_ptr + col); + float res = static_cast<float>(val * qp.scale); + if(bias_ptr) { + res += static_cast<float>(*(bias_ptr + col)); + } + if(accumulate) { + res += *(row_out_ptr + col); + } + res = std::min(std::max(res, minval), maxval); + *(row_out_ptr + col) = res; + } + } +} + } // namespace arm_gemm #endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/quantized.hpp b/src/core/NEON/kernels/arm_gemm/quantized.hpp index 31dd65b397..bc64fd967b 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Arm Limited. + * Copyright (c) 2019, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -45,4 +45,8 @@ template<typename T> void row_sums_indirect(size_t num_strings, const unsigned int *string_lengths, IndirectInputArg<T> A_arg, size_t M, int32_t *output_ptr, const Requantize32 *qp); +void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height, + const int32_t* input, unsigned int in_stride, float *output, unsigned int out_stride, + const float *row_bias, bool not_first_pass, const Activation &act); + } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/transform.cpp b/src/core/NEON/kernels/arm_gemm/transform.cpp index 45e4f0e1de..06d9e2416c 100644 --- a/src/core/NEON/kernels/arm_gemm/transform.cpp +++ b/src/core/NEON/kernels/arm_gemm/transform.cpp @@ -129,17 +129,17 @@ void Transform( // We don't have assembler transforms for AArch32, generate templated ones here. #ifdef __arm__ template void Transform<8, 1, true, VLType::None>(float *, const float *, int, int, int, int, int); -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<8, 1, true, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<8, 1, true, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 #endif // AArch32 -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<12, 1, false, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<12, 1, false, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h index 50b3fc1284..cd849c3666 100644 --- a/src/core/common/Registrars.h +++ b/src/core/common/Registrars.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023 Arm Limited. + * Copyright (c) 2020-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -38,6 +38,12 @@ #define REGISTER_FP16_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_FP16_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_FP16_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #if defined(ARM_COMPUTE_ENABLE_NEON) #define REGISTER_FP16_NEON(func_name) &(func_name) #else /* !defined(ARM_COMPUTE_ENABLE_NEON) */ @@ -48,6 +54,7 @@ #define REGISTER_FP16_NEON(func_name) nullptr #define REGISTER_FP16_SVE(func_name) nullptr #define REGISTER_FP16_SVE2(func_name) nullptr +#define REGISTER_FP16_SME2(func_name) nullptr #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ #if defined(ENABLE_FP32_KERNELS) @@ -64,6 +71,16 @@ #define REGISTER_FP32_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_FP32_SME2(func_name) &(func_name) +#define REGISTER_QASYMM8_SME2(func_name) &(func_name) +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_FP32_SME2(func_name) nullptr +#define REGISTER_QASYMM8_SME2(func_name) nullptr +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #if defined(ARM_COMPUTE_ENABLE_NEON) #define REGISTER_FP32_NEON(func_name) &(func_name) #else /* !defined(ARM_COMPUTE_ENABLE_NEON) */ @@ -74,6 +91,7 @@ #define REGISTER_FP32_NEON(func_name) nullptr #define REGISTER_FP32_SVE(func_name) nullptr #define REGISTER_FP32_SVE2(func_name) nullptr +#define REGISTER_FP32_SME2(func_name) nullptr #endif /* defined(ENABLE_FP32_KERNELS) */ #if defined(ENABLE_QASYMM8_SIGNED_KERNELS) @@ -92,10 +110,17 @@ #define REGISTER_QASYMM8_SIGNED_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #else /* defined(ENABLE_QASYMM8_SIGNED_KERNELS) */ #define REGISTER_QASYMM8_SIGNED_NEON(func_name) nullptr #define REGISTER_QASYMM8_SIGNED_SVE(func_name) nullptr #define REGISTER_QASYMM8_SIGNED_SVE2(func_name) nullptr +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr #endif /* defined(ENABLE_QASYMM8_SIGNED_KERNELS) */ #if defined(ENABLE_QASYMM8_KERNELS) @@ -113,10 +138,17 @@ #define REGISTER_QASYMM8_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_QASYMM8_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_QASYMM8_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #else /* defined(ENABLE_QASYMM8_KERNELS) */ #define REGISTER_QASYMM8_NEON(func_name) nullptr #define REGISTER_QASYMM8_SVE(func_name) nullptr #define REGISTER_QASYMM8_SVE2(func_name) nullptr +#define REGISTER_QASYMM8_SME2(func_name) nullptr #endif /* defined(ENABLE_QASYMM8_KERNELS) */ #if defined(ENABLE_QSYMM16_KERNELS) diff --git a/src/core/helpers/LUTManager.cpp b/src/core/helpers/LUTManager.cpp index 06e35eed8c..2effffbe92 100644 --- a/src/core/helpers/LUTManager.cpp +++ b/src/core/helpers/LUTManager.cpp @@ -30,17 +30,38 @@ namespace arm_compute namespace { -void init_lut_fp16(ActivationLayerInfo::LookupTable65536 *lut) +float16_t activation(float16_t x, const LUTInfo &info) +{ + float16_t out = 0.f; + switch (info.act) + { + case ActivationLayerInfo::ActivationFunction::LOGISTIC: + out = 1.f / (1.f + std::exp(-x)); + break; + case ActivationLayerInfo::ActivationFunction::TANH: + { + out = static_cast<float16_t>(info.alpha * std::tanh(info.beta * x)); + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported Activation for 16-bit LUT table"); + break; + } + return out; +} + +void init_lut_fp16(ActivationLayerInfo::LookupTable65536 *lut, const LUTInfo &info) { union Element { uint16_t i = 0; float16_t fp; } item; + // Fill lut by iterating over all 16 bit values using the union. while (true) { - (*lut)[item.i] = 1.f / (1.f + std::exp(-item.fp)); + (*lut)[item.i] = activation(item.fp, info); if (item.i == 65535) break; item.i++; @@ -62,7 +83,7 @@ std::shared_ptr<ActivationLayerInfo::LookupTable65536> LUTManager::get_lut_table // Not found, or pointer not valid // We do not use make_shared to prevent the weak_ptr keeping the control block alive std::shared_ptr<ActivationLayerInfo::LookupTable65536> ptr(new ActivationLayerInfo::LookupTable65536); - init_lut_fp16(ptr.get()); + init_lut_fp16(ptr.get(), info); map_fp16[info] = ptr; return ptr; } diff --git a/src/core/helpers/LUTManager.h b/src/core/helpers/LUTManager.h index 4e13ead7e3..f3f4bf2832 100644 --- a/src/core/helpers/LUTManager.h +++ b/src/core/helpers/LUTManager.h @@ -38,19 +38,23 @@ namespace arm_compute struct LUTInfo { ActivationLayerInfo::ActivationFunction act; + float alpha; + float beta; DataType dt; - QuantizationInfo qinfo; + UniformQuantizationInfo qinfo; + // Operators enable use of map with Lutinfo as key friend bool operator<(const LUTInfo &l, const LUTInfo &r) { - return (l.act < r.act) || ((l.act == r.act) && (l.dt < r.dt)) || - ((l.act == r.act) && (l.dt == r.dt) && (l.qinfo.scale() < r.qinfo.scale())) || - ((l.act == r.act) && (l.dt == r.dt) && (l.qinfo.scale() == r.qinfo.scale()) && - (l.qinfo.offset() < l.qinfo.offset())); + const auto l_tup = std::make_tuple(l.act, l.alpha, l.beta, l.dt, l.qinfo.scale, l.qinfo.offset); + const auto r_tup = std::make_tuple(r.act, r.alpha, r.beta, r.dt, r.qinfo.scale, r.qinfo.offset); + + return l_tup < r_tup; } - bool operator==(const LUTInfo &l) + bool operator==(const LUTInfo &l) const { - return this->act == l.act && this->dt == l.dt && this->qinfo == l.qinfo; + return this->act == l.act && this->alpha == l.alpha && this->beta == l.beta && this->dt == l.dt && + this->qinfo == l.qinfo; } }; diff --git a/src/core/utils/helpers/tensor_transform.cpp b/src/core/utils/helpers/tensor_transform.cpp index 19d0badd74..212cfdabaa 100644 --- a/src/core/utils/helpers/tensor_transform.cpp +++ b/src/core/utils/helpers/tensor_transform.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2020, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -117,7 +117,10 @@ int calculate_end_on_index(TensorShape input_shape, } // Final clamp - stop = (stride > 0) ? utility::clamp(stop, 0, dim_size) : utility::clamp(stop, -1, dim_size - 1); + if (stride > 0) + stop = utility::clamp(stop, 0, dim_size); + else + stop = utility::clamp(stop, -1, dim_size - 1); return stop; } diff --git a/src/core/utils/quantization/AsymmHelpers.cpp b/src/core/utils/quantization/AsymmHelpers.cpp index f66d3e7064..f8b74a985d 100644 --- a/src/core/utils/quantization/AsymmHelpers.cpp +++ b/src/core/utils/quantization/AsymmHelpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -122,13 +122,13 @@ arm_compute::Status calculate_quantized_multipliers(const QuantizationInfo &iq_ ARM_COMPUTE_RETURN_ERROR_ON(iq_info.scale().empty()); ARM_COMPUTE_RETURN_ERROR_ON(wq_info.scale().empty()); ARM_COMPUTE_RETURN_ERROR_ON(oq_info.scale().empty()); - - const unsigned int size = wq_info.scale().size(); - - auto &quant_multipliers = stage_info.gemmlowp_multipliers; - auto &quant_shifts = stage_info.gemmlowp_shifts; - quant_multipliers.resize(size); - quant_shifts.resize(size); + constexpr unsigned int padding_elems = 32; // assembly kernels assume the shifts and multipliers buffers are padded + const unsigned int size = wq_info.scale().size(); + const size_t padded_size = (size == 1) ? 1 : size + padding_elems; + auto &quant_multipliers = stage_info.gemmlowp_multipliers; + auto &quant_shifts = stage_info.gemmlowp_shifts; + quant_multipliers.resize(padded_size); + quant_shifts.resize(padded_size); const auto &w_scales = wq_info.scale(); const float i_scale = iq_info.scale().at(0); |