diff options
Diffstat (limited to 'src')
95 files changed, 12808 insertions, 2600 deletions
diff --git a/src/BUILD.bazel b/src/BUILD.bazel index d4a3b61836..f270824ab4 100644 --- a/src/BUILD.bazel +++ b/src/BUILD.bazel @@ -117,6 +117,10 @@ filegroup( "cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp", "cpu/kernels/elementwise_unary/generic/sve2/q8.cpp", "cpu/kernels/lut/generic/sve2/u8.cpp", + "cpu/kernels/softmax/generic/sme2/fp16.cpp", + "cpu/kernels/softmax/generic/sme2/fp32.cpp", + "cpu/kernels/softmax/generic/sme2/qasymm8.cpp", + "cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp", "cpu/kernels/softmax/generic/sve2/impl.cpp"] + glob(["**/*.h", "**/*.hpp", @@ -261,6 +265,9 @@ filegroup( "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp", + "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp", + "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp", + "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp", @@ -516,6 +523,7 @@ filegroup( "core/NEON/kernels/arm_gemm/gemm_int8.cpp", "core/NEON/kernels/arm_gemm/gemm_qint8.cpp", "core/NEON/kernels/arm_gemm/gemm_quint8.cpp", + "core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp", "core/NEON/kernels/arm_gemm/gemm_uint16.cpp", "core/NEON/kernels/arm_gemm/gemm_uint8.cpp", "core/NEON/kernels/arm_gemm/interleave-8way.cpp", @@ -524,6 +532,7 @@ filegroup( "core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp", + "core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp", "core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp", @@ -744,6 +753,8 @@ filegroup( "cpu/kernels/depthwiseconv2d/generic/neon/impl.cpp", "cpu/kernels/depthwiseconv2d/generic/neon/qasymm8.cpp", "cpu/kernels/depthwiseconv2d/generic/neon/qasymm8_signed.cpp", + "cpu/kernels/dequantize/generic/neon/fp16.cpp", + "cpu/kernels/dequantize/generic/neon/fp32.cpp", "cpu/kernels/directconv2d/nchw/all.cpp", "cpu/kernels/directconv2d/nchw/fp16.cpp", "cpu/kernels/directconv2d/nhwc/neon/fp16.cpp", @@ -809,9 +820,17 @@ filegroup( "cpu/kernels/pool3d/neon/fp32.cpp", "cpu/kernels/pool3d/neon/qasymm8.cpp", "cpu/kernels/pool3d/neon/qasymm8_signed.cpp", + "cpu/kernels/quantize/generic/neon/fp16.cpp", + "cpu/kernels/quantize/generic/neon/fp32.cpp", + "cpu/kernels/quantize/generic/neon/integer.cpp", "cpu/kernels/range/generic/neon/fp16.cpp", "cpu/kernels/range/generic/neon/fp32.cpp", "cpu/kernels/range/generic/neon/integer.cpp", + "cpu/kernels/reduction_layer/generic/neon/fp16.cpp", + "cpu/kernels/reduction_layer/generic/neon/fp32.cpp", + "cpu/kernels/reduction_layer/generic/neon/integer.cpp", + "cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp", + "cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp", "cpu/kernels/roialign/generic/neon/fp16.cpp", "cpu/kernels/roialign/generic/neon/fp32.cpp", "cpu/kernels/roialign/generic/neon/qasymm8.cpp", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c6410714d2..87c5f8b21d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -238,6 +238,9 @@ target_sources( core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp + core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp + core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp + core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp @@ -335,6 +338,10 @@ target_sources( cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp cpu/kernels/elementwise_unary/generic/sve2/q8.cpp cpu/kernels/lut/generic/sve2/u8.cpp + cpu/kernels/softmax/generic/sme2/fp16.cpp + cpu/kernels/softmax/generic/sme2/fp32.cpp + cpu/kernels/softmax/generic/sme2/qasymm8.cpp + cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp cpu/kernels/softmax/generic/sve2/impl.cpp ) @@ -507,6 +514,7 @@ target_sources( core/NEON/kernels/arm_gemm/gemm_int8.cpp core/NEON/kernels/arm_gemm/gemm_qint8.cpp core/NEON/kernels/arm_gemm/gemm_quint8.cpp + core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp core/NEON/kernels/arm_gemm/gemm_uint16.cpp core/NEON/kernels/arm_gemm/gemm_uint8.cpp core/NEON/kernels/arm_gemm/interleave-8way.cpp @@ -515,6 +523,7 @@ target_sources( core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp + core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp @@ -735,6 +744,8 @@ target_sources( cpu/kernels/depthwiseconv2d/generic/neon/impl.cpp cpu/kernels/depthwiseconv2d/generic/neon/qasymm8.cpp cpu/kernels/depthwiseconv2d/generic/neon/qasymm8_signed.cpp + cpu/kernels/dequantize/generic/neon/fp16.cpp + cpu/kernels/dequantize/generic/neon/fp32.cpp cpu/kernels/directconv2d/nchw/all.cpp cpu/kernels/directconv2d/nchw/fp16.cpp cpu/kernels/directconv2d/nhwc/neon/fp16.cpp @@ -800,9 +811,17 @@ target_sources( cpu/kernels/pool3d/neon/fp32.cpp cpu/kernels/pool3d/neon/qasymm8.cpp cpu/kernels/pool3d/neon/qasymm8_signed.cpp + cpu/kernels/quantize/generic/neon/fp16.cpp + cpu/kernels/quantize/generic/neon/fp32.cpp + cpu/kernels/quantize/generic/neon/integer.cpp cpu/kernels/range/generic/neon/fp16.cpp cpu/kernels/range/generic/neon/fp32.cpp cpu/kernels/range/generic/neon/integer.cpp + cpu/kernels/reduction_layer/generic/neon/fp16.cpp + cpu/kernels/reduction_layer/generic/neon/fp32.cpp + cpu/kernels/reduction_layer/generic/neon/integer.cpp + cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp + cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp cpu/kernels/roialign/generic/neon/fp16.cpp cpu/kernels/roialign/generic/neon/fp32.cpp cpu/kernels/roialign/generic/neon/qasymm8.cpp diff --git a/src/common/cpuinfo/CpuInfo.cpp b/src/common/cpuinfo/CpuInfo.cpp index 93f51e599a..d46d8d7773 100644 --- a/src/common/cpuinfo/CpuInfo.cpp +++ b/src/common/cpuinfo/CpuInfo.cpp @@ -29,6 +29,7 @@ #include "support/StringSupport.h" #include "support/ToolchainSupport.h" +#include <map> #include <sstream> #if !defined(BARE_METAL) @@ -269,6 +270,46 @@ int get_max_cpus() } return max_cpus; } +#if defined(__ANDROID__) +std::vector<uint32_t> get_cpu_capacities() +{ + std::vector<uint32_t> cpu_capacities; + for (int i = 0; i < get_max_cpus(); ++i) + { + std::stringstream str; + str << "/sys/devices/system/cpu/cpu" << i << "/cpu_capacity"; + std::ifstream file(str.str(), std::ios::in); + if (file.is_open()) + { + std::string line; + if (bool(getline(file, line))) + { + cpu_capacities.emplace_back(support::cpp11::stoul(line)); + } + } + } + + return cpu_capacities; +} + +uint32_t not_little_num_cpus_internal() +{ + std::vector<uint32_t> cpus_all = get_cpu_capacities(); + std::vector<uint32_t> cpus_not_little; + + std::vector<uint32_t>::iterator result = std::max_element(cpus_all.begin(), cpus_all.end()); + uint32_t max_capacity = *result; + uint32_t threshold = max_capacity / 2; + for (unsigned int i = 0; i < cpus_all.size(); i++) + { + if (!(cpus_all[i] < threshold)) + { + cpus_not_little.emplace_back(cpus_all[i]); + } + } + return cpus_not_little.size(); +} +#endif /* defined(__ANDROID__) */ #elif defined(__aarch64__) && \ defined(__APPLE__) /* !defined(BARE_METAL) && !defined(__APPLE__) && (defined(__arm__) || defined(__aarch64__)) */ /** Query features through sysctlbyname @@ -363,6 +404,8 @@ CpuInfo CpuInfo::build() isainfo.neon = get_hw_capability("hw.optional.neon"); isainfo.fp16 = get_hw_capability("hw.optional.neon_fp16"); isainfo.dot = get_hw_capability("hw.optional.arm.FEAT_DotProd"); + isainfo.bf16 = get_hw_capability("hw.optional.arm.FEAT_BF16"); + isainfo.i8mm = get_hw_capability("hw.optional.arm.FEAT_I8MM"); CpuInfo info(isainfo, cpus_model); return info; #elif defined(__aarch64__) && defined(_WIN64) /* #elif defined(__aarch64__) && defined(__APPLE__) */ @@ -400,6 +443,15 @@ uint32_t CpuInfo::num_cpus() const return _cpus.size(); } +uint32_t CpuInfo::not_little_num_cpus() const +{ +#if defined(__ANDROID__) + return not_little_num_cpus_internal(); +#else /* defined(__ANDROID__) */ + return num_cpus(); +#endif /* defined(__ANDROID__) */ +} + uint32_t num_threads_hint() { unsigned int num_threads_hint = 1; diff --git a/src/common/cpuinfo/CpuInfo.h b/src/common/cpuinfo/CpuInfo.h index 953e4883c3..78d11e9610 100644 --- a/src/common/cpuinfo/CpuInfo.h +++ b/src/common/cpuinfo/CpuInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef SRC_COMMON_CPUINFO_H -#define SRC_COMMON_CPUINFO_H +#ifndef ACL_SRC_COMMON_CPUINFO_CPUINFO_H +#define ACL_SRC_COMMON_CPUINFO_CPUINFO_H #include "src/common/cpuinfo/CpuIsaInfo.h" #include "src/common/cpuinfo/CpuModel.h" @@ -120,6 +120,7 @@ public: CpuModel cpu_model(uint32_t cpuid) const; CpuModel cpu_model() const; uint32_t num_cpus() const; + uint32_t not_little_num_cpus() const; private: CpuIsaInfo _isa{}; @@ -135,4 +136,4 @@ private: uint32_t num_threads_hint(); } // namespace cpuinfo } // namespace arm_compute -#endif /* SRC_COMMON_CPUINFO_H */ +#endif // ACL_SRC_COMMON_CPUINFO_CPUINFO_H diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl new file mode 100644 index 0000000000..e3ec9cc98e --- /dev/null +++ b/src/core/CL/cl_kernels/common/scatter.cl @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "helpers.h" +#include "tile_helpers.h" + +// The below defines the various reduce operations for our purposes. +// Where a corresponds to the existing value, and b the new value. +#define ADD_OP(a, b) ((a) + (b)) +#define SUB_OP(a, b) ((a) - (b)) + +#ifdef IS_FLOAT +#define MAX_OP(a, b) fmax(a, b) +#define MIN_OP(a, b) fmin(a, b) +#else // ifdef IS_FLOAT +#define MAX_OP(a, b) max(a, b) +#define MIN_OP(a, b) min(a, b) +#endif // ifdef IS_FLOAT + +#define UPDATE_OP(a, b) (b) + +#ifdef SCATTER_MP1D_2D_MPND + +/** This kernel performs scatter operation + * + * @note Datatype should be given as a compile-time argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short + * @note Number of indices should be given as a compile-time argument using -DNUM_INDICES, e.g. -DNUM_INDICES=3 + * @note Index length should be given as a compile-time argument using -DINDEX_LENGTH, e.g. -DINDEX_LENGTH=2 + * @note Outermost output shapes should be given as a compile-time argument using -DOUT_SHAPE_N_MINUS_X, where + * X must be 1,2,3,4,5, e.g. -DOUT_SHAPE_N_MINUS_1=3, ... + * @note Number of elements to copy in a row should be given as a compile-time argument using -DN0, e.g. -DN0=4 + * @note Number of partial elements at the edge to copy in a row should be given as a compile-time argument using + * -DPARTIAL_N0, e.g. -DPARTIAL_N0=2 + * @note Scatter function should be given as a compile-time argument using -DSCATTER_FUNCTION, e.g. -DSCATTER_FUNCTION=ADD + * @note If the kernel should skip reading the output tensor, -DSKIP_OUTPUT_READ option should be provided. + * @note Kernel name in uppercase letters should be provided as a compile-time argument, e.g. -DSCATTER_MP1D_2D_MPND + * + * @param[in] updates_ptr Pointer to the updates tensor. Data Types: F32 + * @param[in] updates_stride_x Stride of the updates tensor in X dimension (in bytes) + * @param[in] updates_step_x updates_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] updates_stride_y Stride of the updates tensor in Y dimension (in bytes) + * @param[in] updates_step_y updates_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] updates_offset_first_element_in_bytes The offset of the first element in the updates tensor + * @param[in] indices_ptr Pointer to the indices tensor. Data Types: S32 + * @param[in] indices_stride_x Stride of the indices tensor in X dimension (in bytes) + * @param[in] indices_step_x indices_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] indices_stride_y Stride of the indices tensor in Y dimension (in bytes) + * @param[in] indices_step_y indices_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] indices_offset_first_element_in_bytes The offset of the first element in the indices tensor + * @param[out] output_ptr Pointer to the destination tensor. Same as @p upt_ptr + * @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes) + * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] output_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] upt_block_stride Update tensor data block stride in bytes + * @param[in] out_block_stride Output tensor data block stride in bytes + */ +__kernel void scatter_mp1d_2d_mpnd( + IMAGE_DECLARATION(updates), + IMAGE_DECLARATION(indices), + IMAGE_DECLARATION(output), + int upt_block_stride, + int out_block_stride + ) +{ + const int out_shape[5] = {OUT_SHAPE_N_MINUS_1, OUT_SHAPE_N_MINUS_2, OUT_SHAPE_N_MINUS_3, + OUT_SHAPE_N_MINUS_4, OUT_SHAPE_N_MINUS_5}; + + const int x = GET_SPATIAL_IDX(0, N0, PARTIAL_N0); // x-coordinate in the tensor + const int y = get_global_id(1); // collapsed y-coordinate (ignoring the outermost dimensions) + + const bool x_cond = (PARTIAL_N0 != 0 && get_global_id(0) == 0); + + uchar *ind_ptr_raw = indices_ptr + indices_offset_first_element_in_bytes; + const uchar *out_ptr_raw = output_ptr + output_offset_first_element_in_bytes + + x * sizeof(DATA_TYPE) + y * output_stride_y; + + const uchar *upt_ptr_raw = updates_ptr + updates_offset_first_element_in_bytes + + x * sizeof(DATA_TYPE) + y * updates_stride_y; + + for(int index_element = 0; index_element < NUM_INDICES; ++index_element) + { + const int *ind_ptr = (const int *) (ind_ptr_raw); + + // Out of bounds check + bool out_of_bounds = false; + LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH, + { + if(ind_ptr[i] >= out_shape[i] || ind_ptr[i] < 0) + { + out_of_bounds = true; + } + }); + + ind_ptr_raw += indices_stride_y; + + if(out_of_bounds) + { + continue; + } + + // Index calculation + int index = 0; + LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH, + { + index = index * out_shape[i] + ind_ptr[i]; + }); + + DATA_TYPE *out_ptr = (DATA_TYPE *) (out_ptr_raw + index * out_block_stride); + + const DATA_TYPE *upt_ptr = (const DATA_TYPE *) (upt_ptr_raw + index_element * upt_block_stride); + + VEC_DATA_TYPE(DATA_TYPE, N0) data_in0 = VLOAD(N0)(0, (__global DATA_TYPE *) upt_ptr); + +#ifdef SKIP_OUTPUT_READ + STORE_VECTOR_SELECT(data_in, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond); +#else // ifdef SKIP_OUTPUT_READ + VEC_DATA_TYPE(DATA_TYPE, N0) data_out0 = VLOAD(N0)(0, (__global DATA_TYPE *) out_ptr); + data_out0 = SCATTER_FUNCTION(data_out0, data_in0); + + STORE_VECTOR_SELECT(data_out, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond); +#endif // ifdef SKIP_OUTPUT_READ + } +} + +#endif // SCATTER_MP1D_2D_MPND + +#ifdef SCATTER1D_PARALLEL + +// NOTE : This code is non-deterministic and can only be excecuted with the "update" ScatterFunction +// This code is currently unusued as it requires changes to the existing test suite. +/** Performs the Scatter1D operation with multiple threads. + * Similar to @ref scatter1D() + */ +__kernel void scatter1D_parallel( + TENSOR4D_DECLARATION(updates), + TENSOR4D_DECLARATION(indices), + TENSOR4D_DECLARATION(output)) +{ + // Currently 1D - only iterate through x dimension of indices. + const int px = get_global_id(0); + const int index_value = *(uchar*)(indices_ptr + indices_offset_first_element_in_bytes + (sizeof(int) * px)); + + if(index_value < OUT_SHAPE_X) + { + const DATA_TYPE update = *(DATA_TYPE *)(updates_ptr + updates_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * px)); + __global uchar *out_addr = output_ptr + indices_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * index_value); + *(__global DATA_TYPE *)(out_addr) = update; + } +} + +#endif // SCATTER1D_PARALLEL diff --git a/src/core/CPP/CPPTypes.cpp b/src/core/CPP/CPPTypes.cpp index 9980db42f3..67fbce490f 100644 --- a/src/core/CPP/CPPTypes.cpp +++ b/src/core/CPP/CPPTypes.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022 Arm Limited. + * Copyright (c) 2018-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ #include "src/common/cpuinfo/CpuInfo.h" #include "src/common/cpuinfo/CpuIsaInfo.h" +#include "src/core/NEON/kernels/arm_gemm/utils.hpp" namespace arm_compute { @@ -135,4 +136,21 @@ unsigned int CPUInfo::get_L2_cache_size() const { return _impl->L2_cache_size; } + +unsigned long CPUInfo::get_sme2_vector_length() const +{ +#ifdef ARM_COMPUTE_ENABLE_SME2 + return arm_gemm::utils::sme::get_vector_length<int8_t>(); +#else // ARM_COMPUTE_ENABLE_SME2 + return 0; +#endif // ARM_COMPUTE_ENABLE_SME2 +} +unsigned int CPUInfo::get_cpu_num_excluding_little() const +{ +#if defined(__ANDROID__) + return _impl->info.not_little_num_cpus(); +#else /* defined(__ANDROID__) */ + return get_cpu_num(); +#endif /* defined(__ANDROID__) */ +} } // namespace arm_compute diff --git a/src/core/NEON/NEAsymm.h b/src/core/NEON/NEAsymm.h index 5f4d08d0f6..b93e64a0ef 100644 --- a/src/core/NEON/NEAsymm.h +++ b/src/core/NEON/NEAsymm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2023 Arm Limited. + * Copyright (c) 2017-2020, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_NEASYMM_H -#define ARM_COMPUTE_NEASYMM_H +#ifndef ACL_SRC_CORE_NEON_NEASYMM_H +#define ACL_SRC_CORE_NEON_NEASYMM_H #include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/intrinsics/intrinsics.h" @@ -637,10 +637,10 @@ inline int32x4x4_t vquantize_internal(const float32x4x4_t &qv, float scale, int3 const float32x4_t vinvscale = vdupq_n_f32(1.f / scale); const int32x4x4_t rf = {{ #ifdef __aarch64__ - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[2], vinvscale)), voffset), - vaddq_s32(vcvtaq_s32_f32(vmulq_f32(qv.val[3], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[2], vinvscale)), voffset), + vaddq_s32(vcvtnq_s32_f32(vmulq_f32(qv.val[3], vinvscale)), voffset), #else //__aarch64__ vaddq_s32(vcvtq_s32_f32(vmulq_f32(qv.val[0], vinvscale)), voffset), vaddq_s32(vcvtq_s32_f32(vmulq_f32(qv.val[1], vinvscale)), voffset), @@ -698,4 +698,4 @@ inline uint16x8x2_t vquantize_qasymm16(const float32x4x4_t &qv, const UniformQua } // namespace arm_compute #include "src/core/NEON/NEAsymm.inl" -#endif // ARM_COMPUTE_NEASYMM_H +#endif // ACL_SRC_CORE_NEON_NEASYMM_H diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp index 455d604b3b..5380e6ccce 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp +++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -31,1747 +31,221 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/Validate.h" +#include "src/core/common/Registrars.h" #include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "src/core/NEON/INEKernel.h" -#include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/wrapper.h" -#include "support/SaturateCast.h" - -#include <arm_neon.h> +#include "src/cpu/kernels/reduction_layer/generic/neon/list.h" namespace arm_compute { -namespace -{ -// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized -template <typename T> -void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0) -{ - if (std::is_same<T, uint8_t>::value) - { - auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2)); - wrapper::vstore(output.ptr() + offset, res); - } - else - { - auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2)); - wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res); - } -} - -template <typename T> -uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4_t mask{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask = wrapper::vcgt(b, a); - } - else - { - mask = wrapper::vclt(b, a); - } - - uint32x4_t vec_idx = {idx, idx + 1, idx + 2, idx + 3}; - if (axis != 0) - { - vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = {{wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0}}; - - return res; -} - -template <typename T> -uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4x4_t mask{{0}}; - uint8x16_t mask_u8{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask_u8 = wrapper::vcgt(b, a); - } - else - { - mask_u8 = wrapper::vclt(b, a); - } - auto wide_u16_1 = - wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); - auto wide_u16_2 = - wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); - mask.val[0] = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); - mask.val[1] = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); - mask.val[2] = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); - mask.val[3] = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); - - uint32x4x4_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, - {idx + 4, idx + 5, idx + 6, idx + 7}, - {idx + 8, idx + 9, idx + 10, idx + 11}, - {idx + 12, idx + 13, idx + 14, idx + 15}}}; - if (axis != 0) - { - vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = { - {vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]), vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]), - vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]), vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])}}; - - return res; -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, - typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type -calculate_min(T in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - return wrapper::vpmin(pmin, pmin); -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, - typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type -calculate_min(T in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmin = wrapper::vpmin(pmin, pmin); - pmin = wrapper::vpmin(pmin, pmin); - return wrapper::vpmin(pmin, pmin); -} - -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, - typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type -calculate_max(T in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - return wrapper::vpmax(pmax, pmax); -} - -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -template <typename T> -inline typename std::enable_if< - std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, - typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type -calculate_max(T in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmax = wrapper::vpmax(pmax, pmax); - pmax = wrapper::vpmax(pmax, pmax); - return wrapper::vpmax(pmax, pmax); -} - -template <typename T> -uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) -{ - uint32x4_t res_idx_mask{0}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); - } - else - { - auto pmax = calculate_max(vec_res_value); - auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); - } - - res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones); - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask)); - pmin = wrapper::vpmin(pmin, pmin); - uint32_t res = wrapper::vgetlane(pmin, 0); - - return (res - 0xFFFFFFFF); -} - -template <typename T> -uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) -{ - uint32x4x4_t res_idx_mask{{0}}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - uint8x16_t mask_u8{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - } - else - { - auto pmax = calculate_max(vec_res_value); - mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - } - - // Widen vectors - auto wide_u16_1 = - wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); - auto wide_u16_2 = - wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); - auto wide_u32_1 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); - auto wide_u32_2 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); - auto wide_u32_3 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); - auto wide_u32_4 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); - res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); - res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); - res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3); - res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4); - res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); - res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); - res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones); - res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones); - - uint32_t res = 0xFFFFFFFF; - int iter = 0; - do - { - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); - pmin = wrapper::vpmin(pmin, pmin); - res = std::min(wrapper::vgetlane(pmin, 0), res); - iter++; - } while (iter < 4); - - return (res - 0xFFFFFFFF); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -uint32x4x4_t -calculate_index(uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis) -{ - uint32x4x2_t mask{0}; - uint16x8_t mask_u16{0}; - if (op == ReductionOperation::ARG_IDX_MIN) - { - mask_u16 = wrapper::vcgt(b, a); - } - else - { - mask_u16 = wrapper::vclt(b, a); - } - mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16)); - mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16)); - uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}}; - if (axis != 0) - { - vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); - } - uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]), - wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0}; - - return res; -} - -// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. -inline float16x4_t calculate_min(float16x8_t in) -{ - auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmin = wrapper::vpmin(pmin, pmin); - return wrapper::vpmin(pmin, pmin); -} -// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. -inline float16x4_t calculate_max(float16x8_t in) -{ - auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); - pmax = wrapper::vpmax(pmax, pmax); - return wrapper::vpmax(pmax, pmax); -} - -template <> -uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op) -{ - uint32x4x2_t res_idx_mask{0}; - uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); - uint16x8_t mask_u16; - if (op == ReductionOperation::ARG_IDX_MIN) - { - auto pmin = calculate_min(vec_res_value); - mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); - } - else - { - auto pmax = calculate_max(vec_res_value); - mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); - } - - // Widen vectors - auto wide_u32_1 = - wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16))); - auto wide_u32_2 = - wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16))); - res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); - res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); - res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); - res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); - - uint32_t res = 0xFFFFFFFF; - uint32_t iter = 0; - do - { - auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); - pmin = wrapper::vpmin(pmin, pmin); - res = std::min(wrapper::vgetlane(pmin, 0), res); - iter++; - } while (iter < 2); - - return (res - 0xFFFFFFFF); -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -template <class F> -class Reducer -{ -public: - static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set out window - Window out_window(window); - out_window.set(Window::DimX, Window::Dimension(0, 1, 1)); - - f(window, out_window, input, output, op); - } - static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in window - Window in_window(window); - Window out_window(window); - - in_window.set(Window::DimY, Window::Dimension(0, 1, 1)); - out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1))); - - f(in_window, out_window, input, output, 1, op); - } - static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in window - Window in_window(window); - Window out_window(window); - - in_window.set(Window::DimZ, Window::Dimension(0, 1, 1)); - out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2))); - - f(in_window, out_window, input, output, 2, op); - } - static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) - { - // Set in/out window - Window in_window(window); - Window out_window(window); - - in_window.set(3, Window::Dimension(0, 1, 1)); - out_window.set(3, Window::Dimension(0, 1, 1)); - - f(in_window, out_window, input, output, 3, op); - } -}; - -template <typename T, int S> -struct RedOpX -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) - { - const size_t input_dim_0 = in->info()->dimension(0); - const int window_step_x = 16 / sizeof(T); - const auto window_start_x = static_cast<int>(in_window.x().start()); - const auto window_end_x = static_cast<int>(in_window.x().end()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_window); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<const T *>(input.ptr()); - - auto init_res_value = static_cast<T>(0.f); - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - init_res_value = static_cast<T>(*input_ptr); - break; - } - case ReductionOperation::PROD: - { - init_res_value = static_cast<T>(1.f); - break; - } - default: - break; - } - auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); - uint32x4x4_t vec_res_idx{{0}}; - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vec_elements = wrapper::vloadq(input_ptr + x); - switch (op) - { - case ReductionOperation::SUM_SQUARE: - vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); - break; - case ReductionOperation::MEAN_SUM: - case ReductionOperation::SUM: - vec_res_value = wrapper::vadd(vec_elements, vec_res_value); - break; - case ReductionOperation::PROD: - vec_res_value = wrapper::vmul(vec_elements, vec_res_value); - break; - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, - vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, - vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - case ReductionOperation::SUM_SQUARE: - { -#ifdef ARM_COMPUTE_DEBUG_ENABLED - auto res = static_cast<T>(0.f); - for (int i = 0; i < S; ++i) - { - res += wrapper::vgetlane(vec_res_value, i); - } -#else // ARM_COMPUTE_DEBUG_ENABLED - auto carry_res = - wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - for (int i = 0; i < S / 4; ++i) - { - carry_res = wrapper::vpadd(carry_res, carry_res); - } - auto res = wrapper::vgetlane(carry_res, 0); -#endif // ARM_COMPUTE_DEBUG_ENABLED - if (op == ReductionOperation::SUM_SQUARE) - { - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += (*(input_ptr + x)) * (*(input_ptr + x)); - } - } - else - { - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += *(input_ptr + x); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - res /= input_dim_0; - } - - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::PROD: - { - auto carry_res = - wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); - T res = 1; - for (int i = 0; i < S / 2; ++i) - { - res *= wrapper::vgetlane(carry_res, i); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res *= *(input_ptr + x); - } - - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) < res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) > res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::MIN: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) < res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::MAX: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) > res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - }, - input, output); - } -}; - -template <typename T> -struct RedOpX_quantized -{ - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) - { - using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; - - const auto oq_info = out->info()->quantization_info().uniform(); - - const TensorInfo in_info = *(in->info()); - const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); - - const int window_step_x = 16 / sizeof(T); - const auto window_start_x = static_cast<int>(in_window.x().start()); - const auto window_end_x = static_cast<int>(in_window.x().end()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_window); - - const auto in_offset = static_cast<float>(iq_info.offset); - const float in_scale = iq_info.scale; - - const auto out_offset = static_cast<float>(oq_info.offset); - const float out_scale = oq_info.scale; - - const auto num_elements = static_cast<float>(in_info.dimension(0)); - - const float A = in_scale / (out_scale * num_elements); - const float B = out_offset - (in_scale * in_offset) / (out_scale); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - auto vec_res_value1 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value2 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value3 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - auto vec_res_value4 = - wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); - - auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f)); - auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f)); - - typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = {0}; - - if (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || - op == ReductionOperation::MIN || op == ReductionOperation::MAX) - { - vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{}); - } - - uint32x4x4_t vec_res_idx{{0}}; - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vec_elements = wrapper::vloadq(input_ptr + x); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); - vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); - vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); - vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset); - const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale); - - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); - auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); - auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); - auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); - - //de-quantize vec_elements - temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4); - temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4); - temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4); - temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4); - - vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f); - vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f); - vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f); - vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f); - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( - x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( - x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::ARG_IDX_MIN: - { - auto idx = - calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) < res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto idx = - calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - if (*(input_ptr + x) > res) - { - idx = x; - res = *(input_ptr + x); - } - } - *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; - break; - } - case ReductionOperation::MIN: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) < res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::MAX: - { - auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res = *(input_ptr + x) > res ? *(input_ptr + x) : res; - } - *(reinterpret_cast<T *>(output.ptr())) = res; - break; - } - case ReductionOperation::PROD: - { - auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); - carry_res = wrapper::vmul(carry_res, vec_res_value3_f); - carry_res = wrapper::vmul(carry_res, vec_res_value4_f); - - float res = wrapper::vgetlane(carry_res, 0); - res *= wrapper::vgetlane(carry_res, 1); - res *= wrapper::vgetlane(carry_res, 2); - res *= wrapper::vgetlane(carry_res, 3); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - //de-quantize input - if (std::is_same<T, uint8_t>::value) - { - res *= dequantize_qasymm8(*(input_ptr + x), iq_info); - } - else - { - res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info); - } - } - - //re-quantize result - if (std::is_same<T, uint8_t>::value) - { - res = quantize_qasymm8(res, iq_info); - } - else - { - res = quantize_qasymm8_signed(res, iq_info); - } - - *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res); - break; - } - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); - carry_res = wrapper::vadd(carry_res, vec_res_value3); - carry_res = wrapper::vadd(carry_res, vec_res_value4); - - auto carry_paddition = - wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); - carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); - auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0)); - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - res += *(input_ptr + x); - } - - if (op == ReductionOperation::MEAN_SUM) - { - const int32_t resFinal = A * (static_cast<float>(res)) + B; - - *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal); - } - else - { - // Subtract accumulated offsets - res -= (in_info.dimension(0) - 1) * iq_info.offset; - *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res); - } - - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - }, - input, output); - } -}; - -template <typename T, int S> -struct RedOpYZW -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; - - inline void operator()(const Window &in_window, - Window &out_window, - const ITensor *in, - ITensor *out, - int axis, - const ReductionOperation op) - { - const TensorInfo in_info = *(in->info()); - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - neon_vector vec_res_value = {0}; - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vloadq(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{}); - break; - } - default: - { - vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - break; - } - } - uint32x4x4_t vec_res_idx{{0}}; - - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); - const auto vec_elements = wrapper::vloadq(in_ptr); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - vec_res_value = wrapper::vadd(vec_elements, vec_res_value); - break; - case ReductionOperation::SUM_SQUARE: - vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); - break; - case ReductionOperation::PROD: - vec_res_value = wrapper::vmul(vec_elements, vec_res_value); - break; - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = - calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = - calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - auto vec_width_inv = - wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{})); - vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv); - } - - if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - if (std::is_same<T, float16_t>::value) - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]); - } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - } - else - { - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value); - } - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto res_value = 0.f; - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - res_value = *(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - res_value = static_cast<T>(1.f); - break; - } - default: - { - res_value = static_cast<T>(0.f); - break; - } - } - - uint32_t res_idx = 0; - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); - - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - res_value += *in_ptr; - break; - case ReductionOperation::SUM_SQUARE: - res_value += *in_ptr * *in_ptr; - break; - case ReductionOperation::PROD: - res_value *= *in_ptr; - break; - case ReductionOperation::ARG_IDX_MIN: - { - if (*in_ptr < res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - if (*in_ptr > res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::MIN: - { - res_value = *in_ptr < res_value ? *in_ptr : res_value; - break; - } - case ReductionOperation::MAX: - { - res_value = *in_ptr > res_value ? *in_ptr : res_value; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - if (op == ReductionOperation::MEAN_SUM) - { - res_value /= in_info.dimension(axis); - } - - if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) - { - *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx; - } - else - { - *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value; - } - } - }, - input, output); - } -}; - -template <typename T, int S, int axis, ReductionOperation op> -struct RedOpYZW_complex -{ - /** SIMD vector tag type. */ - using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; - using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; - - inline void operator()( - const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation) - { - ARM_COMPUTE_ERROR_ON(axis != 2); - ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM); - - const TensorInfo in_info = *(in->info()); - const size_t stride_z = in_info.strides_in_bytes()[axis]; - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - neon_vector vec_res_value_0 = {0}; - neon_vector vec_res_value_1 = {0}; - - vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); - - T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); - T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim); - - const auto vec_elements_0 = wrapper::vloadq(in_ptr_0); - const auto vec_elements_1 = wrapper::vloadq(in_ptr_1); - - vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0); - vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1); - } - - wrapper::vstore(out_ptr, vec_res_value_0); - wrapper::vstore(out_ptr + 4, vec_res_value_1); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto res_value_0 = 0.f; - auto res_value_1 = 0.f; - - T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); - res_value_0 += *in_ptr; - res_value_1 += *(in_ptr + 1); - } - *out_ptr = res_value_0; - *(out_ptr + 1) = res_value_1; - } - }, - input, output); - } -}; - -template <typename T> -struct RedOpYZW_quantized -{ - inline void operator()(const Window &in_window, - Window &out_window, - const ITensor *in, - ITensor *out, - int axis, - const ReductionOperation op) - { - const TensorInfo in_info = *(in->info()); - const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); - using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; - - const auto oq_info = out->info()->quantization_info().uniform(); - - const int window_step_x = 16 / sizeof(T); - const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); - const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); - // As it split over x-axis, need to set the correct spiltted window start and end. - const auto window_start_x = static_cast<int>(0); - const auto window_end_x = static_cast<int>(in_window.shape().x()); - - Window in_win_no_pad = in_window; - in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); - Window out_win_no_pad = out_window; - out_win_no_pad.set(Window::DimX, - Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); - - Iterator input(in, in_win_no_pad); - Iterator output(out, out_win_no_pad); - - using vector_type = - typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type; - using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type; - - vector_type vec_res_value1{}; - vector_type vec_res_value2{}; - vector_type vec_res_value3{}; - vector_type vec_res_value4{}; - - vector_type_f vec_res_value1_f{}; - vector_type_f vec_res_value2_f{}; - vector_type_f vec_res_value3_f{}; - vector_type_f vec_res_value4_f{}; - - const float in_offset = static_cast<float>(iq_info.offset); - const float in_scale = iq_info.scale; - - const float out_offset = static_cast<float>(oq_info.offset); - const float out_scale = oq_info.scale; - - const float num_elements = static_cast<float>(in_info.dimension(axis)); - - const float A = in_scale / (out_scale * num_elements); - const float B = out_offset - (in_scale * in_offset) / (out_scale); - - const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{}); - const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{}); - - execute_window_loop( - in_win_no_pad, - [&](const Coordinates &) - { - const auto input_ptr = reinterpret_cast<T *>(input.ptr()); - - // Compute window_step_x elements per iteration - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - uint32x4x4_t vec_res_idx{{0}}; - vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); - - vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); - - auto vec_res_value = wrapper::vloadq(input_ptr + x); - - for (unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim) - { - const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim; - const auto vec_elements = wrapper::vloadq(in_ptr); - switch (op) - { - case ReductionOperation::SUM: - case ReductionOperation::MEAN_SUM: - { - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); - vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); - vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); - vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), - wrapper::traits::vector_128_tag{}); - const auto scale32x4f_4 = - wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{}); - - const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); - const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); - - const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); - const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); - const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); - const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); - - auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); - auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); - auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); - auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); - - //de-quantize vec_elements - temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4); - temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4); - temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4); - temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4); - - vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f); - vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f); - vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f); - vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f); - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, - vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, - vec_res_idx, op, axis); - vec_res_value = temp_vec_res_value; - break; - } - case ReductionOperation::MIN: - { - vec_res_value = wrapper::vmin(vec_elements, vec_res_value); - break; - } - case ReductionOperation::MAX: - { - vec_res_value = wrapper::vmax(vec_elements, vec_res_value); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::ARG_IDX_MAX: - { - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]); - wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, - vec_res_idx.val[3]); - break; - } - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value); - break; - } - case ReductionOperation::SUM: - { - // Subtract offsets - auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset); - - auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1); - auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2); - auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3); - auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4); - vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets); - vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets); - vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets); - vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets); - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4)); - - combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x); - break; - } - case ReductionOperation::MEAN_SUM: - { - vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A); - vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A); - vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A); - vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A); - -#ifdef __aarch64__ - vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f); -#else // defined(__aarch64__) - vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f); -#endif // __aarch64__ - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); - auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); - - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); - break; - } - case ReductionOperation::PROD: - { - const auto offset32x4f_4 = - wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{}); - const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale)); - - //re-quantize - vec_res_value1_f = - wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4); - vec_res_value2_f = - wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4); - vec_res_value3_f = - wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4); - vec_res_value4_f = - wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4); - - vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f); - vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f); - vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f); - vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f); - - const auto temp16x8t_1 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); - const auto temp16x8t_2 = - wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); - auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); - - wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - float res_value = 0.f; - int32_t res_value_q = 0; - - switch (op) - { - case ReductionOperation::ARG_IDX_MAX: - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::MIN: - case ReductionOperation::MAX: - { - res_value = *(input_ptr + x); - break; - } - case ReductionOperation::PROD: - { - res_value = static_cast<T>(1.0f); - break; - } - default: - { - res_value = static_cast<T>(0.0f); - break; - } - } - uint32_t res_idx = 0; - - for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) - { - const T *in_ptr = - reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim); - switch (op) - { - case ReductionOperation::SUM: - { - res_value += *in_ptr; - break; - } - case ReductionOperation::MEAN_SUM: - { - res_value_q += *in_ptr; - break; - } - case ReductionOperation::SUM_SQUARE: - { - res_value += *in_ptr * *in_ptr; - break; - } - case ReductionOperation::PROD: - { - //de-quantize input - if (std::is_same<T, uint8_t>::value) - { - res_value *= dequantize_qasymm8(*in_ptr, iq_info); - } - else - { - res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info); - } - break; - } - case ReductionOperation::ARG_IDX_MIN: - { - if (*in_ptr < res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::ARG_IDX_MAX: - { - if (*in_ptr > res_value) - { - res_value = *in_ptr; - res_idx = dim; - } - break; - } - case ReductionOperation::MIN: - { - res_value = *in_ptr < res_value ? *in_ptr : res_value; - break; - } - case ReductionOperation::MAX: - { - res_value = *in_ptr > res_value ? *in_ptr : res_value; - break; - } - default: - ARM_COMPUTE_ERROR("Not supported"); - } - } - - switch (op) - { - case ReductionOperation::MEAN_SUM: - { - // Apply previously calculated coefficients (with rounding on aarch64) -#ifdef __aarch64__ - const int32_t res = - arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B); -#else // defined(__aarch64__) - const int32_t res = A * (static_cast<float>(res_value_q)) + B; -#endif // __aarch64__ - *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res); - break; - } - case ReductionOperation::SUM: - { - // Subtract accumulated offsets - res_value -= (in_info.dimension(axis) - 1) * iq_info.offset; - *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value); - break; - } - case ReductionOperation::PROD: - { - //re-quantize result - T res = 0; - if (std::is_same<T, uint8_t>::value) - { - res = quantize_qasymm8(res_value, iq_info); - } - else - { - res = quantize_qasymm8_signed(res_value, iq_info); - } - *(reinterpret_cast<T *>(output.ptr() + x)) = res; - break; - } - case ReductionOperation::ARG_IDX_MIN: - case ReductionOperation::ARG_IDX_MAX: - { - *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx; - break; - } - default: - *(reinterpret_cast<T *>(output.ptr() + x)) = res_value; - } - } - }, - input, output); - } -}; - -void reduce_op( - const Window &window, const ITensor *input, ITensor *output, unsigned int axis, const ReductionOperation op) +void NEReductionOperationKernel::reduce_op() { - const bool is_complex = (input->info()->num_channels() == 2); + const bool is_complex = (_input->info()->num_channels() == 2); if (is_complex) { - switch (axis) + switch (_reduction_axis) { case 2: - switch (input->info()->data_type()) + switch (_input->info()->data_type()) { case DataType::F32: - switch (op) + { + switch (_op) { case ReductionOperation::SUM: - return Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ( - window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), - op); + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM); + break; default: ARM_COMPUTE_ERROR("Not supported"); + break; } + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } return; } - switch (axis) + switch (_reduction_axis) { case 0: { - switch (input->info()->data_type()) + switch (_input->info()->data_type()) { case DataType::QASYMM8: { - return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, - RedOpX_quantized<uint8_t>(), op); + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpX_reduceX_qasymm8); + break; } case DataType::QASYMM8_SIGNED: { - return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), - op); + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpX_reduceX_qasymm8_signed); + break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpX_reduceX_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: { - return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op); + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpX_reduceX_float32_4); + break; } case DataType::S32: { - return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op); + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpX_reduceX_S32_4); + break; } default: { ARM_COMPUTE_ERROR("Not supported"); + break; } } + break; } case 1: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: { - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceY_qasymm8); + break; } case DataType::QASYMM8_SIGNED: { - return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, - RedOpYZW_quantized<int8_t>(), op); + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceY_qasymm8_signed); + break; } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceY_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceY_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceY_S32_4); + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } case 2: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + { + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceZ_qasymm8); + break; + } case DataType::QASYMM8_SIGNED: - return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, - RedOpYZW_quantized<int8_t>(), op); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceZ_qasymm8_signed); + break; + } +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceZ_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceZ_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceZ_S32_4); + break; + } default: + { + std::cout << int(_input->info()->data_type()) << std::endl; ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } case 3: - switch (input->info()->data_type()) + { + switch (_input->info()->data_type()) { case DataType::QASYMM8: - return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, - RedOpYZW_quantized<uint8_t>(), op); + { + _func = REGISTER_QASYMM8_NEON(cpu::reduce_RedOpYZW_reduceW_qasymm8); + break; + } case DataType::QASYMM8_SIGNED: - return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, - RedOpYZW_quantized<int8_t>(), op); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_QASYMM8_SIGNED_NEON(cpu::reduce_RedOpYZW_reduceW_qasymm8_signed); + break; + } +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: - return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), - op); -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + { + _func = REGISTER_FP16_NEON(cpu::reduce_RedOpYZW_reduceW_float16_8); + break; + } +#endif // ARM_COMPUTE_ENABLE_FP16 case DataType::F32: - return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op); + { + _func = REGISTER_FP32_NEON(cpu::reduce_RedOpYZW_reduceW_float32_4); + break; + } case DataType::S32: - return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op); + { + _func = REGISTER_INTEGER_NEON(cpu::reduce_RedOpYZW_reduceW_S32_4); + break; + } default: + { ARM_COMPUTE_ERROR("Not supported"); + break; + } } + break; + } default: + { ARM_COMPUTE_ERROR("Unsupported reduction axis"); + break; + } } } @@ -1819,10 +293,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, u return Status{}; } -} // namespace NEReductionOperationKernel::NEReductionOperationKernel() - : _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE) + : _func(nullptr), _input(nullptr), _output(nullptr), _reduction_axis(0), _op(ReductionOperation::SUM_SQUARE) { } @@ -1856,6 +329,8 @@ void NEReductionOperationKernel::configure(const ITensor *input, .set_data_type(output_data_type) .reset_padding() .set_is_resizable(true)); + // Determine the reduction function + NEReductionOperationKernel::reduce_op(); } Status NEReductionOperationKernel::validate(const ITensorInfo *input, @@ -1874,6 +349,6 @@ void NEReductionOperationKernel::run(const Window &window, const ThreadInfo &inf ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - reduce_op(window, _input, _output, _reduction_axis, _op); + (*_func)(window, _input, _output, _op); } } // namespace arm_compute diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.h b/src/core/NEON/kernels/NEReductionOperationKernel.h index 78bec62c14..407e5de6d6 100644 --- a/src/core/NEON/kernels/NEReductionOperationKernel.h +++ b/src/core/NEON/kernels/NEReductionOperationKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H -#define ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H +#ifndef ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H +#define ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H #include "src/core/NEON/INEKernel.h" @@ -80,14 +80,24 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op); +private: // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override; + /** Common signature for all the specialized Reduction functions + * + * @param[in] window Region on which to execute the kernel. + */ + using ReductionFunction = void (*)(const Window &window, const ITensor *in, ITensor *out, ReductionOperation op); -private: + /** Populate the _func with the right reduction operation handler + */ + void reduce_op(); + + ReductionFunction _func; const ITensor *_input; ITensor *_output; unsigned int _reduction_axis; ReductionOperation _op; }; } // namespace arm_compute -#endif /*ARM_COMPUTE_NEREDUCTIONOPERATIONKERNEL_H */ +#endif // ACL_SRC_CORE_NEON_KERNELS_NEREDUCTIONOPERATIONKERNEL_H diff --git a/src/core/NEON/kernels/NEReorderKernel.cpp b/src/core/NEON/kernels/NEReorderKernel.cpp index f5bea3e163..fe8882f59f 100644 --- a/src/core/NEON/kernels/NEReorderKernel.cpp +++ b/src/core/NEON/kernels/NEReorderKernel.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/Scheduler.h" #include "src/common/utils/Log.h" #include "src/core/NEON/kernels/arm_gemm/transform.hpp" @@ -233,13 +234,20 @@ Status NEReorderKernel::validate(const ITensorInfo *input, } } - int ksize; + int ksize = 0; switch (output_wf) { #if defined(ARM_COMPUTE_ENABLE_SVE) case WeightFormat::OHWIo8: { - ksize = 8; + if (Scheduler::get().cpu_info().has_sve() && arm_gemm::utils::get_vector_length<float>() == 8) + { + ksize = 8; + } + else + { + ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported weight format."); + } break; } #endif /* ARM_COMPUTE_ENABLE_SVE */ diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index 5c08e6137d..0ddca04846 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -86,7 +86,7 @@ static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] = "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, bfloat16, float>(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 3b444ae333..c7adf8e4ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -69,19 +69,19 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 44a7bb894a..0c1d3a387b 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -34,6 +34,7 @@ #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp" #include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp" +#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp" #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" #include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp" #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS @@ -123,14 +124,14 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32bf16fp32_dot_16VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32_mla_16VL", - [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(args); } }, @@ -138,25 +139,25 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); } }, #endif // ARM_COMPUTE_ENABLE_BF16 { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); } }, #ifdef ARM_COMPUTE_ENABLE_BF16 { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(args); } @@ -165,7 +166,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(args); } @@ -174,7 +175,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(args); } }, @@ -182,7 +183,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(args); } }, @@ -198,14 +199,14 @@ GemmImplementation<float, float>::with_estimate( GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_6x4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(args); } ), GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_4x6VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); } ), @@ -292,14 +293,14 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_fp32_mla_8x4", - [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_fp32_mla_6x4", - [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; }, + [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(args); } }, @@ -350,6 +351,14 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>(args); } ), +GemmImplementation<float, float>::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_ffhybrid_fp32bf16fp32_mmla_6x16", + KernelWeightFormat::VL256_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16, float, float>::estimate_cycles<float>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16, float, float>(args); } +), #endif // BF16 GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index 89c2d5a23e..0cc4d4f3d9 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -530,7 +530,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } else if (_convolver) { @@ -563,7 +563,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } else { @@ -579,7 +579,7 @@ public: (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), - !first_pass, + !first_pass || _args._accumulate, // Quantization parameters _os, _col_bias+(multi * _args._Nsize), n0); } diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index fd20e53f60..fedda3a47a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2022-2023 Arm Limited. + * Copyright (c) 2017-2020, 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -63,7 +63,7 @@ static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = { "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL, int8_t, int32_t>(args); } }, { @@ -128,14 +128,14 @@ GemmImplementation<int8_t, int32_t>::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_s8s32_dot_8x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_s8s32_dot_6x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(args); } }, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 4f732f7d94..897ec9d05f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -29,7 +29,6 @@ #include "arm_gemm.hpp" #include "bfloat.hpp" #include "convolver.hpp" -#include "kernel_weight_format.hpp" #include "kernel_traits.hpp" #include "kernel_weight_format.hpp" #include "mergeresults.hpp" @@ -191,10 +190,19 @@ void kernel_and_merge<false, false, Requantize32>::run( auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k); #endif + // Offset C pointer in a similar way to non-quantized case above. + Tri *offset_c_ptr; + + if (c_ptr == nullptr) { + offset_c_ptr = nullptr; + } else { + offset_c_ptr = c_ptr + m_0 * ldc + n_0; + } + strat.kernel(// A and B pointers are just the packed panels. a_ptr, b_panel, // Provide relevant part of output array and row stride. - c_ptr + m_0 * ldc + n_0, ldc, + offset_c_ptr, ldc, // M, N, K sizes m_max-m_0, n_max - n_0, kern_k, // Bias, activation, accumulation. Need to offset the bias as needed. @@ -247,6 +255,84 @@ void kernel_and_merge<true, false, Requantize32>::run( } } +// Run a kernel with integrated merge, dequantizing to FP32 +template<> +template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +void kernel_and_merge<false, false, DequantizeFloat>::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *, + Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, + unsigned int n_0, unsigned int n_max, const Tr *bias, + const Activation &act, bool accumulate, const DequantizeFloat &dq, const int32_t *col_bias, + Tab *acc_buff) +{ +#ifdef CYCLE_PROFILING + auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k); +#endif + + const int32_t *offset_col_bias = nullptr; + const Tr *offset_bias = nullptr; + + if (col_bias) { + offset_col_bias = col_bias + n_0; + } + + if (bias) { + offset_bias = bias + n_0; + } + + strat.kernel(// A and B pointers are just the packed panels. + a_ptr, b_panel, + // Provide relevant part of output array and row stride. + c_ptr ? (c_ptr + m_0 * ldc + n_0) : nullptr, ldc, + // M, N, K sizes + m_max-m_0, n_max - n_0, kern_k, + // Bias, activation, accumulation. Need to offset the bias as needed. + offset_col_bias, dq, offset_bias, act, accumulate, acc_buff); +} + +template<> +template<typename strategy, typename To, typename Tr, typename Tri, typename Tab> +void kernel_and_merge<true, false, DequantizeFloat>::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel, + Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, + unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias, + const Activation &act, bool accumulate, const DequantizeFloat &qp, const int32_t *, + Tab *) +{ + const int bblocks = iceildiv(n_max - n_0, strategy::out_width()); + + { +#ifdef CYCLE_PROFILING + auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k)); +#endif + + strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k); + } + + { +#ifdef CYCLE_PROFILING + auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, ((m_max-m_0) * bblocks * strategy::out_width() * sizeof(Tr))); +#endif + auto out_area = strategy::out_width() * strategy::out_height(); + for (int i=0; i<bblocks; i++) { + const unsigned int n_start = n_0 + (strategy::out_width() * i); + const unsigned int n_end = std::min(n_start + strategy::out_width(), n_max); + + dequantize_block_32(qp, (n_end - n_start), (m_max - m_0), + c_panel + (i * out_area), strategy::out_width(), + c_ptr + m_0 * ldc + n_start, ldc, + bias != nullptr ? bias + n_start : nullptr, accumulate, act); + + } + } +} + // Integer GEMMs can be used in two contexts - "normal" where the full 32-bit output is required, or in // "requantizing" context where the output will be requantized. // @@ -280,6 +366,12 @@ public: typedef int32_t type; }; +template<typename strategy> +class accumulate_buffer_type<strategy, DequantizeFloat, false> { +public: + typedef int32_t type; +}; + template<typename strategy, typename OutputStage> class accumulate_buffer_type<strategy, OutputStage, true> { public: @@ -350,6 +442,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { const bool _thread_columns; const Activation _act; + const bool _accumulate; const int _maxthreads; int _nthreads; @@ -579,15 +672,27 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return roundup(args._cfg->inner_block_size, strategy::k_unroll()); } - // K blocking not supported if we are requantizing. - if (std::is_same<OutputStage, Requantize32>::value) { + // K blocking not supported if we are requantizing with the merging + // kernels. + if (std::is_same<OutputStage, Requantize32>::value && MergeStep) { return get_ktotal(args); } + const unsigned int L1_size = args._ci->get_L1_cache_size(); + // Special blocking for SME if (is_sme<strategy>::value) { - // Don't bother to block below this size threshold, experimentally determined to be 320 for FP32 - unsigned int scaling_threshold = 1280 / sizeof(Toi); + // Target 512 bytes for 64kB L1, or 1024 bytes for 128kB L1. + unsigned int target_bytes_per_block = L1_size / 128; + + // Default cache size in gemm-linux is 32kB though - so make + // sure minimum is 512 + if (target_bytes_per_block < 512) { + target_bytes_per_block = 512; + } + + // Don't bother to block below this size threshold (1.25X target size) + unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi); if (get_ktotal(args) <= scaling_threshold) { return get_ktotal(args); @@ -595,7 +700,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { // Once we are blocking, this (lower) threshold determines when we should use more blocks // NOTE: Could be that some factor-based solution would work better here. - unsigned int max_block_size = 1024 / sizeof(Toi); + unsigned int max_block_size = target_bytes_per_block / sizeof(Toi); unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size); @@ -604,7 +709,6 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return k_block; } - const unsigned int L1_size = args._ci->get_L1_cache_size(); unsigned int k_block; // k_block: Find out how much of the larger array can be loaded into half the cache. @@ -639,6 +743,17 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return roundup(args._cfg->outer_block_size, strategy::out_width()); } + // Special blocking for SME + if (is_sme<strategy>::value) { + // If total width is less than 4x kernel width, return the entire width. + if (args._Nsize < strategy::out_width()*4) { + return roundup(args._Nsize, strategy::out_width()); + } + + // Otherwise block to single kernel width. + return strategy::out_width(); + } + unsigned int x_block; const unsigned int L2_size = args._ci->get_L2_cache_size(); const unsigned int k_block = get_k_block_size(args); @@ -680,7 +795,7 @@ public: _Ksections(args._Ksections), _Ktotal(get_ktotal(args)), _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())), _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)), - _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())), _os(os) { } @@ -690,7 +805,7 @@ public: _Ksections(args._Ksections), _Ktotal(get_ktotal(args)), _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())), _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)), - _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())), _os() { } @@ -763,6 +878,9 @@ public: const bool first_pass = (k0==0); const bool last_pass = (kmax==_Ktotal); + // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass. + const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass; + // Figure out how many "K" the kernel will actually process. unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll()); @@ -821,9 +939,9 @@ public: // K size, and M/N ranges kern_k, start_row, end_row, start_x, end_x, // Only do bias on the first pass - ((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr), + ((bias_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr), // Only do activation on the last pass, and accumulation on any non-first pass. - (last_pass ? _act : Activation()), !first_pass, + (last_pass ? _act : Activation()), (!first_pass || _accumulate), // Pass in quantization parameters for requantizing kernels (others will ignore) _os, col_bias + (multi * _Nsize), // Accumulation buffer @@ -948,6 +1066,9 @@ public: const bool first_pass = (current.k0() == 0); const bool last_pass = (current.kmax() == _Ktotal); + // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass. + const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass; + // Pointer to appropriate part of result array. Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride); @@ -969,9 +1090,9 @@ public: // K size, and M/N ranges kern_k, y, ymax, current.x0(), current.xmax(), // Only do bias on the first pass - ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr), + ((bias_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr), // Only do activation on the last pass, and accumulation on any non-first pass. - (last_pass ? _act : Activation()), !first_pass, + (last_pass ? _act : Activation()), (!first_pass || _accumulate), // Pass in quantization parameters for requantizing kernels (others will ignore) _os, col_bias + (current.multi() * _Nsize), // Accumulation buffer @@ -1184,6 +1305,13 @@ public: } } + void set_dequantize_scale(const float scale) override { + if(std::is_same<OutputStage, DequantizeFloat>::value) { + DequantizeFloat* df = reinterpret_cast<DequantizeFloat *>(&_os); + df->scale = scale; + } + } + void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override { assert(string_len == _Ksize); _indirect_buf = ptr; @@ -1248,4 +1376,10 @@ using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strat template<typename strategy, typename To, typename Tr> using GemmInterleavedQuantized = GemmInterleaved<strategy, To, Tr, Requantize32>; +template<typename strategy, typename To, typename Tr> +using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat, false>; + +template<typename strategy, typename To, typename Tr> +using GemmInterleavedDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat>; + } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index d1c4e49edb..321c97262f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -82,7 +82,7 @@ static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods "sme2_interleaved_nomerge_s8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_1VLx4VL, int8_t, int8_t>(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index b85b1c4fcf..93eecf991e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -78,7 +78,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp new file mode 100644 index 0000000000..38d9b763f6 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" + +#include "kernels/a64_gemm_s16_8x12.hpp" +#include "kernels/a64_gemm_s8_8x12.hpp" +#include "kernels/a64_gemm_s8_4x4.hpp" +#include "kernels/a64_interleaved_s8s32_mmla_8x12.hpp" + +#ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp" +#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp" +#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SME2 +#include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp" +#include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp" +#endif // ARM_COMPUTE_ENABLE_SVE + +#include "gemm_implementation.hpp" +#include "gemm_interleaved.hpp" +#include "utils.hpp" + +#include <cstdint> +#include <vector> +namespace arm_gemm { + +static const GemmImplementation<int8_t, float, DequantizeFloat> gemm_s8fp32_methods[] = +{ +#ifdef ARM_COMPUTE_ENABLE_SVE +#ifdef ARM_COMPUTE_ENABLE_SME2 +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, + [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>(); + return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL, int8_t, float>(args, dq); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8qfp32_mopa_4Vx1VL.hpp", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, + [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>(); + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL, int8_t, float>(args, dq); } +}, +{ + GemmMethod::GEMM_INTERLEAVED, + "sme2_interleaved_nomerge_s8qfp32_mopa_2Vx2VL.hpp", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, + nullptr, + [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL, int8_t, float>(args, dq); } +}, +#endif // ARM_COMPUTE_ENABLE_SME2 +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_interleaved_s8s32_mmla_8x3VL", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_svei8mm(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, float>(args, qp); } +), +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_interleaved_s8s32_dot_8x3VL", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sve(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, float>(args, qp); } +), +#endif // ARM_COMPUTE_ENABLE_SVE +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_interleaved_s8s32_mmla_8x12", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_i8mm(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, float>(args, qp); } +), +{ + GemmMethod::GEMM_INTERLEAVED, + "a64_gemm_s16_8x12", + nullptr, + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->get_cpu_model() == CPUModel::A53 && ((args._Msize > 28) || ((args._Msize % 8) > 4)); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s16_8x12, int8_t, float>(args, qp); } +}, +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_gemm_s8_8x12", + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_dotprod(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_gemm_s8_8x12, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s8_8x12, int8_t, float>(args, qp); } +), +GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_gemm_s8_4x4", + nullptr, + [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_gemm_s8_4x4, int8_t, float>::estimate_cycles<int8_t>(args); }, + [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s8_4x4, int8_t, float>(args, qp); } +), +{ + GemmMethod::DEFAULT, + "", + nullptr, + nullptr, + nullptr +} +}; + +template<> +const GemmImplementation<int8_t, float, DequantizeFloat> *gemm_implementation_list<int8_t, float, DequantizeFloat>() { + return gemm_s8fp32_methods; +} + +template UniqueGemmCommon<int8_t, float> gemm<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os); +template KernelDescription get_gemm_method<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os); +template std::vector<KernelDescription> get_compatible_kernels<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os); + +} // namespace arm_gemm + +#endif // __aarch64__
\ No newline at end of file diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index af5cfbbf2b..dfacb687a8 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020, 2022-2023 Arm Limited. + * Copyright (c) 2017-2020, 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -94,14 +94,14 @@ GemmImplementation<uint8_t, uint32_t>::with_estimate( { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_u8u32_dot_8x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); } }, { GemmMethod::GEMM_HYBRID, "a64_smallK_hybrid_u8u32_dot_6x4", - [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; }, [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); }, [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); } }, diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp index 92c884ce18..dbada36052 100644 --- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp @@ -180,7 +180,7 @@ public: this->_Cptr + (multi * this->_C_multi_stride) + n, (nmax - n), (kmax-k0), this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr, - _args._act, (k0 != 0), + _args._act, (k0 != 0) || _args._accumulate, _os, col_bias, n + (_args._Nsize * multi)); } } diff --git a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp index 59591935cd..7c09608e3e 100644 --- a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp +++ b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -330,11 +330,11 @@ template void Interleave<8, 2, VLType::None>(float *, const float *, size_t, uns #endif // ARM_COMPUTE_ENABLE_SVE && ARM_COMPUTE_ENABLE_SVEF32MM /* FP16 */ -#if defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16) template void IndirectInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void Interleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); -#endif // FP16_KERNELS ar __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // FP16_KERNELS ar ARM_COMPUTE_ENABLE_FP16 template void IndirectInterleave<8, 1, VLType::None>(float *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(float *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp index 923d008bb1..ac3cbf943f 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,10 @@ public: { if (std::is_same<T, float>::value) { switch (ci->get_cpu_model()) { + case CPUModel::V1: + return { 23.64 }; default: - return { 28.48 }; + return { 16.89 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp new file mode 100644 index 0000000000..98f7fc9403 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg<float>, \ + size_t, size_t, \ + const bfloat16 *, \ + size_t, \ + IndirectOutputArg<float>, \ + const float *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffhybrid_fp32bf16fp32_mmla_6x16( ARGLIST ); + +class cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16 +{ +public: + typedef float lhs_operand_type; + typedef bfloat16 rhs_operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 6; + } + static unsigned int stripe_width() + { + return 4; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL256_BL64_BF16; + } + + static unsigned int out_width() + { + return 16; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsFixed<rhs_operand_type, result_type, 6, 16, 4> transforms = {}; + template<typename T> + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same<T, float>::value) { + switch (ci->get_cpu_model()) { + case CPUModel::V1: + return { 21.05 }; + default: + return { 15.27 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffhybrid_fp32bf16fp32_mmla_6x16; + cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp new file mode 100644 index 0000000000..9ab4aa98f9 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp @@ -0,0 +1,3240 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" +#include "../../utils.hpp" +#include "../../bfloat.hpp" + +#include <cassert> +#include <limits> + +namespace arm_gemm { + +void a64_ffhybrid_fp32bf16fp32_mmla_6x16 ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg<float> A_arg, + size_t M, size_t N, const bfloat16 *B_ptr, size_t B_stride, IndirectOutputArg<float> output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast<float>(std::numeric_limits<float>::infinity()); + float minval = - static_cast<float>(std::numeric_limits<float>::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const bfloat16 *B_ptr = {}; + const bfloat16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + void *output_ptr = nullptr; + const float *bias = nullptr; + } ka; + + unsigned long flags=0; + void *input_ptr; + + if (output_arg.is_indirect) { + ka.output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + ka.output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.bias = bias; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 181f\n" + "cmp %x[M], #0x4\n" + "bgt 145f\n" + "beq 109f\n" + "cmp %x[M], #0x2\n" + "bgt 73f\n" + "beq 37f\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "2:" // Height 1: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 3f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 3f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 3f\n" + "mov x11, x12\n" + "3:" // Height 1: B setup done + "cbz x15, 4f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "b 16f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 15f\n" + "cmp x14, #0x10\n" + "bge 13f\n" + "tbz x14, #3, 8f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "tbz x14, #2, 6f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "tbz x14, #1, 5f\n" + "ldr d16, [x13], #0x8\n" + "mov x20, #0x38\n" + "tbz x14, #0, 12f\n" + "ld1 { v16.s }[2], [x13]\n" + "b 12f\n" + "5:" // Height 1: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 12f\n" + "ldr s16, [x13, #0x0]\n" + "b 12f\n" + "6:" // Height 1: Partial accumulate: partial_2_8 + "tbz x14, #1, 7f\n" + "ldr d11, [x13], #0x8\n" + "mov x20, #0x28\n" + "tbz x14, #0, 12f\n" + "ld1 { v11.s }[2], [x13]\n" + "b 12f\n" + "7:" // Height 1: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 12f\n" + "ldr s11, [x13, #0x0]\n" + "b 12f\n" + "8:" // Height 1: Partial accumulate: partial_4_0 + "tbz x14, #2, 10f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "tbz x14, #1, 9f\n" + "ldr d10, [x13], #0x8\n" + "mov x20, #0x18\n" + "tbz x14, #0, 12f\n" + "ld1 { v10.s }[2], [x13]\n" + "b 12f\n" + "9:" // Height 1: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 12f\n" + "ldr s10, [x13, #0x0]\n" + "b 12f\n" + "10:" // Height 1: Partial accumulate: partial_2_0 + "tbz x14, #1, 11f\n" + "ldr d9, [x13], #0x8\n" + "mov x20, #0x8\n" + "tbz x14, #0, 12f\n" + "ld1 { v9.s }[2], [x13]\n" + "b 12f\n" + "11:" // Height 1: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "mov x20, #0x0\n" + "12:" // Height 1: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 14f\n" + "13:" // Height 1: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "14:" // Height 1: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "b 16f\n" + "15:" // Height 1: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "16:" // Height 1: setup done + "mov x28, #0x0\n" + "17:" // Height 1: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 18f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "cbnz x28, 19f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "b 19f\n" + "18:" // Height 1: setup direct input + "mov x26, %x[input_ptr]\n" + "19:" // Height 1: input setup done + "cmp x27, #0x4\n" + "blt 22f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "cmp x27, #0x8\n" + "ldr q7, [x12, #0x10]\n" + "blt 21f\n" + "20:" // Height 1: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + "cmp x27, #0x8\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 20b\n" + "21:" // Height 1: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "22:" // Height 1: Multiply loop: Main loop skip + "cbz x27, 25f\n" + "cbz x27, 25f\n" + "tbz x27, #1, 23f\n" + "ldr d0, [x26], #0x8\n" + "tbz x27, #0, 24f\n" + "ld1 { v0.s }[2], [x26]\n" + "b 24f\n" + "23:" // Height 1: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "24:" // Height 1: Multiply loop: Ragged operand read: Done + "ldr q18, [x12, #0x0]\n" + "ldr q17, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x6e52ec08 // bfmmla v8.4s, v0.8h, v18.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e51ec0c // bfmmla v12.4s, v0.8h, v17.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "25:" // Height 1: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 17b\n" + "uzp1 v8.2d, v8.2d, v12.2d\n" + "uzp1 v9.2d, v9.2d, v13.2d\n" + "uzp1 v10.2d, v10.2d, v14.2d\n" + "uzp1 v11.2d, v11.2d, v15.2d\n" + "tbz %x[flags], #1, 26f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v18.4s }, [x21]\n" + "ld1r { v17.4s }, [x20]\n" + "fmin v8.4s, v8.4s, v18.4s\n" + "fmin v9.4s, v9.4s, v18.4s\n" + "fmin v10.4s, v10.4s, v18.4s\n" + "fmin v11.4s, v11.4s, v18.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v11.4s, v11.4s, v17.4s\n" + "26:" // Height 1: No activation + "cmp x14, #0x10\n" + "bge 35f\n" + "tbz x14, #3, 30f\n" + "st1 { v8.4s }, [x13], #0x10\n" + "st1 { v9.4s }, [x13], #0x10\n" + "tbz x14, #2, 28f\n" + "st1 { v10.4s }, [x13], #0x10\n" + "tbz x14, #1, 27f\n" + "str d11, [x13], #0x8\n" + "tbz x14, #0, 34f\n" + "st1 { v11.s }[2], [x13]\n" + "b 34f\n" + "27:" // Height 1: Partial direct writeback: partial_1_12 + "tbz x14, #0, 34f\n" + "str s11, [x13, #0x0]\n" + "b 34f\n" + "28:" // Height 1: Partial direct writeback: partial_2_8 + "tbz x14, #1, 29f\n" + "str d10, [x13], #0x8\n" + "tbz x14, #0, 34f\n" + "st1 { v10.s }[2], [x13]\n" + "b 34f\n" + "29:" // Height 1: Partial direct writeback: partial_1_8 + "tbz x14, #0, 34f\n" + "str s10, [x13, #0x0]\n" + "b 34f\n" + "30:" // Height 1: Partial direct writeback: partial_4_0 + "tbz x14, #2, 32f\n" + "st1 { v8.4s }, [x13], #0x10\n" + "tbz x14, #1, 31f\n" + "str d9, [x13], #0x8\n" + "tbz x14, #0, 34f\n" + "st1 { v9.s }[2], [x13]\n" + "b 34f\n" + "31:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x14, #0, 34f\n" + "str s9, [x13, #0x0]\n" + "b 34f\n" + "32:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x14, #1, 33f\n" + "str d8, [x13], #0x8\n" + "tbz x14, #0, 34f\n" + "st1 { v8.s }[2], [x13]\n" + "b 34f\n" + "33:" // Height 1: Partial direct writeback: partial_1_0 + "str s8, [x13, #0x0]\n" + "34:" // Height 1: Partial direct writeback: Done + "b 36f\n" + "35:" // Height 1: Full writeback + "str q8, [x13, #0x0]\n" + "str q9, [x13, #0x10]\n" + "str q10, [x13, #0x20]\n" + "str q11, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "36:" // Height 1: Writeback done + "subs x14, x14, #0x10\n" + "bgt 2b\n" + "b 218f\n" + "37:" // Height 2 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "38:" // Height 2: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 39f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 39f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 39f\n" + "mov x11, x12\n" + "39:" // Height 2: B setup done + "cbz x15, 40f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "b 52f\n" + "40:" // Height 2: no bias + "tbz %x[flags], #0, 51f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "bge 49f\n" + "tbz x14, #3, 44f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "tbz x14, #2, 42f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "tbz x14, #1, 41f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "tbz x14, #0, 48f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "b 48f\n" + "41:" // Height 2: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 48f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "b 48f\n" + "42:" // Height 2: Partial accumulate: partial_2_8 + "tbz x14, #1, 43f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "tbz x14, #0, 48f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "b 48f\n" + "43:" // Height 2: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 48f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "b 48f\n" + "44:" // Height 2: Partial accumulate: partial_4_0 + "tbz x14, #2, 46f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "tbz x14, #1, 45f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "tbz x14, #0, 48f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "b 48f\n" + "45:" // Height 2: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 48f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "b 48f\n" + "46:" // Height 2: Partial accumulate: partial_2_0 + "tbz x14, #1, 47f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "tbz x14, #0, 48f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "b 48f\n" + "47:" // Height 2: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "48:" // Height 2: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 50f\n" + "49:" // Height 2: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "50:" // Height 2: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "b 52f\n" + "51:" // Height 2: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "52:" // Height 2: setup done + "mov x28, #0x0\n" + "53:" // Height 2: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 54f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "cbnz x28, 55f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "b 55f\n" + "54:" // Height 2: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "55:" // Height 2: input setup done + "cmp x27, #0x4\n" + "blt 58f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "cmp x27, #0x8\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 57f\n" + "56:" // Height 2: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 56b\n" + "57:" // Height 2: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "58:" // Height 2: Multiply loop: Main loop skip + "cbz x27, 61f\n" + "cbz x27, 61f\n" + "tbz x27, #1, 59f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "tbz x27, #0, 60f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "b 60f\n" + "59:" // Height 2: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "60:" // Height 2: Multiply loop: Ragged operand read: Done + "ldr q18, [x12, #0x0]\n" + "ldr q17, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e52ec08 // bfmmla v8.4s, v0.8h, v18.8h\n" + "ldr q18, [x11, #0x0]\n" + ".inst 0x6e51ec0c // bfmmla v12.4s, v0.8h, v17.8h\n" + "ldr q17, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n" + "ldr q18, [x10, #0x0]\n" + ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n" + "ldr q17, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n" + "ldr q18, [x9, #0x0]\n" + ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n" + "ldr q17, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n" + ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n" + "61:" // Height 2: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 53b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "add x26, x13, x20, LSL #2\n" + "tbz %x[flags], #1, 62f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v18.4s }, [x21]\n" + "ld1r { v17.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v18.4s\n" + "fmin v12.4s, v12.4s, v18.4s\n" + "fmin v13.4s, v13.4s, v18.4s\n" + "fmin v14.4s, v14.4s, v18.4s\n" + "fmin v8.4s, v8.4s, v18.4s\n" + "fmin v9.4s, v9.4s, v18.4s\n" + "fmin v10.4s, v10.4s, v18.4s\n" + "fmin v11.4s, v11.4s, v18.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmax v12.4s, v12.4s, v17.4s\n" + "fmax v13.4s, v13.4s, v17.4s\n" + "fmax v14.4s, v14.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v11.4s, v11.4s, v17.4s\n" + "62:" // Height 2: No activation + "cmp x14, #0x10\n" + "bge 71f\n" + "tbz x14, #3, 66f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "tbz x14, #2, 64f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "tbz x14, #1, 63f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "tbz x14, #0, 70f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "b 70f\n" + "63:" // Height 2: Partial direct writeback: partial_1_12 + "tbz x14, #0, 70f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "b 70f\n" + "64:" // Height 2: Partial direct writeback: partial_2_8 + "tbz x14, #1, 65f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "tbz x14, #0, 70f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "b 70f\n" + "65:" // Height 2: Partial direct writeback: partial_1_8 + "tbz x14, #0, 70f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "b 70f\n" + "66:" // Height 2: Partial direct writeback: partial_4_0 + "tbz x14, #2, 68f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "tbz x14, #1, 67f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "tbz x14, #0, 70f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "b 70f\n" + "67:" // Height 2: Partial direct writeback: partial_1_4 + "tbz x14, #0, 70f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "b 70f\n" + "68:" // Height 2: Partial direct writeback: partial_2_0 + "tbz x14, #1, 69f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "tbz x14, #0, 70f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "b 70f\n" + "69:" // Height 2: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "70:" // Height 2: Partial direct writeback: Done + "b 72f\n" + "71:" // Height 2: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "72:" // Height 2: Writeback done + "subs x14, x14, #0x10\n" + "bgt 38b\n" + "b 218f\n" + "73:" // Height 3 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "74:" // Height 3: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 75f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 75f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 75f\n" + "mov x11, x12\n" + "75:" // Height 3: B setup done + "cbz x15, 76f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "b 88f\n" + "76:" // Height 3: no bias + "tbz %x[flags], #0, 87f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "bge 85f\n" + "tbz x14, #3, 80f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "ld1 { v18.4s }, [x25], #0x10\n" + "tbz x14, #2, 78f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "ld1 { v19.4s }, [x25], #0x10\n" + "tbz x14, #1, 77f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x25], #0x8\n" + "tbz x14, #0, 84f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "b 84f\n" + "77:" // Height 3: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 84f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "ldr s24, [x25, #0x0]\n" + "b 84f\n" + "78:" // Height 3: Partial accumulate: partial_2_8 + "tbz x14, #1, 79f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "ldr d19, [x25], #0x8\n" + "tbz x14, #0, 84f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "ld1 { v19.s }[2], [x25]\n" + "b 84f\n" + "79:" // Height 3: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 84f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "ldr s19, [x25, #0x0]\n" + "b 84f\n" + "80:" // Height 3: Partial accumulate: partial_4_0 + "tbz x14, #2, 82f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "tbz x14, #1, 81f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d18, [x25], #0x8\n" + "tbz x14, #0, 84f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "ld1 { v18.s }[2], [x25]\n" + "b 84f\n" + "81:" // Height 3: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 84f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "ldr s18, [x25, #0x0]\n" + "b 84f\n" + "82:" // Height 3: Partial accumulate: partial_2_0 + "tbz x14, #1, 83f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d17, [x25], #0x8\n" + "tbz x14, #0, 84f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "ld1 { v17.s }[2], [x25]\n" + "b 84f\n" + "83:" // Height 3: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s17, [x25, #0x0]\n" + "84:" // Height 3: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 86f\n" + "85:" // Height 3: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "ldr q17, [x25, #0x0]\n" + "ldr q18, [x25, #0x10]\n" + "ldr q19, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "86:" // Height 3: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "b 88f\n" + "87:" // Height 3: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "88:" // Height 3: setup done + "mov x28, #0x0\n" + "89:" // Height 3: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 90f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "cbnz x28, 91f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "b 91f\n" + "90:" // Height 3: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "91:" // Height 3: input setup done + "cmp x27, #0x4\n" + "blt 94f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "cmp x27, #0x8\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 93f\n" + "92:" // Height 3: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 92b\n" + "93:" // Height 3: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "94:" // Height 3: Multiply loop: Main loop skip + "cbz x27, 97f\n" + "cbz x27, 97f\n" + "tbz x27, #1, 95f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "tbz x27, #0, 96f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "ld1 { v2.s }[2], [x24]\n" + "b 96f\n" + "95:" // Height 3: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "ldr s2, [x24, #0x0]\n" + "96:" // Height 3: Multiply loop: Ragged operand read: Done + "ldr q26, [x12, #0x0]\n" + "ldr q25, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e5aec50 // bfmmla v16.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec54 // bfmmla v20.4s, v2.8h, v25.8h\n" + ".inst 0x6e5aec08 // bfmmla v8.4s, v0.8h, v26.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e59ec0c // bfmmla v12.4s, v0.8h, v25.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "97:" // Height 3: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 89b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "add x26, x13, x20, LSL #2\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "add x25, x26, x20, LSL #2\n" + "uzp1 v16.2d, v16.2d, v20.2d\n" + "uzp1 v17.2d, v17.2d, v21.2d\n" + "uzp1 v18.2d, v18.2d, v22.2d\n" + "uzp1 v19.2d, v19.2d, v23.2d\n" + "tbz %x[flags], #1, 98f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v26.4s }, [x21]\n" + "ld1r { v25.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v26.4s\n" + "fmin v12.4s, v12.4s, v26.4s\n" + "fmin v13.4s, v13.4s, v26.4s\n" + "fmin v14.4s, v14.4s, v26.4s\n" + "fmin v8.4s, v8.4s, v26.4s\n" + "fmin v9.4s, v9.4s, v26.4s\n" + "fmin v10.4s, v10.4s, v26.4s\n" + "fmin v11.4s, v11.4s, v26.4s\n" + "fmin v16.4s, v16.4s, v26.4s\n" + "fmin v17.4s, v17.4s, v26.4s\n" + "fmin v18.4s, v18.4s, v26.4s\n" + "fmin v19.4s, v19.4s, v26.4s\n" + "fmax v6.4s, v6.4s, v25.4s\n" + "fmax v12.4s, v12.4s, v25.4s\n" + "fmax v13.4s, v13.4s, v25.4s\n" + "fmax v14.4s, v14.4s, v25.4s\n" + "fmax v8.4s, v8.4s, v25.4s\n" + "fmax v9.4s, v9.4s, v25.4s\n" + "fmax v10.4s, v10.4s, v25.4s\n" + "fmax v11.4s, v11.4s, v25.4s\n" + "fmax v16.4s, v16.4s, v25.4s\n" + "fmax v17.4s, v17.4s, v25.4s\n" + "fmax v18.4s, v18.4s, v25.4s\n" + "fmax v19.4s, v19.4s, v25.4s\n" + "98:" // Height 3: No activation + "cmp x14, #0x10\n" + "bge 107f\n" + "tbz x14, #3, 102f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "st1 { v16.4s }, [x25], #0x10\n" + "st1 { v17.4s }, [x25], #0x10\n" + "tbz x14, #2, 100f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "st1 { v18.4s }, [x25], #0x10\n" + "tbz x14, #1, 99f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "str d19, [x25], #0x8\n" + "tbz x14, #0, 106f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "st1 { v19.s }[2], [x25]\n" + "b 106f\n" + "99:" // Height 3: Partial direct writeback: partial_1_12 + "tbz x14, #0, 106f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "str s19, [x25, #0x0]\n" + "b 106f\n" + "100:" // Height 3: Partial direct writeback: partial_2_8 + "tbz x14, #1, 101f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "str d18, [x25], #0x8\n" + "tbz x14, #0, 106f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v18.s }[2], [x25]\n" + "b 106f\n" + "101:" // Height 3: Partial direct writeback: partial_1_8 + "tbz x14, #0, 106f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s18, [x25, #0x0]\n" + "b 106f\n" + "102:" // Height 3: Partial direct writeback: partial_4_0 + "tbz x14, #2, 104f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v16.4s }, [x25], #0x10\n" + "tbz x14, #1, 103f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "str d17, [x25], #0x8\n" + "tbz x14, #0, 106f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "st1 { v17.s }[2], [x25]\n" + "b 106f\n" + "103:" // Height 3: Partial direct writeback: partial_1_4 + "tbz x14, #0, 106f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "str s17, [x25, #0x0]\n" + "b 106f\n" + "104:" // Height 3: Partial direct writeback: partial_2_0 + "tbz x14, #1, 105f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "str d16, [x25], #0x8\n" + "tbz x14, #0, 106f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v16.s }[2], [x25]\n" + "b 106f\n" + "105:" // Height 3: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "str s16, [x25, #0x0]\n" + "106:" // Height 3: Partial direct writeback: Done + "b 108f\n" + "107:" // Height 3: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "str q16, [x25, #0x0]\n" + "str q17, [x25, #0x10]\n" + "str q18, [x25, #0x20]\n" + "str q19, [x25, #0x30]\n" + "108:" // Height 3: Writeback done + "subs x14, x14, #0x10\n" + "bgt 74b\n" + "b 218f\n" + "109:" // Height 4 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "110:" // Height 4: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 111f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 111f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 111f\n" + "mov x11, x12\n" + "111:" // Height 4: B setup done + "cbz x15, 112f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "b 124f\n" + "112:" // Height 4: no bias + "tbz %x[flags], #0, 123f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "bge 121f\n" + "tbz x14, #3, 116f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "ld1 { v18.4s }, [x25], #0x10\n" + "ld1 { v21.4s }, [x24], #0x10\n" + "tbz x14, #2, 114f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "ld1 { v19.4s }, [x25], #0x10\n" + "ld1 { v22.4s }, [x24], #0x10\n" + "tbz x14, #1, 113f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x25], #0x8\n" + "ldr d23, [x24], #0x8\n" + "tbz x14, #0, 120f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v23.s }[2], [x24]\n" + "b 120f\n" + "113:" // Height 4: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 120f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "ldr s24, [x25, #0x0]\n" + "ldr s23, [x24, #0x0]\n" + "b 120f\n" + "114:" // Height 4: Partial accumulate: partial_2_8 + "tbz x14, #1, 115f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "ldr d19, [x25], #0x8\n" + "ldr d22, [x24], #0x8\n" + "tbz x14, #0, 120f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "ld1 { v19.s }[2], [x25]\n" + "ld1 { v22.s }[2], [x24]\n" + "b 120f\n" + "115:" // Height 4: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 120f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "ldr s19, [x25, #0x0]\n" + "ldr s22, [x24, #0x0]\n" + "b 120f\n" + "116:" // Height 4: Partial accumulate: partial_4_0 + "tbz x14, #2, 118f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "tbz x14, #1, 117f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d18, [x25], #0x8\n" + "ldr d21, [x24], #0x8\n" + "tbz x14, #0, 120f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "ld1 { v18.s }[2], [x25]\n" + "ld1 { v21.s }[2], [x24]\n" + "b 120f\n" + "117:" // Height 4: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 120f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "ldr s18, [x25, #0x0]\n" + "ldr s21, [x24, #0x0]\n" + "b 120f\n" + "118:" // Height 4: Partial accumulate: partial_2_0 + "tbz x14, #1, 119f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d17, [x25], #0x8\n" + "ldr d20, [x24], #0x8\n" + "tbz x14, #0, 120f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "ld1 { v17.s }[2], [x25]\n" + "ld1 { v20.s }[2], [x24]\n" + "b 120f\n" + "119:" // Height 4: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s17, [x25, #0x0]\n" + "ldr s20, [x24, #0x0]\n" + "120:" // Height 4: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 122f\n" + "121:" // Height 4: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "ldr q17, [x25, #0x0]\n" + "ldr q18, [x25, #0x10]\n" + "ldr q19, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "ldr q20, [x24, #0x0]\n" + "ldr q21, [x24, #0x10]\n" + "ldr q22, [x24, #0x20]\n" + "ldr q23, [x24, #0x30]\n" + "122:" // Height 4: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "b 124f\n" + "123:" // Height 4: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "124:" // Height 4: setup done + "mov x28, #0x0\n" + "125:" // Height 4: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 126f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "cbnz x28, 127f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "b 127f\n" + "126:" // Height 4: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "127:" // Height 4: input setup done + "cmp x27, #0x4\n" + "blt 130f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "cmp x27, #0x8\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "ld1 { v3.4s }, [x23], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 129f\n" + "128:" // Height 4: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + "ld1 { v3.4s }, [x23], #0x10\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 128b\n" + "129:" // Height 4: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "130:" // Height 4: Multiply loop: Main loop skip + "cbz x27, 133f\n" + "cbz x27, 133f\n" + "tbz x27, #1, 131f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "ldr d3, [x23], #0x8\n" + "tbz x27, #0, 132f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "ld1 { v2.s }[2], [x24]\n" + "ld1 { v3.s }[2], [x23]\n" + "b 132f\n" + "131:" // Height 4: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "ldr s2, [x24, #0x0]\n" + "ldr s3, [x23, #0x0]\n" + "132:" // Height 4: Multiply loop: Ragged operand read: Done + "ldr q26, [x12, #0x0]\n" + "ldr q25, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e5aec08 // bfmmla v8.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec50 // bfmmla v16.4s, v2.8h, v26.8h\n" + "ldr q26, [x11, #0x0]\n" + ".inst 0x6e59ec0c // bfmmla v12.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec54 // bfmmla v20.4s, v2.8h, v25.8h\n" + "ldr q25, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n" + "ldr q26, [x10, #0x0]\n" + ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n" + "ldr q25, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n" + "ldr q26, [x9, #0x0]\n" + ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n" + "ldr q25, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n" + ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n" + ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n" + ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n" + "133:" // Height 4: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 125b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "add x24, x25, x20, LSL #2\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "tbz %x[flags], #1, 134f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v26.4s }, [x21]\n" + "ld1r { v25.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v26.4s\n" + "fmin v12.4s, v12.4s, v26.4s\n" + "fmin v13.4s, v13.4s, v26.4s\n" + "fmin v14.4s, v14.4s, v26.4s\n" + "fmin v8.4s, v8.4s, v26.4s\n" + "fmin v9.4s, v9.4s, v26.4s\n" + "fmin v10.4s, v10.4s, v26.4s\n" + "fmin v11.4s, v11.4s, v26.4s\n" + "fmin v15.4s, v15.4s, v26.4s\n" + "fmin v20.4s, v20.4s, v26.4s\n" + "fmin v21.4s, v21.4s, v26.4s\n" + "fmin v22.4s, v22.4s, v26.4s\n" + "fmin v16.4s, v16.4s, v26.4s\n" + "fmin v17.4s, v17.4s, v26.4s\n" + "fmin v18.4s, v18.4s, v26.4s\n" + "fmin v19.4s, v19.4s, v26.4s\n" + "fmax v6.4s, v6.4s, v25.4s\n" + "fmax v12.4s, v12.4s, v25.4s\n" + "fmax v13.4s, v13.4s, v25.4s\n" + "fmax v14.4s, v14.4s, v25.4s\n" + "fmax v8.4s, v8.4s, v25.4s\n" + "fmax v9.4s, v9.4s, v25.4s\n" + "fmax v10.4s, v10.4s, v25.4s\n" + "fmax v11.4s, v11.4s, v25.4s\n" + "fmax v15.4s, v15.4s, v25.4s\n" + "fmax v20.4s, v20.4s, v25.4s\n" + "fmax v21.4s, v21.4s, v25.4s\n" + "fmax v22.4s, v22.4s, v25.4s\n" + "fmax v16.4s, v16.4s, v25.4s\n" + "fmax v17.4s, v17.4s, v25.4s\n" + "fmax v18.4s, v18.4s, v25.4s\n" + "fmax v19.4s, v19.4s, v25.4s\n" + "134:" // Height 4: No activation + "cmp x14, #0x10\n" + "bge 143f\n" + "tbz x14, #3, 138f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v17.4s }, [x24], #0x10\n" + "tbz x14, #2, 136f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v18.4s }, [x24], #0x10\n" + "tbz x14, #1, 135f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "str d22, [x25], #0x8\n" + "str d19, [x24], #0x8\n" + "tbz x14, #0, 142f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v19.s }[2], [x24]\n" + "b 142f\n" + "135:" // Height 4: Partial direct writeback: partial_1_12 + "tbz x14, #0, 142f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s19, [x24, #0x0]\n" + "b 142f\n" + "136:" // Height 4: Partial direct writeback: partial_2_8 + "tbz x14, #1, 137f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "str d21, [x25], #0x8\n" + "str d18, [x24], #0x8\n" + "tbz x14, #0, 142f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v18.s }[2], [x24]\n" + "b 142f\n" + "137:" // Height 4: Partial direct writeback: partial_1_8 + "tbz x14, #0, 142f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s18, [x24, #0x0]\n" + "b 142f\n" + "138:" // Height 4: Partial direct writeback: partial_4_0 + "tbz x14, #2, 140f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "tbz x14, #1, 139f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "str d20, [x25], #0x8\n" + "str d17, [x24], #0x8\n" + "tbz x14, #0, 142f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v17.s }[2], [x24]\n" + "b 142f\n" + "139:" // Height 4: Partial direct writeback: partial_1_4 + "tbz x14, #0, 142f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s17, [x24, #0x0]\n" + "b 142f\n" + "140:" // Height 4: Partial direct writeback: partial_2_0 + "tbz x14, #1, 141f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "str d15, [x25], #0x8\n" + "str d16, [x24], #0x8\n" + "tbz x14, #0, 142f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v15.s }[2], [x25]\n" + "st1 { v16.s }[2], [x24]\n" + "b 142f\n" + "141:" // Height 4: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "str s15, [x25, #0x0]\n" + "str s16, [x24, #0x0]\n" + "142:" // Height 4: Partial direct writeback: Done + "b 144f\n" + "143:" // Height 4: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "str q15, [x25, #0x0]\n" + "str q20, [x25, #0x10]\n" + "str q21, [x25, #0x20]\n" + "str q22, [x25, #0x30]\n" + "str q16, [x24, #0x0]\n" + "str q17, [x24, #0x10]\n" + "str q18, [x24, #0x20]\n" + "str q19, [x24, #0x30]\n" + "144:" // Height 4: Writeback done + "subs x14, x14, #0x10\n" + "bgt 110b\n" + "b 218f\n" + "145:" // Height 5 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "146:" // Height 5: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 147f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 147f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 147f\n" + "mov x11, x12\n" + "147:" // Height 5: B setup done + "cbz x15, 148f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "mov v24.16b, v8.16b\n" + "mov v28.16b, v12.16b\n" + "mov v25.16b, v9.16b\n" + "mov v29.16b, v13.16b\n" + "mov v26.16b, v10.16b\n" + "mov v30.16b, v14.16b\n" + "mov v27.16b, v11.16b\n" + "mov v31.16b, v15.16b\n" + "b 160f\n" + "148:" // Height 5: no bias + "tbz %x[flags], #0, 159f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "bge 157f\n" + "tbz x14, #3, 152f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "ld1 { v18.4s }, [x25], #0x10\n" + "ld1 { v21.4s }, [x24], #0x10\n" + "ld1 { v26.4s }, [x23], #0x10\n" + "tbz x14, #2, 150f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "ld1 { v19.4s }, [x25], #0x10\n" + "ld1 { v22.4s }, [x24], #0x10\n" + "ld1 { v27.4s }, [x23], #0x10\n" + "tbz x14, #1, 149f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x25], #0x8\n" + "ldr d23, [x24], #0x8\n" + "ldr d6, [x23], #0x8\n" + "tbz x14, #0, 156f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v23.s }[2], [x24]\n" + "ld1 { v6.s }[2], [x23]\n" + "b 156f\n" + "149:" // Height 5: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 156f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "ldr s24, [x25, #0x0]\n" + "ldr s23, [x24, #0x0]\n" + "ldr s6, [x23, #0x0]\n" + "b 156f\n" + "150:" // Height 5: Partial accumulate: partial_2_8 + "tbz x14, #1, 151f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "ldr d19, [x25], #0x8\n" + "ldr d22, [x24], #0x8\n" + "ldr d27, [x23], #0x8\n" + "tbz x14, #0, 156f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "ld1 { v19.s }[2], [x25]\n" + "ld1 { v22.s }[2], [x24]\n" + "ld1 { v27.s }[2], [x23]\n" + "b 156f\n" + "151:" // Height 5: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 156f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "ldr s19, [x25, #0x0]\n" + "ldr s22, [x24, #0x0]\n" + "ldr s27, [x23, #0x0]\n" + "b 156f\n" + "152:" // Height 5: Partial accumulate: partial_4_0 + "tbz x14, #2, 154f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "tbz x14, #1, 153f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d18, [x25], #0x8\n" + "ldr d21, [x24], #0x8\n" + "ldr d26, [x23], #0x8\n" + "tbz x14, #0, 156f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "ld1 { v18.s }[2], [x25]\n" + "ld1 { v21.s }[2], [x24]\n" + "ld1 { v26.s }[2], [x23]\n" + "b 156f\n" + "153:" // Height 5: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 156f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "ldr s18, [x25, #0x0]\n" + "ldr s21, [x24, #0x0]\n" + "ldr s26, [x23, #0x0]\n" + "b 156f\n" + "154:" // Height 5: Partial accumulate: partial_2_0 + "tbz x14, #1, 155f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d17, [x25], #0x8\n" + "ldr d20, [x24], #0x8\n" + "ldr d25, [x23], #0x8\n" + "tbz x14, #0, 156f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "ld1 { v17.s }[2], [x25]\n" + "ld1 { v20.s }[2], [x24]\n" + "ld1 { v25.s }[2], [x23]\n" + "b 156f\n" + "155:" // Height 5: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s17, [x25, #0x0]\n" + "ldr s20, [x24, #0x0]\n" + "ldr s25, [x23, #0x0]\n" + "156:" // Height 5: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 158f\n" + "157:" // Height 5: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "ldr q17, [x25, #0x0]\n" + "ldr q18, [x25, #0x10]\n" + "ldr q19, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "ldr q20, [x24, #0x0]\n" + "ldr q21, [x24, #0x10]\n" + "ldr q22, [x24, #0x20]\n" + "ldr q23, [x24, #0x30]\n" + "ldr q25, [x23, #0x0]\n" + "ldr q26, [x23, #0x10]\n" + "ldr q27, [x23, #0x20]\n" + "ldr q6, [x23, #0x30]\n" + "158:" // Height 5: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "zip1 v24.2d, v25.2d, v28.2d\n" + "zip2 v28.2d, v25.2d, v28.2d\n" + "zip1 v25.2d, v26.2d, v29.2d\n" + "zip2 v29.2d, v26.2d, v29.2d\n" + "zip1 v26.2d, v27.2d, v30.2d\n" + "zip2 v30.2d, v27.2d, v30.2d\n" + "zip1 v27.2d, v6.2d, v31.2d\n" + "zip2 v31.2d, v6.2d, v31.2d\n" + "b 160f\n" + "159:" // Height 5: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "160:" // Height 5: setup done + "mov x28, #0x0\n" + "161:" // Height 5: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 162f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "ldr x22, [x20, #0x20]\n" + "cbnz x28, 163f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "b 163f\n" + "162:" // Height 5: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "add x22, x23, x21, LSL #2\n" + "163:" // Height 5: input setup done + "cmp x27, #0x4\n" + "blt 166f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "cmp x27, #0x8\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "ld1 { v3.4s }, [x23], #0x10\n" + "ld1 { v4.4s }, [x22], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 165f\n" + "164:" // Height 5: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + "ld1 { v3.4s }, [x23], #0x10\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q6, [x11, #0x0]\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q5, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e45ec0d // bfmmla v13.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n" + ".inst 0x6e45ec9d // bfmmla v29.4s, v4.8h, v5.8h\n" + "ldr q5, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n" + ".inst 0x6e45ec9e // bfmmla v30.4s, v4.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e45ec0f // bfmmla v15.4s, v0.8h, v5.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + ".inst 0x6e45ec57 // bfmmla v23.4s, v2.8h, v5.8h\n" + "ld1 { v2.4s }, [x24], #0x10\n" + ".inst 0x6e45ec9f // bfmmla v31.4s, v4.8h, v5.8h\n" + "ld1 { v4.4s }, [x22], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 164b\n" + "165:" // Height 5: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q3, [x11, #0x0]\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q1, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n" + "ldr q3, [x10, #0x0]\n" + ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n" + "ldr q1, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n" + "ldr q3, [x9, #0x0]\n" + ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n" + "ldr q1, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n" + ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n" + "166:" // Height 5: Multiply loop: Main loop skip + "cbz x27, 169f\n" + "cbz x27, 169f\n" + "tbz x27, #1, 167f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "ldr d3, [x23], #0x8\n" + "ldr d4, [x22], #0x8\n" + "tbz x27, #0, 168f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "ld1 { v2.s }[2], [x24]\n" + "ld1 { v3.s }[2], [x23]\n" + "ld1 { v4.s }[2], [x22]\n" + "b 168f\n" + "167:" // Height 5: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "ldr s2, [x24, #0x0]\n" + "ldr s3, [x23, #0x0]\n" + "ldr s4, [x22, #0x0]\n" + "168:" // Height 5: Multiply loop: Ragged operand read: Done + "ldr q6, [x12, #0x0]\n" + "ldr q5, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + ".inst 0x6e45ec0c // bfmmla v12.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec9c // bfmmla v28.4s, v4.8h, v5.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q3, [x11, #0x0]\n" + ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n" + "ldr q1, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n" + "ldr q3, [x10, #0x0]\n" + ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n" + "ldr q1, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n" + "ldr q3, [x9, #0x0]\n" + ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n" + "ldr q1, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n" + ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n" + "169:" // Height 5: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 161b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "add x23, x24, x20, LSL #2\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "uzp1 v24.2d, v24.2d, v28.2d\n" + "uzp1 v25.2d, v25.2d, v29.2d\n" + "uzp1 v26.2d, v26.2d, v30.2d\n" + "uzp1 v27.2d, v27.2d, v31.2d\n" + "tbz %x[flags], #1, 170f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v1.4s }, [x21]\n" + "ld1r { v0.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmax v6.4s, v6.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "170:" // Height 5: No activation + "cmp x14, #0x10\n" + "bge 179f\n" + "tbz x14, #3, 174f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v17.4s }, [x24], #0x10\n" + "st1 { v24.4s }, [x23], #0x10\n" + "st1 { v25.4s }, [x23], #0x10\n" + "tbz x14, #2, 172f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v18.4s }, [x24], #0x10\n" + "st1 { v26.4s }, [x23], #0x10\n" + "tbz x14, #1, 171f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "str d22, [x25], #0x8\n" + "str d19, [x24], #0x8\n" + "str d27, [x23], #0x8\n" + "tbz x14, #0, 178f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v19.s }[2], [x24]\n" + "st1 { v27.s }[2], [x23]\n" + "b 178f\n" + "171:" // Height 5: Partial direct writeback: partial_1_12 + "tbz x14, #0, 178f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s19, [x24, #0x0]\n" + "str s27, [x23, #0x0]\n" + "b 178f\n" + "172:" // Height 5: Partial direct writeback: partial_2_8 + "tbz x14, #1, 173f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "str d21, [x25], #0x8\n" + "str d18, [x24], #0x8\n" + "str d26, [x23], #0x8\n" + "tbz x14, #0, 178f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v18.s }[2], [x24]\n" + "st1 { v26.s }[2], [x23]\n" + "b 178f\n" + "173:" // Height 5: Partial direct writeback: partial_1_8 + "tbz x14, #0, 178f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s18, [x24, #0x0]\n" + "str s26, [x23, #0x0]\n" + "b 178f\n" + "174:" // Height 5: Partial direct writeback: partial_4_0 + "tbz x14, #2, 176f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v24.4s }, [x23], #0x10\n" + "tbz x14, #1, 175f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "str d20, [x25], #0x8\n" + "str d17, [x24], #0x8\n" + "str d25, [x23], #0x8\n" + "tbz x14, #0, 178f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v17.s }[2], [x24]\n" + "st1 { v25.s }[2], [x23]\n" + "b 178f\n" + "175:" // Height 5: Partial direct writeback: partial_1_4 + "tbz x14, #0, 178f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s17, [x24, #0x0]\n" + "str s25, [x23, #0x0]\n" + "b 178f\n" + "176:" // Height 5: Partial direct writeback: partial_2_0 + "tbz x14, #1, 177f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "str d15, [x25], #0x8\n" + "str d16, [x24], #0x8\n" + "str d24, [x23], #0x8\n" + "tbz x14, #0, 178f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v15.s }[2], [x25]\n" + "st1 { v16.s }[2], [x24]\n" + "st1 { v24.s }[2], [x23]\n" + "b 178f\n" + "177:" // Height 5: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "str s15, [x25, #0x0]\n" + "str s16, [x24, #0x0]\n" + "str s24, [x23, #0x0]\n" + "178:" // Height 5: Partial direct writeback: Done + "b 180f\n" + "179:" // Height 5: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "str q15, [x25, #0x0]\n" + "str q20, [x25, #0x10]\n" + "str q21, [x25, #0x20]\n" + "str q22, [x25, #0x30]\n" + "str q16, [x24, #0x0]\n" + "str q17, [x24, #0x10]\n" + "str q18, [x24, #0x20]\n" + "str q19, [x24, #0x30]\n" + "str q24, [x23, #0x0]\n" + "str q25, [x23, #0x10]\n" + "str q26, [x23, #0x20]\n" + "str q27, [x23, #0x30]\n" + "180:" // Height 5: Writeback done + "subs x14, x14, #0x10\n" + "bgt 146b\n" + "b 218f\n" + "181:" // Height 6 + "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n" + "mov x21, #0x18\n" + "ldr x14, [%x[args_ptr], %[offsetof_N]]\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "madd x21, x20, x21, x13\n" + "str x21, [%x[args_ptr], %[offsetof_output_ptr]]\n" + "182:" // Height 6: Column loop + "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cmp x14, #0xc\n" + "add x11, x12, x20, LSL #1\n" + "add x10, x11, x20, LSL #1\n" + "add x9, x10, x20, LSL #1\n" + "add x20, x9, x20, LSL #1\n" + "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 183f\n" + "cmp x14, #0x8\n" + "mov x9, x12\n" + "bgt 183f\n" + "cmp x14, #0x4\n" + "mov x10, x12\n" + "bgt 183f\n" + "mov x11, x12\n" + "183:" // Height 6: B setup done + "cbz x15, 184f\n" + "ldr q8, [x15, #0x0]\n" + "ldr q9, [x15, #0x10]\n" + "ldr q10, [x15, #0x20]\n" + "ldr q11, [x15, #0x30]\n" + "add x15, x15, #0x40\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "mov v24.16b, v8.16b\n" + "mov v28.16b, v12.16b\n" + "mov v25.16b, v9.16b\n" + "mov v29.16b, v13.16b\n" + "mov v26.16b, v10.16b\n" + "mov v30.16b, v14.16b\n" + "mov v27.16b, v11.16b\n" + "mov v31.16b, v15.16b\n" + "b 196f\n" + "184:" // Height 6: no bias + "tbz %x[flags], #0, 195f\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x14, #0x10\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "add x23, x24, x20, LSL #2\n" + "add x22, x23, x20, LSL #2\n" + "bge 193f\n" + "tbz x14, #3, 188f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "ld1 { v28.4s }, [x22], #0x10\n" + "ld1 { v10.4s }, [x13], #0x10\n" + "ld1 { v13.4s }, [x26], #0x10\n" + "ld1 { v18.4s }, [x25], #0x10\n" + "ld1 { v21.4s }, [x24], #0x10\n" + "ld1 { v26.4s }, [x23], #0x10\n" + "ld1 { v29.4s }, [x22], #0x10\n" + "tbz x14, #2, 186f\n" + "ld1 { v11.4s }, [x13], #0x10\n" + "ld1 { v14.4s }, [x26], #0x10\n" + "ld1 { v19.4s }, [x25], #0x10\n" + "ld1 { v22.4s }, [x24], #0x10\n" + "ld1 { v27.4s }, [x23], #0x10\n" + "ld1 { v30.4s }, [x22], #0x10\n" + "tbz x14, #1, 185f\n" + "ldr d16, [x13], #0x8\n" + "ldr d15, [x26], #0x8\n" + "mov x20, #0x38\n" + "ldr d24, [x25], #0x8\n" + "ldr d23, [x24], #0x8\n" + "ldr d6, [x23], #0x8\n" + "ldr d31, [x22], #0x8\n" + "tbz x14, #0, 192f\n" + "ld1 { v16.s }[2], [x13]\n" + "ld1 { v15.s }[2], [x26]\n" + "ld1 { v24.s }[2], [x25]\n" + "ld1 { v23.s }[2], [x24]\n" + "ld1 { v6.s }[2], [x23]\n" + "ld1 { v31.s }[2], [x22]\n" + "b 192f\n" + "185:" // Height 6: Partial accumulate: partial_1_12 + "mov x20, #0x30\n" + "tbz x14, #0, 192f\n" + "ldr s16, [x13, #0x0]\n" + "ldr s15, [x26, #0x0]\n" + "ldr s24, [x25, #0x0]\n" + "ldr s23, [x24, #0x0]\n" + "ldr s6, [x23, #0x0]\n" + "ldr s31, [x22, #0x0]\n" + "b 192f\n" + "186:" // Height 6: Partial accumulate: partial_2_8 + "tbz x14, #1, 187f\n" + "ldr d11, [x13], #0x8\n" + "ldr d14, [x26], #0x8\n" + "mov x20, #0x28\n" + "ldr d19, [x25], #0x8\n" + "ldr d22, [x24], #0x8\n" + "ldr d27, [x23], #0x8\n" + "ldr d30, [x22], #0x8\n" + "tbz x14, #0, 192f\n" + "ld1 { v11.s }[2], [x13]\n" + "ld1 { v14.s }[2], [x26]\n" + "ld1 { v19.s }[2], [x25]\n" + "ld1 { v22.s }[2], [x24]\n" + "ld1 { v27.s }[2], [x23]\n" + "ld1 { v30.s }[2], [x22]\n" + "b 192f\n" + "187:" // Height 6: Partial accumulate: partial_1_8 + "mov x20, #0x20\n" + "tbz x14, #0, 192f\n" + "ldr s11, [x13, #0x0]\n" + "ldr s14, [x26, #0x0]\n" + "ldr s19, [x25, #0x0]\n" + "ldr s22, [x24, #0x0]\n" + "ldr s27, [x23, #0x0]\n" + "ldr s30, [x22, #0x0]\n" + "b 192f\n" + "188:" // Height 6: Partial accumulate: partial_4_0 + "tbz x14, #2, 190f\n" + "ld1 { v9.4s }, [x13], #0x10\n" + "ld1 { v12.4s }, [x26], #0x10\n" + "ld1 { v17.4s }, [x25], #0x10\n" + "ld1 { v20.4s }, [x24], #0x10\n" + "ld1 { v25.4s }, [x23], #0x10\n" + "ld1 { v28.4s }, [x22], #0x10\n" + "tbz x14, #1, 189f\n" + "ldr d10, [x13], #0x8\n" + "ldr d13, [x26], #0x8\n" + "mov x20, #0x18\n" + "ldr d18, [x25], #0x8\n" + "ldr d21, [x24], #0x8\n" + "ldr d26, [x23], #0x8\n" + "ldr d29, [x22], #0x8\n" + "tbz x14, #0, 192f\n" + "ld1 { v10.s }[2], [x13]\n" + "ld1 { v13.s }[2], [x26]\n" + "ld1 { v18.s }[2], [x25]\n" + "ld1 { v21.s }[2], [x24]\n" + "ld1 { v26.s }[2], [x23]\n" + "ld1 { v29.s }[2], [x22]\n" + "b 192f\n" + "189:" // Height 6: Partial accumulate: partial_1_4 + "mov x20, #0x10\n" + "tbz x14, #0, 192f\n" + "ldr s10, [x13, #0x0]\n" + "ldr s13, [x26, #0x0]\n" + "ldr s18, [x25, #0x0]\n" + "ldr s21, [x24, #0x0]\n" + "ldr s26, [x23, #0x0]\n" + "ldr s29, [x22, #0x0]\n" + "b 192f\n" + "190:" // Height 6: Partial accumulate: partial_2_0 + "tbz x14, #1, 191f\n" + "ldr d9, [x13], #0x8\n" + "ldr d12, [x26], #0x8\n" + "mov x20, #0x8\n" + "ldr d17, [x25], #0x8\n" + "ldr d20, [x24], #0x8\n" + "ldr d25, [x23], #0x8\n" + "ldr d28, [x22], #0x8\n" + "tbz x14, #0, 192f\n" + "ld1 { v9.s }[2], [x13]\n" + "ld1 { v12.s }[2], [x26]\n" + "ld1 { v17.s }[2], [x25]\n" + "ld1 { v20.s }[2], [x24]\n" + "ld1 { v25.s }[2], [x23]\n" + "ld1 { v28.s }[2], [x22]\n" + "b 192f\n" + "191:" // Height 6: Partial accumulate: partial_1_0 + "ldr s9, [x13, #0x0]\n" + "ldr s12, [x26, #0x0]\n" + "mov x20, #0x0\n" + "ldr s17, [x25, #0x0]\n" + "ldr s20, [x24, #0x0]\n" + "ldr s25, [x23, #0x0]\n" + "ldr s28, [x22, #0x0]\n" + "192:" // Height 6: Partial accumulate: Done + "sub x13, x13, x20\n" + "b 194f\n" + "193:" // Height 6: full accumulate + "ldr q9, [x13, #0x0]\n" + "ldr q10, [x13, #0x10]\n" + "ldr q11, [x13, #0x20]\n" + "ldr q16, [x13, #0x30]\n" + "ldr q12, [x26, #0x0]\n" + "ldr q13, [x26, #0x10]\n" + "ldr q14, [x26, #0x20]\n" + "ldr q15, [x26, #0x30]\n" + "ldr q17, [x25, #0x0]\n" + "ldr q18, [x25, #0x10]\n" + "ldr q19, [x25, #0x20]\n" + "ldr q24, [x25, #0x30]\n" + "ldr q20, [x24, #0x0]\n" + "ldr q21, [x24, #0x10]\n" + "ldr q22, [x24, #0x20]\n" + "ldr q23, [x24, #0x30]\n" + "ldr q25, [x23, #0x0]\n" + "ldr q26, [x23, #0x10]\n" + "ldr q27, [x23, #0x20]\n" + "ldr q6, [x23, #0x30]\n" + "ldr q28, [x22, #0x0]\n" + "ldr q29, [x22, #0x10]\n" + "ldr q30, [x22, #0x20]\n" + "ldr q31, [x22, #0x30]\n" + "194:" // Height 6: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "zip1 v24.2d, v25.2d, v28.2d\n" + "zip2 v28.2d, v25.2d, v28.2d\n" + "zip1 v25.2d, v26.2d, v29.2d\n" + "zip2 v29.2d, v26.2d, v29.2d\n" + "zip1 v26.2d, v27.2d, v30.2d\n" + "zip2 v30.2d, v27.2d, v30.2d\n" + "zip1 v27.2d, v6.2d, v31.2d\n" + "zip2 v31.2d, v6.2d, v31.2d\n" + "b 196f\n" + "195:" // Height 6: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "196:" // Height 6: setup done + "mov x28, #0x0\n" + "197:" // Height 6: String loop + "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "ldr w27, [x20, x28, LSL #0x2]\n" + "tbz %x[flags], #3, 198f\n" + "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n" + "add x20, x20, x21, LSL #3\n" + "ldr x26, [x20, #0x0]\n" + "ldr x25, [x20, #0x8]\n" + "ldr x24, [x20, #0x10]\n" + "ldr x23, [x20, #0x18]\n" + "ldr x22, [x20, #0x20]\n" + "ldr x21, [x20, #0x28]\n" + "cbnz x28, 199f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x26, x26, x20, LSL #2\n" + "add x25, x25, x20, LSL #2\n" + "add x24, x24, x20, LSL #2\n" + "add x23, x23, x20, LSL #2\n" + "add x22, x22, x20, LSL #2\n" + "add x21, x21, x20, LSL #2\n" + "b 199f\n" + "198:" // Height 6: setup direct input + "mov x26, %x[input_ptr]\n" + "add x25, x26, x21, LSL #2\n" + "add x24, x25, x21, LSL #2\n" + "add x23, x24, x21, LSL #2\n" + "add x22, x23, x21, LSL #2\n" + "add x21, x22, x21, LSL #2\n" + "199:" // Height 6: input setup done + "cmp x27, #0x4\n" + "blt 202f\n" + "ld1 { v0.4s }, [x26], #0x10\n" + "ld1 { v2.4s }, [x24], #0x10\n" + "cmp x27, #0x8\n" + "ld1 { v4.4s }, [x22], #0x10\n" + "ld1 { v1.4s }, [x25], #0x10\n" + "ld1 { v3.4s }, [x23], #0x10\n" + "ld1 { v5.4s }, [x21], #0x10\n" + "ldr q6, [x12, #0x0]\n" + "ldr q7, [x12, #0x10]\n" + "blt 201f\n" + "200:" // Height 6: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "cmp x27, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x25], #0x10\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + "ld1 { v3.4s }, [x23], #0x10\n" + ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n" + "ld1 { v5.4s }, [x21], #0x10\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + "ldr q6, [x11, #0x0]\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + "ldr q7, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9d // bfmmla v29.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9e // bfmmla v30.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n" + "ldr q6, [x12, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x26], #0x10\n" + ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n" + "ld1 { v2.4s }, [x24], #0x10\n" + ".inst 0x6e47ec9f // bfmmla v31.4s, v4.8h, v7.8h\n" + "ld1 { v4.4s }, [x22], #0x10\n" + "ldr q7, [x12, #0x10]\n" + "bge 200b\n" + "201:" // Height 6: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x27, x27, #0x4\n" + "add x12, x12, #0x20\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + "ldr q3, [x11, #0x0]\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + "ldr q1, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n" + "ldr q3, [x10, #0x0]\n" + ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n" + "ldr q1, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n" + "ldr q3, [x9, #0x0]\n" + ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n" + "ldr q1, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n" + ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n" + "202:" // Height 6: Multiply loop: Main loop skip + "cbz x27, 205f\n" + "cbz x27, 205f\n" + "tbz x27, #1, 203f\n" + "ldr d0, [x26], #0x8\n" + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "ldr d3, [x23], #0x8\n" + "ldr d4, [x22], #0x8\n" + "ldr d5, [x21], #0x8\n" + "tbz x27, #0, 204f\n" + "ld1 { v0.s }[2], [x26]\n" + "ld1 { v1.s }[2], [x25]\n" + "ld1 { v2.s }[2], [x24]\n" + "ld1 { v3.s }[2], [x23]\n" + "ld1 { v4.s }[2], [x22]\n" + "ld1 { v5.s }[2], [x21]\n" + "b 204f\n" + "203:" // Height 6: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x26, #0x0]\n" + "ldr s1, [x25, #0x0]\n" + "ldr s2, [x24, #0x0]\n" + "ldr s3, [x23, #0x0]\n" + "ldr s4, [x22, #0x0]\n" + "ldr s5, [x21, #0x0]\n" + "204:" // Height 6: Multiply loop: Ragged operand read: Done + "ldr q7, [x12, #0x0]\n" + "ldr q6, [x12, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n" + "add x12, x12, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n" + "ldr q3, [x11, #0x0]\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n" + "ldr q1, [x11, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n" + "ldr q3, [x10, #0x0]\n" + ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n" + "ldr q1, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n" + "ldr q3, [x9, #0x0]\n" + ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n" + "ldr q1, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n" + ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n" + ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n" + ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n" + ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n" + ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n" + "205:" // Height 6: Multiply loop: No odd multiplies + "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x28, x28, #0x1\n" + "cmp x28, x20\n" + "bne 197b\n" + "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v6.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "add x26, x13, x20, LSL #2\n" + "add x25, x26, x20, LSL #2\n" + "add x24, x25, x20, LSL #2\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "add x23, x24, x20, LSL #2\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "add x22, x23, x20, LSL #2\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "uzp1 v23.2d, v24.2d, v28.2d\n" + "uzp2 v24.2d, v24.2d, v28.2d\n" + "uzp1 v28.2d, v25.2d, v29.2d\n" + "uzp2 v25.2d, v25.2d, v29.2d\n" + "uzp1 v29.2d, v26.2d, v30.2d\n" + "uzp2 v26.2d, v26.2d, v30.2d\n" + "uzp1 v30.2d, v27.2d, v31.2d\n" + "uzp2 v27.2d, v27.2d, v31.2d\n" + "tbz %x[flags], #1, 206f\n" + "add x21, %x[args_ptr], %[offset_max]\n" + "add x20, %x[args_ptr], %[offset_min]\n" + "ld1r { v1.4s }, [x21]\n" + "ld1r { v0.4s }, [x20]\n" + "fmin v6.4s, v6.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmax v6.4s, v6.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "206:" // Height 6: No activation + "cmp x14, #0x10\n" + "bge 215f\n" + "tbz x14, #3, 210f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v12.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v9.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v20.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v17.4s }, [x24], #0x10\n" + "st1 { v23.4s }, [x23], #0x10\n" + "st1 { v28.4s }, [x23], #0x10\n" + "st1 { v24.4s }, [x22], #0x10\n" + "st1 { v25.4s }, [x22], #0x10\n" + "tbz x14, #2, 208f\n" + "st1 { v13.4s }, [x13], #0x10\n" + "st1 { v10.4s }, [x26], #0x10\n" + "st1 { v21.4s }, [x25], #0x10\n" + "st1 { v18.4s }, [x24], #0x10\n" + "st1 { v29.4s }, [x23], #0x10\n" + "st1 { v26.4s }, [x22], #0x10\n" + "tbz x14, #1, 207f\n" + "str d14, [x13], #0x8\n" + "str d11, [x26], #0x8\n" + "str d22, [x25], #0x8\n" + "str d19, [x24], #0x8\n" + "str d30, [x23], #0x8\n" + "str d27, [x22], #0x8\n" + "tbz x14, #0, 214f\n" + "st1 { v14.s }[2], [x13]\n" + "st1 { v11.s }[2], [x26]\n" + "st1 { v22.s }[2], [x25]\n" + "st1 { v19.s }[2], [x24]\n" + "st1 { v30.s }[2], [x23]\n" + "st1 { v27.s }[2], [x22]\n" + "b 214f\n" + "207:" // Height 6: Partial direct writeback: partial_1_12 + "tbz x14, #0, 214f\n" + "str s14, [x13, #0x0]\n" + "str s11, [x26, #0x0]\n" + "str s22, [x25, #0x0]\n" + "str s19, [x24, #0x0]\n" + "str s30, [x23, #0x0]\n" + "str s27, [x22, #0x0]\n" + "b 214f\n" + "208:" // Height 6: Partial direct writeback: partial_2_8 + "tbz x14, #1, 209f\n" + "str d13, [x13], #0x8\n" + "str d10, [x26], #0x8\n" + "str d21, [x25], #0x8\n" + "str d18, [x24], #0x8\n" + "str d29, [x23], #0x8\n" + "str d26, [x22], #0x8\n" + "tbz x14, #0, 214f\n" + "st1 { v13.s }[2], [x13]\n" + "st1 { v10.s }[2], [x26]\n" + "st1 { v21.s }[2], [x25]\n" + "st1 { v18.s }[2], [x24]\n" + "st1 { v29.s }[2], [x23]\n" + "st1 { v26.s }[2], [x22]\n" + "b 214f\n" + "209:" // Height 6: Partial direct writeback: partial_1_8 + "tbz x14, #0, 214f\n" + "str s13, [x13, #0x0]\n" + "str s10, [x26, #0x0]\n" + "str s21, [x25, #0x0]\n" + "str s18, [x24, #0x0]\n" + "str s29, [x23, #0x0]\n" + "str s26, [x22, #0x0]\n" + "b 214f\n" + "210:" // Height 6: Partial direct writeback: partial_4_0 + "tbz x14, #2, 212f\n" + "st1 { v6.4s }, [x13], #0x10\n" + "st1 { v8.4s }, [x26], #0x10\n" + "st1 { v15.4s }, [x25], #0x10\n" + "st1 { v16.4s }, [x24], #0x10\n" + "st1 { v23.4s }, [x23], #0x10\n" + "st1 { v24.4s }, [x22], #0x10\n" + "tbz x14, #1, 211f\n" + "str d12, [x13], #0x8\n" + "str d9, [x26], #0x8\n" + "str d20, [x25], #0x8\n" + "str d17, [x24], #0x8\n" + "str d28, [x23], #0x8\n" + "str d25, [x22], #0x8\n" + "tbz x14, #0, 214f\n" + "st1 { v12.s }[2], [x13]\n" + "st1 { v9.s }[2], [x26]\n" + "st1 { v20.s }[2], [x25]\n" + "st1 { v17.s }[2], [x24]\n" + "st1 { v28.s }[2], [x23]\n" + "st1 { v25.s }[2], [x22]\n" + "b 214f\n" + "211:" // Height 6: Partial direct writeback: partial_1_4 + "tbz x14, #0, 214f\n" + "str s12, [x13, #0x0]\n" + "str s9, [x26, #0x0]\n" + "str s20, [x25, #0x0]\n" + "str s17, [x24, #0x0]\n" + "str s28, [x23, #0x0]\n" + "str s25, [x22, #0x0]\n" + "b 214f\n" + "212:" // Height 6: Partial direct writeback: partial_2_0 + "tbz x14, #1, 213f\n" + "str d6, [x13], #0x8\n" + "str d8, [x26], #0x8\n" + "str d15, [x25], #0x8\n" + "str d16, [x24], #0x8\n" + "str d23, [x23], #0x8\n" + "str d24, [x22], #0x8\n" + "tbz x14, #0, 214f\n" + "st1 { v6.s }[2], [x13]\n" + "st1 { v8.s }[2], [x26]\n" + "st1 { v15.s }[2], [x25]\n" + "st1 { v16.s }[2], [x24]\n" + "st1 { v23.s }[2], [x23]\n" + "st1 { v24.s }[2], [x22]\n" + "b 214f\n" + "213:" // Height 6: Partial direct writeback: partial_1_0 + "str s6, [x13, #0x0]\n" + "str s8, [x26, #0x0]\n" + "str s15, [x25, #0x0]\n" + "str s16, [x24, #0x0]\n" + "str s23, [x23, #0x0]\n" + "str s24, [x22, #0x0]\n" + "214:" // Height 6: Partial direct writeback: Done + "b 216f\n" + "215:" // Height 6: Full writeback + "str q6, [x13, #0x0]\n" + "str q12, [x13, #0x10]\n" + "str q13, [x13, #0x20]\n" + "str q14, [x13, #0x30]\n" + "add x13, x13, #0x40\n" + "str q8, [x26, #0x0]\n" + "str q9, [x26, #0x10]\n" + "str q10, [x26, #0x20]\n" + "str q11, [x26, #0x30]\n" + "str q15, [x25, #0x0]\n" + "str q20, [x25, #0x10]\n" + "str q21, [x25, #0x20]\n" + "str q22, [x25, #0x30]\n" + "str q16, [x24, #0x0]\n" + "str q17, [x24, #0x10]\n" + "str q18, [x24, #0x20]\n" + "str q19, [x24, #0x30]\n" + "str q23, [x23, #0x0]\n" + "str q28, [x23, #0x10]\n" + "str q29, [x23, #0x20]\n" + "str q30, [x23, #0x30]\n" + "str q24, [x22, #0x0]\n" + "str q25, [x22, #0x10]\n" + "str q26, [x22, #0x20]\n" + "str q27, [x22, #0x30]\n" + "216:" // Height 6: Writeback done + "subs x14, x14, #0x10\n" + "bgt 182b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 218f\n" + "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 217f\n" + "add x21, x21, #0x6\n" + "str x21, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "217:" // Update direct input + "mov x20, #0x18\n" + "madd %x[input_ptr], x20, x21, %x[input_ptr]\n" + "b 1b\n" + "218:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr) + : [args_ptr] "r" (&ka), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_output_ptr] "I" (offsetof(KernelArgs, output_ptr)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp index cf4d74266a..1a8b0fd630 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,10 @@ public: if (std::is_same<T, float>::value) { switch (ci->get_cpu_model()) { + case CPUModel::V1: + return { 45.25, 4.29, 4.80 }; default: - return { 38.10, 5.23, 3.15 }; + return { 29.85, 2.60, 5.49 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp index 586d6a64a4..d9668aae02 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) #include "../performance_parameters.hpp" #include "../std_transforms_fixed.hpp" @@ -89,4 +89,4 @@ public: } // namespace arm_gemm -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp new file mode 100644 index 0000000000..7792192856 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include <cstdint> +#include "../std_transforms_sme.hpp" + +namespace arm_gemm +{ + +// Implementations +void sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + +class cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL +{ +public: + typedef int8_t operand_type; + typedef float result_type; + + typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + + /* Kernel blocking parameters */ + static unsigned int out_height() + { + return sme::get_vector_length<int32_t>() * 1; + } + + static unsigned int out_width() + { + return sme::get_vector_length<int32_t>() * 4; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + static constexpr bool supports_bias() + { + return true; + } + + static constexpr bool supports_activation() + { + return true; + } + + static constexpr bool is_sme() + { + return true; + } + + // Default to the generic kernel + kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL; + + StdTransformsSME<operand_type, result_type, 1, 4, 4> transforms = {}; + + cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp new file mode 100644 index 0000000000..4b26a6578c --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp @@ -0,0 +1,417 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_gemm.hpp" + +#include <cstdint> +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm { + +void sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer) +{ + struct KernelArgs + { + KernelArgs( + const int8_t *const A, + const int8_t *const B, + float *const C, const int ldc, + const int M, const int N, const int K, + const int32_t *const bias, const float *const late_bias, const Activation act, + bool accumulate, + int32_t *const accumulator_buffer + ) : A(A), + B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)), + C(C), ldcb(ldc * sizeof(float)), + M(M), N(N), K(K), + min(-std::numeric_limits<float>::infinity()), + max(std::numeric_limits<float>::infinity()), + bias(bias), late_bias(late_bias), + accumulator_buffer(accumulator_buffer), + flags(0x0) + { + if (accumulate) + { + flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER + } + if (C == nullptr) + { + flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER + } + + // Initialise the activation values + switch (act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + this->max = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + this->min = static_cast<float>(0); + break; + } + } + + const int8_t *const A; + const int8_t *const B; + const long kstride_bytes; + float *const C; + const long ldcb; + const long M, N, K; + float min = -std::numeric_limits<float>::infinity(); + float max = std::numeric_limits<float>::infinity(); + + const int32_t *const bias; + const float *const late_bias; + + int32_t *const accumulator_buffer; + uint64_t flags; + }; + + // Construct arguments for this kernel + KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer); + + __asm__ __volatile__( + "ldr x13, [%x[args], %[offsetof_flags]]\n" + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p0.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr x11, [%x[args], %[offsetof_accumulator_buffer]]\n" + "ldr x10, [%x[args], %[offsetof_accumulator_buffer]]\n" + "tbz x13, #0, 2f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "1:" // Initial accumulator load from buffer: Loop + ".inst 0xa040c57c // ld1w { z28.s-z31.s }, pn9.b/Z, [x11]\n" + ".inst 0xa041c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n" + ".inst 0xa042c578 // ld1w { z24.s-z27.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n" + ".inst 0xa043c56c // ld1w { z12.s-z15.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n" + ".inst 0xc0840780 // mova za0h.s[x12], { z28.s-z31.s }\n" + "addvl x11, x11, #16\n" + ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840702 // mova za2h.s[x12], { z24.s-z27.s }\n" + ".inst 0xc0840583 // mova za3h.s[x12], { z12.s-z15.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 1b\n" + "2:" // Initial accumulator load from buffer: End + "ldr w9, [%x[args], %[offsetof_M]]\n" + "mov x28, #0x0\n" + "mov x27, #0x0\n" + "ldr w26, [%x[args], %[offsetof_N]]\n" + "ldr x25, [%x[args], %[offsetof_A]]\n" + "3:" // M and N loop + "mov x24, x25\n" + ".inst 0x25ba6770 // whilelt pn8.s, x27, x26, VLx4\n" + "tbnz x13, #0, 4f\n" + "ldr x20, [%x[args], %[offsetof_bias]]\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "cbz x20, 5f\n" + ".inst 0xa01bc288 // ld1w { z8.s-z11.s }, p8/Z, [x20, x27, LSL #2]\n" + ".inst 0xc0900100 // addha za0.s, p0/M, p0/M, z8.s\n" + ".inst 0xc0900121 // addha za1.s, p0/M, p0/M, z9.s\n" + ".inst 0xc0900142 // addha za2.s, p0/M, p0/M, z10.s\n" + ".inst 0xc0900163 // addha za3.s, p0/M, p0/M, z11.s\n" + "4:" // Prepare accumulators: Test for last block + "mov x20, x27\n" + "mov x21, x28\n" + "incw x20, ALL, MUL #4\n" + "incw x21\n" + "cmp x20, x26\n" + "mov x20, x13\n" + "csel x21, x28, x21, LT\n" + "bfm x13, XZR, #0x0, #0x0 // bfc x13, #0x0, #0x1\n" + "cmp x21, x9\n" + "csel x13, x20, x13, LT\n" + "5:" // Prepare accumulators: End + "ldr x20, [%x[args], %[offsetof_K]]\n" + "ldr x23, [%x[args], %[offsetof_B]]\n" + "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "madd x23, x27, x22, x23\n" // bptr = B + n * kstride_bytes + "and x20, x20, #0x3\n" + "cbz x21, 8f\n" + "subs x21, x21, #0x1\n" + "ld1b { z31.b }, p0/Z, [x24]\n" + ".inst 0xa04086e8 // ld1b { z8.b-z11.b }, pn9.b/Z, [x23]\n" + "ld1b { z1.b }, p0/Z, [x24, #1, MUL VL]\n" + ".inst 0xa04186e4 // ld1b { z4.b-z7.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n" + "ld1b { z0.b }, p0/Z, [x24, #2, MUL VL]\n" + ".inst 0xa04286ec // ld1b { z12.b-z15.b }, pn9.b/Z, [x23, #0x8, MUL VL]\n" + "ld1b { z3.b }, p0/Z, [x24, #3, MUL VL]\n" + "addvl x24, x24, #4\n" + ".inst 0xa04386f0 // ld1b { z16.b-z19.b }, pn9.b/Z, [x23, #0xc, MUL VL]\n" + "addvl x23, x23, #16\n" + "ble 7f\n" + "6:" // K loop + ".inst 0xa08803e0 // smopa za0.s, p0/M, p0/M, z31.b, z8.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa08903e1 // smopa za1.s, p0/M, p0/M, z31.b, z9.b\n" + ".inst 0xa08a03e2 // smopa za2.s, p0/M, p0/M, z31.b, z10.b\n" + ".inst 0xa08b03e3 // smopa za3.s, p0/M, p0/M, z31.b, z11.b\n" + "ld1b { z31.b }, p0/Z, [x24]\n" + ".inst 0xa0840020 // smopa za0.s, p0/M, p0/M, z1.b, z4.b\n" + ".inst 0xa04086e8 // ld1b { z8.b-z11.b }, pn9.b/Z, [x23]\n" + ".inst 0xa0850021 // smopa za1.s, p0/M, p0/M, z1.b, z5.b\n" + ".inst 0xa0860022 // smopa za2.s, p0/M, p0/M, z1.b, z6.b\n" + ".inst 0xa0870023 // smopa za3.s, p0/M, p0/M, z1.b, z7.b\n" + "ld1b { z1.b }, p0/Z, [x24, #1, MUL VL]\n" + ".inst 0xa08c0000 // smopa za0.s, p0/M, p0/M, z0.b, z12.b\n" + ".inst 0xa04186e4 // ld1b { z4.b-z7.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n" + ".inst 0xa08d0001 // smopa za1.s, p0/M, p0/M, z0.b, z13.b\n" + ".inst 0xa08e0002 // smopa za2.s, p0/M, p0/M, z0.b, z14.b\n" + ".inst 0xa08f0003 // smopa za3.s, p0/M, p0/M, z0.b, z15.b\n" + "ld1b { z0.b }, p0/Z, [x24, #2, MUL VL]\n" + ".inst 0xa04286ec // ld1b { z12.b-z15.b }, pn9.b/Z, [x23, #0x8, MUL VL]\n" + ".inst 0xa0900060 // smopa za0.s, p0/M, p0/M, z3.b, z16.b\n" + ".inst 0xa0910061 // smopa za1.s, p0/M, p0/M, z3.b, z17.b\n" + ".inst 0xa0920062 // smopa za2.s, p0/M, p0/M, z3.b, z18.b\n" + ".inst 0xa0930063 // smopa za3.s, p0/M, p0/M, z3.b, z19.b\n" + "ld1b { z3.b }, p0/Z, [x24, #3, MUL VL]\n" + "addvl x24, x24, #4\n" + ".inst 0xa04386f0 // ld1b { z16.b-z19.b }, pn9.b/Z, [x23, #0xc, MUL VL]\n" + "addvl x23, x23, #16\n" + "bgt 6b\n" + "7:" // K loop tail + ".inst 0xa08803e0 // smopa za0.s, p0/M, p0/M, z31.b, z8.b\n" + ".inst 0xa08903e1 // smopa za1.s, p0/M, p0/M, z31.b, z9.b\n" + ".inst 0xa08a03e2 // smopa za2.s, p0/M, p0/M, z31.b, z10.b\n" + ".inst 0xa08b03e3 // smopa za3.s, p0/M, p0/M, z31.b, z11.b\n" + ".inst 0xa0840020 // smopa za0.s, p0/M, p0/M, z1.b, z4.b\n" + ".inst 0xa0850021 // smopa za1.s, p0/M, p0/M, z1.b, z5.b\n" + ".inst 0xa0860022 // smopa za2.s, p0/M, p0/M, z1.b, z6.b\n" + ".inst 0xa0870023 // smopa za3.s, p0/M, p0/M, z1.b, z7.b\n" + ".inst 0xa08c0000 // smopa za0.s, p0/M, p0/M, z0.b, z12.b\n" + ".inst 0xa08d0001 // smopa za1.s, p0/M, p0/M, z0.b, z13.b\n" + ".inst 0xa08e0002 // smopa za2.s, p0/M, p0/M, z0.b, z14.b\n" + ".inst 0xa08f0003 // smopa za3.s, p0/M, p0/M, z0.b, z15.b\n" + ".inst 0xa0900060 // smopa za0.s, p0/M, p0/M, z3.b, z16.b\n" + ".inst 0xa0910061 // smopa za1.s, p0/M, p0/M, z3.b, z17.b\n" + ".inst 0xa0920062 // smopa za2.s, p0/M, p0/M, z3.b, z18.b\n" + ".inst 0xa0930063 // smopa za3.s, p0/M, p0/M, z3.b, z19.b\n" + "8:" // K oddments + "cbz x20, 10f\n" + "9:" // K oddments: Loop + "ld1b { z18.b }, p0/Z, [x24]\n" + "subs x20, x20, #0x1\n" + "addvl x24, x24, #1\n" + ".inst 0xa04086fc // ld1b { z28.b-z31.b }, pn9.b/Z, [x23]\n" + "addvl x23, x23, #4\n" + ".inst 0xa09c0240 // smopa za0.s, p0/M, p0/M, z18.b, z28.b\n" + ".inst 0xa09d0241 // smopa za1.s, p0/M, p0/M, z18.b, z29.b\n" + ".inst 0xa09e0242 // smopa za2.s, p0/M, p0/M, z18.b, z30.b\n" + ".inst 0xa09f0243 // smopa za3.s, p0/M, p0/M, z18.b, z31.b\n" + "bgt 9b\n" + "10:" // K oddments: End + "tbz x13, #1, 14f\n" + "tbz x13, #0, 12f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "11:" // Store to partial result buffer: Store and refill: Loop + ".inst 0xa040c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11]\n" + ".inst 0xc0860408 // mova { z8.s-z11.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xa041c57c // ld1w { z28.s-z31.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n" + ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xa042c578 // ld1w { z24.s-z27.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n" + ".inst 0xa043c574 // ld1w { z20.s-z23.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n" + ".inst 0xc0840400 // mova za0h.s[x12], { z0.s-z3.s }\n" + "addvl x11, x11, #16\n" + ".inst 0xc0840781 // mova za1h.s[x12], { z28.s-z31.s }\n" + ".inst 0xa060c548 // st1w { z8.s-z11.s }, pn9.b, [x10]\n" + ".inst 0xc0840702 // mova za2h.s[x12], { z24.s-z27.s }\n" + ".inst 0xa061c54c // st1w { z12.s-z15.s }, pn9.b, [x10, #0x4, MUL VL]\n" + ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n" + "add x12, x12, #0x4\n" + ".inst 0xa062c544 // st1w { z4.s-z7.s }, pn9.b, [x10, #0x8, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa063c550 // st1w { z16.s-z19.s }, pn9.b, [x10, #0xc, MUL VL]\n" + "addvl x10, x10, #16\n" + "blt 11b\n" + "b 21f\n" + "12:" // Store to partial result buffer: Store only + "mov x12, #0x0\n" + "cntw x20\n" + "13:" // Store to partial result buffer: Store only: Loop + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n" + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xa060c544 // st1w { z4.s-z7.s }, pn9.b, [x10]\n" + "add x12, x12, #0x4\n" + ".inst 0xa061c550 // st1w { z16.s-z19.s }, pn9.b, [x10, #0x4, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa062c548 // st1w { z8.s-z11.s }, pn9.b, [x10, #0x8, MUL VL]\n" + ".inst 0xa063c54c // st1w { z12.s-z15.s }, pn9.b, [x10, #0xc, MUL VL]\n" + "addvl x10, x10, #16\n" + "blt 13b\n" + "b 21f\n" + "14:" // Store to output array + "ldr x23, [%x[args], %[offsetof_C]]\n" + "sub x21, x9, x28\n" + "ld1rw { z18.s }, p0/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n" + "fmov z20.s, #0x0\n" + "ldr x22, [%x[args], %[offsetof_ldcb]]\n" + "fmov z21.s, #0x0\n" + "fmov z22.s, #0x0\n" + "ldr x20, [%x[args], %[offsetof_late_bias]]\n" + "fmov z23.s, #0x0\n" + "add x23, x23, x27, LSL #2\n" // C += n + "madd x23, x28, x22, x23\n" // C += m * ldc + "cbz x20, 15f\n" + "add x20, x20, x27, LSL #2\n" + ".inst 0xa040c294 // ld1w { z20.s-z23.s }, p8/Z, [x20]\n" + "15:" // Store to output array: no late bias + "cntw x20\n" + "ld1rw { z17.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "mov x12, #0x0\n" + "cmp x21, x20\n" + "ld1rw { z16.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x20, x21, x20, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 17f\n" + "16:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n" + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmad z0.s, p0/M, z18.s, z20.s\n" + "fmad z1.s, p0/M, z18.s, z20.s\n" + "fmad z2.s, p0/M, z18.s, z20.s\n" + "fmad z3.s, p0/M, z18.s, z20.s\n" + "add x12, x12, #0x4\n" + "fmad z4.s, p0/M, z18.s, z21.s\n" + "fmad z5.s, p0/M, z18.s, z21.s\n" + "cmp x12, x21, LSL #2\n" + "fmad z6.s, p0/M, z18.s, z21.s\n" + "fmad z7.s, p0/M, z18.s, z21.s\n" + "fmad z8.s, p0/M, z18.s, z22.s\n" + "fmad z9.s, p0/M, z18.s, z22.s\n" + "fmad z10.s, p0/M, z18.s, z22.s\n" + "fmad z11.s, p0/M, z18.s, z22.s\n" + "fmad z12.s, p0/M, z18.s, z23.s\n" + "fmad z13.s, p0/M, z18.s, z23.s\n" + "fmad z14.s, p0/M, z18.s, z23.s\n" + "fmad z15.s, p0/M, z18.s, z23.s\n" + ".inst 0xc1b0ca20 // fclamp { z0.s-z3.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca24 // fclamp { z4.s-z7.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca28 // fclamp { z8.s-z11.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca2c // fclamp { z12.s-z15.s }, z17.s, z16.s\n" + ".inst 0xa160c2e0 // st1w { z0.s, z4.s, z8.s, z12.s }, p8, [x23]\n" + "add x23, x23, x22\n" + ".inst 0xa160c2e1 // st1w { z1.s, z5.s, z9.s, z13.s }, p8, [x23]\n" + "add x23, x23, x22\n" + ".inst 0xa160c2e2 // st1w { z2.s, z6.s, z10.s, z14.s }, p8, [x23]\n" + "add x23, x23, x22\n" + ".inst 0xa160c2e3 // st1w { z3.s, z7.s, z11.s, z15.s }, p8, [x23]\n" + "add x23, x23, x22\n" + "blt 16b\n" + "17:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 18f\n" + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n" + ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmad z0.s, p0/M, z18.s, z20.s\n" + "fmad z1.s, p0/M, z18.s, z20.s\n" + "fmad z2.s, p0/M, z18.s, z20.s\n" + "fmad z3.s, p0/M, z18.s, z20.s\n" + "subs x20, x20, #0x1\n" + "fmad z4.s, p0/M, z18.s, z21.s\n" + "fmad z5.s, p0/M, z18.s, z21.s\n" + "fmad z6.s, p0/M, z18.s, z21.s\n" + "fmad z7.s, p0/M, z18.s, z21.s\n" + "fmad z8.s, p0/M, z18.s, z22.s\n" + "fmad z9.s, p0/M, z18.s, z22.s\n" + "fmad z10.s, p0/M, z18.s, z22.s\n" + "fmad z11.s, p0/M, z18.s, z22.s\n" + "fmad z12.s, p0/M, z18.s, z23.s\n" + "fmad z13.s, p0/M, z18.s, z23.s\n" + "fmad z14.s, p0/M, z18.s, z23.s\n" + "fmad z15.s, p0/M, z18.s, z23.s\n" + ".inst 0xc1b0ca20 // fclamp { z0.s-z3.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca24 // fclamp { z4.s-z7.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca28 // fclamp { z8.s-z11.s }, z17.s, z16.s\n" + ".inst 0xc1b0ca2c // fclamp { z12.s-z15.s }, z17.s, z16.s\n" + ".inst 0xa160c2e0 // st1w { z0.s, z4.s, z8.s, z12.s }, p8, [x23]\n" + "add x23, x23, x22\n" + "beq 18f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa160c2e1 // st1w { z1.s, z5.s, z9.s, z13.s }, p8, [x23]\n" + "add x23, x23, x22\n" + "beq 18f\n" + ".inst 0xa160c2e2 // st1w { z2.s, z6.s, z10.s, z14.s }, p8, [x23]\n" + "18:" // Store to output array: Accumulator row 0 oddments: End + "19:" // Store to output array: End + "tbz x13, #0, 21f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "20:" // Store to output array: Refill accumulators: Loop + ".inst 0xa040c574 // ld1w { z20.s-z23.s }, pn9.b/Z, [x11]\n" + ".inst 0xa041c56c // ld1w { z12.s-z15.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n" + ".inst 0xa042c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n" + ".inst 0xa043c568 // ld1w { z8.s-z11.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n" + ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n" + "addvl x11, x11, #16\n" + ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n" + ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840503 // mova za3h.s[x12], { z8.s-z11.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 20b\n" + "21:" // End block + "incw x27, ALL, MUL #4\n" + "cmp x27, x26\n" + "blt 3b\n" + "incw x28\n" + "mov x27, #0x0\n" + "cmp x28, x9\n" + "mov x25, x24\n" + "blt 3b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp new file mode 100644 index 0000000000..df2c9c0ca3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include <cstdint> +#include "../std_transforms_sme.hpp" + +namespace arm_gemm +{ + +// Implementations +void sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + +class cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL +{ +public: + typedef int8_t operand_type; + typedef float result_type; + + typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + + /* Kernel blocking parameters */ + static unsigned int out_height() + { + return sme::get_vector_length<int32_t>() * 2; + } + + static unsigned int out_width() + { + return sme::get_vector_length<int32_t>() * 2; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + static constexpr bool supports_bias() + { + return true; + } + + static constexpr bool supports_activation() + { + return true; + } + + static constexpr bool is_sme() + { + return true; + } + + // Default to the generic kernel + kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL; + + StdTransformsSME<operand_type, result_type, 2, 2, 4> transforms = {}; + + cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp new file mode 100644 index 0000000000..1631fae8e9 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_gemm.hpp" + +#include <cstdint> +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm { + +void sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer) +{ + struct KernelArgs + { + KernelArgs( + const int8_t *const A, + const int8_t *const B, + float *const C, const int ldc, + const int M, const int N, const int K, + const int32_t *const bias, const float *const late_bias, const Activation act, + bool accumulate, + int32_t *const accumulator_buffer + ) : A(A), + B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)), + C(C), ldcb(ldc * sizeof(float)), + M(M), N(N), K(K), + min(-std::numeric_limits<float>::infinity()), + max(std::numeric_limits<float>::infinity()), + bias(bias), late_bias(late_bias), + accumulator_buffer(accumulator_buffer), + flags(0x0) + { + if (accumulate) + { + flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER + } + if (C == nullptr) + { + flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER + } + + // Initialise the activation values + switch (act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + this->max = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + this->min = static_cast<float>(0); + break; + } + } + + const int8_t *const A; + const int8_t *const B; + const long kstride_bytes; + float *const C; + const long ldcb; + const long M, N, K; + float min = -std::numeric_limits<float>::infinity(); + float max = std::numeric_limits<float>::infinity(); + + const int32_t *const bias; + const float *const late_bias; + + int32_t *const accumulator_buffer; + uint64_t flags; + }; + + // Construct arguments for this kernel + KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer); + + __asm__ __volatile__( + "ldr x16, [%x[args], %[offsetof_flags]]\n" + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p0.b\n" + ".inst 0x25207811 // ptrue pn9.b\n" + "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n" + "ldr x14, [%x[args], %[offsetof_accumulator_buffer]]\n" + "tbz x16, #0, 2f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "1:" // Initial accumulator load from buffer: Loop + ".inst 0xa040c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15]\n" + ".inst 0xa041c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xa042c5e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c5f8 // ld1w { z24.s-z27.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840580 // mova za0h.s[x12], { z12.s-z15.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840681 // mova za1h.s[x12], { z20.s-z23.s }\n" + ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840703 // mova za3h.s[x12], { z24.s-z27.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 1b\n" + "2:" // Initial accumulator load from buffer: End + "ldr w13, [%x[args], %[offsetof_M]]\n" + "mov x11, #0x0\n" + "mov x10, #0x0\n" + "ldr w9, [%x[args], %[offsetof_N]]\n" + "ldr x28, [%x[args], %[offsetof_A]]\n" + "3:" // M and N loop + "mov x27, x28\n" + ".inst 0x25a94550 // whilelt pn8.s, x10, x9, VLx2\n" + "tbnz x16, #0, 4f\n" + "ldr x20, [%x[args], %[offsetof_bias]]\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "cbz x20, 5f\n" + ".inst 0xa10a4286 // ld1w { z6.s, z14.s }, p8/Z, [x20, x10, LSL #2]\n" + ".inst 0xc09000c0 // addha za0.s, p0/M, p0/M, z6.s\n" + ".inst 0xc09001c1 // addha za1.s, p0/M, p0/M, z14.s\n" + ".inst 0xc09000c2 // addha za2.s, p0/M, p0/M, z6.s\n" + ".inst 0xc09001c3 // addha za3.s, p0/M, p0/M, z14.s\n" + "4:" // Prepare accumulators: Test for last block + "mov x20, x10\n" + "mov x21, x11\n" + "incw x20, ALL, MUL #2\n" + "incw x21, ALL, MUL #2\n" + "cmp x20, x9\n" + "mov x20, x16\n" + "csel x21, x11, x21, LT\n" + "bfm x16, XZR, #0x0, #0x0 // bfc x16, #0x0, #0x1\n" + "cmp x21, x13\n" + "csel x16, x20, x16, LT\n" + "5:" // Prepare accumulators: End + "ldr x20, [%x[args], %[offsetof_K]]\n" + "ldr x23, [%x[args], %[offsetof_B]]\n" + "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "madd x23, x10, x22, x23\n" // bptr = B + n * kstride_bytes + "and x20, x20, #0x3\n" + "cbz x21, 8f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa1400775 // ld1b { z21.b, z29.b }, pn9.b/Z, [x27]\n" + ".inst 0xa04006f2 // ld1b { z18.b-z19.b }, pn9.b/Z, [x23]\n" + ".inst 0xa041076a // ld1b { z10.b-z11.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa14106e5 // ld1b { z5.b, z13.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n" + ".inst 0xa1420767 // ld1b { z7.b, z15.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa14206f0 // ld1b { z16.b, z24.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n" + ".inst 0xa1430774 // ld1b { z20.b, z28.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa14306f7 // ld1b { z23.b, z31.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n" + "addvl x23, x23, #8\n" + "ble 7f\n" + "6:" // K loop + ".inst 0xa09202a0 // smopa za0.s, p0/M, p0/M, z21.b, z18.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa09302a1 // smopa za1.s, p0/M, p0/M, z21.b, z19.b\n" + ".inst 0xa09203a2 // smopa za2.s, p0/M, p0/M, z29.b, z18.b\n" + ".inst 0xa09303a3 // smopa za3.s, p0/M, p0/M, z29.b, z19.b\n" + ".inst 0xa1400775 // ld1b { z21.b, z29.b }, pn9.b/Z, [x27]\n" + ".inst 0xa0850140 // smopa za0.s, p0/M, p0/M, z10.b, z5.b\n" + ".inst 0xa04006f2 // ld1b { z18.b-z19.b }, pn9.b/Z, [x23]\n" + ".inst 0xa08d0141 // smopa za1.s, p0/M, p0/M, z10.b, z13.b\n" + ".inst 0xa0850162 // smopa za2.s, p0/M, p0/M, z11.b, z5.b\n" + ".inst 0xa08d0163 // smopa za3.s, p0/M, p0/M, z11.b, z13.b\n" + ".inst 0xa041076a // ld1b { z10.b-z11.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n" + ".inst 0xa09000e0 // smopa za0.s, p0/M, p0/M, z7.b, z16.b\n" + ".inst 0xa14106e5 // ld1b { z5.b, z13.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n" + ".inst 0xa09800e1 // smopa za1.s, p0/M, p0/M, z7.b, z24.b\n" + ".inst 0xa09001e2 // smopa za2.s, p0/M, p0/M, z15.b, z16.b\n" + ".inst 0xa09801e3 // smopa za3.s, p0/M, p0/M, z15.b, z24.b\n" + ".inst 0xa1420767 // ld1b { z7.b, z15.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa14206f0 // ld1b { z16.b, z24.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n" + ".inst 0xa0970280 // smopa za0.s, p0/M, p0/M, z20.b, z23.b\n" + ".inst 0xa09f0281 // smopa za1.s, p0/M, p0/M, z20.b, z31.b\n" + ".inst 0xa0970382 // smopa za2.s, p0/M, p0/M, z28.b, z23.b\n" + ".inst 0xa09f0383 // smopa za3.s, p0/M, p0/M, z28.b, z31.b\n" + ".inst 0xa1430774 // ld1b { z20.b, z28.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n" + "addvl x27, x27, #8\n" + ".inst 0xa14306f7 // ld1b { z23.b, z31.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n" + "addvl x23, x23, #8\n" + "bgt 6b\n" + "7:" // K loop tail + ".inst 0xa09202a0 // smopa za0.s, p0/M, p0/M, z21.b, z18.b\n" + ".inst 0xa09302a1 // smopa za1.s, p0/M, p0/M, z21.b, z19.b\n" + ".inst 0xa09203a2 // smopa za2.s, p0/M, p0/M, z29.b, z18.b\n" + ".inst 0xa09303a3 // smopa za3.s, p0/M, p0/M, z29.b, z19.b\n" + ".inst 0xa0850140 // smopa za0.s, p0/M, p0/M, z10.b, z5.b\n" + ".inst 0xa08d0141 // smopa za1.s, p0/M, p0/M, z10.b, z13.b\n" + ".inst 0xa0850162 // smopa za2.s, p0/M, p0/M, z11.b, z5.b\n" + ".inst 0xa08d0163 // smopa za3.s, p0/M, p0/M, z11.b, z13.b\n" + ".inst 0xa09000e0 // smopa za0.s, p0/M, p0/M, z7.b, z16.b\n" + ".inst 0xa09800e1 // smopa za1.s, p0/M, p0/M, z7.b, z24.b\n" + ".inst 0xa09001e2 // smopa za2.s, p0/M, p0/M, z15.b, z16.b\n" + ".inst 0xa09801e3 // smopa za3.s, p0/M, p0/M, z15.b, z24.b\n" + ".inst 0xa0970280 // smopa za0.s, p0/M, p0/M, z20.b, z23.b\n" + ".inst 0xa09f0281 // smopa za1.s, p0/M, p0/M, z20.b, z31.b\n" + ".inst 0xa0970382 // smopa za2.s, p0/M, p0/M, z28.b, z23.b\n" + ".inst 0xa09f0383 // smopa za3.s, p0/M, p0/M, z28.b, z31.b\n" + "8:" // K oddments + "cbz x20, 10f\n" + "9:" // K oddments: Loop + ".inst 0xa040077e // ld1b { z30.b-z31.b }, pn9.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #2\n" + ".inst 0xa14006e7 // ld1b { z7.b, z15.b }, pn9.b/Z, [x23]\n" + "addvl x23, x23, #2\n" + ".inst 0xa08703c0 // smopa za0.s, p0/M, p0/M, z30.b, z7.b\n" + ".inst 0xa08f03c1 // smopa za1.s, p0/M, p0/M, z30.b, z15.b\n" + ".inst 0xa08703e2 // smopa za2.s, p0/M, p0/M, z31.b, z7.b\n" + ".inst 0xa08f03e3 // smopa za3.s, p0/M, p0/M, z31.b, z15.b\n" + "bgt 9b\n" + "10:" // K oddments: End + "tbz x16, #1, 14f\n" + "tbz x16, #0, 12f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "11:" // Store to partial result buffer: Store and refill: Loop + ".inst 0xa040c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15]\n" + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n" + ".inst 0xa041c5f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xc0860440 // mova { z0.s-z3.s }, za2h.s[x12]\n" + ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n" + ".inst 0xa042c5fc // ld1w { z28.s-z31.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840580 // mova za0h.s[x12], { z12.s-z15.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840601 // mova za1h.s[x12], { z16.s-z19.s }\n" + ".inst 0xa060c5c4 // st1w { z4.s-z7.s }, pn9.b, [x14]\n" + ".inst 0xc0840782 // mova za2h.s[x12], { z28.s-z31.s }\n" + ".inst 0xa061c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14, #0x4, MUL VL]\n" + ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n" + "add x12, x12, #0x4\n" + ".inst 0xa062c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14, #0x8, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa063c5d8 // st1w { z24.s-z27.s }, pn9.b, [x14, #0xc, MUL VL]\n" + "addvl x14, x14, #16\n" + "blt 11b\n" + "b 24f\n" + "12:" // Store to partial result buffer: Store only + "mov x12, #0x0\n" + "cntw x20\n" + "13:" // Store to partial result buffer: Store only: Loop + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n" + ".inst 0xc0860468 // mova { z8.s-z11.s }, za3h.s[x12]\n" + ".inst 0xa060c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14]\n" + "add x12, x12, #0x4\n" + ".inst 0xa061c5cc // st1w { z12.s-z15.s }, pn9.b, [x14, #0x4, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa062c5d0 // st1w { z16.s-z19.s }, pn9.b, [x14, #0x8, MUL VL]\n" + ".inst 0xa063c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14, #0xc, MUL VL]\n" + "addvl x14, x14, #16\n" + "blt 13b\n" + "b 24f\n" + "14:" // Store to output array + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x13, x11\n" + "ld1rw { z3.s }, p0/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n" + "fmov z2.s, #0x0\n" + "ldr x24, [%x[args], %[offsetof_ldcb]]\n" + "fmov z10.s, #0x0\n" + "ldr x20, [%x[args], %[offsetof_late_bias]]\n" + "add x26, x26, x10, LSL #2\n" // C += n + "madd x26, x11, x24, x26\n" // C += m * ldc + "cbz x20, 15f\n" + "add x20, x20, x10, LSL #2\n" + ".inst 0xa1404282 // ld1w { z2.s, z10.s }, p8/Z, [x20]\n" + "15:" // Store to output array: no late bias + "cntw x23\n" + "ld1rw { z1.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "mov x12, #0x0\n" + "cmp x25, x23\n" + "ld1rw { z0.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 17f\n" + "16:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmad z4.s, p0/M, z3.s, z2.s\n" + "fmad z5.s, p0/M, z3.s, z2.s\n" + "add x12, x12, #0x4\n" + "fmad z6.s, p0/M, z3.s, z2.s\n" + "fmad z7.s, p0/M, z3.s, z2.s\n" + "cmp x12, x21, LSL #2\n" + "fmad z12.s, p0/M, z3.s, z10.s\n" + "fmad z13.s, p0/M, z3.s, z10.s\n" + "fmad z14.s, p0/M, z3.s, z10.s\n" + "fmad z15.s, p0/M, z3.s, z10.s\n" + ".inst 0xc1a0c824 // fclamp { z4.s-z7.s }, z1.s, z0.s\n" + ".inst 0xc1a0c82c // fclamp { z12.s-z15.s }, z1.s, z0.s\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "blt 16b\n" + "17:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 18f\n" + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + ".inst 0xc0860438 // mova { z24.s-z27.s }, za1h.s[x12]\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n" + "fmad z16.s, p0/M, z3.s, z2.s\n" + "fmad z17.s, p0/M, z3.s, z2.s\n" + "subs x20, x20, #0x1\n" + "fmad z18.s, p0/M, z3.s, z2.s\n" + "fmad z19.s, p0/M, z3.s, z2.s\n" + "fmad z24.s, p0/M, z3.s, z10.s\n" + "fmad z25.s, p0/M, z3.s, z10.s\n" + "fmad z26.s, p0/M, z3.s, z10.s\n" + "fmad z27.s, p0/M, z3.s, z10.s\n" + ".inst 0xc1a0c830 // fclamp { z16.s-z19.s }, z1.s, z0.s\n" + ".inst 0xc1a0c838 // fclamp { z24.s-z27.s }, z1.s, z0.s\n" + ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "18:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 22f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x20, x25, x23, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 20f\n" + "19:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n" + ".inst 0xc086047c // mova { z28.s-z31.s }, za3h.s[x12]\n" + ".inst 0xc132e294 // scvtf { z20.s-z23.s }, { z20.s-z23.s }\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmad z20.s, p0/M, z3.s, z2.s\n" + "fmad z21.s, p0/M, z3.s, z2.s\n" + "add x12, x12, #0x4\n" + "fmad z22.s, p0/M, z3.s, z2.s\n" + "fmad z23.s, p0/M, z3.s, z2.s\n" + "cmp x12, x21, LSL #2\n" + "fmad z28.s, p0/M, z3.s, z10.s\n" + "fmad z29.s, p0/M, z3.s, z10.s\n" + "fmad z30.s, p0/M, z3.s, z10.s\n" + "fmad z31.s, p0/M, z3.s, z10.s\n" + ".inst 0xc1a0c834 // fclamp { z20.s-z23.s }, z1.s, z0.s\n" + ".inst 0xc1a0c83c // fclamp { z28.s-z31.s }, z1.s, z0.s\n" + ".inst 0xa1604354 // st1w { z20.s, z28.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604355 // st1w { z21.s, z29.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604356 // st1w { z22.s, z30.s }, p8, [x26]\n" + "add x26, x26, x24\n" + ".inst 0xa1604357 // st1w { z23.s, z31.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "blt 19b\n" + "20:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 21f\n" + ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n" + ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n" + ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "fmad z4.s, p0/M, z3.s, z2.s\n" + "fmad z5.s, p0/M, z3.s, z2.s\n" + "subs x20, x20, #0x1\n" + "fmad z6.s, p0/M, z3.s, z2.s\n" + "fmad z7.s, p0/M, z3.s, z2.s\n" + "fmad z12.s, p0/M, z3.s, z10.s\n" + "fmad z13.s, p0/M, z3.s, z10.s\n" + "fmad z14.s, p0/M, z3.s, z10.s\n" + "fmad z15.s, p0/M, z3.s, z10.s\n" + ".inst 0xc1a0c824 // fclamp { z4.s-z7.s }, z1.s, z0.s\n" + ".inst 0xc1a0c82c // fclamp { z12.s-z15.s }, z1.s, z0.s\n" + ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + "subs x20, x20, #0x1\n" + ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n" + "21:" // Store to output array: Accumulator row 1 oddments: End + "22:" // Store to output array: End + "tbz x16, #0, 24f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "23:" // Store to output array: Refill accumulators: Loop + ".inst 0xa040c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15]\n" + ".inst 0xa041c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xa042c5e4 // ld1w { z4.s-z7.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c5e8 // ld1w { z8.s-z11.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n" + ".inst 0xc0840482 // mova za2h.s[x12], { z4.s-z7.s }\n" + ".inst 0xc0840503 // mova za3h.s[x12], { z8.s-z11.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 23b\n" + "24:" // End block + "incw x10, ALL, MUL #2\n" + "cmp x10, x9\n" + "blt 3b\n" + "incw x11, ALL, MUL #2\n" + "mov x10, #0x0\n" + "cmp x11, x13\n" + "mov x28, x27\n" + "blt 3b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp new file mode 100644 index 0000000000..70952f4f03 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include <cstdint> +#include "../std_transforms_sme.hpp" + +namespace arm_gemm +{ + +// Implementations +void sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + +class cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL +{ +public: + typedef int8_t operand_type; + typedef float result_type; + + typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer); + + /* Kernel blocking parameters */ + static unsigned int out_height() + { + return sme::get_vector_length<int32_t>() * 4; + } + + static unsigned int out_width() + { + return sme::get_vector_length<int32_t>() * 1; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + static constexpr bool supports_bias() + { + return true; + } + + static constexpr bool supports_activation() + { + return true; + } + + static constexpr bool is_sme() + { + return true; + } + + // Default to the generic kernel + kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL; + + StdTransformsSME<operand_type, result_type, 4, 1, 4> transforms = {}; + + cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp new file mode 100644 index 0000000000..bafb16bca8 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp @@ -0,0 +1,513 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_gemm.hpp" + +#include <cstdint> +#include "../../asmlib.hpp" +#include "../../utils.hpp" + +namespace arm_gemm { + +void sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer) +{ + struct KernelArgs + { + KernelArgs( + const int8_t *const A, + const int8_t *const B, + float *const C, const int ldc, + const int M, const int N, const int K, + const int32_t *const bias, const float *const late_bias, const Activation act, + bool accumulate, + int32_t *const accumulator_buffer + ) : A(A), + B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)), + C(C), ldcb(ldc * sizeof(float)), + M(M), N(N), K(K), + min(-std::numeric_limits<float>::infinity()), + max(std::numeric_limits<float>::infinity()), + bias(bias), late_bias(late_bias), + accumulator_buffer(accumulator_buffer), + flags(0x0) + { + if (accumulate) + { + flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER + } + if (C == nullptr) + { + flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER + } + + // Initialise the activation values + switch (act.type) + { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + this->max = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + this->min = static_cast<float>(0); + break; + } + } + + const int8_t *const A; + const int8_t *const B; + const long kstride_bytes; + float *const C; + const long ldcb; + const long M, N, K; + float min = -std::numeric_limits<float>::infinity(); + float max = std::numeric_limits<float>::infinity(); + + const int32_t *const bias; + const float *const late_bias; + + int32_t *const accumulator_buffer; + uint64_t flags; + }; + + // Construct arguments for this kernel + KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer); + + __asm__ __volatile__( + "ldr x16, [%x[args], %[offsetof_flags]]\n" + ".inst 0xd503477f // SMSTART ZA\n" + "ptrue p1.b\n" + ".inst 0x25207810 // ptrue pn8.b\n" + "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n" + "ldr x14, [%x[args], %[offsetof_accumulator_buffer]]\n" + "tbz x16, #0, 2f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "1:" // Initial accumulator load from buffer: Loop + ".inst 0xa040c1f4 // ld1w { z20.s-z23.s }, pn8.b/Z, [x15]\n" + ".inst 0xa041c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xa042c1e8 // ld1w { z8.s-z11.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c1f0 // ld1w { z16.s-z19.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840781 // mova za1h.s[x12], { z28.s-z31.s }\n" + ".inst 0xc0840502 // mova za2h.s[x12], { z8.s-z11.s }\n" + ".inst 0xc0840603 // mova za3h.s[x12], { z16.s-z19.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 1b\n" + "2:" // Initial accumulator load from buffer: End + "ldr w13, [%x[args], %[offsetof_M]]\n" + "mov x11, #0x0\n" + "mov x10, #0x0\n" + "ldr w9, [%x[args], %[offsetof_N]]\n" + "ldr x28, [%x[args], %[offsetof_A]]\n" + "3:" // M and N loop + "mov x27, x28\n" + "whilelt p0.s, x10, x9\n" + "tbnz x16, #0, 4f\n" + "ldr x20, [%x[args], %[offsetof_bias]]\n" + ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n" + "cbz x20, 5f\n" + "ld1w { z23.s }, p0/Z, [x20, x10, LSL #2]\n" + ".inst 0xc09026e0 // addha za0.s, p1/M, p1/M, z23.s\n" + ".inst 0xc09026e1 // addha za1.s, p1/M, p1/M, z23.s\n" + ".inst 0xc09026e2 // addha za2.s, p1/M, p1/M, z23.s\n" + ".inst 0xc09026e3 // addha za3.s, p1/M, p1/M, z23.s\n" + "4:" // Prepare accumulators: Test for last block + "mov x20, x10\n" + "mov x21, x11\n" + "incw x20\n" + "incw x21, ALL, MUL #4\n" + "cmp x20, x9\n" + "mov x20, x16\n" + "csel x21, x11, x21, LT\n" + "bfm x16, XZR, #0x0, #0x0 // bfc x16, #0x0, #0x1\n" + "cmp x21, x13\n" + "csel x16, x20, x16, LT\n" + "5:" // Prepare accumulators: End + "ldr x20, [%x[args], %[offsetof_K]]\n" + "ldr x23, [%x[args], %[offsetof_B]]\n" + "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n" + "add x20, x20, #0x3\n" + "lsr x20, x20, #0x2\n" + "lsr x21, x20, #0x2\n" + "madd x23, x10, x22, x23\n" // bptr = B + n * kstride_bytes + "and x20, x20, #0x3\n" + "cbz x21, 8f\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0408378 // ld1b { z24.b-z27.b }, pn8.b/Z, [x27]\n" + "ld1b { z4.b }, p1/Z, [x23]\n" + ".inst 0xa0418374 // ld1b { z20.b-z23.b }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + "ld1b { z2.b }, p1/Z, [x23, #1, MUL VL]\n" + ".inst 0xa042836c // ld1b { z12.b-z15.b }, pn8.b/Z, [x27, #0x8, MUL VL]\n" + "ld1b { z11.b }, p1/Z, [x23, #2, MUL VL]\n" + ".inst 0xa0438370 // ld1b { z16.b-z19.b }, pn8.b/Z, [x27, #0xc, MUL VL]\n" + "addvl x27, x27, #16\n" + "ld1b { z28.b }, p1/Z, [x23, #3, MUL VL]\n" + "addvl x23, x23, #4\n" + "ble 7f\n" + "6:" // K loop + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + "subs x21, x21, #0x1\n" + ".inst 0xa0842721 // smopa za1.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0842742 // smopa za2.s, p1/M, p1/M, z26.b, z4.b\n" + ".inst 0xa0842763 // smopa za3.s, p1/M, p1/M, z27.b, z4.b\n" + ".inst 0xa0408378 // ld1b { z24.b-z27.b }, pn8.b/Z, [x27]\n" + ".inst 0xa0822680 // smopa za0.s, p1/M, p1/M, z20.b, z2.b\n" + "ld1b { z4.b }, p1/Z, [x23]\n" + ".inst 0xa08226a1 // smopa za1.s, p1/M, p1/M, z21.b, z2.b\n" + ".inst 0xa08226c2 // smopa za2.s, p1/M, p1/M, z22.b, z2.b\n" + ".inst 0xa08226e3 // smopa za3.s, p1/M, p1/M, z23.b, z2.b\n" + ".inst 0xa0418374 // ld1b { z20.b-z23.b }, pn8.b/Z, [x27, #0x4, MUL VL]\n" + ".inst 0xa08b2580 // smopa za0.s, p1/M, p1/M, z12.b, z11.b\n" + "ld1b { z2.b }, p1/Z, [x23, #1, MUL VL]\n" + ".inst 0xa08b25a1 // smopa za1.s, p1/M, p1/M, z13.b, z11.b\n" + ".inst 0xa08b25c2 // smopa za2.s, p1/M, p1/M, z14.b, z11.b\n" + ".inst 0xa08b25e3 // smopa za3.s, p1/M, p1/M, z15.b, z11.b\n" + ".inst 0xa042836c // ld1b { z12.b-z15.b }, pn8.b/Z, [x27, #0x8, MUL VL]\n" + "ld1b { z11.b }, p1/Z, [x23, #2, MUL VL]\n" + ".inst 0xa09c2600 // smopa za0.s, p1/M, p1/M, z16.b, z28.b\n" + ".inst 0xa09c2621 // smopa za1.s, p1/M, p1/M, z17.b, z28.b\n" + ".inst 0xa09c2642 // smopa za2.s, p1/M, p1/M, z18.b, z28.b\n" + ".inst 0xa09c2663 // smopa za3.s, p1/M, p1/M, z19.b, z28.b\n" + ".inst 0xa0438370 // ld1b { z16.b-z19.b }, pn8.b/Z, [x27, #0xc, MUL VL]\n" + "addvl x27, x27, #16\n" + "ld1b { z28.b }, p1/Z, [x23, #3, MUL VL]\n" + "addvl x23, x23, #4\n" + "bgt 6b\n" + "7:" // K loop tail + ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n" + ".inst 0xa0842721 // smopa za1.s, p1/M, p1/M, z25.b, z4.b\n" + ".inst 0xa0842742 // smopa za2.s, p1/M, p1/M, z26.b, z4.b\n" + ".inst 0xa0842763 // smopa za3.s, p1/M, p1/M, z27.b, z4.b\n" + ".inst 0xa0822680 // smopa za0.s, p1/M, p1/M, z20.b, z2.b\n" + ".inst 0xa08226a1 // smopa za1.s, p1/M, p1/M, z21.b, z2.b\n" + ".inst 0xa08226c2 // smopa za2.s, p1/M, p1/M, z22.b, z2.b\n" + ".inst 0xa08226e3 // smopa za3.s, p1/M, p1/M, z23.b, z2.b\n" + ".inst 0xa08b2580 // smopa za0.s, p1/M, p1/M, z12.b, z11.b\n" + ".inst 0xa08b25a1 // smopa za1.s, p1/M, p1/M, z13.b, z11.b\n" + ".inst 0xa08b25c2 // smopa za2.s, p1/M, p1/M, z14.b, z11.b\n" + ".inst 0xa08b25e3 // smopa za3.s, p1/M, p1/M, z15.b, z11.b\n" + ".inst 0xa09c2600 // smopa za0.s, p1/M, p1/M, z16.b, z28.b\n" + ".inst 0xa09c2621 // smopa za1.s, p1/M, p1/M, z17.b, z28.b\n" + ".inst 0xa09c2642 // smopa za2.s, p1/M, p1/M, z18.b, z28.b\n" + ".inst 0xa09c2663 // smopa za3.s, p1/M, p1/M, z19.b, z28.b\n" + "8:" // K oddments + "cbz x20, 10f\n" + "9:" // K oddments: Loop + ".inst 0xa1408373 // ld1b { z19.b, z23.b, z27.b, z31.b }, pn8.b/Z, [x27]\n" + "subs x20, x20, #0x1\n" + "addvl x27, x27, #4\n" + "ld1b { z16.b }, p1/Z, [x23]\n" + "addvl x23, x23, #1\n" + ".inst 0xa0902660 // smopa za0.s, p1/M, p1/M, z19.b, z16.b\n" + ".inst 0xa09026e1 // smopa za1.s, p1/M, p1/M, z23.b, z16.b\n" + ".inst 0xa0902762 // smopa za2.s, p1/M, p1/M, z27.b, z16.b\n" + ".inst 0xa09027e3 // smopa za3.s, p1/M, p1/M, z31.b, z16.b\n" + "bgt 9b\n" + "10:" // K oddments: End + "tbz x16, #1, 14f\n" + "tbz x16, #0, 12f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "11:" // Store to partial result buffer: Store and refill: Loop + ".inst 0xa040c1e8 // ld1w { z8.s-z11.s }, pn8.b/Z, [x15]\n" + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n" + ".inst 0xa041c1ec // ld1w { z12.s-z15.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xc0860458 // mova { z24.s-z27.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xa042c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c1f4 // ld1w { z20.s-z23.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840500 // mova za0h.s[x12], { z8.s-z11.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n" + ".inst 0xa060c1c0 // st1w { z0.s-z3.s }, pn8.b, [x14]\n" + ".inst 0xc0840782 // mova za2h.s[x12], { z28.s-z31.s }\n" + ".inst 0xa061c1c4 // st1w { z4.s-z7.s }, pn8.b, [x14, #0x4, MUL VL]\n" + ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n" + "add x12, x12, #0x4\n" + ".inst 0xa062c1d8 // st1w { z24.s-z27.s }, pn8.b, [x14, #0x8, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, pn8.b, [x14, #0xc, MUL VL]\n" + "addvl x14, x14, #16\n" + "blt 11b\n" + "b 30f\n" + "12:" // Store to partial result buffer: Store only + "mov x12, #0x0\n" + "cntw x20\n" + "13:" // Store to partial result buffer: Store only: Loop + ".inst 0xc0860408 // mova { z8.s-z11.s }, za0h.s[x12]\n" + ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n" + ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + ".inst 0xa060c1c8 // st1w { z8.s-z11.s }, pn8.b, [x14]\n" + "add x12, x12, #0x4\n" + ".inst 0xa061c1cc // st1w { z12.s-z15.s }, pn8.b, [x14, #0x4, MUL VL]\n" + "cmp x12, x20\n" + ".inst 0xa062c1d4 // st1w { z20.s-z23.s }, pn8.b, [x14, #0x8, MUL VL]\n" + ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, pn8.b, [x14, #0xc, MUL VL]\n" + "addvl x14, x14, #16\n" + "blt 13b\n" + "b 30f\n" + "14:" // Store to output array + "ldr x26, [%x[args], %[offsetof_C]]\n" + "sub x25, x13, x11\n" + "ld1rw { z23.s }, p1/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n" + "fmov z22.s, #0x0\n" + "ldr x24, [%x[args], %[offsetof_ldcb]]\n" + "ldr x20, [%x[args], %[offsetof_late_bias]]\n" + "add x26, x26, x10, LSL #2\n" // C += n + "madd x26, x11, x24, x26\n" // C += m * ldc + "cbz x20, 15f\n" + "add x20, x20, x10, LSL #2\n" + "ld1w { z22.s }, p0/Z, [x20]\n" + "15:" // Store to output array: no late bias + "cntw x23\n" + "ld1rw { z21.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n" + "mov x12, #0x0\n" + "cmp x25, x23\n" + "ld1rw { z20.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 17f\n" + "16:" // Store to output array: Accumulator row 0 loop + ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n" + "add x12, x12, #0x4\n" + ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + "cmp x12, x21, LSL #2\n" + "fmad z0.s, p1/M, z23.s, z22.s\n" + "fmad z1.s, p1/M, z23.s, z22.s\n" + "fmad z2.s, p1/M, z23.s, z22.s\n" + "fmad z3.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4caa0 // fclamp { z0.s-z3.s }, z21.s, z20.s\n" + "st1w { z0.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z1.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z2.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z3.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "blt 16b\n" + "17:" // Store to output array: Accumulator row 0 oddments + "cbz x20, 18f\n" + ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmad z16.s, p1/M, z23.s, z22.s\n" + "fmad z17.s, p1/M, z23.s, z22.s\n" + "fmad z18.s, p1/M, z23.s, z22.s\n" + "fmad z19.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + "st1w { z16.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + "subs x20, x20, #0x1\n" + "st1w { z17.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 18f\n" + "st1w { z18.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "18:" // Store to output array: Accumulator row 0 oddments: End + "subs x25, x25, x22\n" + "beq 28f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 20f\n" + "19:" // Store to output array: Accumulator row 1 loop + ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n" + "add x12, x12, #0x4\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "cmp x12, x21, LSL #2\n" + "fmad z16.s, p1/M, z23.s, z22.s\n" + "fmad z17.s, p1/M, z23.s, z22.s\n" + "fmad z18.s, p1/M, z23.s, z22.s\n" + "fmad z19.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + "st1w { z16.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z17.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z18.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z19.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "blt 19b\n" + "20:" // Store to output array: Accumulator row 1 oddments + "cbz x20, 21f\n" + ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n" + "fmad z28.s, p1/M, z23.s, z22.s\n" + "fmad z29.s, p1/M, z23.s, z22.s\n" + "fmad z30.s, p1/M, z23.s, z22.s\n" + "fmad z31.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cabc // fclamp { z28.s-z31.s }, z21.s, z20.s\n" + "st1w { z28.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + "subs x20, x20, #0x1\n" + "st1w { z29.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 21f\n" + "st1w { z30.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "21:" // Store to output array: Accumulator row 1 oddments: End + "subs x25, x25, x22\n" + "beq 28f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x22, x25, x23, LT\n" + "lsr x21, x22, #0x2\n" + "and x20, x22, #0x3\n" + "cbz x21, 23f\n" + "22:" // Store to output array: Accumulator row 2 loop + ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n" + "add x12, x12, #0x4\n" + ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n" + "cmp x12, x21, LSL #2\n" + "fmad z12.s, p1/M, z23.s, z22.s\n" + "fmad z13.s, p1/M, z23.s, z22.s\n" + "fmad z14.s, p1/M, z23.s, z22.s\n" + "fmad z15.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4caac // fclamp { z12.s-z15.s }, z21.s, z20.s\n" + "st1w { z12.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z13.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z14.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z15.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "blt 22b\n" + "23:" // Store to output array: Accumulator row 2 oddments + "cbz x20, 24f\n" + ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmad z16.s, p1/M, z23.s, z22.s\n" + "fmad z17.s, p1/M, z23.s, z22.s\n" + "fmad z18.s, p1/M, z23.s, z22.s\n" + "fmad z19.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + "st1w { z16.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 24f\n" + "subs x20, x20, #0x1\n" + "st1w { z17.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 24f\n" + "st1w { z18.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "24:" // Store to output array: Accumulator row 2 oddments: End + "subs x25, x25, x22\n" + "beq 28f\n" + "cmp x25, x23\n" + "mov x12, #0x0\n" + "csel x20, x25, x23, LT\n" + "lsr x21, x20, #0x2\n" + "and x20, x20, #0x3\n" + "cbz x21, 26f\n" + "25:" // Store to output array: Accumulator row 3 loop + ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n" + "add x12, x12, #0x4\n" + ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n" + "cmp x12, x21, LSL #2\n" + "fmad z24.s, p1/M, z23.s, z22.s\n" + "fmad z25.s, p1/M, z23.s, z22.s\n" + "fmad z26.s, p1/M, z23.s, z22.s\n" + "fmad z27.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n" + "st1w { z24.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z25.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z26.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "st1w { z27.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "blt 25b\n" + "26:" // Store to output array: Accumulator row 3 oddments + "cbz x20, 27f\n" + ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n" + "subs x20, x20, #0x1\n" + ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n" + "fmad z16.s, p1/M, z23.s, z22.s\n" + "fmad z17.s, p1/M, z23.s, z22.s\n" + "fmad z18.s, p1/M, z23.s, z22.s\n" + "fmad z19.s, p1/M, z23.s, z22.s\n" + ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n" + "st1w { z16.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 27f\n" + "subs x20, x20, #0x1\n" + "st1w { z17.s }, p0, [x26]\n" + "add x26, x26, x24\n" + "beq 27f\n" + "st1w { z18.s }, p0, [x26]\n" + "27:" // Store to output array: Accumulator row 3 oddments: End + "28:" // Store to output array: End + "tbz x16, #0, 30f\n" + "mov x12, #0x0\n" + "cntw x20\n" + "29:" // Store to output array: Refill accumulators: Loop + ".inst 0xa040c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15]\n" + ".inst 0xa041c1e0 // ld1w { z0.s-z3.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n" + ".inst 0xa042c1ec // ld1w { z12.s-z15.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n" + ".inst 0xa043c1e4 // ld1w { z4.s-z7.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n" + ".inst 0xc0840780 // mova za0h.s[x12], { z28.s-z31.s }\n" + "addvl x15, x15, #16\n" + ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n" + ".inst 0xc0840582 // mova za2h.s[x12], { z12.s-z15.s }\n" + ".inst 0xc0840483 // mova za3h.s[x12], { z4.s-z7.s }\n" + "add x12, x12, #0x4\n" + "cmp x12, x20\n" + "blt 29b\n" + "30:" // End block + "incw x10\n" + "cmp x10, x9\n" + "blt 3b\n" + "incw x11, ALL, MUL #4\n" + "mov x10, #0x0\n" + "cmp x11, x13\n" + "mov x28, x27\n" + "blt 3b\n" + ".inst 0xd503467f // SMSTOP\n" + : + : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp index 887d78e1de..23f686a902 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,10 @@ public: { if (std::is_same<T, float>::value) { switch (ci->get_cpu_model()) { + case CPUModel::V1: + return { 28.74 }; default: - return { 32.35 }; + return { 15.27 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp index d0ef531c33..1fe5f48da6 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,8 +88,10 @@ public: if (std::is_same<T, float>::value) { switch (ci->get_cpu_model()) { - default: - return { 39.66, 5.18, 4.37 }; + case CPUModel::V1: + return { 53.48, 4.23, 6.53 }; + default: + return { 29.07, 2.76, 5.39 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp index a81d4504ae..ba47e0aa54 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2020, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) template<> void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const __fp16 *bias, Activation act, bool append) @@ -86,7 +86,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -140,7 +140,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -217,7 +217,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -317,7 +317,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -439,7 +439,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -584,7 +584,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -752,7 +752,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -944,7 +944,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1150,7 +1150,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1204,7 +1204,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1278,7 +1278,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1372,7 +1372,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1485,7 +1485,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1618,7 +1618,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1771,7 +1771,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1945,7 +1945,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -2112,4 +2112,4 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } } -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index 111d01ed3a..6da9f4be0e 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2019, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -1142,6 +1142,64 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); +void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height, + const int32_t* in_ptr, unsigned int in_stride, float *out_ptr, unsigned int out_stride, + const float* bias_ptr, bool accumulate, const Activation &act) +{ + const float32x4_t vscale = vdupq_n_f32(qp.scale); + float maxval = std::numeric_limits<float>::infinity(); + float minval = -std::numeric_limits<float>::infinity(); + + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + maxval = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + minval = 0; + break; + } + + const float32x4_t vmin = vdupq_n_f32(minval); + const float32x4_t vmax = vdupq_n_f32(maxval); + + for(unsigned int row=0; row<height; row++) { + auto row_in_ptr = in_ptr + (row * in_stride); + auto row_out_ptr = out_ptr + (row * out_stride); + unsigned int col=0; + if (width >= 4) { + for(; col <= (width - 4); col+= 4) { + const int32x4_t vin = vld1q_s32(row_in_ptr + col); + float32x4_t vdeq = vmulq_f32(vcvtq_f32_s32(vin), vscale); + if(bias_ptr) { + const float32x4_t bin = vld1q_f32(bias_ptr + col); + vdeq = vaddq_f32(vdeq, bin); + } + if(accumulate) { + vdeq = vaddq_f32(vdeq, vld1q_f32(row_out_ptr + col)); + } + vdeq = vminq_f32(vmaxq_f32(vdeq, vmin), vmax); + vst1q_f32(reinterpret_cast<float *>(row_out_ptr + col), vdeq); + } + } + // left-over elements + for(; col < width; ++col) { + const int32_t val = *(row_in_ptr + col); + float res = static_cast<float>(val * qp.scale); + if(bias_ptr) { + res += static_cast<float>(*(bias_ptr + col)); + } + if(accumulate) { + res += *(row_out_ptr + col); + } + res = std::min(std::max(res, minval), maxval); + *(row_out_ptr + col) = res; + } + } +} + } // namespace arm_gemm #endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/quantized.hpp b/src/core/NEON/kernels/arm_gemm/quantized.hpp index 31dd65b397..bc64fd967b 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Arm Limited. + * Copyright (c) 2019, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -45,4 +45,8 @@ template<typename T> void row_sums_indirect(size_t num_strings, const unsigned int *string_lengths, IndirectInputArg<T> A_arg, size_t M, int32_t *output_ptr, const Requantize32 *qp); +void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height, + const int32_t* input, unsigned int in_stride, float *output, unsigned int out_stride, + const float *row_bias, bool not_first_pass, const Activation &act); + } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/transform.cpp b/src/core/NEON/kernels/arm_gemm/transform.cpp index 45e4f0e1de..06d9e2416c 100644 --- a/src/core/NEON/kernels/arm_gemm/transform.cpp +++ b/src/core/NEON/kernels/arm_gemm/transform.cpp @@ -129,17 +129,17 @@ void Transform( // We don't have assembler transforms for AArch32, generate templated ones here. #ifdef __arm__ template void Transform<8, 1, true, VLType::None>(float *, const float *, int, int, int, int, int); -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<8, 1, true, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<8, 1, true, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 #endif // AArch32 -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<12, 1, false, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<12, 1, false, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h index 50b3fc1284..cd849c3666 100644 --- a/src/core/common/Registrars.h +++ b/src/core/common/Registrars.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023 Arm Limited. + * Copyright (c) 2020-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -38,6 +38,12 @@ #define REGISTER_FP16_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_FP16_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_FP16_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #if defined(ARM_COMPUTE_ENABLE_NEON) #define REGISTER_FP16_NEON(func_name) &(func_name) #else /* !defined(ARM_COMPUTE_ENABLE_NEON) */ @@ -48,6 +54,7 @@ #define REGISTER_FP16_NEON(func_name) nullptr #define REGISTER_FP16_SVE(func_name) nullptr #define REGISTER_FP16_SVE2(func_name) nullptr +#define REGISTER_FP16_SME2(func_name) nullptr #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ #if defined(ENABLE_FP32_KERNELS) @@ -64,6 +71,16 @@ #define REGISTER_FP32_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_FP32_SME2(func_name) &(func_name) +#define REGISTER_QASYMM8_SME2(func_name) &(func_name) +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_FP32_SME2(func_name) nullptr +#define REGISTER_QASYMM8_SME2(func_name) nullptr +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #if defined(ARM_COMPUTE_ENABLE_NEON) #define REGISTER_FP32_NEON(func_name) &(func_name) #else /* !defined(ARM_COMPUTE_ENABLE_NEON) */ @@ -74,6 +91,7 @@ #define REGISTER_FP32_NEON(func_name) nullptr #define REGISTER_FP32_SVE(func_name) nullptr #define REGISTER_FP32_SVE2(func_name) nullptr +#define REGISTER_FP32_SME2(func_name) nullptr #endif /* defined(ENABLE_FP32_KERNELS) */ #if defined(ENABLE_QASYMM8_SIGNED_KERNELS) @@ -92,10 +110,17 @@ #define REGISTER_QASYMM8_SIGNED_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #else /* defined(ENABLE_QASYMM8_SIGNED_KERNELS) */ #define REGISTER_QASYMM8_SIGNED_NEON(func_name) nullptr #define REGISTER_QASYMM8_SIGNED_SVE(func_name) nullptr #define REGISTER_QASYMM8_SIGNED_SVE2(func_name) nullptr +#define REGISTER_QASYMM8_SIGNED_SME2(func_name) nullptr #endif /* defined(ENABLE_QASYMM8_SIGNED_KERNELS) */ #if defined(ENABLE_QASYMM8_KERNELS) @@ -113,10 +138,17 @@ #define REGISTER_QASYMM8_SVE2(func_name) nullptr #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_SME2) +#define REGISTER_QASYMM8_SME2(func_name) &(func_name) +#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */ +#define REGISTER_QASYMM8_SME2(func_name) nullptr +#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */ + #else /* defined(ENABLE_QASYMM8_KERNELS) */ #define REGISTER_QASYMM8_NEON(func_name) nullptr #define REGISTER_QASYMM8_SVE(func_name) nullptr #define REGISTER_QASYMM8_SVE2(func_name) nullptr +#define REGISTER_QASYMM8_SME2(func_name) nullptr #endif /* defined(ENABLE_QASYMM8_KERNELS) */ #if defined(ENABLE_QSYMM16_KERNELS) diff --git a/src/core/helpers/LUTManager.cpp b/src/core/helpers/LUTManager.cpp index 06e35eed8c..2effffbe92 100644 --- a/src/core/helpers/LUTManager.cpp +++ b/src/core/helpers/LUTManager.cpp @@ -30,17 +30,38 @@ namespace arm_compute namespace { -void init_lut_fp16(ActivationLayerInfo::LookupTable65536 *lut) +float16_t activation(float16_t x, const LUTInfo &info) +{ + float16_t out = 0.f; + switch (info.act) + { + case ActivationLayerInfo::ActivationFunction::LOGISTIC: + out = 1.f / (1.f + std::exp(-x)); + break; + case ActivationLayerInfo::ActivationFunction::TANH: + { + out = static_cast<float16_t>(info.alpha * std::tanh(info.beta * x)); + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported Activation for 16-bit LUT table"); + break; + } + return out; +} + +void init_lut_fp16(ActivationLayerInfo::LookupTable65536 *lut, const LUTInfo &info) { union Element { uint16_t i = 0; float16_t fp; } item; + // Fill lut by iterating over all 16 bit values using the union. while (true) { - (*lut)[item.i] = 1.f / (1.f + std::exp(-item.fp)); + (*lut)[item.i] = activation(item.fp, info); if (item.i == 65535) break; item.i++; @@ -62,7 +83,7 @@ std::shared_ptr<ActivationLayerInfo::LookupTable65536> LUTManager::get_lut_table // Not found, or pointer not valid // We do not use make_shared to prevent the weak_ptr keeping the control block alive std::shared_ptr<ActivationLayerInfo::LookupTable65536> ptr(new ActivationLayerInfo::LookupTable65536); - init_lut_fp16(ptr.get()); + init_lut_fp16(ptr.get(), info); map_fp16[info] = ptr; return ptr; } diff --git a/src/core/helpers/LUTManager.h b/src/core/helpers/LUTManager.h index 4e13ead7e3..f3f4bf2832 100644 --- a/src/core/helpers/LUTManager.h +++ b/src/core/helpers/LUTManager.h @@ -38,19 +38,23 @@ namespace arm_compute struct LUTInfo { ActivationLayerInfo::ActivationFunction act; + float alpha; + float beta; DataType dt; - QuantizationInfo qinfo; + UniformQuantizationInfo qinfo; + // Operators enable use of map with Lutinfo as key friend bool operator<(const LUTInfo &l, const LUTInfo &r) { - return (l.act < r.act) || ((l.act == r.act) && (l.dt < r.dt)) || - ((l.act == r.act) && (l.dt == r.dt) && (l.qinfo.scale() < r.qinfo.scale())) || - ((l.act == r.act) && (l.dt == r.dt) && (l.qinfo.scale() == r.qinfo.scale()) && - (l.qinfo.offset() < l.qinfo.offset())); + const auto l_tup = std::make_tuple(l.act, l.alpha, l.beta, l.dt, l.qinfo.scale, l.qinfo.offset); + const auto r_tup = std::make_tuple(r.act, r.alpha, r.beta, r.dt, r.qinfo.scale, r.qinfo.offset); + + return l_tup < r_tup; } - bool operator==(const LUTInfo &l) + bool operator==(const LUTInfo &l) const { - return this->act == l.act && this->dt == l.dt && this->qinfo == l.qinfo; + return this->act == l.act && this->alpha == l.alpha && this->beta == l.beta && this->dt == l.dt && + this->qinfo == l.qinfo; } }; diff --git a/src/core/utils/helpers/tensor_transform.cpp b/src/core/utils/helpers/tensor_transform.cpp index 19d0badd74..212cfdabaa 100644 --- a/src/core/utils/helpers/tensor_transform.cpp +++ b/src/core/utils/helpers/tensor_transform.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2020, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -117,7 +117,10 @@ int calculate_end_on_index(TensorShape input_shape, } // Final clamp - stop = (stride > 0) ? utility::clamp(stop, 0, dim_size) : utility::clamp(stop, -1, dim_size - 1); + if (stride > 0) + stop = utility::clamp(stop, 0, dim_size); + else + stop = utility::clamp(stop, -1, dim_size - 1); return stop; } diff --git a/src/core/utils/quantization/AsymmHelpers.cpp b/src/core/utils/quantization/AsymmHelpers.cpp index f66d3e7064..f8b74a985d 100644 --- a/src/core/utils/quantization/AsymmHelpers.cpp +++ b/src/core/utils/quantization/AsymmHelpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -122,13 +122,13 @@ arm_compute::Status calculate_quantized_multipliers(const QuantizationInfo &iq_ ARM_COMPUTE_RETURN_ERROR_ON(iq_info.scale().empty()); ARM_COMPUTE_RETURN_ERROR_ON(wq_info.scale().empty()); ARM_COMPUTE_RETURN_ERROR_ON(oq_info.scale().empty()); - - const unsigned int size = wq_info.scale().size(); - - auto &quant_multipliers = stage_info.gemmlowp_multipliers; - auto &quant_shifts = stage_info.gemmlowp_shifts; - quant_multipliers.resize(size); - quant_shifts.resize(size); + constexpr unsigned int padding_elems = 32; // assembly kernels assume the shifts and multipliers buffers are padded + const unsigned int size = wq_info.scale().size(); + const size_t padded_size = (size == 1) ? 1 : size + padding_elems; + auto &quant_multipliers = stage_info.gemmlowp_multipliers; + auto &quant_shifts = stage_info.gemmlowp_shifts; + quant_multipliers.resize(padded_size); + quant_shifts.resize(padded_size); const auto &w_scales = wq_info.scale(); const float i_scale = iq_info.scale().at(0); diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp index 7cfa39b286..4253027231 100644 --- a/src/cpu/kernels/CpuActivationKernel.cpp +++ b/src/cpu/kernels/CpuActivationKernel.cpp @@ -43,6 +43,13 @@ namespace kernels { namespace { + +bool is_fp16_lut_supported(ActivationLayerInfo::ActivationFunction func) +{ + return func == ActivationLayerInfo::ActivationFunction::LOGISTIC || + func == ActivationLayerInfo::ActivationFunction::TANH; +} + static const std::vector<CpuActivationKernel::ActivationKernel> available_kernels = { #ifdef ARM_COMPUTE_ENABLE_SVE {"sve2_q8_activation_lut", @@ -85,10 +92,7 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel REGISTER_QSYMM16_SVE2(arm_compute::cpu::sve2_qsymm16_activation)}, {"sve_fp16_activation_lut", [](const ActivationDataTypeISASelectorData &data) - { - return data.dt == DataType::F16 && data.isa.fp16 && data.isa.sve && - data.f == ActivationLayerInfo::ActivationFunction::LOGISTIC; - }, + { return data.dt == DataType::F16 && data.isa.fp16 && data.isa.sve && is_fp16_lut_supported(data.f); }, REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_activation_lut)}, {"sve_fp16_activation", [](const ActivationDataTypeISASelectorData &data) @@ -299,10 +303,10 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac activation_info.setLookupTable256(tmp_lut); } - if (src->data_type() == DataType::F16 && - activation_info.activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC) + if (std::string(uk->name) == "sve_fp16_activation_lut") { - const LUTInfo info = {activation_info.activation(), src->data_type(), src->quantization_info()}; + const LUTInfo info = {activation_info.activation(), activation_info.a(), activation_info.b(), src->data_type(), + src->quantization_info().uniform()}; activation_info.setLookupTable65536((lut_manager.get_lut_table(info))); } #endif // __aarch64__ diff --git a/src/cpu/kernels/CpuDequantizeKernel.cpp b/src/cpu/kernels/CpuDequantizeKernel.cpp index d17128b5ac..5595ace998 100644 --- a/src/cpu/kernels/CpuDequantizeKernel.cpp +++ b/src/cpu/kernels/CpuDequantizeKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,12 +29,14 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include "src/core/common/Registrars.h" #include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" #include "src/core/NEON/NEAsymm.h" #include "src/core/NEON/NESymm.h" #include "src/core/NEON/wrapper/wrapper.h" +#include "src/cpu/kernels/dequantize/generic/neon/list.h" #include <arm_neon.h> @@ -62,301 +64,6 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst) return Status{}; } - -template <typename T> -inline void store_result(T *ptr, const float32x4x4_t &v) -{ - ARM_COMPUTE_UNUSED(ptr, v); -} - -template <> -inline void store_result<float>(float *ptr, const float32x4x4_t &v) -{ - wrapper::vstore(ptr, v.val[0]); - wrapper::vstore(ptr + 4, v.val[1]); - wrapper::vstore(ptr + 8, v.val[2]); - wrapper::vstore(ptr + 12, v.val[3]); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -inline void store_result<float16_t>(float16_t *ptr, const float32x4x4_t &v) -{ - wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1]))); - wrapper::vstore(ptr + 8, vcombine_f16(vcvt_f16_f32(v.val[2]), vcvt_f16_f32(v.val[3]))); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - -template <typename T> -inline void store_result(T *ptr, const float32x4x2_t &v) -{ - ARM_COMPUTE_UNUSED(ptr, v); -} - -template <> -inline void store_result<float>(float *ptr, const float32x4x2_t &v) -{ - wrapper::vstore(ptr, v.val[0]); - wrapper::vstore(ptr + 4, v.val[1]); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -inline void store_result<float16_t>(float16_t *ptr, const float32x4x2_t &v) -{ - wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1]))); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - -template <typename TOut, typename TIn> -void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Window &window) -{ - const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); - const float scale = qinfo.scale; - const int32_t offset = qinfo.offset; - - const int window_step_x = 16; - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Create iterators - Iterator in(input, win_collapsed); - Iterator out(output, win_collapsed); - - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - const auto in_ptr = reinterpret_cast<const TIn *>(in.ptr()); - const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin = wrapper::vloadq(in_ptr + x); - const auto vdeq = vdequantize(vin, scale, offset); - - store_result(reinterpret_cast<TOut *>(out_ptr + x), vdeq); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto val = *(in_ptr + x); - *(out_ptr + x) = static_cast<TOut>(Qasymm8QuantizationHelper<TIn>::dequantize(val, qinfo)); - } - }, - in, out); -} - -template <typename T> -void run_dequantization_qsymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window) -{ - const auto scale = input->info()->quantization_info().scale(); - - const int window_step_x = 16; - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - // Reset first dimension to handle tail calculations manually - Window win(window); - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Create iterators - Iterator in(input, win); - Iterator out(output, win); - - execute_window_loop( - win, - [&](const Coordinates &id) - { - const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr()); - const auto out_ptr = reinterpret_cast<T *>(out.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin = wrapper::vloadq(in_ptr + x); - const auto vdeq = vdequantize(vin, scale[id.z()]); - - store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - int8_t val = *(in_ptr + x); - *(out_ptr + x) = static_cast<T>(dequantize(val, scale[id.z()])); - } - }, - in, out); -} - -template <typename T> -void run_dequantization_qsymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window) -{ - const auto scale = input->info()->quantization_info().scale(); - - const int window_step_x = 16; - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - // Reset first dimension to handle tail calculations manually - Window win(window); - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Create iterators - Iterator in(input, win); - Iterator out(output, win); - - execute_window_loop( - win, - [&](const Coordinates &) - { - const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr()); - const auto out_ptr = reinterpret_cast<T *>(out.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const float32x4x4_t vscale = {{scale[x + 0], scale[x + 1], scale[x + 2], scale[x + 3], scale[x + 4], - scale[x + 5], scale[x + 6], scale[x + 7], scale[x + 8], scale[x + 9], - scale[x + 10], scale[x + 11], scale[x + 12], scale[x + 13], - scale[x + 14], scale[x + 15]}}; - const auto vin = wrapper::vloadq(in_ptr + x); - const auto vdeq = vdequantize(vin, vscale); - - store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - int8_t val = *(in_ptr + x); - *(out_ptr + x) = static_cast<T>(dequantize(val, scale[x])); - } - }, - in, out); -} - -template <typename T> -void run_dequantization_qsymm8(const ITensor *input, ITensor *output, const Window &window) -{ - const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); - const float scale = qinfo.scale; - - const int window_step_x = 16; - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Create iterators - Iterator in(input, win_collapsed); - Iterator out(output, win_collapsed); - - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr()); - const auto out_ptr = reinterpret_cast<T *>(out.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin = wrapper::vloadq(in_ptr + x); - const auto vdeq = vdequantize(vin, scale); - - store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - int8_t val = *(in_ptr + x); - *(out_ptr + x) = static_cast<T>(dequantize(val, scale)); - } - }, - in, out); -} - -template <typename T> -void run_dequantization_qsymm16(const ITensor *input, ITensor *output, const Window &window) -{ - const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); - const float scale = qinfo.scale; - - const int window_step_x = 8; - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Create iterators - Iterator in(input, win_collapsed); - Iterator out(output, win_collapsed); - - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - const auto in_ptr = reinterpret_cast<const int16_t *>(in.ptr()); - const auto out_ptr = reinterpret_cast<T *>(out.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin = wrapper::vloadq(in_ptr + x); - const auto vdeq = vdequantize_int16(vin, scale); - - store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); - } - - // Compute left-over elements - for (; x < window_end_x; ++x) - { - int16_t val = *(in_ptr + x); - *(out_ptr + x) = static_cast<T>(dequantize_qsymm16(val, scale)); - } - }, - in, out); -} - -template <typename T> -void run_dequantization_core(const ITensor *input, ITensor *output, const Window &window) -{ - switch (input->info()->data_type()) - { - case DataType::QASYMM8: - run_dequantization_qasymm8<T, uint8_t>(input, output, window); - break; - case DataType::QASYMM8_SIGNED: - run_dequantization_qasymm8<T, int8_t>(input, output, window); - break; - case DataType::QSYMM8_PER_CHANNEL: - input->info()->data_layout() == DataLayout::NHWC - ? run_dequantization_qsymm8_per_channel_nhwc<T>(input, output, window) - : run_dequantization_qsymm8_per_channel_nchw<T>(input, output, window); - break; - case DataType::QSYMM8: - run_dequantization_qsymm8<T>(input, output, window); - break; - case DataType::QSYMM16: - run_dequantization_qsymm16<T>(input, output, window); - break; - default: - ARM_COMPUTE_ERROR("Unsupported data type."); - } -} } // namespace void CpuDequantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst) @@ -370,6 +77,20 @@ void CpuDequantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst) auto_init_if_empty(*dst, src->tensor_shape(), 1, DataType::F32); ICpuKernel::configure(win); + + switch (dst->data_type()) + { + case DataType::F32: + _func = REGISTER_FP32_NEON(fp32_run_dequantization_core); + break; +#ifdef ARM_COMPUTE_ENABLE_FP16 + case DataType::F16: + _func = REGISTER_FP16_NEON(fp16_run_dequantization_core); + break; +#endif /* ARM_COMPUTE_ENABLE_FP16 */ + default: + ARM_COMPUTE_ERROR("Unsupported data type."); + } } Status CpuDequantizeKernel::validate(const ITensorInfo *src, const ITensorInfo *dst) @@ -386,20 +107,7 @@ void CpuDequantizeKernel::run_op(ITensorPack &tensors, const Window &window, con const auto src = tensors.get_const_tensor(TensorType::ACL_SRC); auto dst = tensors.get_tensor(TensorType::ACL_DST); - - switch (dst->info()->data_type()) - { - case DataType::F32: - run_dequantization_core<float>(src, dst, window); - break; -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - run_dequantization_core<float16_t>(src, dst, window); - break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - ARM_COMPUTE_ERROR("Unsupported data type."); - } + _func(src, dst, window); } const char *CpuDequantizeKernel::name() const { diff --git a/src/cpu/kernels/CpuDequantizeKernel.h b/src/cpu/kernels/CpuDequantizeKernel.h index 6ed58587c9..d8b6444f0a 100644 --- a/src/cpu/kernels/CpuDequantizeKernel.h +++ b/src/cpu/kernels/CpuDequantizeKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H -#define ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H +#ifndef ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H +#define ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H #include "src/core/common/Macros.h" #include "src/cpu/ICpuKernel.h" @@ -56,8 +56,16 @@ public: // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + +private: + /** Common signature for all the specialised @ref CpuDequantizeKernel functions + * + * @param[in] window Region on which to execute the kernel. + */ + using DequantizeFunctionExecutorPtr = void (*)(const ITensor *input, ITensor *output, const Window &window); + DequantizeFunctionExecutorPtr _func{nullptr}; }; } // namespace kernels } // namespace cpu } // namespace arm_compute -#endif /* ARM_COMPUTE_CPU_DEQUANTIZE_KERNEL_H */ +#endif // ACL_SRC_CPU_KERNELS_CPUDEQUANTIZEKERNEL_H diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp index e290783021..2a76a5958d 100644 --- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp +++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2022,2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -51,17 +51,19 @@ Status validate_arguments(const ITensorInfo *mm_result, int32_t a_offset, int32_t b_offset) { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32, DataType::F32); - // If a_offset == 0, vector_sum_col can be a nullptr - if (a_offset != 0) + // We run if the offset is nonzero or a sum col has been provided, we need + // the second option in case the QuantizationInfo is dynamic + if (a_offset != 0 || vector_sum_col != nullptr) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32); ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != mm_result->dimension(0)); } - // If b_offset == 0, vector_sum_row can be a nullptr - if (b_offset != 0) + // We run if the offset is nonzero or a sum row has been provided, we need + // the second option in case the QuantizationInfo is dynamic + if (b_offset != 0 || vector_sum_row != nullptr) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32); @@ -86,7 +88,7 @@ Status validate_arguments(const ITensorInfo *mm_result, ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[output_batch_idx], "mm_result tensor must have the same number of batches of output tensor"); - if (a_offset != 0) + if (vector_sum_col != nullptr) { TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape(); vector_sum_col_shape.collapse_from(1); @@ -102,6 +104,275 @@ Status validate_arguments(const ITensorInfo *mm_result, return Status{}; } +void run_offset_contribution_float(const Window &window, + ITensor *mm_result, + const ITensor *vector_sum_col, + const ITensor *vector_sum_row, + int32_t a_offset, + int32_t b_offset, + int32_t k_offset, + float scale, + bool slide_vector_sum_col, + bool is_gemm3d) +{ + Window collapsed_window = window.collapse_if_possible(window, Window::DimZ); + collapsed_window.set(Window::DimX, Window::Dimension(0, 1, 1)); + + const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0; + const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1; + + const int window_start_x = window.x().start(); + const int window_end_x = window.x().end(); + const int window_step_x = 16; + + // if vector_sum_col is nullptr then stride_y is 0, else get stride_y + const size_t sum_col_stride_y = (vector_sum_col != nullptr) ? (vector_sum_col->info()->strides_in_bytes().y()) : 0; + Iterator mm_result_it(mm_result, collapsed_window); + + if ((a_offset != 0) && (b_offset != 0) && (vector_sum_col != nullptr) && (vector_sum_row != nullptr)) // true, true + { + // Set window for vector_sum_col + Window win_vector_sum_col(collapsed_window); + win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0)); + win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0)); + + // Set window for vector_sum_row + Window win_vector_sum_row(collapsed_window); + win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0)); + win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0)); + + Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col); + Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row); + + const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y(); + + // Offset in case vector_sum_col is batched + const int vector_sum_col_batch_offset = + slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0; + + execute_window_loop( + collapsed_window, + [&](const Coordinates &id) + { + const int batch_id = id.z() / depth_input; + const size_t batch_offset_col = batch_id * (sum_col_stride_y); + auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col + + batch_id * vector_sum_col_batch_offset); + auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr()); + + // Compute the leftover term due to b_offset. + int32_t b_offset_term_s32 = + *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + + id.y() + (id.z() % depth_input) * height_input); + b_offset_term_s32 *= b_offset; + + const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32); + + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + // Compute the leftover term due to a_offset. + int32x4x4_t a_offset_term_s32 = { + {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4), + vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}}; + + a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset); + a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset); + a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset); + a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset); + + // Add a_offset_term_s32 and b_offset_term_s32 + int32x4x4_t offset_term_s32 = { + {vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset)}}; + + offset_term_s32.val[0] = + vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32_vec)); + offset_term_s32.val[1] = + vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32_vec)); + offset_term_s32.val[2] = + vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32_vec)); + offset_term_s32.val[3] = + vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32_vec)); + + float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4), + vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}}; + + // Convert and scale the S32 offsets to match the already scaled GEMM results + float32x4x4_t offset_terms_scaled = {{ + vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[0]), scale), + vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[1]), scale), + vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[2]), scale), + vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[3]), scale), + }}; + + // Add the offset terms to the GEMM result + in_f32.val[0] = vaddq_f32(in_f32.val[0], offset_terms_scaled.val[0]); + in_f32.val[1] = vaddq_f32(in_f32.val[1], offset_terms_scaled.val[1]); + in_f32.val[2] = vaddq_f32(in_f32.val[2], offset_terms_scaled.val[2]); + in_f32.val[3] = vaddq_f32(in_f32.val[3], offset_terms_scaled.val[3]); + + // Store the result with the offset contribution + vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]); + vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]); + vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]); + vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]); + } + + // Left-overs loop + for (; x < window_end_x; ++x) + { + // Compute the leftover term due to a_offset. + int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x); + + a_offset_term_s32 *= a_offset; + + // Add the offset terms to GEMM's result + // Store the result with the offset contribution + mm_result_ptr[x] += (k_offset + a_offset_term_s32 + b_offset_term_s32) * scale; + } + }, + vector_sum_col_it, vector_sum_row_it, mm_result_it); + } + else if ((a_offset == 0) && (b_offset != 0) && (vector_sum_row != nullptr)) // false, true + { + ARM_COMPUTE_ERROR_ON_NULLPTR(vector_sum_row); + + // Set window for vector_sum_row + Window win_vector_sum_row(collapsed_window); + win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0)); + win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0)); + + Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row); + + const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y(); + + execute_window_loop( + collapsed_window, + [&](const Coordinates &id) + { + const int batch_id = id.z() / depth_input; + auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr()); + + // Compute the leftover term due to b_offset. + int32_t row_sum = + *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) + + id.y() + (id.z() % depth_input) * height_input); + float scaled_b_offset_term_f32 = row_sum * b_offset * scale; + + const float32x4_t b_offset_term_f32_vec = vdupq_n_f32(scaled_b_offset_term_f32); + + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4), + vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}}; + + // Add the offset terms to GEMM's result + in_f32.val[0] = vaddq_f32(in_f32.val[0], b_offset_term_f32_vec); + in_f32.val[1] = vaddq_f32(in_f32.val[1], b_offset_term_f32_vec); + in_f32.val[2] = vaddq_f32(in_f32.val[2], b_offset_term_f32_vec); + in_f32.val[3] = vaddq_f32(in_f32.val[3], b_offset_term_f32_vec); + + // Store the result with the offset contribution + vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]); + vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]); + vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]); + vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]); + } + + // Left-overs loop + for (; x < window_end_x; ++x) + { + // Add the offset terms to GEMM's result + // Store the result with the offset contribution + mm_result_ptr[x] += scaled_b_offset_term_f32; + } + }, + vector_sum_row_it, mm_result_it); + } + else if ((a_offset != 0) && (b_offset == 0) && (vector_sum_col != nullptr)) // true, false + { + // Set window for vector_sum_col + Window win_vector_sum_col(collapsed_window); + win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0)); + win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0)); + + Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col); + + // Offset in case vector_sum_col is batched + const int vector_sum_col_batch_offset = + slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0; + + execute_window_loop( + collapsed_window, + [&](const Coordinates &id) + { + const int batch_id = id.z() / depth_input; + const size_t batch_offset_col = + batch_id * + (sum_col_stride_y); // Value to offset vector_sum_col_ptr to allow for iteration of y values in tensor + auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col + + batch_id * vector_sum_col_batch_offset); + auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + // Compute the leftover term due to a_offset. + int32x4x4_t a_offset_term_s32 = { + {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4), + vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}}; + + a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset); + a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset); + a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset); + a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset); + + float32x4x4_t a_offset_term_scaled = {{ + vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[0]), scale), + vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[1]), scale), + vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[2]), scale), + vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[3]), scale), + }}; + + float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4), + vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}}; + + // Add the offset terms to GEMM's result + in_f32.val[0] = vaddq_f32(in_f32.val[0], a_offset_term_scaled.val[0]); + in_f32.val[1] = vaddq_f32(in_f32.val[1], a_offset_term_scaled.val[1]); + in_f32.val[2] = vaddq_f32(in_f32.val[2], a_offset_term_scaled.val[2]); + in_f32.val[3] = vaddq_f32(in_f32.val[3], a_offset_term_scaled.val[3]); + + // Store the result with the offset contribution + vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]); + vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]); + vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]); + vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]); + } + + // Left-overs loop + for (; x < window_end_x; ++x) + { + // Compute the leftover term due to a_offset. + const int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x); + + // Add the offset terms to GEMM's result + // Store the result with the offset contribution + mm_result_ptr[x] += a_offset_term_s32 * a_offset * scale; + } + }, + vector_sum_col_it, mm_result_it); + } + else // false, false + { + // No offset contribution from matrix A and matrix B + return; + } +} + void run_offset_contribution(const Window &window, ITensor *mm_result, const ITensor *vector_sum_col, @@ -361,7 +632,8 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result, ITensorInfo *vector_sum_row, int32_t k, int32_t a_offset, - int32_t b_offset) + int32_t b_offset, + float scale) { // Perform validate step ARM_COMPUTE_UNUSED(vector_sum_row); @@ -370,10 +642,11 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result, _a_offset = a_offset; _b_offset = b_offset; - _k_offset = a_offset * b_offset * k; + _k = k; - // If a_offset == 0, vector_sum_col can be a nullptr - if (a_offset != 0) + _scale = scale; + + if (vector_sum_col != nullptr) { // Check if vector_sum_col_shape should be slidden or not // Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1 @@ -386,6 +659,21 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result, ICpuKernel::configure(win); } +void CpuGemmLowpOffsetContributionKernel::set_a_offset(int32_t a_offset) +{ + _a_offset = a_offset; +} + +void CpuGemmLowpOffsetContributionKernel::set_b_offset(int32_t b_offset) +{ + _b_offset = b_offset; +} + +void CpuGemmLowpOffsetContributionKernel::set_scale(float scale) +{ + _scale = scale; +} + Status CpuGemmLowpOffsetContributionKernel::validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, @@ -410,8 +698,18 @@ void CpuGemmLowpOffsetContributionKernel::run_op(ITensorPack &tensors, const Win const bool reinterpret_as_3d = vector_sum_row != nullptr && mm_result->info()->num_dimensions() > 1 && mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x(); - run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, _k_offset, - _slide_vector_sum_col, reinterpret_as_3d); + // check to see what is the output type of result + auto k_offset = _a_offset * _b_offset * _k; + if (mm_result->info()->data_type() == DataType::F32) + { + run_offset_contribution_float(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, k_offset, + _scale, _slide_vector_sum_col, reinterpret_as_3d); + } + else + { + run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, k_offset, + _slide_vector_sum_col, reinterpret_as_3d); + } } const char *CpuGemmLowpOffsetContributionKernel::name() const diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h index 08b2d47529..ecbfb0c282 100644 --- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h +++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2022,2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,12 +21,14 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H -#define ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H +#ifndef ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H +#define ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H #include "src/core/common/Macros.h" #include "src/cpu/ICpuKernel.h" +#include <cstdint> + namespace arm_compute { namespace cpu @@ -62,13 +64,16 @@ public: * @param[in] k Number of matrix A columns or Matrix B rows * @param[in] a_offset Offset to be added to each element of the matrix A. * @param[in] b_offset Offset to be added to each element of the matrix B. + * @param[in] scale (Optional) multiplies the contribution to make it the same scale as the dst in the case where mm_result is float + * (and so has already been scaled). Default is 1.0 */ void configure(ITensorInfo *mm_result, ITensorInfo *vector_sum_col, ITensorInfo *vector_sum_row, int32_t k, int32_t a_offset, - int32_t b_offset); + int32_t b_offset, + float scale = 1.0f); /** Static function to check if given info will lead to a valid configuration * * Similar to CpuGemmLowpOffsetContributionKernel::configure() @@ -81,6 +86,29 @@ public: int32_t a_offset, int32_t b_offset); + /** Set the a offset + * Warning: if a_offset is non-zero then vector_sum_col must be set in run_op. + * Run configure or validate again if you aren't sure + * + * @param[in] a_offset Offset to be added to each element of the matrix A. + */ + void set_a_offset(int32_t a_offset); + + /** Set the b offset + * Warning: if b_offset is non-zero then vector_sum_row must be set in run_op. + * Run configure or validate again if you aren't sure + * + * @param[in] b_offset Offset to be added to each element of the matrix B. + */ + void set_b_offset(int32_t b_offset); + + /** Set the dequantize scale + * + * @param[in] scale Multiplies the contribution to make it the same scale as the dst in the case where + * mm_result is float (and so has already been scaled). + */ + void set_scale(float scale); + // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; @@ -88,10 +116,11 @@ public: private: int32_t _a_offset{0}; int32_t _b_offset{0}; - int32_t _k_offset{0}; + int32_t _k{0}; // Number of columns of A or rows of B, used in last offset term + float _scale{1.0}; bool _slide_vector_sum_col{true}; }; } // namespace kernels } // namespace cpu } // namespace arm_compute -#endif /* ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H */ +#endif // ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp index d008842398..3c113f2828 100644 --- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp +++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, 2023 Arm Limited. + * Copyright (c) 2019-2021, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -919,7 +919,7 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::configure(const ITensorInfo _a_offset = a_offset; _b_offset = b_offset; - _k_offset = a_offset * b_offset * k; + _k = k; _output_stage = output_stage; // If a_offset == 0, vector_sum_col can be a nullptr @@ -958,6 +958,16 @@ Status CpuGemmLowpOffsetContributionOutputStageKernel::validate(const ITensorInf return Status{}; } +void CpuGemmLowpOffsetContributionOutputStageKernel::set_a_offset(int32_t a_offset) +{ + _a_offset = a_offset; +} + +void CpuGemmLowpOffsetContributionOutputStageKernel::set_b_offset(int32_t b_offset) +{ + _b_offset = b_offset; +} + void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) @@ -993,10 +1003,11 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &te // Check if symmetric per-channel execution const bool is_symm = _output_stage.is_quantized_per_channel; + auto k_offset = _a_offset * _b_offset * _k; if (is_symm) { run_offset_contribution_output_stage_symm(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, - _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched, + _a_offset, _b_offset, k_offset, _is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point); } else @@ -1004,13 +1015,13 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &te if (is_signed) { run_offset_contribution_output_stage<int8_t>( - window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, + window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, k_offset, _is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point); } else { run_offset_contribution_output_stage<uint8_t>( - window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, + window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, k_offset, _is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point); } } diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h index af477d4756..ff706ff3dc 100644 --- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h +++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 Arm Limited. + * Copyright (c) 2019-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H -#define ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H +#ifndef ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H +#define ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H #include "arm_compute/core/KernelDescriptors.h" @@ -110,6 +110,22 @@ public: int32_t b_offset, GEMMLowpOutputStageInfo output_stage); + /** Set the a offset + * Warning: if a_offset is non-zero then vector_sum_col must be set in run_op. + * Run configure or validate again if you aren't sure + * + * @param[in] a_offset Offset to be added to each element of the matrix A. + */ + void set_a_offset(int32_t a_offset); + + /** Set the b offset + * Warning: if b_offset is non-zero then vector_sum_col must be set in run_op. + * Run configure or validate again if you aren't sure + * + * @param[in] b_offset Offset to be added to each element of the matrix B. + */ + void set_b_offset(int32_t b_offset); + // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; @@ -118,11 +134,11 @@ private: /** Function to use for the particular tensors passed to configure() */ int32_t _a_offset{0}; int32_t _b_offset{0}; - int32_t _k_offset{0}; + int32_t _k{0}; // Number of columns of A or rows of B, used in last offset term bool _is_vector_sum_col_batched{true}; GEMMLowpOutputStageInfo _output_stage{GEMMLowpOutputStageInfo()}; }; } // namespace kernels } // namespace cpu } // namespace arm_compute -#endif /* ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H */ +#endif // ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h index 45ebeec394..7c1e4772a6 100644 --- a/src/cpu/kernels/CpuKernelSelectionTypes.h +++ b/src/cpu/kernels/CpuKernelSelectionTypes.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 Arm Limited. + * Copyright (c) 2021-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -104,6 +104,8 @@ struct SoftmaxKernelDataTypeISASelectorData DataType dt; cpuinfo::CpuIsaInfo isa; bool is_log; + int axis; + unsigned long sme2_vector_length; }; // Selector pointer types diff --git a/src/cpu/kernels/CpuQuantizeKernel.cpp b/src/cpu/kernels/CpuQuantizeKernel.cpp index d2ac6cf8ac..ed4675ae3d 100644 --- a/src/cpu/kernels/CpuQuantizeKernel.cpp +++ b/src/cpu/kernels/CpuQuantizeKernel.cpp @@ -29,12 +29,12 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include "src/core/common/Registrars.h" #include "src/core/CPP/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" -#include "src/core/NEON/NEAsymm.h" -#include "src/core/NEON/NEMath.h" #include "src/core/NEON/wrapper/wrapper.h" +#include "src/cpu/kernels/quantize/generic/neon/list.h" #include <arm_neon.h> #include <map> @@ -47,7 +47,6 @@ namespace kernels { namespace { -constexpr auto window_step = 16; Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst) { @@ -63,59 +62,6 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst) return Status{}; } -template <typename T> -inline float32x4x4_t load_value(const T *input_ptr) -{ - using Tx16_t = typename wrapper::traits::neon_vector<T, 16>::type; - return arm_compute::convert_to_float32x4x4<Tx16_t>(wrapper::vloadq(input_ptr)); -} - -template <> -inline float32x4x4_t load_value(const float *input_ptr) -{ - return {wrapper::vloadq(input_ptr), wrapper::vloadq(input_ptr + 4), wrapper::vloadq(input_ptr + 8), - wrapper::vloadq(input_ptr + 12)}; -} -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template <> -inline float32x4x4_t load_value(const float16_t *input_ptr) -{ - return {vcvt_f32_f16(wrapper::vload(input_ptr)), vcvt_f32_f16(wrapper::vload(input_ptr + 4)), - vcvt_f32_f16(wrapper::vload(input_ptr + 8)), vcvt_f32_f16(wrapper::vload(input_ptr + 12))}; -} - -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -template <typename element_type> -using vector_type = wrapper::traits::neon_vector_t<element_type, window_step>; - -template <typename quantized_type> -vector_type<quantized_type> vquantize_qasymm8(const float32x4x4_t &qv, const UniformQuantizationInfo &qi); - -template <> -vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi) -{ - return vquantize(qv, qi); -} - -template <> -vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi) -{ - return vquantize_signed(qv, qi); -} - -template <typename TOut, typename = typename std::enable_if<std::is_signed<TOut>::value, bool>::type> -inline int8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper) -{ - return wrapper::vcombine(wrapper::vqmovn(lower), wrapper::vqmovn(upper)); -} - -template <typename TOut, typename = typename std::enable_if<std::is_unsigned<TOut>::value, bool>::type> -inline uint8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper) -{ - return wrapper::vcombine(wrapper::vqmovun(lower), wrapper::vqmovun(upper)); -} - } // namespace void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst) @@ -124,38 +70,36 @@ void CpuQuantizeKernel::configure(const ITensorInfo *src, ITensorInfo *dst) ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst)); static const std::map<std::string, QuantizeFunctionExecutorPtr> quant_map = { - {"op_QASYMM8_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<uint8_t, uint8_t>}, - {"op_QASYMM8_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<uint8_t, int8_t>}, - {"op_QASYMM8_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<uint8_t>}, + {"op_QASYMM8_QASYMM8", REGISTER_INTEGER_NEON(u8_u8_run_quantize_qasymm8)}, + {"op_QASYMM8_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(u8_i8_run_quantize_qasymm8)}, + {"op_QASYMM8_QASYMM16", REGISTER_INTEGER_NEON(u8_run_quantize_qasymm16)}, - {"op_QASYMM8_SIGNED_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<int8_t, uint8_t>}, - {"op_QASYMM8_SIGNED_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<int8_t, int8_t>}, - {"op_QASYMM8_SIGNED_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<int8_t>}, + {"op_QASYMM8_SIGNED_QASYMM8", REGISTER_INTEGER_NEON(i8_u8_run_quantize_qasymm8)}, + {"op_QASYMM8_SIGNED_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(i8_i8_run_quantize_qasymm8)}, + {"op_QASYMM8_SIGNED_QASYMM16", REGISTER_INTEGER_NEON(i8_run_quantize_qasymm16)}, // Functions for offset only requantization - {"op_OFFSET_ONLY_QASYMM8_QASYMM8", &CpuQuantizeKernel::run_requantize_offset_only<uint8_t, uint8_t>}, - {"op_OFFSET_ONLY_QASYMM8_QASYMM8_SIGNED", &CpuQuantizeKernel::run_requantize_offset_only<uint8_t, int8_t>}, - {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8", &CpuQuantizeKernel::run_requantize_offset_only<int8_t, uint8_t>}, - {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8_SIGNED", - &CpuQuantizeKernel::run_requantize_offset_only<int8_t, int8_t>}, + {"op_OFFSET_ONLY_QASYMM8_QASYMM8", REGISTER_INTEGER_NEON(u8_u8_run_requantize_offset_only)}, + {"op_OFFSET_ONLY_QASYMM8_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(u8_i8_run_requantize_offset_only)}, + {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8", REGISTER_INTEGER_NEON(i8_u8_run_requantize_offset_only)}, + {"op_OFFSET_ONLY_QASYMM8_SIGNED_QASYMM8_SIGNED", REGISTER_INTEGER_NEON(i8_i8_run_requantize_offset_only)}, // Functions for offset uint8 to int8 and vice versa quantization (no scale changes) {"op_OFFSET_ONLY_CONVERT_QASYMM8_SIGNED_QASYMM8", - &CpuQuantizeKernel::run_requantize_offset_only_convert<int8_t, uint8_t>}, + REGISTER_INTEGER_NEON(i8_u8_run_requantize_offset_only_convert)}, {"op_OFFSET_ONLY_CONVERT_QASYMM8_QASYMM8_SIGNED", - &CpuQuantizeKernel::run_requantize_offset_only_convert<uint8_t, int8_t>}, - - {"op_F32_QSYMM8", &CpuQuantizeKernel::run_quantize_qsymm8<float, int8_t>}, - - {"op_F32_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<float, uint8_t>}, - {"op_F32_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<float, int8_t>}, - {"op_F32_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<float>}, - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - {"op_F16_QASYMM8", &CpuQuantizeKernel::run_quantize_qasymm8<float16_t, uint8_t>}, - {"op_F16_QASYMM8_SIGNED", &CpuQuantizeKernel::run_quantize_qasymm8<float16_t, int8_t>}, - {"op_F16_QASYMM16", &CpuQuantizeKernel::run_quantize_qasymm16<float16_t>}, -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC*/ + REGISTER_INTEGER_NEON(u8_i8_run_requantize_offset_only_convert)}, + + {"op_F32_QSYMM8", REGISTER_FP32_NEON(fp32_i8_run_quantize_qsymm8)}, + {"op_F32_QASYMM8", REGISTER_FP32_NEON(fp32_u8_run_quantize_qasymm8)}, + {"op_F32_QASYMM8_SIGNED", REGISTER_FP32_NEON(fp32_i8_run_quantize_qasymm8)}, + {"op_F32_QASYMM16", REGISTER_FP32_NEON(fp32_run_quantize_qasymm16)}, + +#ifdef ARM_COMPUTE_ENABLE_FP16 + {"op_F16_QASYMM8", REGISTER_FP16_NEON(fp16_u8_run_quantize_qasymm8)}, + {"op_F16_QASYMM8_SIGNED", REGISTER_FP16_NEON(fp16_i8_run_quantize_qasymm8)}, + {"op_F16_QASYMM16", REGISTER_FP16_NEON(fp16_run_quantize_qasymm16)}, +#endif /* ARM_COMPUTE_ENABLE_FP16 */ }; std::string function_to_call("op_"); @@ -203,242 +147,6 @@ Status CpuQuantizeKernel::validate(const ITensorInfo *src, const ITensorInfo *ds return Status{}; } -template <typename TIn, typename TOut> -void CpuQuantizeKernel::run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); - auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo)); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - output_ptr[x] = quantize_qsymm8(input_ptr[x], dst->info()->quantization_info()); - } - }, - input, output); -} - -template <typename TIn, typename TOut> -void CpuQuantizeKernel::run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - // Calculate output offset difference. - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Duplicate offset in signed vector format - const int8x16_t offset = wrapper::vdup_n(static_cast<int8_t>(uqinfo.offset), wrapper::traits::vector_128_tag{}); - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); - auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - const wrapper::traits::neon_vector_t<TIn, window_step> qv = - wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype - - // Signed addition. - auto res = vaddq_s8(reinterpret_cast<int8x16_t>(qv), offset); - - // Output is dependent on datatype. - wrapper::vstore(&output_ptr[x], - reinterpret_cast<wrapper::traits::neon_vector_t<TOut, window_step>>(res)); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - auto result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]); - output_ptr[x] = static_cast<TOut>(result); - } - }, - input, output); -} - -template <typename TIn, typename TOut> -void CpuQuantizeKernel::run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - // Duplicate offset in signed vector format - const int16x8_t offset = wrapper::vdup_n(static_cast<int16_t>(uqinfo.offset), wrapper::traits::vector_128_tag{}); - - const int32_t low_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 0 : -128; - const int32_t upper_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 255 : 127; - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); - TOut *output_ptr = reinterpret_cast<TOut *>(output.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - const auto qv = wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype - int16x8_t lower = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgetlow(qv))); - int16x8_t upper = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgethigh(qv))); - - // Signed addition. - lower = wrapper::vqadd(lower, offset); - upper = wrapper::vqadd(upper, offset); - - // Output is dependent on datatype. - auto res = recombine_8_16<TOut>(lower, upper); - wrapper::vstore(&output_ptr[x], res); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - // Add offset and clamp result to within the range of the output datatype. - int32_t result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]); - result = utility::clamp<int32_t>(result, low_bound, upper_bound); - - // Cast result to output datatype. - output_ptr[x] = static_cast<TOut>(result); - } - }, - input, output); -} - -template <typename TIn, typename TOut> -void CpuQuantizeKernel::run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - if (is_data_type_quantized_asymmetric(src->info()->data_type())) - { - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - } -#ifdef __aarch64__ - constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN; -#else //__aarch64__ - constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO; -#endif //__aarch64__ - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); - auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo)); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - output_ptr[x] = Qasymm8QuantizationHelper<TOut>::quantize(input_ptr[x], uqinfo, rounding_policy); - } - }, - input, output); -} - -template <typename T> -void CpuQuantizeKernel::run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) -{ - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); - UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); - if (is_data_type_quantized_asymmetric(src->info()->data_type())) - { - uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); - } -#ifdef __aarch64__ - constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN; -#else //__aarch64__ - constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO; -#endif //__aarch64__ - - // Collapse window and reset first dimension to handle tail calculations manually - Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); - win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input(src, win_collapsed); - Iterator output(dst, win_collapsed); - execute_window_loop( - win_collapsed, - [&](const Coordinates &) - { - auto input_ptr = reinterpret_cast<const T *>(input.ptr()); - auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr()); - - int x = window_start_x; - for (; x <= (window_end_x - window_step); x += window_step) - { - uint16x8x2_t tmp = vquantize_qasymm16(load_value(&input_ptr[x]), uqinfo); - vst1q_u16(&output_ptr[x], tmp.val[0]); - vst1q_u16(&output_ptr[x + 8], tmp.val[1]); - } - // Compute left-over elements - for (; x < window_end_x; ++x) - { - output_ptr[x] = quantize_qasymm16(input_ptr[x], uqinfo, rounding_policy); - } - }, - input, output); -} - void CpuQuantizeKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); @@ -448,7 +156,7 @@ void CpuQuantizeKernel::run_op(ITensorPack &tensors, const Window &window, const const auto src = tensors.get_const_tensor(TensorType::ACL_SRC); auto dst = tensors.get_tensor(TensorType::ACL_DST); - (this->*_func)(src, dst, window); + (*_func)(src, dst, window); } const char *CpuQuantizeKernel::name() const diff --git a/src/cpu/kernels/CpuQuantizeKernel.h b/src/cpu/kernels/CpuQuantizeKernel.h index c2f7ac6d9d..750310c811 100644 --- a/src/cpu/kernels/CpuQuantizeKernel.h +++ b/src/cpu/kernels/CpuQuantizeKernel.h @@ -76,31 +76,7 @@ private: * * @param[in] window Region on which to execute the kernel. */ - using QuantizeFunctionExecutorPtr = void (CpuQuantizeKernel::*)(const ITensor *src, - ITensor *dst, - const Window &window); - /** Function to apply QASYMM8 or QASYMM8_SIGNED quantization on a tensor. - * - * @param[in] window Region on which to execute the kernel. - */ - template <typename TIn, typename TOut> - void run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window); - /** Function to apply QASYMM16 quantization on a tensor. - * - * @param[in] window Region on which to execute the kernel. - */ - template <typename T> - void run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window); - - template <typename TIn, typename TOut> - void run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window); - - template <typename TIn, typename TOut> - void run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window); - - template <typename TIn, typename TOut> - void run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window); - + using QuantizeFunctionExecutorPtr = void (*)(const ITensor *src, ITensor *dst, const Window &window); QuantizeFunctionExecutorPtr _func{nullptr}; size_t _split_dimension{Window::DimY}; }; diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp index 54ff858eeb..b7e395fb79 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.cpp +++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp @@ -48,18 +48,41 @@ namespace kernels { namespace { + /* Softmax */ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = { + {"sme2_fp32_softmax", + [](const SoftmaxKernelDataTypeISASelectorData &data) + { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2 && data.axis == 0); }, + REGISTER_FP32_SME2(sme2_fp32_softmax)}, {"neon_fp32_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); }, REGISTER_FP32_NEON(neon_fp32_softmax<false>)}, + {"sme2_fp16_softmax", + [](const SoftmaxKernelDataTypeISASelectorData &data) + { return (!data.is_log && data.dt == DataType::F16 && data.isa.sme2 && data.axis == 0); }, + REGISTER_FP16_SME2(sme2_fp16_softmax)}, {"neon_fp16_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; }, REGISTER_FP16_NEON(neon_fp16_softmax<false>)}, + {"sme2_qu8_softmax_lut_512VL", + [](const SoftmaxKernelDataTypeISASelectorData &data) + { + return (!data.is_log && data.dt == DataType::QASYMM8 && data.isa.sme2 && data.axis == 0 && + data.sme2_vector_length == 512); + }, + REGISTER_QASYMM8_SME2(sme2_qasymm8_softmax_lut_512VL)}, {"neon_qu8_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax<false>)}, + {"sme2_qs8_softmax_lut_512VL", + [](const SoftmaxKernelDataTypeISASelectorData &data) + { + return (!data.is_log && data.dt == DataType::QASYMM8_SIGNED && data.isa.sme2 && data.axis == 0 && + data.sme2_vector_length == 512); + }, + REGISTER_QASYMM8_SIGNED_SME2(sme2_qasymm8_signed_softmax_lut_512VL)}, {"neon_qs8_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::QASYMM8_SIGNED); }, @@ -80,6 +103,28 @@ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_ker REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax<true>)}, }; +void init_lut(std::vector<float> &lut, DataType type, float scale, float beta) +{ + if (type == DataType::QASYMM8) + { + for (int i = 0; i < 256; ++i) + { + lut.push_back(std::exp(-scale * beta * i)); + } + } + else if (type == DataType::QASYMM8_SIGNED) + { + for (int i = -128; i < 128; ++i) + { + lut.push_back(std::exp(-scale * beta * i)); + } + } + else + { + ARM_COMPUTE_ERROR("Invalid datatype for QASYMM8/QASYMM8_SIGNED softmax"); + } +} + Status validate_arguments_softmax( const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log) { @@ -149,8 +194,8 @@ void CpuSoftmaxKernel::configure( auto_init_if_empty(*tmp, TensorInfo(*src).set_data_type(DataType::F32).reset_padding()); } - const auto *uk = CpuSoftmaxKernel::get_implementation( - SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log}); + const auto *uk = CpuSoftmaxKernel::get_implementation(SoftmaxKernelDataTypeISASelectorData{ + src->data_type(), CPUInfo::get().get_isa(), is_log, axis, CPUInfo::get().get_sme2_vector_length()}); ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel"); @@ -186,6 +231,13 @@ void CpuSoftmaxKernel::configure( win.set(_axis, Window::Dimension(0, 1, 1)); ICpuKernel<CpuSoftmaxKernel>::configure(win); + + const std::string uk_name = uk->name; + if (uk_name == "sme2_qu8_softmax_lut_512VL" || uk_name == "sme2_qs8_softmax_lut_512VL") + { + const float scale = src->quantization_info().uniform().scale; + init_lut(_lut, src->data_type(), scale, beta); + } } Status CpuSoftmaxKernel::validate( @@ -222,11 +274,11 @@ void CpuSoftmaxKernel::run_op(ITensorPack &tensors, const Window &window, const const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration; void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread); - _run_method(src, tmp_for_thread, dst, _beta, _axis, window); + _run_method(src, tmp_for_thread, dst, _beta, _axis, window, _lut.data()); } else { - _run_method(src, nullptr, dst, _beta, _axis, window); + _run_method(src, nullptr, dst, _beta, _axis, window, nullptr); } } diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h index 043ad975d5..676e79782b 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.h +++ b/src/cpu/kernels/CpuSoftmaxKernel.h @@ -37,8 +37,8 @@ namespace kernels class CpuSoftmaxKernel : public ICpuKernel<CpuSoftmaxKernel> { private: - using SoftmaxKernelPtr = - std::add_pointer<void(const ITensor *, void *const, ITensor *, float, int, const Window &)>::type; + using SoftmaxKernelPtr = std::add_pointer<void( + const ITensor *, void *const, ITensor *, float, int, const Window &, const float *)>::type; public: CpuSoftmaxKernel() = default; @@ -78,10 +78,11 @@ public: static const std::vector<SoftmaxKernel> &get_available_kernels(); private: - float _beta{1.0f}; - SoftmaxKernelPtr _run_method{nullptr}; - std::string _name{}; - int _axis{}; + float _beta{1.0f}; + SoftmaxKernelPtr _run_method{nullptr}; + std::string _name{}; + int _axis{}; + std::vector<float> _lut = {}; }; } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index 9a913c5c58..941fed0ba8 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022 Arm Limited. + * Copyright (c) 2018-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,6 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ + +#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP +#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP + #pragma once #include "arm_gemm_local.hpp" @@ -151,6 +155,7 @@ public: int _maxthreads; bool _fixed_format; bool _fast_mode; + bool _accumulate; const GemmConfig *_cfg; GemmArgs(const CPUInfo *ci, @@ -165,6 +170,7 @@ public: const int maxthreads, bool fixed_format = false, bool fast_mode = false, + bool accumulate = false, const GemmConfig *cfg = nullptr) : _ci(ci), _Msize(M), @@ -178,6 +184,7 @@ public: _maxthreads(maxthreads), _fixed_format(fixed_format), _fast_mode(fast_mode), + _accumulate(accumulate), _cfg(cfg) { } @@ -253,6 +260,19 @@ public: } }; +struct DequantizeFloat +{ +public: + float scale = 0; + + DequantizeFloat() = default; + + // Constructor + DequantizeFloat(const float scale) : scale(scale) + { + } +}; + struct Nothing { }; @@ -278,3 +298,5 @@ template <typename Top, typename Tret, class OutputStage = Nothing> bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {}); } // namespace arm_gemm + +#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp index 4825814e31..45d1e43274 100644 --- a/src/cpu/kernels/assembly/gemm_common.hpp +++ b/src/cpu/kernels/assembly/gemm_common.hpp @@ -166,6 +166,12 @@ public: { } + /*** Dequanize scale interface (optional) ***/ + /* Set the dequantize scale for GEMMs when converting from int to float (float out = scale * float(int out) ) */ + virtual void set_dequantize_scale(const float) + { + } + /*** Introspection interface ***/ /* Get the configuration of this GEMM */ virtual GemmConfig get_config() = 0; diff --git a/src/cpu/kernels/dequantize/generic/neon/fp16.cpp b/src/cpu/kernels/dequantize/generic/neon/fp16.cpp new file mode 100644 index 0000000000..caffdf53e1 --- /dev/null +++ b/src/cpu/kernels/dequantize/generic/neon/fp16.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) +#include "src/cpu/kernels/dequantize/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void fp16_run_dequantization_core(const ITensor *input, ITensor *output, const Window &window) +{ + run_dequantization_core<float16_t>(input, output, window); +} +} // namespace cpu +} // namespace arm_compute +#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ diff --git a/src/cpu/kernels/dequantize/generic/neon/fp32.cpp b/src/cpu/kernels/dequantize/generic/neon/fp32.cpp new file mode 100644 index 0000000000..58e987b450 --- /dev/null +++ b/src/cpu/kernels/dequantize/generic/neon/fp32.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/cpu/kernels/dequantize/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void fp32_run_dequantization_core(const ITensor *input, ITensor *output, const Window &window) +{ + run_dequantization_core<float>(input, output, window); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/dequantize/generic/neon/impl.h b/src/cpu/kernels/dequantize/generic/neon/impl.h new file mode 100644 index 0000000000..7197d4dff6 --- /dev/null +++ b/src/cpu/kernels/dequantize/generic/neon/impl.h @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H +#define ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H + +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/Window.h" + +#include "src/core/NEON/NEAsymm.h" +#include "src/core/NEON/NESymm.h" +#include "src/core/NEON/wrapper/wrapper.h" +#include "src/cpu/kernels/dequantize/generic/neon/list.h" + +#include <arm_neon.h> + +namespace arm_compute +{ +namespace cpu +{ + +template <typename T> +inline void store_result(T *ptr, const float32x4x4_t &v) +{ + ARM_COMPUTE_UNUSED(ptr, v); +} + +template <> +inline void store_result<float>(float *ptr, const float32x4x4_t &v) +{ + wrapper::vstore(ptr, v.val[0]); + wrapper::vstore(ptr + 4, v.val[1]); + wrapper::vstore(ptr + 8, v.val[2]); + wrapper::vstore(ptr + 12, v.val[3]); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +inline void store_result<float16_t>(float16_t *ptr, const float32x4x4_t &v) +{ + wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1]))); + wrapper::vstore(ptr + 8, vcombine_f16(vcvt_f16_f32(v.val[2]), vcvt_f16_f32(v.val[3]))); +} +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +template <typename T> +inline void store_result(T *ptr, const float32x4x2_t &v) +{ + ARM_COMPUTE_UNUSED(ptr, v); +} + +template <> +inline void store_result<float>(float *ptr, const float32x4x2_t &v) +{ + wrapper::vstore(ptr, v.val[0]); + wrapper::vstore(ptr + 4, v.val[1]); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +inline void store_result<float16_t>(float16_t *ptr, const float32x4x2_t &v) +{ + wrapper::vstore(ptr, vcombine_f16(vcvt_f16_f32(v.val[0]), vcvt_f16_f32(v.val[1]))); +} +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +template <typename TOut, typename TIn> +void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Window &window) +{ + const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); + const float scale = qinfo.scale; + const int32_t offset = qinfo.offset; + + const int window_step_x = 16; + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Create iterators + Iterator in(input, win_collapsed); + Iterator out(output, win_collapsed); + + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast<const TIn *>(in.ptr()); + const auto out_ptr = reinterpret_cast<TOut *>(out.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vin = wrapper::vloadq(in_ptr + x); + const auto vdeq = vdequantize(vin, scale, offset); + + store_result(reinterpret_cast<TOut *>(out_ptr + x), vdeq); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto val = *(in_ptr + x); + *(out_ptr + x) = static_cast<TOut>(Qasymm8QuantizationHelper<TIn>::dequantize(val, qinfo)); + } + }, + in, out); +} + +template <typename T> +void run_dequantization_qsymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window) +{ + const auto scale = input->info()->quantization_info().scale(); + + const int window_step_x = 16; + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + // Reset first dimension to handle tail calculations manually + Window win(window); + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Create iterators + Iterator in(input, win); + Iterator out(output, win); + + execute_window_loop( + win, + [&](const Coordinates &id) + { + const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr()); + const auto out_ptr = reinterpret_cast<T *>(out.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vin = wrapper::vloadq(in_ptr + x); + const auto vdeq = vdequantize(vin, scale[id.z()]); + + store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + int8_t val = *(in_ptr + x); + *(out_ptr + x) = static_cast<T>(dequantize(val, scale[id.z()])); + } + }, + in, out); +} + +template <typename T> +void run_dequantization_qsymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window) +{ + const auto scale = input->info()->quantization_info().scale(); + + const int window_step_x = 16; + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + // Reset first dimension to handle tail calculations manually + Window win(window); + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Create iterators + Iterator in(input, win); + Iterator out(output, win); + + execute_window_loop( + win, + [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr()); + const auto out_ptr = reinterpret_cast<T *>(out.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const float32x4x4_t vscale = {{scale[x + 0], scale[x + 1], scale[x + 2], scale[x + 3], scale[x + 4], + scale[x + 5], scale[x + 6], scale[x + 7], scale[x + 8], scale[x + 9], + scale[x + 10], scale[x + 11], scale[x + 12], scale[x + 13], + scale[x + 14], scale[x + 15]}}; + const auto vin = wrapper::vloadq(in_ptr + x); + const auto vdeq = vdequantize(vin, vscale); + + store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + int8_t val = *(in_ptr + x); + *(out_ptr + x) = static_cast<T>(dequantize(val, scale[x])); + } + }, + in, out); +} + +template <typename T> +void run_dequantization_qsymm8(const ITensor *input, ITensor *output, const Window &window) +{ + const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); + const float scale = qinfo.scale; + + const int window_step_x = 16; + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Create iterators + Iterator in(input, win_collapsed); + Iterator out(output, win_collapsed); + + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr()); + const auto out_ptr = reinterpret_cast<T *>(out.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vin = wrapper::vloadq(in_ptr + x); + const auto vdeq = vdequantize(vin, scale); + + store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + int8_t val = *(in_ptr + x); + *(out_ptr + x) = static_cast<T>(dequantize(val, scale)); + } + }, + in, out); +} + +template <typename T> +void run_dequantization_qsymm16(const ITensor *input, ITensor *output, const Window &window) +{ + const UniformQuantizationInfo &qinfo = input->info()->quantization_info().uniform(); + const float scale = qinfo.scale; + + const int window_step_x = 8; + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Create iterators + Iterator in(input, win_collapsed); + Iterator out(output, win_collapsed); + + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast<const int16_t *>(in.ptr()); + const auto out_ptr = reinterpret_cast<T *>(out.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vin = wrapper::vloadq(in_ptr + x); + const auto vdeq = vdequantize_int16(vin, scale); + + store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + int16_t val = *(in_ptr + x); + *(out_ptr + x) = static_cast<T>(dequantize_qsymm16(val, scale)); + } + }, + in, out); +} + +template <typename T> +void run_dequantization_core(const ITensor *input, ITensor *output, const Window &window) +{ + switch (input->info()->data_type()) + { + case DataType::QASYMM8: + run_dequantization_qasymm8<T, uint8_t>(input, output, window); + break; + case DataType::QASYMM8_SIGNED: + run_dequantization_qasymm8<T, int8_t>(input, output, window); + break; + case DataType::QSYMM8_PER_CHANNEL: + input->info()->data_layout() == DataLayout::NHWC + ? run_dequantization_qsymm8_per_channel_nhwc<T>(input, output, window) + : run_dequantization_qsymm8_per_channel_nchw<T>(input, output, window); + break; + case DataType::QSYMM8: + run_dequantization_qsymm8<T>(input, output, window); + break; + case DataType::QSYMM16: + run_dequantization_qsymm16<T>(input, output, window); + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type."); + } +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_IMPL_H diff --git a/src/cpu/kernels/dequantize/generic/neon/list.h b/src/cpu/kernels/dequantize/generic/neon/list.h new file mode 100644 index 0000000000..678eb2c01a --- /dev/null +++ b/src/cpu/kernels/dequantize/generic/neon/list.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H +#define ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H + +#include "arm_compute/core/Helpers.h" + +namespace arm_compute +{ +namespace cpu +{ + +#define DECLARE_DEQUANTIZE_KERNEL(func_name) void func_name(const ITensor *input, ITensor *output, const Window &window) + +DECLARE_DEQUANTIZE_KERNEL(fp32_run_dequantization_core); +DECLARE_DEQUANTIZE_KERNEL(fp16_run_dequantization_core); + +#undef DECLARE_DEQUANTIZE_KERNEL + +} // namespace cpu +} // namespace arm_compute +#endif // ACL_SRC_CPU_KERNELS_DEQUANTIZE_GENERIC_NEON_LIST_H diff --git a/src/cpu/kernels/quantize/generic/neon/fp16.cpp b/src/cpu/kernels/quantize/generic/neon/fp16.cpp new file mode 100644 index 0000000000..37bfb5b2aa --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/fp16.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) +#include "src/cpu/kernels/quantize/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void fp16_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<float16_t, uint8_t>(src, dst, window); +} +void fp16_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<float16_t, int8_t>(src, dst, window); +} +void fp16_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm16<float16_t>(src, dst, window); +} +} // namespace cpu +} // namespace arm_compute +#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ diff --git a/src/cpu/kernels/quantize/generic/neon/fp32.cpp b/src/cpu/kernels/quantize/generic/neon/fp32.cpp new file mode 100644 index 0000000000..0cba332fd6 --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/fp32.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/cpu/kernels/quantize/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void fp32_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<float, uint8_t>(src, dst, window); +} +void fp32_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<float, int8_t>(src, dst, window); +} +void fp32_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm16<float>(src, dst, window); +} + +void fp32_i8_run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qsymm8<float, int8_t>(src, dst, window); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/quantize/generic/neon/impl.h b/src/cpu/kernels/quantize/generic/neon/impl.h new file mode 100644 index 0000000000..9954a7645e --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/impl.h @@ -0,0 +1,330 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H +#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H + +#include "arm_compute/core/Helpers.h" + +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/NEON/NEAsymm.h" +#include "src/core/NEON/wrapper/intrinsics/intrinsics.h" + +namespace arm_compute +{ +namespace cpu +{ +constexpr auto window_step = 16; + +template <typename T> +inline float32x4x4_t load_value(const T *input_ptr) +{ + using Tx16_t = typename wrapper::traits::neon_vector<T, 16>::type; + return arm_compute::convert_to_float32x4x4<Tx16_t>(wrapper::vloadq(input_ptr)); +} + +template <> +inline float32x4x4_t load_value(const float *input_ptr) +{ + return {wrapper::vloadq(input_ptr), wrapper::vloadq(input_ptr + 4), wrapper::vloadq(input_ptr + 8), + wrapper::vloadq(input_ptr + 12)}; +} +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +inline float32x4x4_t load_value(const float16_t *input_ptr) +{ + return {vcvt_f32_f16(wrapper::vload(input_ptr)), vcvt_f32_f16(wrapper::vload(input_ptr + 4)), + vcvt_f32_f16(wrapper::vload(input_ptr + 8)), vcvt_f32_f16(wrapper::vload(input_ptr + 12))}; +} + +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template <typename element_type> +using vector_type = wrapper::traits::neon_vector_t<element_type, window_step>; + +template <typename quantized_type> +inline vector_type<quantized_type> vquantize_qasymm8(const float32x4x4_t &qv, const UniformQuantizationInfo &qi); + +template <> +inline vector_type<uint8_t> vquantize_qasymm8<uint8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi) +{ + return vquantize(qv, qi); +} + +template <> +inline vector_type<int8_t> vquantize_qasymm8<int8_t>(const float32x4x4_t &qv, const UniformQuantizationInfo &qi) +{ + return vquantize_signed(qv, qi); +} + +template <typename TOut, typename = typename std::enable_if<std::is_signed<TOut>::value, bool>::type> +inline int8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper) +{ + return wrapper::vcombine(wrapper::vqmovn(lower), wrapper::vqmovn(upper)); +} + +template <typename TOut, typename = typename std::enable_if<std::is_unsigned<TOut>::value, bool>::type> +inline uint8x16_t recombine_8_16(int16x8_t lower, int16x8_t upper) +{ + return wrapper::vcombine(wrapper::vqmovun(lower), wrapper::vqmovun(upper)); +} + +template <typename TIn, typename TOut> +void run_quantize_qsymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); + auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo)); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + output_ptr[x] = quantize_qsymm8(input_ptr[x], dst->info()->quantization_info()); + } + }, + input, output); +} + +template <typename TIn, typename TOut> +void run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + // Calculate output offset difference. + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Duplicate offset in signed vector format + const int8x16_t offset = wrapper::vdup_n(static_cast<int8_t>(uqinfo.offset), wrapper::traits::vector_128_tag{}); + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); + auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + const wrapper::traits::neon_vector_t<TIn, window_step> qv = + wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype + + // Signed addition. + auto res = vaddq_s8(reinterpret_cast<int8x16_t>(qv), offset); + + // Output is dependent on datatype. + wrapper::vstore(&output_ptr[x], + reinterpret_cast<wrapper::traits::neon_vector_t<TOut, window_step>>(res)); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]); + output_ptr[x] = static_cast<TOut>(result); + } + }, + input, output); +} + +template <typename TIn, typename TOut> +void run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + // Duplicate offset in signed vector format + const int16x8_t offset = wrapper::vdup_n(static_cast<int16_t>(uqinfo.offset), wrapper::traits::vector_128_tag{}); + + const int32_t low_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 0 : -128; + const int32_t upper_bound = (dst->info()->data_type() == DataType::QASYMM8) ? 255 : 127; + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); + TOut *output_ptr = reinterpret_cast<TOut *>(output.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + const auto qv = wrapper::vloadq(input_ptr + x); // load 128 bit vector of 8 bit datatype + int16x8_t lower = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgetlow(qv))); + int16x8_t upper = reinterpret_cast<int16x8_t>(wrapper::vmovl(wrapper::vgethigh(qv))); + + // Signed addition. + lower = wrapper::vqadd(lower, offset); + upper = wrapper::vqadd(upper, offset); + + // Output is dependent on datatype. + auto res = recombine_8_16<TOut>(lower, upper); + wrapper::vstore(&output_ptr[x], res); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + // Add offset and clamp result to within the range of the output datatype. + int32_t result = uqinfo.offset + static_cast<int32_t>(input_ptr[x]); + result = utility::clamp<int32_t>(result, low_bound, upper_bound); + + // Cast result to output datatype. + output_ptr[x] = static_cast<TOut>(result); + } + }, + input, output); +} + +template <typename TIn, typename TOut> +void run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + if (is_data_type_quantized_asymmetric(src->info()->data_type())) + { + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + } +#ifdef __aarch64__ + constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN; +#else //__aarch64__ + constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO; +#endif //__aarch64__ + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const TIn *>(input.ptr()); + auto output_ptr = reinterpret_cast<TOut *>(output.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + wrapper::vstore(&output_ptr[x], vquantize_qasymm8<TOut>(load_value(&input_ptr[x]), uqinfo)); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + output_ptr[x] = Qasymm8QuantizationHelper<TOut>::quantize(input_ptr[x], uqinfo, rounding_policy); + } + }, + input, output); +} + +template <typename T> +void run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + const auto window_start_x = static_cast<int>(window.x().start()); + const auto window_end_x = static_cast<int>(window.x().end()); + + const UniformQuantizationInfo uqinfo_in = src->info()->quantization_info().uniform(); + UniformQuantizationInfo uqinfo = dst->info()->quantization_info().uniform(); + if (is_data_type_quantized_asymmetric(src->info()->data_type())) + { + uqinfo = compute_requantization_scale_offset(uqinfo_in, uqinfo); + } +#ifdef __aarch64__ + constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN; +#else //__aarch64__ + constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO; +#endif //__aarch64__ + + // Collapse window and reset first dimension to handle tail calculations manually + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop( + win_collapsed, + [&](const Coordinates &) + { + auto input_ptr = reinterpret_cast<const T *>(input.ptr()); + auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr()); + + int x = window_start_x; + for (; x <= (window_end_x - window_step); x += window_step) + { + uint16x8x2_t tmp = vquantize_qasymm16(load_value(&input_ptr[x]), uqinfo); + vst1q_u16(&output_ptr[x], tmp.val[0]); + vst1q_u16(&output_ptr[x + 8], tmp.val[1]); + } + // Compute left-over elements + for (; x < window_end_x; ++x) + { + output_ptr[x] = quantize_qasymm16(input_ptr[x], uqinfo, rounding_policy); + } + }, + input, output); +} +} // namespace cpu +} // namespace arm_compute + +#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_IMPL_H diff --git a/src/cpu/kernels/quantize/generic/neon/integer.cpp b/src/cpu/kernels/quantize/generic/neon/integer.cpp new file mode 100644 index 0000000000..4e39afaaee --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/integer.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/cpu/kernels/quantize/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void u8_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<uint8_t, uint8_t>(src, dst, window); +} +void u8_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<uint8_t, int8_t>(src, dst, window); +} +void i8_u8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<int8_t, uint8_t>(src, dst, window); +} +void i8_i8_run_quantize_qasymm8(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm8<int8_t, int8_t>(src, dst, window); +} + +void u8_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm16<uint8_t>(src, dst, window); +} +void i8_run_quantize_qasymm16(const ITensor *src, ITensor *dst, const Window &window) +{ + run_quantize_qasymm16<int8_t>(src, dst, window); +} + +void u8_u8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only<uint8_t, uint8_t>(src, dst, window); +} +void u8_i8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only<uint8_t, int8_t>(src, dst, window); +} +void i8_u8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only<int8_t, uint8_t>(src, dst, window); +} +void i8_i8_run_requantize_offset_only(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only<int8_t, int8_t>(src, dst, window); +} + +void i8_u8_run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only_convert<int8_t, uint8_t>(src, dst, window); +} +void u8_i8_run_requantize_offset_only_convert(const ITensor *src, ITensor *dst, const Window &window) +{ + run_requantize_offset_only_convert<uint8_t, int8_t>(src, dst, window); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/quantize/generic/neon/list.h b/src/cpu/kernels/quantize/generic/neon/list.h new file mode 100644 index 0000000000..c4fb1048eb --- /dev/null +++ b/src/cpu/kernels/quantize/generic/neon/list.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H +#define ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H + +#include "arm_compute/core/Helpers.h" + +namespace arm_compute +{ +namespace cpu +{ + +#define DECLARE_QUANTIZE_KERNEL(func_name) void func_name(const ITensor *src, ITensor *dst, const Window &window) + +DECLARE_QUANTIZE_KERNEL(u8_u8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(u8_i8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(i8_u8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(i8_i8_run_quantize_qasymm8); + +DECLARE_QUANTIZE_KERNEL(u8_u8_run_requantize_offset_only); +DECLARE_QUANTIZE_KERNEL(u8_i8_run_requantize_offset_only); +DECLARE_QUANTIZE_KERNEL(i8_u8_run_requantize_offset_only); +DECLARE_QUANTIZE_KERNEL(i8_i8_run_requantize_offset_only); + +DECLARE_QUANTIZE_KERNEL(i8_u8_run_requantize_offset_only_convert); +DECLARE_QUANTIZE_KERNEL(u8_i8_run_requantize_offset_only_convert); + +DECLARE_QUANTIZE_KERNEL(u8_run_quantize_qasymm16); +DECLARE_QUANTIZE_KERNEL(i8_run_quantize_qasymm16); + +DECLARE_QUANTIZE_KERNEL(fp32_u8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(fp32_i8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(fp32_run_quantize_qasymm16); + +DECLARE_QUANTIZE_KERNEL(fp32_i8_run_quantize_qsymm8); + +DECLARE_QUANTIZE_KERNEL(fp16_u8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(fp16_i8_run_quantize_qasymm8); +DECLARE_QUANTIZE_KERNEL(fp16_run_quantize_qasymm16); + +#undef DECLARE_QUANTIZE_KERNEL + +} // namespace cpu +} // namespace arm_compute +#endif // ACL_SRC_CPU_KERNELS_QUANTIZE_GENERIC_NEON_LIST_H diff --git a/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp b/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp new file mode 100644 index 0000000000..143bb5487f --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/fp16.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpX_reduceX_float16_8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX<float16_t, 8>>::reduceX(window, input, output, RedOpX<float16_t, 8>(), op); +} + +void reduce_RedOpYZW_reduceY_float16_8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float16_t, 8>>::reduceY(window, input, output, RedOpYZW<float16_t, 8>(), op); +} + +void reduce_RedOpYZW_reduceZ_float16_8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float16_t, 8>>::reduceZ(window, input, output, RedOpYZW<float16_t, 8>(), op); +} + +void reduce_RedOpYZW_reduceW_float16_8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float16_t, 8>>::reduceW(window, input, output, RedOpYZW<float16_t, 8>(), op); +} +} // namespace cpu +} // namespace arm_compute +#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ diff --git a/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp b/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp new file mode 100644 index 0000000000..6f5f13e571 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/fp32.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + Reducer<RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>>::reduceZ( + window, input, output, RedOpYZW_complex<float, 4, 2, ReductionOperation::SUM>(), op); +} + +void reduce_RedOpX_reduceX_float32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op); +} + +void reduce_RedOpYZW_reduceY_float32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op); +} + +void reduce_RedOpYZW_reduceZ_float32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op); +} + +void reduce_RedOpYZW_reduceW_float32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op); +} + +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/reduction_layer/generic/neon/impl.h b/src/cpu/kernels/reduction_layer/generic/neon/impl.h new file mode 100644 index 0000000000..3fa821d3a4 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/impl.h @@ -0,0 +1,1633 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H +#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H + +#include "arm_compute/core/Coordinates.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorInfo.h" + +#include "src/core/NEON/NEMath.h" +#include "src/core/NEON/wrapper/wrapper.h" +#include "support/SaturateCast.h" + +#include <arm_neon.h> + +namespace arm_compute +{ +// Helper function that calls vqmovun/vqmvn, vcombine and vstore, allows templating of RedOpYZW_quantized +template <typename T> +void combine_and_store(int16x8_t t1, int16x8_t t2, Iterator &output, int offset = 0) +{ + if (std::is_same<T, uint8_t>::value) + { + auto res = wrapper::vcombine(wrapper::vqmovun(t1), wrapper::vqmovun(t2)); + wrapper::vstore(output.ptr() + offset, res); + } + else + { + auto res = wrapper::vcombine(wrapper::vqmovn(t1), wrapper::vqmovn(t2)); + wrapper::vstore(reinterpret_cast<int8_t *>(output.ptr() + offset), res); + } +} + +template <typename T> +uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) +{ + uint32x4_t mask{0}; + if (op == ReductionOperation::ARG_IDX_MIN) + { + mask = wrapper::vcgt(b, a); + } + else + { + mask = wrapper::vclt(b, a); + } + + uint32x4_t vec_idx = {idx, idx + 1, idx + 2, idx + 3}; + if (axis != 0) + { + vec_idx = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + } + uint32x4x4_t res = {{wrapper::vbsl(mask, vec_idx, c.val[0]), 0, 0, 0}}; + + return res; +} + +template <typename T> +uint32x4x4_t calculate_index_quantized(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis) +{ + uint32x4x4_t mask{{0}}; + uint8x16_t mask_u8{0}; + if (op == ReductionOperation::ARG_IDX_MIN) + { + mask_u8 = wrapper::vcgt(b, a); + } + else + { + mask_u8 = wrapper::vclt(b, a); + } + auto wide_u16_1 = + wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); + auto wide_u16_2 = + wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); + mask.val[0] = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); + mask.val[1] = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); + mask.val[2] = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); + mask.val[3] = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); + + uint32x4x4_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, + {idx + 4, idx + 5, idx + 6, idx + 7}, + {idx + 8, idx + 9, idx + 10, idx + 11}, + {idx + 12, idx + 13, idx + 14, idx + 15}}}; + if (axis != 0) + { + vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + vec_idx.val[2] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + vec_idx.val[3] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + } + uint32x4x4_t res = { + {vbslq_u32(mask.val[0], vec_idx.val[0], c.val[0]), vbslq_u32(mask.val[1], vec_idx.val[1], c.val[1]), + vbslq_u32(mask.val[2], vec_idx.val[2], c.val[2]), vbslq_u32(mask.val[3], vec_idx.val[3], c.val[3])}}; + + return res; +} + +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +template <typename T> +inline typename std::enable_if< + std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, + typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type +calculate_min(T in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmin(pmin, pmin); +} + +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +template <typename T> +inline typename std::enable_if< + std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, + typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type +calculate_min(T in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmin = wrapper::vpmin(pmin, pmin); + pmin = wrapper::vpmin(pmin, pmin); + return wrapper::vpmin(pmin, pmin); +} + +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +template <typename T> +inline typename std::enable_if< + std::is_same<T, float32x4_t>::value || std::is_same<T, int32x4_t>::value, + typename std::conditional<std::is_same<T, float32x4_t>::value, float32x2_t, int32x2_t>::type>::type +calculate_max(T in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + return wrapper::vpmax(pmax, pmax); +} + +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +template <typename T> +inline typename std::enable_if< + std::is_same<T, uint8x16_t>::value || std::is_same<T, int8x16_t>::value, + typename std::conditional<std::is_same<T, uint8x16_t>::value, uint8x8_t, int8x8_t>::type>::type +calculate_max(T in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmax = wrapper::vpmax(pmax, pmax); + pmax = wrapper::vpmax(pmax, pmax); + return wrapper::vpmax(pmax, pmax); +} + +template <typename T> +uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) +{ + uint32x4_t res_idx_mask{0}; + uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); + + if (op == ReductionOperation::ARG_IDX_MIN) + { + auto pmin = calculate_min(vec_res_value); + auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); + res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); + } + else + { + auto pmax = calculate_max(vec_res_value); + auto mask = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); + res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask); + } + + res_idx_mask = wrapper::vadd(res_idx_mask, mask_ones); + auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask), wrapper::vgetlow(res_idx_mask)); + pmin = wrapper::vpmin(pmin, pmin); + uint32_t res = wrapper::vgetlane(pmin, 0); + + return (res - 0xFFFFFFFF); +} + +template <typename T> +uint32_t calculate_vector_index_quantized(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op) +{ + uint32x4x4_t res_idx_mask{{0}}; + uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); + uint8x16_t mask_u8{0}; + if (op == ReductionOperation::ARG_IDX_MIN) + { + auto pmin = calculate_min(vec_res_value); + mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); + } + else + { + auto pmax = calculate_max(vec_res_value); + mask_u8 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); + } + + // Widen vectors + auto wide_u16_1 = + wrapper::vorr(vshll_n_u8(wrapper::vgetlow(mask_u8), 8), wrapper::vmovl(wrapper::vgetlow(mask_u8))); + auto wide_u16_2 = + wrapper::vorr(vshll_n_u8(wrapper::vgethigh(mask_u8), 8), wrapper::vmovl(wrapper::vgethigh(mask_u8))); + auto wide_u32_1 = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_1), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_1))); + auto wide_u32_2 = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_1), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_1))); + auto wide_u32_3 = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(wide_u16_2), 16), wrapper::vmovl(wrapper::vgetlow(wide_u16_2))); + auto wide_u32_4 = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(wide_u16_2), 16), wrapper::vmovl(wrapper::vgethigh(wide_u16_2))); + res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); + res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); + res_idx_mask.val[2] = wrapper::vand(vec_res_idx.val[2], wide_u32_3); + res_idx_mask.val[3] = wrapper::vand(vec_res_idx.val[3], wide_u32_4); + res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); + res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); + res_idx_mask.val[2] = wrapper::vadd(res_idx_mask.val[2], mask_ones); + res_idx_mask.val[3] = wrapper::vadd(res_idx_mask.val[3], mask_ones); + + uint32_t res = 0xFFFFFFFF; + int iter = 0; + do + { + auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); + pmin = wrapper::vpmin(pmin, pmin); + res = std::min(wrapper::vgetlane(pmin, 0), res); + iter++; + } while (iter < 4); + + return (res - 0xFFFFFFFF); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> +uint32x4x4_t inline calculate_index( + uint32_t idx, float16x8_t a, float16x8_t b, uint32x4x4_t c, ReductionOperation op, int axis) +{ + uint32x4x2_t mask{0}; + uint16x8_t mask_u16{0}; + if (op == ReductionOperation::ARG_IDX_MIN) + { + mask_u16 = wrapper::vcgt(b, a); + } + else + { + mask_u16 = wrapper::vclt(b, a); + } + mask.val[0] = wrapper::vmovl(wrapper::vgetlow(mask_u16)); + mask.val[1] = wrapper::vmovl(wrapper::vgethigh(mask_u16)); + uint32x4x2_t vec_idx = {{{idx + 0, idx + 1, idx + 2, idx + 3}, {idx + 4, idx + 5, idx + 6, idx + 7}}}; + if (axis != 0) + { + vec_idx.val[0] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + vec_idx.val[1] = wrapper::vdup_n(idx, wrapper::traits::vector_128_tag{}); + } + uint32x4x4_t res = {wrapper::vbsl(mask.val[0], vec_idx.val[0], c.val[0]), + wrapper::vbsl(mask.val[1], vec_idx.val[1], c.val[1]), 0, 0}; + + return res; +} + +// Helper function to calculate the minimum value of the input vector. All the elements in the output vector contain the min value. +inline float16x4_t calculate_min(float16x8_t in) +{ + auto pmin = wrapper::vpmin(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmin = wrapper::vpmin(pmin, pmin); + return wrapper::vpmin(pmin, pmin); +} +// Helper function to calculate the maximum value of the input vector. All the elements in the output vector contain the max value. +inline float16x4_t calculate_max(float16x8_t in) +{ + auto pmax = wrapper::vpmax(wrapper::vgethigh(in), wrapper::vgetlow(in)); + pmax = wrapper::vpmax(pmax, pmax); + return wrapper::vpmax(pmax, pmax); +} + +template <> +inline uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float16x8_t vec_res_value, ReductionOperation op) +{ + uint32x4x2_t res_idx_mask{0}; + uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF); + uint16x8_t mask_u16; + if (op == ReductionOperation::ARG_IDX_MIN) + { + auto pmin = calculate_min(vec_res_value); + mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmin, pmin)); + } + else + { + auto pmax = calculate_max(vec_res_value); + mask_u16 = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax)); + } + + // Widen vectors + auto wide_u32_1 = + wrapper::vorr(vshll_n_u16(wrapper::vgetlow(mask_u16), 8), wrapper::vmovl(wrapper::vgetlow(mask_u16))); + auto wide_u32_2 = + wrapper::vorr(vshll_n_u16(wrapper::vgethigh(mask_u16), 8), wrapper::vmovl(wrapper::vgethigh(mask_u16))); + res_idx_mask.val[0] = wrapper::vand(vec_res_idx.val[0], wide_u32_1); + res_idx_mask.val[1] = wrapper::vand(vec_res_idx.val[1], wide_u32_2); + res_idx_mask.val[0] = wrapper::vadd(res_idx_mask.val[0], mask_ones); + res_idx_mask.val[1] = wrapper::vadd(res_idx_mask.val[1], mask_ones); + + uint32_t res = 0xFFFFFFFF; + uint32_t iter = 0; + do + { + auto pmin = wrapper::vpmin(wrapper::vgethigh(res_idx_mask.val[iter]), wrapper::vgetlow(res_idx_mask.val[iter])); + pmin = wrapper::vpmin(pmin, pmin); + res = std::min(wrapper::vgetlane(pmin, 0), res); + iter++; + } while (iter < 2); + + return (res - 0xFFFFFFFF); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +template <class F> +class Reducer +{ +public: + static void reduceX(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set out window + Window out_window(window); + out_window.set(Window::DimX, Window::Dimension(0, 1, 1)); + + f(window, out_window, input, output, op); + } + static void reduceY(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in window + Window in_window(window); + Window out_window(window); + + in_window.set(Window::DimY, Window::Dimension(0, 1, 1)); + out_window.set(Window::DimY, Window::Dimension(0, output->info()->dimension(1), output->info()->dimension(1))); + + f(in_window, out_window, input, output, 1, op); + } + static void reduceZ(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in window + Window in_window(window); + Window out_window(window); + + in_window.set(Window::DimZ, Window::Dimension(0, 1, 1)); + out_window.set(Window::DimZ, Window::Dimension(0, output->info()->dimension(2), output->info()->dimension(2))); + + f(in_window, out_window, input, output, 2, op); + } + static void reduceW(const Window &window, const ITensor *input, ITensor *output, F f, const ReductionOperation op) + { + // Set in/out window + Window in_window(window); + Window out_window(window); + + in_window.set(3, Window::Dimension(0, 1, 1)); + out_window.set(3, Window::Dimension(0, 1, 1)); + + f(in_window, out_window, input, output, 3, op); + } +}; + +template <typename T, int S> +struct RedOpX +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + + inline void operator()( + const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) + { + const size_t input_dim_0 = in->info()->dimension(0); + const int window_step_x = 16 / sizeof(T); + const auto window_start_x = static_cast<int>(in_window.x().start()); + const auto window_end_x = static_cast<int>(in_window.x().end()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_window); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<const T *>(input.ptr()); + + auto init_res_value = static_cast<T>(0.f); + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + init_res_value = static_cast<T>(*input_ptr); + break; + } + case ReductionOperation::PROD: + { + init_res_value = static_cast<T>(1.f); + break; + } + default: + break; + } + auto vec_res_value = wrapper::vdup_n(init_res_value, ExactTagType{}); + uint32x4x4_t vec_res_idx{{0}}; + + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vec_elements = wrapper::vloadq(input_ptr + x); + switch (op) + { + case ReductionOperation::SUM_SQUARE: + vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); + break; + case ReductionOperation::MEAN_SUM: + case ReductionOperation::SUM: + vec_res_value = wrapper::vadd(vec_elements, vec_res_value); + break; + case ReductionOperation::PROD: + vec_res_value = wrapper::vmul(vec_elements, vec_res_value); + break; + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, + vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = calculate_index<decltype(vec_res_value)>(x, temp_vec_res_value, vec_res_value, + vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + case ReductionOperation::SUM_SQUARE: + { +#ifdef ARM_COMPUTE_DEBUG_ENABLED + auto res = static_cast<T>(0.f); + for (int i = 0; i < S; ++i) + { + res += wrapper::vgetlane(vec_res_value, i); + } +#else // ARM_COMPUTE_DEBUG_ENABLED + auto carry_res = + wrapper::vpadd(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); + for (int i = 0; i < S / 4; ++i) + { + carry_res = wrapper::vpadd(carry_res, carry_res); + } + auto res = wrapper::vgetlane(carry_res, 0); +#endif // ARM_COMPUTE_DEBUG_ENABLED + if (op == ReductionOperation::SUM_SQUARE) + { + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res += (*(input_ptr + x)) * (*(input_ptr + x)); + } + } + else + { + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res += *(input_ptr + x); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + res /= input_dim_0; + } + + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::PROD: + { + auto carry_res = + wrapper::vmul(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value)); + T res = 1; + for (int i = 0; i < S / 2; ++i) + { + res *= wrapper::vgetlane(carry_res, i); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res *= *(input_ptr + x); + } + + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) < res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto idx = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) > res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::MIN: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) < res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::MAX: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) > res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + }, + input, output); + } +}; + +template <typename T> +struct RedOpX_quantized +{ + inline void operator()( + const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, const ReductionOperation op) + { + using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; + + const auto oq_info = out->info()->quantization_info().uniform(); + + const TensorInfo in_info = *(in->info()); + const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); + + const int window_step_x = 16 / sizeof(T); + const auto window_start_x = static_cast<int>(in_window.x().start()); + const auto window_end_x = static_cast<int>(in_window.x().end()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_window); + + const auto in_offset = static_cast<float>(iq_info.offset); + const float in_scale = iq_info.scale; + + const auto out_offset = static_cast<float>(oq_info.offset); + const float out_scale = oq_info.scale; + + const auto num_elements = static_cast<float>(in_info.dimension(0)); + + const float A = in_scale / (out_scale * num_elements); + const float B = out_offset - (in_scale * in_offset) / (out_scale); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<T *>(input.ptr()); + + auto vec_res_value1 = + wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); + auto vec_res_value2 = + wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); + auto vec_res_value3 = + wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); + auto vec_res_value4 = + wrapper::vdup_n(static_cast<PromotedType>(0.f), wrapper::traits::vector_128_tag{}); + + auto vec_res_value1_f = vdupq_n_f32(static_cast<float>(1.f)); + auto vec_res_value2_f = vdupq_n_f32(static_cast<float>(1.f)); + auto vec_res_value3_f = vdupq_n_f32(static_cast<float>(1.f)); + auto vec_res_value4_f = vdupq_n_f32(static_cast<float>(1.f)); + + typename wrapper::traits::neon_vector<T, 16>::type vec_res_value = {0}; + + if (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || + op == ReductionOperation::MIN || op == ReductionOperation::MAX) + { + vec_res_value = wrapper::vdup_n(*input_ptr, wrapper::traits::vector_128_tag{}); + } + + uint32x4x4_t vec_res_idx{{0}}; + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto vec_elements = wrapper::vloadq(input_ptr + x); + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + { + const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); + const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); + + const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); + const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); + const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); + const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); + + vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); + vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); + vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); + vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); + break; + } + case ReductionOperation::PROD: + { + const auto offset32x4f_4 = vdupq_n_f32(iq_info.offset); + const auto scale32x4f_4 = vdupq_n_f32(iq_info.scale); + + const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); + const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); + + const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); + const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); + const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); + const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); + + auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); + auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); + auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); + auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); + + //de-quantize vec_elements + temp32x4f_1 = vmulq_f32(vsubq_f32(temp32x4f_1, offset32x4f_4), scale32x4f_4); + temp32x4f_2 = vmulq_f32(vsubq_f32(temp32x4f_2, offset32x4f_4), scale32x4f_4); + temp32x4f_3 = vmulq_f32(vsubq_f32(temp32x4f_3, offset32x4f_4), scale32x4f_4); + temp32x4f_4 = vmulq_f32(vsubq_f32(temp32x4f_4, offset32x4f_4), scale32x4f_4); + + vec_res_value1_f = vmulq_f32(temp32x4f_1, vec_res_value1_f); + vec_res_value2_f = vmulq_f32(temp32x4f_2, vec_res_value2_f); + vec_res_value3_f = vmulq_f32(temp32x4f_3, vec_res_value3_f); + vec_res_value4_f = vmulq_f32(temp32x4f_4, vec_res_value4_f); + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( + x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = calculate_index_quantized<decltype(vec_res_value)>( + x, temp_vec_res_value, vec_res_value, vec_res_idx, op, 0); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::ARG_IDX_MIN: + { + auto idx = + calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) < res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto idx = + calculate_vector_index_quantized<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op); + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + if (*(input_ptr + x) > res) + { + idx = x; + res = *(input_ptr + x); + } + } + *(reinterpret_cast<uint32_t *>(output.ptr())) = idx; + break; + } + case ReductionOperation::MIN: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_min(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) < res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::MAX: + { + auto res = static_cast<T>(wrapper::vgetlane(calculate_max(vec_res_value), 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res = *(input_ptr + x) > res ? *(input_ptr + x) : res; + } + *(reinterpret_cast<T *>(output.ptr())) = res; + break; + } + case ReductionOperation::PROD: + { + auto carry_res = wrapper::vmul(vec_res_value1_f, vec_res_value2_f); + carry_res = wrapper::vmul(carry_res, vec_res_value3_f); + carry_res = wrapper::vmul(carry_res, vec_res_value4_f); + + float res = wrapper::vgetlane(carry_res, 0); + res *= wrapper::vgetlane(carry_res, 1); + res *= wrapper::vgetlane(carry_res, 2); + res *= wrapper::vgetlane(carry_res, 3); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + //de-quantize input + if (std::is_same<T, uint8_t>::value) + { + res *= dequantize_qasymm8(*(input_ptr + x), iq_info); + } + else + { + res *= dequantize_qasymm8_signed(*(input_ptr + x), iq_info); + } + } + + //re-quantize result + if (std::is_same<T, uint8_t>::value) + { + res = quantize_qasymm8(res, iq_info); + } + else + { + res = quantize_qasymm8_signed(res, iq_info); + } + + *reinterpret_cast<T *>(output.ptr()) = static_cast<T>(res); + break; + } + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + { + auto carry_res = wrapper::vadd(vec_res_value1, vec_res_value2); + carry_res = wrapper::vadd(carry_res, vec_res_value3); + carry_res = wrapper::vadd(carry_res, vec_res_value4); + + auto carry_paddition = + wrapper::vpadd(wrapper::vgethigh(carry_res), wrapper::vgetlow(carry_res)); + carry_paddition = wrapper::vpadd(carry_paddition, carry_paddition); + auto res = static_cast<int32_t>(wrapper::vgetlane(carry_paddition, 0)); + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + res += *(input_ptr + x); + } + + if (op == ReductionOperation::MEAN_SUM) + { + const int32_t resFinal = A * (static_cast<float>(res)) + B; + + *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(resFinal); + } + else + { + // Subtract accumulated offsets + res -= (in_info.dimension(0) - 1) * iq_info.offset; + *reinterpret_cast<T *>(output.ptr()) = utils::cast::saturate_cast<T>(res); + } + + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + }, + input, output); + } +}; + +template <typename T, int S> +struct RedOpYZW +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; + + inline void operator()(const Window &in_window, + Window &out_window, + const ITensor *in, + ITensor *out, + int axis, + const ReductionOperation op) + { + const TensorInfo in_info = *(in->info()); + const int window_step_x = 16 / sizeof(T); + const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); + const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); + // As it split over x-axis, need to set the correct spiltted window start and end. + const auto window_start_x = static_cast<int>(0); + const auto window_end_x = static_cast<int>(in_window.shape().x()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); + Window out_win_no_pad = out_window; + out_win_no_pad.set(Window::DimX, + Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_win_no_pad); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<T *>(input.ptr()); + + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + neon_vector vec_res_value = {0}; + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vloadq(input_ptr + x); + break; + } + case ReductionOperation::PROD: + { + vec_res_value = wrapper::vdup_n(static_cast<T>(1.f), ExactTagType{}); + break; + } + default: + { + vec_res_value = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + break; + } + } + uint32x4x4_t vec_res_idx{{0}}; + + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + const T *in_ptr = + reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); + const auto vec_elements = wrapper::vloadq(in_ptr); + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + vec_res_value = wrapper::vadd(vec_elements, vec_res_value); + break; + case ReductionOperation::SUM_SQUARE: + vec_res_value = wrapper::vadd(wrapper::vmul(vec_elements, vec_elements), vec_res_value); + break; + case ReductionOperation::PROD: + vec_res_value = wrapper::vmul(vec_elements, vec_res_value); + break; + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = + calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = + calculate_index(dim, temp_vec_res_value, vec_res_value, vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + auto vec_width_inv = + wrapper::vinv(wrapper::vdup_n(static_cast<T>(in_info.dimension(axis)), ExactTagType{})); + vec_res_value = wrapper::vmul(vec_res_value, vec_width_inv); + } + + if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) + { + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x, vec_res_idx.val[0]); +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + if (std::is_same<T, float16_t>::value) + { + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr()) + x + 4, vec_res_idx.val[1]); + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + } + else + { + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x * sizeof(T)), vec_res_value); + } + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto res_value = 0.f; + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + res_value = *(input_ptr + x); + break; + } + case ReductionOperation::PROD: + { + res_value = static_cast<T>(1.f); + break; + } + default: + { + res_value = static_cast<T>(0.f); + break; + } + } + + uint32_t res_idx = 0; + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + const T *in_ptr = + reinterpret_cast<T *>(input.ptr() + x * sizeof(T) + in_info.strides_in_bytes()[axis] * dim); + + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + res_value += *in_ptr; + break; + case ReductionOperation::SUM_SQUARE: + res_value += *in_ptr * *in_ptr; + break; + case ReductionOperation::PROD: + res_value *= *in_ptr; + break; + case ReductionOperation::ARG_IDX_MIN: + { + if (*in_ptr < res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + if (*in_ptr > res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::MIN: + { + res_value = *in_ptr < res_value ? *in_ptr : res_value; + break; + } + case ReductionOperation::MAX: + { + res_value = *in_ptr > res_value ? *in_ptr : res_value; + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + if (op == ReductionOperation::MEAN_SUM) + { + res_value /= in_info.dimension(axis); + } + + if (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX) + { + *(reinterpret_cast<uint32_t *>(output.ptr()) + x) = res_idx; + } + else + { + *(reinterpret_cast<T *>(output.ptr() + x * sizeof(T))) = res_value; + } + } + }, + input, output); + } +}; + +template <typename T, int S, int axis, ReductionOperation op> +struct RedOpYZW_complex +{ + /** SIMD vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type; + using neon_vector = typename wrapper::traits::neon_vector<T, S>::type; + + inline void operator()( + const Window &in_window, Window &out_window, const ITensor *in, ITensor *out, int, const ReductionOperation) + { + ARM_COMPUTE_ERROR_ON(axis != 2); + ARM_COMPUTE_ERROR_ON(op != ReductionOperation::SUM); + + const TensorInfo in_info = *(in->info()); + const size_t stride_z = in_info.strides_in_bytes()[axis]; + const int window_step_x = 16 / sizeof(T); + const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); + const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); + // As it split over x-axis, need to set the correct spiltted window start and end. + const auto window_start_x = static_cast<int>(0); + const auto window_end_x = static_cast<int>(in_window.shape().x()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); + Window out_win_no_pad = out_window; + out_win_no_pad.set(Window::DimX, + Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_win_no_pad); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + neon_vector vec_res_value_0 = {0}; + neon_vector vec_res_value_1 = {0}; + + vec_res_value_0 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + vec_res_value_1 = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{}); + + T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + T *in_ptr_0 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); + T *in_ptr_1 = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + 16 + stride_z * dim); + + const auto vec_elements_0 = wrapper::vloadq(in_ptr_0); + const auto vec_elements_1 = wrapper::vloadq(in_ptr_1); + + vec_res_value_0 = wrapper::vadd(vec_elements_0, vec_res_value_0); + vec_res_value_1 = wrapper::vadd(vec_elements_1, vec_res_value_1); + } + + wrapper::vstore(out_ptr, vec_res_value_0); + wrapper::vstore(out_ptr + 4, vec_res_value_1); + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + auto res_value_0 = 0.f; + auto res_value_1 = 0.f; + + T *out_ptr = reinterpret_cast<T *>(output.ptr() + 2 * x * sizeof(T)); + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + T *in_ptr = reinterpret_cast<T *>(input.ptr() + 2 * x * sizeof(T) + stride_z * dim); + res_value_0 += *in_ptr; + res_value_1 += *(in_ptr + 1); + } + *out_ptr = res_value_0; + *(out_ptr + 1) = res_value_1; + } + }, + input, output); + } +}; + +template <typename T> +struct RedOpYZW_quantized +{ + inline void operator()(const Window &in_window, + Window &out_window, + const ITensor *in, + ITensor *out, + int axis, + const ReductionOperation op) + { + const TensorInfo in_info = *(in->info()); + const UniformQuantizationInfo iq_info = in_info.quantization_info().uniform(); + using PromotedType = typename wrapper::traits::promote<typename wrapper::traits::promote<T>::type>::type; + + const auto oq_info = out->info()->quantization_info().uniform(); + + const int window_step_x = 16 / sizeof(T); + const auto window_start_x_tmp = static_cast<int>(in_window.x().start()); + const auto window_end_x_tmp = static_cast<int>(in_window.x().end()); + // As it split over x-axis, need to set the correct spiltted window start and end. + const auto window_start_x = static_cast<int>(0); + const auto window_end_x = static_cast<int>(in_window.shape().x()); + + Window in_win_no_pad = in_window; + in_win_no_pad.set(Window::DimX, Window::Dimension(window_start_x_tmp, window_end_x_tmp, in_window.shape().x())); + Window out_win_no_pad = out_window; + out_win_no_pad.set(Window::DimX, + Window::Dimension(window_start_x_tmp, window_end_x_tmp, out_window.shape().x())); + + Iterator input(in, in_win_no_pad); + Iterator output(out, out_win_no_pad); + + using vector_type = + typename wrapper::traits::neon_bitvector<PromotedType, wrapper::traits::BitWidth::W128>::type; + using vector_type_f = typename wrapper::traits::neon_vector<float, 4>::type; + + vector_type vec_res_value1{}; + vector_type vec_res_value2{}; + vector_type vec_res_value3{}; + vector_type vec_res_value4{}; + + vector_type_f vec_res_value1_f{}; + vector_type_f vec_res_value2_f{}; + vector_type_f vec_res_value3_f{}; + vector_type_f vec_res_value4_f{}; + + const float in_offset = static_cast<float>(iq_info.offset); + const float in_scale = iq_info.scale; + + const float out_offset = static_cast<float>(oq_info.offset); + const float out_scale = oq_info.scale; + + const float num_elements = static_cast<float>(in_info.dimension(axis)); + + const float A = in_scale / (out_scale * num_elements); + const float B = out_offset - (in_scale * in_offset) / (out_scale); + + const auto vec_A = wrapper::vdup_n(static_cast<float>(A), wrapper::traits::vector_128_tag{}); + const auto vec_B = wrapper::vdup_n(static_cast<float>(B), wrapper::traits::vector_128_tag{}); + + execute_window_loop( + in_win_no_pad, + [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast<T *>(input.ptr()); + + // Compute window_step_x elements per iteration + int x = window_start_x; + for (; x <= (window_end_x - window_step_x); x += window_step_x) + { + uint32x4x4_t vec_res_idx{{0}}; + vec_res_value1 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); + vec_res_value2 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); + vec_res_value3 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); + vec_res_value4 = wrapper::vdup_n(static_cast<PromotedType>(0), wrapper::traits::vector_128_tag{}); + + vec_res_value1_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); + vec_res_value2_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); + vec_res_value3_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); + vec_res_value4_f = wrapper::vdup_n(static_cast<float>(1), wrapper::traits::vector_128_tag{}); + + auto vec_res_value = wrapper::vloadq(input_ptr + x); + + for (unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim) + { + const T *in_ptr = input_ptr + x + in_info.strides_in_bytes()[axis] * index_dim; + const auto vec_elements = wrapper::vloadq(in_ptr); + switch (op) + { + case ReductionOperation::SUM: + case ReductionOperation::MEAN_SUM: + { + const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); + const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); + + const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); + const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); + const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); + const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); + + vec_res_value1 = wrapper::vadd(temp32x4t_1, vec_res_value1); + vec_res_value2 = wrapper::vadd(temp32x4t_2, vec_res_value2); + vec_res_value3 = wrapper::vadd(temp32x4t_3, vec_res_value3); + vec_res_value4 = wrapper::vadd(temp32x4t_4, vec_res_value4); + break; + } + case ReductionOperation::PROD: + { + const auto offset32x4f_4 = wrapper::vdup_n(static_cast<float>(iq_info.offset), + wrapper::traits::vector_128_tag{}); + const auto scale32x4f_4 = + wrapper::vdup_n(iq_info.scale, wrapper::traits::vector_128_tag{}); + + const auto temp16x8t_1 = wrapper::vmovl(wrapper::vgetlow(vec_elements)); + const auto temp16x8t_2 = wrapper::vmovl(wrapper::vgethigh(vec_elements)); + + const auto temp32x4t_1 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_1)); + const auto temp32x4t_2 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_1)); + const auto temp32x4t_3 = wrapper::vmovl(wrapper::vgetlow(temp16x8t_2)); + const auto temp32x4t_4 = wrapper::vmovl(wrapper::vgethigh(temp16x8t_2)); + + auto temp32x4f_1 = wrapper::vcvt<float>(temp32x4t_1); + auto temp32x4f_2 = wrapper::vcvt<float>(temp32x4t_2); + auto temp32x4f_3 = wrapper::vcvt<float>(temp32x4t_3); + auto temp32x4f_4 = wrapper::vcvt<float>(temp32x4t_4); + + //de-quantize vec_elements + temp32x4f_1 = wrapper::vmul(wrapper::vsub(temp32x4f_1, offset32x4f_4), scale32x4f_4); + temp32x4f_2 = wrapper::vmul(wrapper::vsub(temp32x4f_2, offset32x4f_4), scale32x4f_4); + temp32x4f_3 = wrapper::vmul(wrapper::vsub(temp32x4f_3, offset32x4f_4), scale32x4f_4); + temp32x4f_4 = wrapper::vmul(wrapper::vsub(temp32x4f_4, offset32x4f_4), scale32x4f_4); + + vec_res_value1_f = wrapper::vmul(temp32x4f_1, vec_res_value1_f); + vec_res_value2_f = wrapper::vmul(temp32x4f_2, vec_res_value2_f); + vec_res_value3_f = wrapper::vmul(temp32x4f_3, vec_res_value3_f); + vec_res_value4_f = wrapper::vmul(temp32x4f_4, vec_res_value4_f); + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, + vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + vec_res_idx = calculate_index_quantized(index_dim, temp_vec_res_value, vec_res_value, + vec_res_idx, op, axis); + vec_res_value = temp_vec_res_value; + break; + } + case ReductionOperation::MIN: + { + vec_res_value = wrapper::vmin(vec_elements, vec_res_value); + break; + } + case ReductionOperation::MAX: + { + vec_res_value = wrapper::vmax(vec_elements, vec_res_value); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::ARG_IDX_MAX: + { + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x), vec_res_idx.val[0]); + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 4, vec_res_idx.val[1]); + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 8, vec_res_idx.val[2]); + wrapper::vstore(reinterpret_cast<uint32_t *>(output.ptr() + 4 * x) + 12, + vec_res_idx.val[3]); + break; + } + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), vec_res_value); + break; + } + case ReductionOperation::SUM: + { + // Subtract offsets + auto offsets = vdupq_n_s32((in_info.dimension(axis) - 1) * iq_info.offset); + + auto vec_res_s_value1 = wrapper::vreinterpret(vec_res_value1); + auto vec_res_s_value2 = wrapper::vreinterpret(vec_res_value2); + auto vec_res_s_value3 = wrapper::vreinterpret(vec_res_value3); + auto vec_res_s_value4 = wrapper::vreinterpret(vec_res_value4); + + vec_res_s_value1 = wrapper::vsub(vec_res_s_value1, offsets); + vec_res_s_value2 = wrapper::vsub(vec_res_s_value2, offsets); + vec_res_s_value3 = wrapper::vsub(vec_res_s_value3, offsets); + vec_res_s_value4 = wrapper::vsub(vec_res_s_value4, offsets); + + const auto temp16x8t_1 = + wrapper::vcombine(wrapper::vqmovn(vec_res_s_value1), wrapper::vqmovn(vec_res_s_value2)); + const auto temp16x8t_2 = + wrapper::vcombine(wrapper::vqmovn(vec_res_s_value3), wrapper::vqmovn(vec_res_s_value4)); + + combine_and_store<T>(temp16x8t_1, temp16x8t_2, output, x); + break; + } + case ReductionOperation::MEAN_SUM: + { + vec_res_value1_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value1), vec_A); + vec_res_value2_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value2), vec_A); + vec_res_value3_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value3), vec_A); + vec_res_value4_f = wrapper::vmla(vec_B, wrapper::vcvt<float>(vec_res_value4), vec_A); + +#ifdef __aarch64__ + vec_res_value1 = wrapper::vcvta<PromotedType>(vec_res_value1_f); + vec_res_value2 = wrapper::vcvta<PromotedType>(vec_res_value2_f); + vec_res_value3 = wrapper::vcvta<PromotedType>(vec_res_value3_f); + vec_res_value4 = wrapper::vcvta<PromotedType>(vec_res_value4_f); +#else // defined(__aarch64__) + vec_res_value1 = wrapper::vcvt<PromotedType>(vec_res_value1_f); + vec_res_value2 = wrapper::vcvt<PromotedType>(vec_res_value2_f); + vec_res_value3 = wrapper::vcvt<PromotedType>(vec_res_value3_f); + vec_res_value4 = wrapper::vcvt<PromotedType>(vec_res_value4_f); +#endif // __aarch64__ + + const auto temp16x8t_1 = + wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); + const auto temp16x8t_2 = + wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); + auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); + + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); + break; + } + case ReductionOperation::PROD: + { + const auto offset32x4f_4 = + wrapper::vdup_n(static_cast<float>(iq_info.offset), wrapper::traits::vector_128_tag{}); + const auto iscale32x4f_4 = vinvq_f32(vdupq_n_f32(iq_info.scale)); + + //re-quantize + vec_res_value1_f = + wrapper::vadd(wrapper::vmul(vec_res_value1_f, iscale32x4f_4), offset32x4f_4); + vec_res_value2_f = + wrapper::vadd(wrapper::vmul(vec_res_value2_f, iscale32x4f_4), offset32x4f_4); + vec_res_value3_f = + wrapper::vadd(wrapper::vmul(vec_res_value3_f, iscale32x4f_4), offset32x4f_4); + vec_res_value4_f = + wrapper::vadd(wrapper::vmul(vec_res_value4_f, iscale32x4f_4), offset32x4f_4); + + vec_res_value1 = wrapper::vcvt<T>(vec_res_value1_f); + vec_res_value2 = wrapper::vcvt<T>(vec_res_value2_f); + vec_res_value3 = wrapper::vcvt<T>(vec_res_value3_f); + vec_res_value4 = wrapper::vcvt<T>(vec_res_value4_f); + + const auto temp16x8t_1 = + wrapper::vcombine(wrapper::vqmovn(vec_res_value1), wrapper::vqmovn(vec_res_value2)); + const auto temp16x8t_2 = + wrapper::vcombine(wrapper::vqmovn(vec_res_value3), wrapper::vqmovn(vec_res_value4)); + auto res = wrapper::vcombine(wrapper::vqmovn(temp16x8t_1), wrapper::vqmovn(temp16x8t_2)); + + wrapper::vstore(reinterpret_cast<T *>(output.ptr() + x), res); + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + // Compute left-over elements + for (; x < window_end_x; ++x) + { + float res_value = 0.f; + int32_t res_value_q = 0; + + switch (op) + { + case ReductionOperation::ARG_IDX_MAX: + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::MIN: + case ReductionOperation::MAX: + { + res_value = *(input_ptr + x); + break; + } + case ReductionOperation::PROD: + { + res_value = static_cast<T>(1.0f); + break; + } + default: + { + res_value = static_cast<T>(0.0f); + break; + } + } + uint32_t res_idx = 0; + + for (unsigned int dim = 0; dim < in_info.dimension(axis); ++dim) + { + const T *in_ptr = + reinterpret_cast<T *>(input.ptr() + x + in_info.strides_in_bytes()[axis] * dim); + switch (op) + { + case ReductionOperation::SUM: + { + res_value += *in_ptr; + break; + } + case ReductionOperation::MEAN_SUM: + { + res_value_q += *in_ptr; + break; + } + case ReductionOperation::SUM_SQUARE: + { + res_value += *in_ptr * *in_ptr; + break; + } + case ReductionOperation::PROD: + { + //de-quantize input + if (std::is_same<T, uint8_t>::value) + { + res_value *= dequantize_qasymm8(*in_ptr, iq_info); + } + else + { + res_value *= dequantize_qasymm8_signed(*in_ptr, iq_info); + } + break; + } + case ReductionOperation::ARG_IDX_MIN: + { + if (*in_ptr < res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::ARG_IDX_MAX: + { + if (*in_ptr > res_value) + { + res_value = *in_ptr; + res_idx = dim; + } + break; + } + case ReductionOperation::MIN: + { + res_value = *in_ptr < res_value ? *in_ptr : res_value; + break; + } + case ReductionOperation::MAX: + { + res_value = *in_ptr > res_value ? *in_ptr : res_value; + break; + } + default: + ARM_COMPUTE_ERROR("Not supported"); + } + } + + switch (op) + { + case ReductionOperation::MEAN_SUM: + { + // Apply previously calculated coefficients (with rounding on aarch64) +#ifdef __aarch64__ + const int32_t res = + arm_compute::support::cpp11::round(A * (static_cast<float>(res_value_q)) + B); +#else // defined(__aarch64__) + const int32_t res = A * (static_cast<float>(res_value_q)) + B; +#endif // __aarch64__ + *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res); + break; + } + case ReductionOperation::SUM: + { + // Subtract accumulated offsets + res_value -= (in_info.dimension(axis) - 1) * iq_info.offset; + *reinterpret_cast<T *>(output.ptr() + x) = utils::cast::saturate_cast<T>(res_value); + break; + } + case ReductionOperation::PROD: + { + //re-quantize result + T res = 0; + if (std::is_same<T, uint8_t>::value) + { + res = quantize_qasymm8(res_value, iq_info); + } + else + { + res = quantize_qasymm8_signed(res_value, iq_info); + } + *(reinterpret_cast<T *>(output.ptr() + x)) = res; + break; + } + case ReductionOperation::ARG_IDX_MIN: + case ReductionOperation::ARG_IDX_MAX: + { + *(reinterpret_cast<uint32_t *>(output.ptr() + x * 4)) = res_idx; + break; + } + default: + *(reinterpret_cast<T *>(output.ptr() + x)) = res_value; + } + } + }, + input, output); + } +}; + +} // namespace arm_compute +#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_IMPL_H diff --git a/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp b/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp new file mode 100644 index 0000000000..ad66b456ac --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/integer.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpX_reduceX_S32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op); +} + +void reduce_RedOpYZW_reduceY_S32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op); +} +void reduce_RedOpYZW_reduceZ_S32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op); +} + +void reduce_RedOpYZW_reduceW_S32_4(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/reduction_layer/generic/neon/list.h b/src/cpu/kernels/reduction_layer/generic/neon/list.h new file mode 100644 index 0000000000..947c28a130 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/list.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H +#define ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H + +#include "arm_compute/core/Helpers.h" + +namespace arm_compute +{ +namespace cpu +{ + +#define DECLARE_REDUCTION_KERNEL(func_name) \ + void func_name(const Window &window, const ITensor *in, ITensor *out, const ReductionOperation op) + +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_complex_reduceZ_float32_4_2_SUM); +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_float32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_float32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_float32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_float32_4); + +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_float16_8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_float16_8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_float16_8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_float16_8); + +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_S32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_S32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_S32_4); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_S32_4); + +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_qasymm8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_qasymm8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_qasymm8); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_qasymm8); + +DECLARE_REDUCTION_KERNEL(reduce_RedOpX_reduceX_qasymm8_signed); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceY_qasymm8_signed); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceZ_qasymm8_signed); +DECLARE_REDUCTION_KERNEL(reduce_RedOpYZW_reduceW_qasymm8_signed); + +#undef DECLARE_REDUCTION_KERNEL +} // namespace cpu +} // namespace arm_compute +#endif // ACL_SRC_CPU_KERNELS_REDUCTION_LAYER_GENERIC_NEON_LIST_H diff --git a/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp new file mode 100644 index 0000000000..bc711c6855 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpX_reduceX_qasymm8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX_quantized<uint8_t>>::reduceX(window, input, output, RedOpX_quantized<uint8_t>(), op); +} + +void reduce_RedOpYZW_reduceY_qasymm8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<uint8_t>>::reduceY(window, input, output, RedOpYZW_quantized<uint8_t>(), op); +} + +void reduce_RedOpYZW_reduceZ_qasymm8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<uint8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<uint8_t>(), op); +} + +void reduce_RedOpYZW_reduceW_qasymm8(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<uint8_t>>::reduceW(window, input, output, RedOpYZW_quantized<uint8_t>(), op); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp new file mode 100644 index 0000000000..10ac3d6715 --- /dev/null +++ b/src/cpu/kernels/reduction_layer/generic/neon/qasymm8_signed.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/reduction_layer/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void reduce_RedOpX_reduceX_qasymm8_signed(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpX_quantized<int8_t>>::reduceX(window, input, output, RedOpX_quantized<int8_t>(), op); +} + +void reduce_RedOpYZW_reduceY_qasymm8_signed(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<int8_t>>::reduceY(window, input, output, RedOpYZW_quantized<int8_t>(), op); +} + +void reduce_RedOpYZW_reduceZ_qasymm8_signed(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<int8_t>>::reduceZ(window, input, output, RedOpYZW_quantized<int8_t>(), op); +} + +void reduce_RedOpYZW_reduceW_qasymm8_signed(const Window &window, + const ITensor *input, + ITensor *output, + const ReductionOperation op) +{ + return Reducer<RedOpYZW_quantized<int8_t>>::reduceW(window, input, output, RedOpYZW_quantized<int8_t>(), op); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp index da62d2d614..425fcf7ac6 100644 --- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp +++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp @@ -33,9 +33,15 @@ namespace cpu { template <bool IS_LOG> -void neon_fp16_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +void neon_fp16_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); if (axis == 0) { return neon_softmax_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window); @@ -46,10 +52,20 @@ void neon_fp16_softmax( } } -template void neon_fp16_softmax<true>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); -template void neon_fp16_softmax<false>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +template void neon_fp16_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_fp16_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp index 0701620636..a64946eb74 100644 --- a/src/cpu/kernels/softmax/generic/neon/fp32.cpp +++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp @@ -31,9 +31,15 @@ namespace cpu { template <bool IS_LOG> -void neon_fp32_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +void neon_fp32_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); if (axis == 0) { return neon_softmax_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window); @@ -44,10 +50,20 @@ void neon_fp32_softmax( } } -template void neon_fp32_softmax<true>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); -template void neon_fp32_softmax<false>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +template void neon_fp32_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_fp32_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp index d39240bb38..369f9bb005 100644 --- a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp +++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp @@ -30,9 +30,15 @@ namespace arm_compute namespace cpu { template <bool IS_LOG> -void neon_qasymm8_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +void neon_qasymm8_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); if (axis == 0) { return neon_softmax_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window); @@ -43,10 +49,20 @@ void neon_qasymm8_softmax( } } -template void neon_qasymm8_softmax<true>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); -template void neon_qasymm8_softmax<false>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +template void neon_qasymm8_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_qasymm8_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp index 26fd5dbfa0..594ceb7654 100644 --- a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp +++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp @@ -30,9 +30,15 @@ namespace arm_compute namespace cpu { template <bool IS_LOG> -void neon_qasymm8_signed_softmax( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +void neon_qasymm8_signed_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) { + ARM_COMPUTE_UNUSED(lut_ptr); if (axis == 0) { return neon_softmax_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window); @@ -43,10 +49,20 @@ void neon_qasymm8_signed_softmax( } } -template void neon_qasymm8_signed_softmax<true>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); -template void neon_qasymm8_signed_softmax<false>( - const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window); +template void neon_qasymm8_signed_softmax<true>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); +template void neon_qasymm8_signed_softmax<false>(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/softmax/generic/sme2/fp16.cpp b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp new file mode 100644 index 0000000000..e70c9f4793 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp @@ -0,0 +1,781 @@ +/* + * Copyright (c) 2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_f16_softmax_kernel( // + const float16_t *src, + float16_t *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4]) +{ + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // Registers + // + // * x9: temporary, index + // * x10: temporary, -inf + // * x11: temporary, 0 + // * x12: temporary, 1.0f + // * x13: temporary, body_length + // + // * x20: index_3 + // * x21: src_3 + // * x22: dst_3 + // * x23: index_2 + // * x24: src_2 + // * x25: dst_2 + // * x26: index_1 + // * x27: src_1 + // * x28: dst_1 + // + // * z0: c1 + // * z1: c2 + // * z2: c3 + // * z3: c4 + // * z4: c5 + // * z5: shift + // * z6: inv_ln2 + // * z7: neg_ln2_hi + // * z8: neg_ln2_lo + // * z9: min_input + // * z10: 23, 0 + // * z11: max_value + // * z12-z15: x, x_fp32_lower_halves, r_hi, r, r2 + // * z16-z19: max_value, shift, z, scale, poly + // * z20-z21: n, p1, p12345 + // * z22-z23: n, p23, p2345 + // * z24-z25: p45 + // * z26: beta + // * z28-z31: sum_value, x_fp32_upper_halves + // + // * za0-za3: sum_value + // + // * p0: all-true + // * p1: left-over predicate for find-max & normalize loops + // * p2-p4: left-over predicates for regularize loop + // * p4-p7: underflow in vector loop + // * p5-p6: underflow in leftover loop + // * + // * pn9: all-true + + // Prepares all constant values + + ptrue p0.b + .inst 0x25207811 // ptrue pn9.b + + mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb + mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010 + + movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb + movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010 + + dup z0.s, w9 // c1. + dup z1.s, w10 // c2. + dup z2.s, w11 // c3. + dup z3.s, w12 // c4. + dup z4.s, w13 // c5. + + mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + dup z5.s, w9 // shift + dup z6.s, w10 // inv_ln2 + dup z7.s, w11 // neg_ln2_hi + dup z8.s, w12 // neg_ln2_lo + dup z9.s, w13 // min_input + + dup z26.s, %w[beta] // beta + fcvt h26, s26 + dup z26.h, z26.h[0] + + mov w10, #0xfc00 // -inf: 0xfc00 for fp16 + + mov w11, #0 // 0 + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cnth x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + + // ---------------------------------------------------------------- z16-z19: max_value = -inf + dup z16.h, w10 + dup z17.h, w10 + dup z18.h, w10 + dup z19.h, w10 + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + dup z11.h, w10 // z11: max_value = -inf + +find_max_body_start%=: + cmp x9, x13 + b.eq find_max_body_end%= + + .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x + .inst 0xc16cb910 // fmax {z16.h-z19.h}, {z16.h-z19.h}, {z12.h-z15.h} // z16-z19: max_value = max(max_value, x) + + inch x9, ALL, MUL #4 + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.h, x9, %x[length] + b.none find_max_leftover_end%= + + ld1h z12.h, p1/z, [x27, x9, LSL #1] // z12: x + fmax z16.h, p1/m, z16.h, z12.h // z16: max_value = max(max_value, x) + + inch x9 + b find_max_leftover_start%= +find_max_leftover_end%=: + + // ---------------------------------------------------------------- z16: max_value + .inst 0xc172b110 // fmax {z16.h-z17.h}, {z16.h-z17.h}, {z18.s-z19.h} + fmax z16.h, p0/m, z16.h, z17.h + fmaxv h16, p0, z16.h + + // ---------------------------------------------------------------- z11: max_value + dup z11.h, z16.h[0] + + // ================================================== + // Step 2: Regularize, i.e. Calculate exp(x - max(x) + // ================================================== + + .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value (in fp32) + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // ---------------------------------------------------- x9: index + +regularize_body_start%=: + cmp x9, x13 + b.eq regularize_body_end%= + + // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data + .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x + + // ---------------------------------------------------------------- z12-z15: x = input_data - max_value + fsub z12.h, z12.h, z11.h + fsub z13.h, z13.h, z11.h + fsub z14.h, z14.h, z11.h + fsub z15.h, z15.h, z11.h + + // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta + fmul z12.h, z12.h, z26.h + fmul z13.h, z13.h, z26.h + fmul z14.h, z14.h, z26.h + fmul z15.h, z15.h, z26.h + + // ---------------------------------------------------------------- + // Convert fp16 values to fp32. This results in four more registers. + // z12 --> z12, z28 + fcvtlt z28.s, p0/m, z12.h + fcvt z12.s, p0/m, z12.h + + // z13 --> z13, z29 + fcvtlt z29.s, p0/m, z13.h + fcvt z13.s, p0/m, z13.h + + // z14 --> z14, z30 + fcvtlt z30.s, p0/m, z14.h + fcvt z14.s, p0/m, z14.h + + // z15 --> z15, z31 + fcvtlt z31.s, p0/m, z15.h + fcvt z15.s, p0/m, z15.h + + // ---------------------------------------------------------------- + // Process z12-z15 + // ---------------------------------------------------------------- + // ---------------------------------------------------------------- z16-z19: shift + mov z16.d, z5.d + mov z17.d, z5.d + mov z18.d, z5.d + mov z19.d, z5.d + + // ---------------------------------------------------------------- p4-p7: underflow = x < min_input + fcmlt p4.s, p0/z, z12.s, z9.s + fcmlt p5.s, p0/z, z13.s, z9.s + fcmlt p6.s, p0/z, z14.s, z9.s + fcmlt p7.s, p0/z, z15.s, z9.s + + // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2 + fmla z16.s, p0/m, z12.s, z6.s + fmla z17.s, p0/m, z13.s, z6.s + fmla z18.s, p0/m, z14.s, z6.s + fmla z19.s, p0/m, z15.s, z6.s + + // ---------------------------------------------------------------- z20-z23: n = z - shift + fsub z20.s, z16.s, z5.s + fsub z21.s, z17.s, z5.s + fsub z22.s, z18.s, z5.s + fsub z23.s, z19.s, z5.s + + // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi + fmla z12.s, p0/m, z20.s, z7.s + fmla z13.s, p0/m, z21.s, z7.s + fmla z14.s, p0/m, z22.s, z7.s + fmla z15.s, p0/m, z23.s, z7.s + + // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo + fmla z12.s, p0/m, z20.s, z8.s + fmla z13.s, p0/m, z21.s, z8.s + fmla z14.s, p0/m, z22.s, z8.s + fmla z15.s, p0/m, z23.s, z8.s + + // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n) + dup z10.s, #23 + urshl z16.s, p0/m, z16.s, z10.s + urshl z17.s, p0/m, z17.s, z10.s + urshl z18.s, p0/m, z18.s, z10.s + urshl z19.s, p0/m, z19.s, z10.s + + // Processes the first 2 vectors. (z12-z13) + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z12.s, z0.s + fmul z21.s, z13.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z12.s, z2.s + fmla z23.s, p0/m, z13.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z12.s, z4.s + fmla z25.s, p0/m, z13.s, z4.s + + // ---------------------------------------------------------------- z12-z13: r2 = r * r + fmul z12.s, z12.s, z12.s + fmul z13.s, z13.s, z13.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z12.s, z24.s + fmla z23.s, p0/m, z13.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z12.s, z22.s + fmla z21.s, p0/m, z13.s, z23.s + + // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale + fmla z16.s, p0/m, z20.s, z16.s + fmla z17.s, p0/m, z21.s, z17.s + + // Processes the last 2 vectors (z14-z15) + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z14.s, z0.s + fmul z21.s, z15.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z14.s, z2.s + fmla z23.s, p0/m, z15.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z14.s, z4.s + fmla z25.s, p0/m, z15.s, z4.s + + // ---------------------------------------------------------------- z14-z15: r2 = r * r + fmul z14.s, z14.s, z14.s + fmul z15.s, z15.s, z15.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z14.s, z24.s + fmla z23.s, p0/m, z15.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z14.s, z22.s + fmla z21.s, p0/m, z15.s, z23.s + + // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale + fmla z18.s, p0/m, z20.s, z18.s + fmla z19.s, p0/m, z21.s, z19.s + + // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly + dup z10.s, #0 + sel z12.s, p4, z10.s, z16.s + sel z13.s, p5, z10.s, z17.s + sel z14.s, p6, z10.s, z18.s + sel z15.s, p7, z10.s, z19.s + + // ---------------------------------------------------------------- sum in fp32 + .inst 0xc1a17d80 // fadd za.s[w11, #0, VGx4], {z12.s-z15.s} za0-za3: sum_value = sum_value + poly + + // ---------------------------------------------------------------- + // Process z28-z31 + // ---------------------------------------------------------------- + // ---------------------------------------------------------------- z16-z19: shift + mov z16.d, z5.d + mov z17.d, z5.d + mov z18.d, z5.d + mov z19.d, z5.d + + // ---------------------------------------------------------------- p4-p7: underflow = x < min_input + fcmlt p4.s, p0/z, z28.s, z9.s + fcmlt p5.s, p0/z, z29.s, z9.s + fcmlt p6.s, p0/z, z30.s, z9.s + fcmlt p7.s, p0/z, z31.s, z9.s + + // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2 + fmla z16.s, p0/m, z28.s, z6.s + fmla z17.s, p0/m, z29.s, z6.s + fmla z18.s, p0/m, z30.s, z6.s + fmla z19.s, p0/m, z31.s, z6.s + + // ---------------------------------------------------------------- z20-z23: n = z - shift + fsub z20.s, z16.s, z5.s + fsub z21.s, z17.s, z5.s + fsub z22.s, z18.s, z5.s + fsub z23.s, z19.s, z5.s + + // ---------------------------------------------------------------- z24-z27: r_hi = x + n * neg_ln2_hi + fmla z28.s, p0/m, z20.s, z7.s + fmla z29.s, p0/m, z21.s, z7.s + fmla z30.s, p0/m, z22.s, z7.s + fmla z31.s, p0/m, z23.s, z7.s + + // ---------------------------------------------------------------- z27-z30: r = r_hi + n * neg_ln2_lo + fmla z28.s, p0/m, z20.s, z8.s + fmla z29.s, p0/m, z21.s, z8.s + fmla z30.s, p0/m, z22.s, z8.s + fmla z31.s, p0/m, z23.s, z8.s + + // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n) + dup z10.s, #23 + urshl z16.s, p0/m, z16.s, z10.s + urshl z17.s, p0/m, z17.s, z10.s + urshl z18.s, p0/m, z18.s, z10.s + urshl z19.s, p0/m, z19.s, z10.s + + // Processes the first 2 vectors. (z28-z29) + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z28.s, z0.s + fmul z21.s, z29.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z28.s, z2.s + fmla z23.s, p0/m, z29.s, z2.s + + // ---------------------------------------------------------------- z24-z25: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z28.s, z4.s + fmla z25.s, p0/m, z29.s, z4.s + + // ---------------------------------------------------------------- z28-z29: r2 = r * r + fmul z28.s, z28.s, z28.s + fmul z29.s, z29.s, z29.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z28.s, z24.s + fmla z23.s, p0/m, z29.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z28.s, z22.s + fmla z21.s, p0/m, z29.s, z23.s + + // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale + fmla z16.s, p0/m, z20.s, z16.s + fmla z17.s, p0/m, z21.s, z17.s + + // Processes the last 2 vectors (z30-z31) + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z30.s, z0.s + fmul z21.s, z31.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z30.s, z2.s + fmla z23.s, p0/m, z31.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z30.s, z4.s + fmla z25.s, p0/m, z31.s, z4.s + + // ---------------------------------------------------------------- z30-z31: r2 = r * r + fmul z30.s, z30.s, z30.s + fmul z31.s, z31.s, z31.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z30.s, z24.s + fmla z23.s, p0/m, z31.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z30.s, z22.s + fmla z21.s, p0/m, z31.s, z23.s + + // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale + fmla z18.s, p0/m, z20.s, z18.s + fmla z19.s, p0/m, z21.s, z19.s + + // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly + dup z10.s, #0 + sel z28.s, p4, z10.s, z16.s + sel z29.s, p5, z10.s, z17.s + sel z30.s, p6, z10.s, z18.s + sel z31.s, p7, z10.s, z19.s + + // ---------------------------------------------------------------- sum in fp32 + .inst 0xc1a17f80 // fadd za.s[w11, #0, VGx4], {z28.s-z31.s} za0-za3: sum_value = sum_value + poly + + fcvt z12.h, p0/m, z12.s + fcvtnt z12.h, p0/m, z28.s + + fcvt z13.h, p0/m, z13.s + fcvtnt z13.h, p0/m, z29.s + + fcvt z14.h, p0/m, z14.s + fcvtnt z14.h, p0/m, z30.s + + fcvt z15.h, p0/m, z15.s + fcvtnt z15.h, p0/m, z31.s + + // Stores 4 consecutive registers to the output + .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1] + + inch x9, ALL, MUL #4 + b regularize_body_start%= +regularize_body_end%=: + + // ---------------------------------------------------------------- z28: sum_value + .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4] + fadd z28.s, z28.s, z29.s + fadd z30.s, z30.s, z31.s + fadd z28.s, z28.s, z30.s + + // Loop for processing the leftover part. +regularize_leftover_start%=: + whilelo p2.h, x9, %x[length] + b.none regularize_leftover_end%= + + ld1h z12.h, p2/z, [x27, x9, LSL #1] // x12: input_data + + fsub z12.h, z12.h, z11.h // z12: x = input_data - max_value + fmul z12.h, z12.h, z26.h // z12: x = (input_data - max_value) * beta + + // ---------------------------------------------------------------- z12.h --> z12.s, z13.s + fcvtlt z13.s, p2/m, z12.h + fcvt z12.s, p2/m, z12.h + + // ---------------------------------------------------------------- p3, p4: predicates for z12, z14 + pfalse p1.b + trn1 p3.h, p2.h, p1.h // for z12 + trn2 p4.h, p2.h, p1.h // for z13 + + mov z16.d, z5.d // z16: shift + mov z17.d, z5.d // z17: shift + fcmlt p5.s, p3/z, z12.s, z9.s // p5: underflow = x < min_input + fcmlt p6.s, p4/z, z13.s, z9.s // p6: underflow = x < min_input + fmla z16.s, p3/m, z12.s, z6.s // z16: z = shift + x * inv_ln2 + fmla z17.s, p4/m, z13.s, z6.s // z17: z = shift + x * inv_ln2 + fsub z20.s, z16.s, z5.s // z20: n = z - shift + fsub z21.s, z17.s, z5.s // z21: n = z - shift + fmla z12.s, p3/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi + fmla z13.s, p4/m, z21.s, z7.s // z13: r_hi = x + n * neg_ln2_hi + fmla z12.s, p3/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo + fmla z13.s, p4/m, z21.s, z8.s // z13: r = r_hi + n * neg_ln2_lo + dup z10.s, #23 // z10: 23 + urshl z16.s, p3/m, z16.s, z10.s // z16: scale = z << 23 (2^n) + urshl z17.s, p4/m, z17.s, z10.s // z17: scale = z << 23 (2^n) + fmul z20.s, z12.s, z0.s // z20: p1 = r * c1 + fmul z21.s, z13.s, z0.s // z21: p1 = r * c1 + mov z22.d, z1.d // z22: p23 = c2 + mov z23.d, z1.d // z23: p23 = c2 + fmla z22.s, p3/m, z12.s, z2.s // z22: p23 = c2 + r * c3 + fmla z23.s, p4/m, z13.s, z2.s // z23: p23 = c2 + r * c3 + mov z24.d, z3.d // z24: c4 + mov z25.d, z3.d // z25: c4 + fmla z24.s, p3/m, z12.s, z4.s // z24: p45 = c4 + r * c5 + fmla z25.s, p4/m, z13.s, z4.s // z25: p45 = c4 + r * c5 + fmul z12.s, z12.s, z12.s // z12: r2 = r * r + fmul z13.s, z13.s, z13.s // z13: r2 = r * r + fmla z22.s, p3/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45 + fmla z23.s, p4/m, z13.s, z25.s // z23: p2345 = p23 + r2 * p45 + fmla z20.s, p3/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345 + fmla z21.s, p4/m, z13.s, z23.s // z21: p12345 = p1 + r2 * p2345 + fmla z16.s, p3/m, z20.s, z16.s // z16: poly = scale + p12345 * scale + fmla z17.s, p4/m, z21.s, z17.s // z17: poly = scale + p12345 * scale + dup z10.s, #0 // z10: 0 + sel z16.s, p5, z10.s, z16.s // z16: poly = underflow ? 0 : poly + sel z17.s, p6, z10.s, z17.s // z17: poly = underflow ? 0 : poly + fadd z28.s, p3/m, z28.s, z16.s // z28: sum_value = sum_value + poly + fadd z28.s, p4/m, z28.s, z17.s // z28: sum_value = sum_value + poly + + fcvt z16.h, p3/m, z16.s + fcvtnt z16.h, p4/m, z17.s + st1h z16.h, p2, [x28, x9, LSL #1] + + inch x9 + b regularize_leftover_start%= +regularize_leftover_end%=: + + // ================================================== + // Step 3: Normalize + // ================================================== + + // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value + faddv s28, p0, z28.s + fmov s29, #1.0 // 1.0f + fdiv s28, s29, s28 + fcvt h28, s28 + + dup z28.h, z28.h[0] + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + +normalize_body_start%=: + cmp x9, x13 + b.eq normalize_body_end%= + + .inst 0xa009a78c // ld1h {z12.h-z15.h}, pn9/z, [x28, x9, LSL #1] + + // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value + fmul z12.h, z12.h, z28.h + fmul z13.h, z13.h, z28.h + fmul z14.h, z14.h, z28.h + fmul z15.h, z15.h, z28.h + + .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1] + + inch x9, ALL, MUL #4 + b normalize_body_start%= +normalize_body_end%=: + + // Loop for processing the leftover part. +normalize_leftover_start%=: + whilelo p1.h, x9, %x[length] + b.none normalize_leftover_end%= + + ld1h z12.h, p1/z, [x28, x9, LSL #1] // z12: x + fmul z12.h, z12.h, z28.h // z12: result = x * inv_sum_value + + st1h z12.h, p1, [x28, x9, LSL #1] + + inch x9 + b normalize_leftover_start%= +normalize_leftover_end%=: + + // ================================================== + // 3D loop closing + // ================================================== + + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p9", // + "x9", "x10", "x11", "x12", "x13", "x14", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_fp16_softmax(const ITensor *in, + void *const, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(lut_ptr); + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const auto *k_src = reinterpret_cast<const float16_t *>(in->buffer() + k_src_offset); + auto *k_dst = reinterpret_cast<float16_t *>(out->buffer() + k_dst_offset); + + sme2_f16_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/generic/sme2/fp32.cpp b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp new file mode 100644 index 0000000000..5e29d51746 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp @@ -0,0 +1,585 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_f32_softmax_kernel( // + const float *src, + float *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4]) +{ + // Precondition: + // * src_strides[0] == sizeof(float) + // * dst_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // Registers + // + // * x9: temporary, index + // * x10: temporary, -inf + // * x11: temporary, 0 + // * x12: temporary, 1.0f + // * x13: temporary, body_length + // + // * x20: index_3 + // * x21: src_3 + // * x22: dst_3 + // * x23: index_2 + // * x24: src_2 + // * x25: dst_2 + // * x26: index_1 + // * x27: src_1 + // * x28: dst_1 + // + // * z0: c1 + // * z1: c2 + // * z2: c3 + // * z3: c4 + // * z4: c5 + // * z5: shift + // * z6: inv_ln2 + // * z7: neg_ln2_hi + // * z8: neg_ln2_lo + // * z9: min_input + // * z10: 23, 0 + // * z11: max_value + // * z12-z15: x, r_hi, r, r2 + // * z16-z19: max_value, shift, z, scale, poly + // * z20-z21: n, p1, p12345 + // * z22-z23: n, p23, p2345 + // * z24-z25: p45 + // * z26: beta + // * z28-z31: sum_value + // + // * za0-za3: sum_value + // + // * p0: all-true + // * p1: left-over predicate + // * p4-p7: underflow + // * pn9: all-true + + // Prepares all constant values + + ptrue p0.b + .inst 0x25207811 // ptrue pn9.b + + mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb + mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010 + + movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6 + movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb + movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33 + movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17 + movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010 + + dup z0.s, w9 // c1. + dup z1.s, w10 // c2. + dup z2.s, w11 // c3. + dup z3.s, w12 // c4. + dup z4.s, w13 // c5. + + mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f + movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b + movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200 + movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e + movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae + + dup z5.s, w9 // shift + dup z6.s, w10 // inv_ln2 + dup z7.s, w11 // neg_ln2_hi + dup z8.s, w12 // neg_ln2_lo + dup z9.s, w13 // min_input + + dup z26.s, %w[beta] // beta + + mov w10, #0x0000 // -inf: 0xff800000 + movk w10, #0xff80 // -inf: 0xff800000 + + mov w11, #0 // 0 + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntw x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + dup z11.s, w10 // z11: max_value = -inf + + // ---------------------------------------------------------------- z16-z19: max_value = -inf + mov z16.d, z11.d + mov z17.d, z11.d + mov z18.d, z11.d + mov z19.d, z11.d + +find_max_body_start%=: + cmp x9, x13 + b.eq find_max_body_end%= + + .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x + .inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x) + + incw x9, ALL, MUL #4 + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none find_max_leftover_end%= + + ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x + fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x) + + incw x9 + b find_max_leftover_start%= +find_max_leftover_end%=: + + // ---------------------------------------------------------------- z16: max_value + .inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s} + fmax z16.s, p0/m, z16.s, z17.s + fmaxv s16, p0, z16.s + + // ---------------------------------------------------------------- z11: max_value + dup z11.s, z16.s[0] + + // ================================================== + // Step 2: Regularize + // ================================================== + + .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // ---------------------------------------------------- x9: index + +regularize_body_start%=: + cmp x9, x13 + b.eq regularize_body_end%= + + // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data + .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] + + // ---------------------------------------------------------------- z12-z15: x = input_data - max_value + fsub z12.s, z12.s, z11.s + fsub z13.s, z13.s, z11.s + fsub z14.s, z14.s, z11.s + fsub z15.s, z15.s, z11.s + + // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta + fmul z12.s, z12.s, z26.s + fmul z13.s, z13.s, z26.s + fmul z14.s, z14.s, z26.s + fmul z15.s, z15.s, z26.s + + // ---------------------------------------------------------------- z16-z19: shift + mov z16.d, z5.d + mov z17.d, z5.d + mov z18.d, z5.d + mov z19.d, z5.d + + // ---------------------------------------------------------------- p4-p7: underflow = x < min_input + fcmlt p4.s, p0/z, z12.s, z9.s + fcmlt p5.s, p0/z, z13.s, z9.s + fcmlt p6.s, p0/z, z14.s, z9.s + fcmlt p7.s, p0/z, z15.s, z9.s + + // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2 + fmla z16.s, p0/m, z12.s, z6.s + fmla z17.s, p0/m, z13.s, z6.s + fmla z18.s, p0/m, z14.s, z6.s + fmla z19.s, p0/m, z15.s, z6.s + + // ---------------------------------------------------------------- z20-z23: n = z - shift + fsub z20.s, z16.s, z5.s + fsub z21.s, z17.s, z5.s + fsub z22.s, z18.s, z5.s + fsub z23.s, z19.s, z5.s + + // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi + fmla z12.s, p0/m, z20.s, z7.s + fmla z13.s, p0/m, z21.s, z7.s + fmla z14.s, p0/m, z22.s, z7.s + fmla z15.s, p0/m, z23.s, z7.s + + // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo + fmla z12.s, p0/m, z20.s, z8.s + fmla z13.s, p0/m, z21.s, z8.s + fmla z14.s, p0/m, z22.s, z8.s + fmla z15.s, p0/m, z23.s, z8.s + + // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n) + dup z10.s, #23 + urshl z16.s, p0/m, z16.s, z10.s + urshl z17.s, p0/m, z17.s, z10.s + urshl z18.s, p0/m, z18.s, z10.s + urshl z19.s, p0/m, z19.s, z10.s + + // Processes the first 2 vectors. + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z12.s, z0.s + fmul z21.s, z13.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z12.s, z2.s + fmla z23.s, p0/m, z13.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z12.s, z4.s + fmla z25.s, p0/m, z13.s, z4.s + + // ---------------------------------------------------------------- z12-z13: r2 = r * r + fmul z12.s, z12.s, z12.s + fmul z13.s, z13.s, z13.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z12.s, z24.s + fmla z23.s, p0/m, z13.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z12.s, z22.s + fmla z21.s, p0/m, z13.s, z23.s + + // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale + fmla z16.s, p0/m, z20.s, z16.s + fmla z17.s, p0/m, z21.s, z17.s + + // Processes the last 2 vectors + + // ---------------------------------------------------------------- z20-z21: p1 = r * c1 + fmul z20.s, z14.s, z0.s + fmul z21.s, z15.s, z0.s + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + mov z22.d, z1.d + mov z23.d, z1.d + + // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3 + fmla z22.s, p0/m, z14.s, z2.s + fmla z23.s, p0/m, z15.s, z2.s + + // ---------------------------------------------------------------- z24-z35: c4 + mov z24.d, z3.d + mov z25.d, z3.d + + // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5 + fmla z24.s, p0/m, z14.s, z4.s + fmla z25.s, p0/m, z15.s, z4.s + + // ---------------------------------------------------------------- z14-z15: r2 = r * r + fmul z14.s, z14.s, z14.s + fmul z15.s, z15.s, z15.s + + // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45 + fmla z22.s, p0/m, z14.s, z24.s + fmla z23.s, p0/m, z15.s, z25.s + + // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345 + fmla z20.s, p0/m, z14.s, z22.s + fmla z21.s, p0/m, z15.s, z23.s + + // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale + fmla z18.s, p0/m, z20.s, z18.s + fmla z19.s, p0/m, z21.s, z19.s + + // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly + dup z10.s, #0 + sel z16.s, p4, z10.s, z16.s + sel z17.s, p5, z10.s, z17.s + sel z18.s, p6, z10.s, z18.s + sel z19.s, p7, z10.s, z19.s + + // Stores 4 consecutive registers to the output + .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2] + + .inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly + + incw x9, ALL, MUL #4 + b regularize_body_start%= +regularize_body_end%=: + + // ---------------------------------------------------------------- z28: sum_value + .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4] + fadd z28.s, z28.s, z29.s + fadd z30.s, z30.s, z31.s + fadd z28.s, z28.s, z30.s + + // Loop for processing the leftover part. +regularize_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none regularize_leftover_end%= + + ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data + + fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value + fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta + + mov z16.d, z5.d // z16: shift + fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input + fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2 + fsub z20.s, z16.s, z5.s // z20: n = z - shift + fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi + fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo + dup z10.s, #23 // z10: 23 + urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n) + fmul z20.s, z12.s, z0.s // z20: p1 = r * c1 + mov z22.d, z1.d // z22: p23 = c2 + fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3 + mov z24.d, z3.d // z24: c4 + fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5 + fmul z12.s, z12.s, z12.s // z12: r2 = r * r + fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45 + fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345 + fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale + dup z10.s, #0 // z10: 0 + sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly + + st1w z16.s, p1, [x28, x9, LSL #2] + + fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly + + incw x9 + b regularize_leftover_start%= +regularize_leftover_end%=: + + // ================================================== + // Step 3: Normalize + // ================================================== + + // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value + fmov s29, #1.0 // 1.0f + faddv s28, p0, z28.s + fdiv s28, s29, s28 + dup z28.s, z28.s[0] + + // Loop for processing 4 vectors per iteration. + mov x9, #0 // x9: index + +normalize_body_start%=: + cmp x9, x13 + b.eq normalize_body_end%= + + .inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x + + // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value + fmul z12.s, z12.s, z28.s + fmul z13.s, z13.s, z28.s + fmul z14.s, z14.s, z28.s + fmul z15.s, z15.s, z28.s + + .inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2] + + incw x9, ALL, MUL #4 + b normalize_body_start%= +normalize_body_end%=: + + // Loop for processing the leftover part. +normalize_leftover_start%=: + whilelo p1.s, x9, %x[length] + b.none normalize_leftover_end%= + + ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x + fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value + + st1w z12.s, p1, [x28, x9, LSL #2] + + incw x9 + b normalize_leftover_start%= +normalize_leftover_end%=: + + // ================================================== + // 3D loop closing + // ================================================== + + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p4", "p5", "p6", "p7", "p9", // + "x9", "x10", "x11", "x12", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_fp32_softmax(const ITensor *in, + void *const, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(lut_ptr); + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const auto *k_src = reinterpret_cast<const float *>(in->buffer() + k_src_offset); + auto *k_dst = reinterpret_cast<float *>(out->buffer() + k_dst_offset); + + sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp new file mode 100644 index 0000000000..9feb669f7c --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8.cpp @@ -0,0 +1,634 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_qasymm8_softmax_kernel_512VL( // + const uint8_t *src, + uint8_t *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4], + const float *lut, + float *tmp) +{ + // Precondition: + // * src_strides[0] == sizeof(uint8_t) + // * dst_strides[0] == sizeof(uint8_t) + // * tmp_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // Registers + // + // * x1: Loop index + // * x2: LUT index + // * x13: temporary, body_length + // + // * x20: index_3 + // * x21: src_3 + // * x22: dst_3 + // * x23: index_2 + // * x24: src_2 + // * x25: dst_2 + // * x26: index_1 + // * x27: src_1 + // * x28: dst_1 + // * x29 tmp + // + // + // * p0: all-true + // * p1: predicate for QASYMM8 values + // * p2: predicate 0 for FP32 values (first quarter of expanded/unpacked p1) + // * p3: predicate 1 for FP32 values (second quarter of expanded/unpacked p1) + // * p4: predicate 2 for FP32 values (third quarter of expanded/unpacked p1) + // * p5: predicate 3 for FP32 values (fourth quarter of expanded/unpacked p1) + // * pn9: all-true for 32 bit values + // * pn8: all-true for 8-bit values + // + // * z0-z15 the 256 LUT values of exp(-scale*beta*x) for x in QASYMM8, stored as FP32 values + + // Prepares all constant values + + ptrue p0.b + .inst 0x25a07811 // ptrue pn9.s + .inst 0x25207810 // ptrue pn8.b + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntb x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + mov x19, %x[lut] + mov x29, %x[tmp] + + // Load the LUT to the register file. + mov x2, %x[lut] + .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2] + + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + // z16-z19 = minimum QASYMM8 value (0) to allow for it to be used for comparison to find the max. + dup z16.b, #0 + dup z17.b, #0 + dup z18.b, #0 + dup z19.b, #0 + mov x1, #0 // x1: index +find_max_body_start%=: + cmp x1, x13 + b.eq find_max_body_end%= + .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z20-z23: x + .inst 0xc134b811 // umax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x) + add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers. + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none find_max_leftover_end%= + + ld1b z30.b, p1/z, [x27, x1] // z30: x + umax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x) + + add x1, x1, #64 + + b find_max_leftover_start%= +find_max_leftover_end%=: + + .inst 0xc132b011 // umax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b } + umax z16.b, p0/m, z16.b, z17.b + umaxv b16, p0, z16.b // Reduction unsigned max operation to get maximum_value + dup z16.b, z16.b[0] + uunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction + uunpklo z16.s, z16.h + + mov x1, #0 // reset index + dup z25.s, #0 + + mov x1, #0 + +regularize_start%=: + whilelo p1.b, x1, %x[length] + b.none regularize_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + ld1b z17.b, p1/z, [x27, x1] //z17: input data + + uunpklo z18.h, z17.b //Using unpack instructions to align the input QASYMM8 values with the FP32 entries in the LUT for use in the TBX instruction + uunpkhi z19.h, z17.b + + uunpklo z17.s, z18.h // z17 = low low input QASYMM8 values + uunpkhi z18.s, z18.h // z18 = low high input QASYMM8 values + + uunpkhi z20.s, z19.h // z20 = high high input QASYMM8 values + uunpklo z19.s, z19.h // z19 = high low input QASYMM8 values + + sub z17.s, z16.s, z17.s // z12: x = max_value - input_data + sub z18.s, z16.s, z18.s // z13: x = max_value - input_data + sub z19.s, z16.s, z19.s // z14: x = max_value - input_data + sub z20.s, z16.s, z20.s // z15: x = max_value - input_data + + tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT. + tbx z22.s, z0.s, z18.s + tbx z23.s, z0.s, z19.s + tbx z24.s, z0.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT. + tbx z22.s, z1.s, z18.s + tbx z23.s, z1.s, z19.s + tbx z24.s, z1.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT. + tbx z22.s, z2.s, z18.s + tbx z23.s, z2.s, z19.s + tbx z24.s, z2.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT. + tbx z22.s, z3.s, z18.s + tbx z23.s, z3.s, z19.s + tbx z24.s, z3.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT. + tbx z22.s, z4.s, z18.s + tbx z23.s, z4.s, z19.s + tbx z24.s, z4.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT. + tbx z22.s, z5.s, z18.s + tbx z23.s, z5.s, z19.s + tbx z24.s, z5.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT. + tbx z22.s, z6.s, z18.s + tbx z23.s, z6.s, z19.s + tbx z24.s, z6.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT. + tbx z22.s, z7.s, z18.s + tbx z23.s, z7.s, z19.s + tbx z24.s, z7.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT. + tbx z22.s, z8.s, z18.s + tbx z23.s, z8.s, z19.s + tbx z24.s, z8.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT. + tbx z22.s, z9.s, z18.s + tbx z23.s, z9.s, z19.s + tbx z24.s, z9.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT. + tbx z22.s, z10.s, z18.s + tbx z23.s, z10.s, z19.s + tbx z24.s, z10.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT. + tbx z22.s, z11.s, z18.s + tbx z23.s, z11.s, z19.s + tbx z24.s, z11.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT. + tbx z22.s, z12.s, z18.s + tbx z23.s, z12.s, z19.s + tbx z24.s, z12.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT. + tbx z22.s, z13.s, z18.s + tbx z23.s, z13.s, z19.s + tbx z24.s, z13.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT. + tbx z22.s, z14.s, z18.s + tbx z23.s, z14.s, z19.s + tbx z24.s, z14.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT. + tbx z22.s, z15.s, z18.s + tbx z23.s, z15.s, z19.s + tbx z24.s, z15.s, z20.s + + + st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p2/m, z25.s, z21.s + add x1, x1, #16 + + st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p3/m, z25.s, z22.s + add x1, x1, #16 + + st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p4/m, z25.s, z23.s + add x1, x1, #16 + + st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p5/m, z25.s, z24.s + add x1, x1, #16 + + b regularize_start%= +regularize_end%=: + + mov w9, 0x0000 + movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [0,255] integer range of QASYMM8 + dup z29.s, w9 + faddv s25, p0, z25.s + fdiv s25, s29, s25 + dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax. + + // ================================================== + // Step 3: Normalize + // ================================================== + mov x1, #0 +normalize_body_start%=: + cmp x1, x13 + b.eq normalize_body_end%= + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z16.s, z25.s, z16.s + fmul z17.s, z25.s, z17.s + fmul z18.s, z25.s, z18.s + fmul z19.s, z25.s, z19.s + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z16-z23: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z16.s, p0/m, z16.s + fcvtzu z17.s, p0/m, z17.s + fcvtzu z18.s, p0/m, z18.s + fcvtzu z19.s, p0/m, z19.s + fcvtzu z20.s, p0/m, z20.s + fcvtzu z21.s, p0/m, z21.s + fcvtzu z22.s, p0/m, z22.s + fcvtzu z23.s, p0/m, z23.s + + // z16-z17: narrow the uint32 values into uint8 and saturate them. + .inst 0xc133e230 // uqcvt z16.b, { z16.s - z19.s } + .inst 0xc133e2b1 // uqcvt z17.b, { z20.s - z23.s } + + dup z20.s, z25.s[0] // Juggling the value to z20 as z25 will be overwritten by the load below + + .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z24.s, z20.s, z24.s + fmul z25.s, z20.s, z25.s + fmul z26.s, z20.s, z26.s + fmul z27.s, z20.s, z27.s + fmul z28.s, z20.s, z28.s + fmul z29.s, z20.s, z29.s + fmul z30.s, z20.s, z30.s + fmul z31.s, z20.s, z31.s + + // z24-z31: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z24.s, p0/m, z24.s + fcvtzu z25.s, p0/m, z25.s + fcvtzu z26.s, p0/m, z26.s + fcvtzu z27.s, p0/m, z27.s + fcvtzu z28.s, p0/m, z28.s + fcvtzu z29.s, p0/m, z29.s + fcvtzu z30.s, p0/m, z30.s + fcvtzu z31.s, p0/m, z31.s + + // z18-z19: narrow the uint32 values into uint8 and saturate them. + .inst 0xc133e332 // uqcvt z18.b, { z24.s - z27.s } + .inst 0xc133e3b3 // uqcvt z19.b, { z28.s - z31.s } + + .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2] + + dup z25.s, z20.s[0] // Juggling the value back to z25 as z20 will be overwritten by the next iteration or z25 will be used below. + +b normalize_body_start%= +normalize_body_end%=: + +normalize_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none normalize_leftover_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + + // z20-z23: load exp(-scale*beta*x) from the tmp tensor + ld1w z20.s, p2/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z21.s, p3/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z22.s, p4/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z23.s, p5/z, [x29, x1, LSL #2] + add x1, x1, #16 + + // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z20-23: convert the FP32 values from the tmp tensor to uint32. + fcvtzu z20.s, p0/m, z20.s + fcvtzu z21.s, p0/m, z21.s + fcvtzu z22.s, p0/m, z22.s + fcvtzu z23.s, p0/m, z23.s + + .inst 0xc133e2b3 // uqcvt z19.b, { z20.s - z23.s }, narrow the uint32 values into uint8 and saturate them into z19. + + st1b z19.b, p1, [x28, x2] + + b normalize_leftover_start%= +normalize_leftover_end%=: + // ================================================== + // 3D loop closing + // ================================================== + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p1", "p2", "p3", "p4", // + "x2", "x9", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_qasymm8_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + Strides tmp_strides; + + tmp_strides[0] = src_strides[0] * 4; + tmp_strides[1] = src_strides[1] * 4; + tmp_strides[2] = src_strides[2] * 4; + tmp_strides[3] = src_strides[3] * 4; + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + // + window[1].start() * tmp_strides[1] + // + window[2].start() * tmp_strides[2] + // + window[3].start() * tmp_strides[3]; + + const auto *k_src = reinterpret_cast<const uint8_t *>(in->buffer() + k_src_offset); + float *tmp_float_ptr = reinterpret_cast<float *>(tmp); + auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset); + auto *k_dst = reinterpret_cast<uint8_t *>(out->buffer() + k_dst_offset); + + sme2_qasymm8_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp new file mode 100644 index 0000000000..14c0f6c327 --- /dev/null +++ b/src/cpu/kernels/softmax/generic/sme2/qasymm8_signed.cpp @@ -0,0 +1,655 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ARM_COMPUTE_ENABLE_SME2 + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Window.h" + +namespace arm_compute +{ +namespace cpu +{ + +// SoftMax +// +// Steps: +// * Find max: max_value = max(src) +// * Regularize: dst[i] = exp(src[i] - max_value) +// sum_value = sum(dst) +// * Normalize: dst[i] = dst[i] / sum_value +void sme2_qasymm8_signed_softmax_kernel_512VL( // + const int8_t *src, + int8_t *dst, + float beta, + const uintptr_t shape[4], + const uintptr_t src_strides[4], + const uintptr_t dst_strides[4], + const float *lut, + float *tmp) +{ + // Precondition: + // * src_strides[0] == sizeof(int8_t) + // * dst_strides[0] == sizeof(int8_t) + // * tmp_strides[0] == sizeof(float) + + __asm__ volatile( + R"( + .inst 0xd503477f // smstart + + // For register list explanation refer to qasymm8.cpp. + + // Prepares all constant values + + ptrue p0.b + .inst 0x25a07811 // ptrue pn9.s + .inst 0x25207810 // ptrue pn8.b + + // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl + cntb x13, ALL, MUL #4 + udiv x9, %x[length], x13 + mul x13, x13, x9 + + // ================================================== + // 3D loop opening + // ================================================== + + mov x20, %x[shape_3] + mov x21, %x[src] + mov x22, %x[dst] + mov x19, %x[lut] + mov x29, %x[tmp] + + // Load the LUT to the register file. + mov x2, %x[lut] + .inst 0xa040c440 //ld1w { z0.s - z3.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c444 //ld1w { z4.s - z7.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c448 //ld1w { z8.s - z11.s }, pn9/z, [x2] + add x2, x2, #256 + .inst 0xa040c44c //ld1w { z12.s - z15.s }, pn9/z, [x2] + + +loop_3_start%=: + // for index_3 in shape_3 downto 1 + cmp x20, #0 + b.eq loop_3_end%= + sub x20, x20, #1 + + mov x23, %x[shape_2] + mov x24, x21 + mov x25, x22 + +loop_2_start%=: + // for index_2 in shape_2 downto 1 + cmp x23, #0 + b.eq loop_2_end%= + sub x23, x23, #1 + + mov x26, %x[shape_1] + mov x27, x24 + mov x28, x25 + +loop_1_start%=: + // for index_1 in shape_2 downto 1 + cmp x26, #0 + b.eq loop_1_end%= + sub x26, x26, #1 + + // ================================================== + // Step 1: Find max + // ================================================== + // z16-z19 = minimum QASYMM8_SIGNED value (-128) to allow for it to be used for comparison to find the max. + dup z16.b, #0x80 + dup z17.b, #0x80 + dup z18.b, #0x80 + dup z19.b, #0x80 + + mov x1, #0 // x1: index +find_max_body_start%=: + cmp x1, x13 + b.eq find_max_body_end%= + .inst 0xa0018374 // ld1b { z20.b - z23.b }, pn8/z, [x27, x1] z16-z19: x + .inst 0xc134b810 // smax { z16.b - z19.b }, { z16.b - z19.b }, { z20.b - z23.b } z16-z19: max_value = max(max_value, x) + add x1, x1, #256 // Advance index by 256 bytes/integers: Z registers = 2048-bit data = 256 8-bit integers. + b find_max_body_start%= +find_max_body_end%=: + + // Loop for processing the leftover part. +find_max_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none find_max_leftover_end%= + + ld1b z30.b, p1/z, [x27, x1] // z30: x + smax z16.b, p1/m, z16.b, z30.b // z16: max_value = max(max_value, x) + + add x1, x1, #64 + + b find_max_leftover_start%= +find_max_leftover_end%=: + .inst 0xc132b010 // smax { z16.b, z17.b }, { z16.b, z17.b }, { z18.b, z19.b } + smax z16.b, p0/m, z16.b, z17.b + smaxv b16, p0, z16.b // Reduction signed max operation to get maximum_value + mov z16.b, b16 // z16: duplicated max_value for current row + + sunpklo z16.h, z16.b // Using unpack instructions to align the max value with the FP32 entries in the LUT for use in the TBX instruction + sunpklo z16.s, z16.h + + mov x1, #0 // reset index + dup z25.s, #0 + + +regularize_start%=: + whilelo p1.b, x1, %x[length] + b.none regularize_end%= + + mov w9, 0xFF80 + movk w9, 0xFFFF, LSL #16 // Moving -127.f into w9 to set the registers below to the minimum QASYMM8_SIGNED value + dup z17.s, w9 + dup z18.s, w9 + dup z19.s, w9 + dup z20.s, w9 + + dup z21.s, #0x0 + dup z22.s, #0x0 + dup z23.s, #0x0 + dup z24.s, #0x0 + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + ld1b z17.b, p1/z, [x27, x1] //z17: input data + + sunpklo z18.h, z17.b // Using unpack instructions to align the input QASYMM8_SIGNED values with the FP32 entries in the LUT for use in the TBX instruction + sunpkhi z19.h, z17.b // + + sunpklo z17.s, z18.h // z17 = low low input QASYMM8_SIGNED values + sunpkhi z18.s, z18.h // z18 = low high input QASYMM8_SIGNED values + + sunpkhi z20.s, z19.h // z20 = high high input QASYMM8_SIGNED values + sunpklo z19.s, z19.h // z19 = high low input QASYMM8_SIGNED values + + sub z17.s, z16.s, z17.s // z12: x = max_value - input_data + sub z18.s, z16.s, z18.s // z13: x = max_value - input_data + sub z19.s, z16.s, z19.s // z14: x = max_value - input_data + sub z20.s, z16.s, z20.s // z15: x = max_value - input_data + + add z17.s, z17.s, #128 + add z18.s, z18.s, #128 + add z19.s, z19.s, #128 + add z20.s, z20.s, #128 + + tbx z21.s, z0.s, z17.s // Look-up entries 0-15 in the LUT. + tbx z22.s, z0.s, z18.s + tbx z23.s, z0.s, z19.s + tbx z24.s, z0.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z1.s, z17.s // Look-up entries 16-31 in the LUT. + tbx z22.s, z1.s, z18.s + tbx z23.s, z1.s, z19.s + tbx z24.s, z1.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z2.s, z17.s // Look-up entries 32-47 in the LUT. + tbx z22.s, z2.s, z18.s + tbx z23.s, z2.s, z19.s + tbx z24.s, z2.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z3.s, z17.s // Look-up entries 48-63 in the LUT. + tbx z22.s, z3.s, z18.s + tbx z23.s, z3.s, z19.s + tbx z24.s, z3.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z4.s, z17.s // Look-up entries 64-79 in the LUT. + tbx z22.s, z4.s, z18.s + tbx z23.s, z4.s, z19.s + tbx z24.s, z4.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z5.s, z17.s // Look-up entries 80-95 in the LUT. + tbx z22.s, z5.s, z18.s + tbx z23.s, z5.s, z19.s + tbx z24.s, z5.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z6.s, z17.s // Look-up entries 96-111 in the LUT. + tbx z22.s, z6.s, z18.s + tbx z23.s, z6.s, z19.s + tbx z24.s, z6.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z7.s, z17.s // Look-up entries 112-127 in the LUT. + tbx z22.s, z7.s, z18.s + tbx z23.s, z7.s, z19.s + tbx z24.s, z7.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z8.s, z17.s // Look-up entries 128-143 in the LUT. + tbx z22.s, z8.s, z18.s + tbx z23.s, z8.s, z19.s + tbx z24.s, z8.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z9.s, z17.s // Look-up entries 144-159 in the LUT. + tbx z22.s, z9.s, z18.s + tbx z23.s, z9.s, z19.s + tbx z24.s, z9.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z10.s, z17.s // Look-up entries 160-175 in the LUT. + tbx z22.s, z10.s, z18.s + tbx z23.s, z10.s, z19.s + tbx z24.s, z10.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z11.s, z17.s // Look-up entries 176-191 in the LUT. + tbx z22.s, z11.s, z18.s + tbx z23.s, z11.s, z19.s + tbx z24.s, z11.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z12.s, z17.s // Look-up entries 192-207 in the LUT. + tbx z22.s, z12.s, z18.s + tbx z23.s, z12.s, z19.s + tbx z24.s, z12.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z13.s, z17.s // Look-up entries 208-223 in the LUT. + tbx z22.s, z13.s, z18.s + tbx z23.s, z13.s, z19.s + tbx z24.s, z13.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z14.s, z17.s // Look-up entries 224-239 in the LUT. + tbx z22.s, z14.s, z18.s + tbx z23.s, z14.s, z19.s + tbx z24.s, z14.s, z20.s + + sub z17.s, z17.s, #16 + sub z18.s, z18.s, #16 + sub z19.s, z19.s, #16 + sub z20.s, z20.s, #16 + + tbx z21.s, z15.s, z17.s // Look-up entries 240-255 in the LUT. + tbx z22.s, z15.s, z18.s + tbx z23.s, z15.s, z19.s + tbx z24.s, z15.s, z20.s + + + st1w z21.s, p2, [x29, x1, LSL #2]// z21 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p2/m, z25.s, z21.s + add x1, x1, #16 + + st1w z22.s, p3, [x29, x1, LSL #2]// z22 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p3/m, z25.s, z22.s + add x1, x1, #16 + + st1w z23.s, p4, [x29, x1, LSL #2]// z23 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p4/m, z25.s, z23.s + add x1, x1, #16 + + st1w z24.s, p5, [x29, x1, LSL #2]// z24 store exp(-scale*beta*x) into the tmp tensor + fadd z25.s, p5/m, z25.s, z24.s + add x1, x1, #16 + + b regularize_start%= +regularize_end%=: + + mov w9, 0x0000 + movk w9, 0x4380, LSL #16 // Moving 256.f into w9 to scale - via multiplication (division by reciprocal) - the floating point [0,1] range of the results to the [-128, 127] integer range of QASYMM8_SIGNED + mov w10, 0x0000 + movk w10, 0x4300, LSL #16 // Moving 128.f into w10 for the subtraction to move the results - via subtraction - from the [0,255] range to the [-128,127] range + dup z29.s, w9 + dup z30.s, w10 + faddv s25, p0, z25.s + fdiv s25, s29, s25 + dup z25.s, z25.s[0] // z25: 256.f/sum. 256 is needed to get the full range and 1/sum is part of softmax. + + // ================================================== + // Step 3: Normalize + // ================================================== + mov x1, #0 +normalize_body_start%=: + cmp x1, x13 + b.eq normalize_body_end%= + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + .inst 0xa001c7b0 // ld1w { z16.s - z19.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7b4 // ld1w { z20.s - z23.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z16-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z16.s, z25.s, z16.s + fmul z17.s, z25.s, z17.s + fmul z18.s, z25.s, z18.s + fmul z19.s, z25.s, z19.s + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + // z16-z23: subtract 128.f. + fsub z16.s, z16.s, z30.s // Subtract 128.f + fsub z17.s, z17.s, z30.s // Subtract 128.f + fsub z18.s, z18.s, z30.s // Subtract 128.f + fsub z19.s, z19.s, z30.s // Subtract 128.f + fsub z20.s, z20.s, z30.s // Subtract 128.f + fsub z21.s, z21.s, z30.s // Subtract 128.f + fsub z22.s, z22.s, z30.s // Subtract 128.f + fsub z23.s, z23.s, z30.s // Subtract 128.f + + // z16-z23: convert the FP32 values from the tmp tensor to int32. + fcvtzs z16.s, p0/m, z16.s + fcvtzs z17.s, p0/m, z17.s + fcvtzs z18.s, p0/m, z18.s + fcvtzs z19.s, p0/m, z19.s + fcvtzs z20.s, p0/m, z20.s + fcvtzs z21.s, p0/m, z21.s + fcvtzs z22.s, p0/m, z22.s + fcvtzs z23.s, p0/m, z23.s + + // z16-z17: narrow the int32 values into int8 and saturate them. + .inst 0xc133e210 // sqcvt z16.b, { z16.s - z19.s } + .inst 0xc133e291 // sqcvt z17.b, { z20.s - z23.s } + + // Juggling the value to z20 (resp. 21) as z25 (resp. z30) will be overwritten by the load below. + dup z20.s, z25.s[0] + dup z21.s, z30.s[0] + + .inst 0xa001c7b8 // ld1w { z24.s - z27.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + .inst 0xa001c7bc // ld1w { z28.s - z31.s }, pn9/z, [x29, x1, lsl #2] + add x1, x1, #64 + + // z24-z31: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z24.s, z20.s, z24.s + fmul z25.s, z20.s, z25.s + fmul z26.s, z20.s, z26.s + fmul z27.s, z20.s, z27.s + fmul z28.s, z20.s, z28.s + fmul z29.s, z20.s, z29.s + fmul z30.s, z20.s, z30.s + fmul z31.s, z20.s, z31.s + + // z24-z31: subtract 128.f. + fsub z24.s, z24.s, z21.s + fsub z25.s, z25.s, z21.s + fsub z26.s, z26.s, z21.s + fsub z27.s, z27.s, z21.s + fsub z28.s, z28.s, z21.s + fsub z29.s, z29.s, z21.s + fsub z30.s, z30.s, z21.s + fsub z31.s, z31.s, z21.s + + // z24-z31: convert the FP32 values from the tmp tensor to int32. + fcvtzs z24.s, p0/m, z24.s + fcvtzs z25.s, p0/m, z25.s + fcvtzs z26.s, p0/m, z26.s + fcvtzs z27.s, p0/m, z27.s + fcvtzs z28.s, p0/m, z28.s + fcvtzs z29.s, p0/m, z29.s + fcvtzs z30.s, p0/m, z30.s + fcvtzs z31.s, p0/m, z31.s + + // z18-z19: narrow the int32 values into int8 and saturate them. + .inst 0xc133e312 // sqcvt z18.b, { z24.s - z27.s } + .inst 0xc133e393 // sqcvt z19.b, { z28.s - z31.s } + + .inst 0xa0228390 // st1b { z16.b - z19.b }, pn8, [x28, x2] + + // Juggling the values back to z25 (resp. z30) as z20 (resp. z21) will be overwritten by the next iteration or z25 (resp. z30) will be used below. + dup z25.s, z20.s[0] + dup z30.s, z21.s[0] +b normalize_body_start%= +normalize_body_end%=: +normalize_leftover_start%=: + whilelo p1.b, x1, %x[length] + b.none normalize_leftover_end%= + + // p2-p5 are - together - the 32-bit version of p1, the instructions below unpack p1 into those four predicate registers to allow for the 32-bit loads below to be correctly predicated + punpklo p2.h, p1.b + punpkhi p4.h, p1.b + + punpkhi p3.h, p2.b + punpklo p2.h, p2.b + + punpkhi p5.h, p4.b + punpklo p4.h, p4.b + + mov x2, x1 // Preserve the index into x2 for the final store to dst. + + // z20-z23: load exp(-scale*beta*x) from the tmp tensor + ld1w z20.s, p2/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z21.s, p3/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z22.s, p4/z, [x29, x1, LSL #2] + add x1, x1, #16 + + ld1w z23.s, p5/z, [x29, x1, LSL #2] + add x1, x1, #16 + + // z20-z23: effectively divides exp(-scale*beta*x) by the sum of the exponentials for the current row and multiplies by 256. + fmul z20.s, z25.s, z20.s + fmul z21.s, z25.s, z21.s + fmul z22.s, z25.s, z22.s + fmul z23.s, z25.s, z23.s + + //z20-z23: Subtract 128.f. + fsub z20.s, z20.s, z30.s + fsub z21.s, z21.s, z30.s + fsub z22.s, z22.s, z30.s + fsub z23.s, z23.s, z30.s + + // z20-23: convert the FP32 values from the tmp tensor to int32. + fcvtzs z20.s, p0/m, z20.s + fcvtzs z21.s, p0/m, z21.s + fcvtzs z22.s, p0/m, z22.s + fcvtzs z23.s, p0/m, z23.s + + .inst 0xc133e293 // sqcvt z19.b, { z20.s - z23.s }, narrow the int32 values into int8 and saturate them into z19. + + st1b z19.b, p1, [x28, x2] + + b normalize_leftover_start%= +normalize_leftover_end%=: + // ================================================== + // 3D loop closing + // ================================================== + add x27, x27, %x[src_stride_1] + add x28, x28, %x[dst_stride_1] + b loop_1_start%= +loop_1_end%=: + + add x24, x24, %x[src_stride_2] + add x25, x25, %x[dst_stride_2] + b loop_2_start%= +loop_2_end%=: + + add x21, x21, %x[src_stride_3] + add x22, x22, %x[dst_stride_3] + b loop_3_start%= +loop_3_end%=: + .inst 0xd503467f // smstop + )" + : + : [src] "r"(src), [tmp] "r"(tmp), [dst] "r"(dst), [beta] "r"(beta), [lut] "r"(lut), // + [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), // + [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]), + [src_stride_3] "r"(src_strides[3]), // + [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]), + [dst_stride_3] "r"(dst_strides[3]), // + [length] "r"(shape[0]) // + : "cc", "memory", // + "p0", "p1", "p2", "p3", "p4", // + "x2", "x9", "x13", // + "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "x19", // + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", // + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", // + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", // + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" // + ); +} + +void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr) +{ + ARM_COMPUTE_UNUSED(axis); + + const auto *src_info = in->info(); + const auto *dst_info = out->info(); + + const auto &full_shape = dst_info->tensor_shape(); + const auto &src_strides = src_info->strides_in_bytes(); + const auto &dst_strides = dst_info->strides_in_bytes(); + Strides tmp_strides; + + tmp_strides[0] = src_strides[0] * 4; + tmp_strides[1] = src_strides[1] * 4; + tmp_strides[2] = src_strides[2] * 4; + tmp_strides[3] = src_strides[3] * 4; + + const uintptr_t k_shape[] = { + full_shape[0], + window.num_iterations(1), + window.num_iterations(2), + window.num_iterations(3), + }; + + const uintptr_t k_src_strides[] = { + src_strides[0], + src_strides[1], + src_strides[2], + src_strides[3], + }; + + const uintptr_t k_dst_strides[] = { + dst_strides[0], + dst_strides[1], + dst_strides[2], + dst_strides[3], + }; + + const uintptr_t k_src_offset = window[0].start() * src_strides[0] + // + window[1].start() * src_strides[1] + // + window[2].start() * src_strides[2] + // + window[3].start() * src_strides[3]; + + const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + // + window[1].start() * dst_strides[1] + // + window[2].start() * dst_strides[2] + // + window[3].start() * dst_strides[3]; + + const uintptr_t k_tmp_offset = window[0].start() * tmp_strides[0] + // + window[1].start() * tmp_strides[1] + // + window[2].start() * tmp_strides[2] + // + window[3].start() * tmp_strides[3]; + + const auto *k_src = reinterpret_cast<const int8_t *>(in->buffer() + k_src_offset); + float *tmp_float_ptr = reinterpret_cast<float *>(tmp); + auto *k_tmp = reinterpret_cast<float *>(tmp_float_ptr + k_tmp_offset); + auto *k_dst = reinterpret_cast<int8_t *>(out->buffer() + k_dst_offset); + + sme2_qasymm8_signed_softmax_kernel_512VL(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides, lut_ptr, k_tmp); +} + +} // namespace cpu +} // namespace arm_compute + +#endif // ARM_COMPUTE_ENABLE_SME2 diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h index f9295ebbcc..7bbb265022 100644 --- a/src/cpu/kernels/softmax/list.h +++ b/src/cpu/kernels/softmax/list.h @@ -28,15 +28,52 @@ namespace arm_compute { namespace cpu { -#define DECLARE_SOFTMAX_KERNEL(func_name) \ - template <bool IS_LOG> \ - void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window) +#define DECLARE_SOFTMAX_KERNEL(func_name) \ + template <bool IS_LOG> \ + void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window, \ + const float *lut_ptr) DECLARE_SOFTMAX_KERNEL(neon_fp32_softmax); DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax); DECLARE_SOFTMAX_KERNEL(neon_qasymm8_softmax); DECLARE_SOFTMAX_KERNEL(neon_qasymm8_signed_softmax); +#ifdef ARM_COMPUTE_ENABLE_SME2 + +void sme2_fp32_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +void sme2_fp16_softmax(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +void sme2_qasymm8_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +void sme2_qasymm8_signed_softmax_lut_512VL(const ITensor *in, + void *const tmp, + ITensor *out, + const float beta, + int axis, + const Window &window, + const float *lut_ptr); + +#endif // ARM_COMPUTE_ENABLE_SME2 + #undef DECLARE_SOFTMAX_KERNEL } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index e035de0131..905e86c185 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 Arm Limited. + * Copyright (c) 2021-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -53,6 +53,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.fast_mode = info.fast_math(); asm_info.fixed_format = info.fixed_format(); asm_info.weight_format = info.weight_format(); + asm_info.accumulate = info.accumulate(); asm_info.transpose_b = info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method @@ -219,6 +220,16 @@ Status CpuGemm::validate(const ITensorInfo *a, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); + // When using accumulation(in place summation), for now, the only supported values for alpha and beta are 1 respectively 0. + // Do the appropriate checks before proceeding. + if (gemm_info.accumulate()) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(alpha != 1, "Accumulation is not supported when alpha is different from 1"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + (beta != 0 && c != nullptr), + "Accumulation is not supported when beta is different from 0 with a non-null bias matrix c"); + } + const bool is_c_bias = beta == 1 && c != nullptr; const bool run_addition = c != nullptr && beta != 0 && beta != 1; // Check if we should use the pretransposed_b or original b diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp index 7460f2020c..55d950ff4a 100644 --- a/src/cpu/operators/CpuGemmConv2d.cpp +++ b/src/cpu/operators/CpuGemmConv2d.cpp @@ -809,9 +809,16 @@ void CpuGemmConv2d::run(ITensorPack &tensors) if (!_skip_im2col) { // Run input reshaping - unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); - ITensorPack pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, im2col_output.get()}}; - NEScheduler::get().schedule_op(_im2col_kernel.get(), y_dim, _im2col_kernel->window(), pack); + unsigned int hint_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); + unsigned int x_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); + unsigned int hint_dim_iterations = _im2col_kernel->window().num_iterations(hint_dim); + unsigned int x_dim_iterations = _im2col_kernel->window().num_iterations(x_dim); + if (hint_dim_iterations < NEScheduler::get().num_threads() && x_dim_iterations > hint_dim_iterations) + { + hint_dim = x_dim; + } + ITensorPack pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, im2col_output.get()}}; + NEScheduler::get().schedule_op(_im2col_kernel.get(), hint_dim, _im2col_kernel->window(), pack); gemm_input_to_use = im2col_output.get(); } diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp index b25505a85d..f3396fbb5c 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 Arm Limited. + * Copyright (c) 2021-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -65,6 +65,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.activation_info = info.activation_info(); asm_info.output_stage = info.gemmlowp_output_stage(); asm_info.fast_mode = info.fast_math(); + asm_info.accumulate = info.accumulate(); return asm_info; } @@ -127,6 +128,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure( _reshape_b_only_on_first_run; _gemm_info = gemm_info; + // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). + // It is not needed if the datatype is symmetric, because there is no offset + bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic(); + bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic(); + _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>(); const ITensorInfo *a_to_use = a; @@ -228,8 +234,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure( // Build reduction info const GEMMLowpReductionKernelInfo reduction_info(a_to_use->dimension(0), false, 0, false); - // Initialize matrix B reduction kernel only if _a_offset is not equal to 0 - if (_a_offset != 0) + if (a_offset_kernel_needed) { _vector_sum_col = TensorInfo(compute_reductionA_shape(*b), 1, DataType::S32); @@ -238,8 +243,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure( _mtx_b_reduction_kernel->configure(b, &_vector_sum_col, reduction_info); } - // Initialize Matrix A reduction kernel only if _b_offset is not equal to 0 - if (_b_offset != 0) + if (b_offset_kernel_needed) { _vector_sum_row = TensorInfo(compute_reductionB_shape(*a_to_use), 1, DataType::S32); @@ -260,8 +264,8 @@ void CpuGemmLowpMatrixMultiplyCore::configure( _offset_contribution_output_stage_kernel = std::make_unique<kernels::CpuGemmLowpOffsetContributionOutputStageKernel>(); _offset_contribution_output_stage_kernel->configure( - &_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, - _b_offset == 0 ? nullptr : &_vector_sum_row, c, _flip_signedness ? &_signed_output : dst, + &_mm_result_s32, a_offset_kernel_needed ? &_vector_sum_col : nullptr, + b_offset_kernel_needed ? &_vector_sum_row : nullptr, c, _flip_signedness ? &_signed_output : dst, a->dimension(0), _a_offset, _b_offset, info.gemmlowp_output_stage()); if (_flip_signedness) @@ -272,6 +276,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure( } else { + // This scale is needed for the s8_f32 kernel where the multiplication output is dequantized to F32. + const float dequantize_scale = + (dst->data_type() == DataType::F32) + ? a->quantization_info().uniform().scale * b->quantization_info().uniform().scale + : 1.0f; // Configure matrix multiply kernel if (!_assembly_path) { @@ -280,9 +289,9 @@ void CpuGemmLowpMatrixMultiplyCore::configure( } // Configure offset contribution kernel _offset_contribution_kernel = std::make_unique<kernels::CpuGemmLowpOffsetContributionKernel>(); - _offset_contribution_kernel->configure(dst, _a_offset == 0 ? nullptr : &_vector_sum_col, - _b_offset == 0 ? nullptr : &_vector_sum_row, a_to_use->dimension(0), - _a_offset, _b_offset); + _offset_contribution_kernel->configure(dst, a_offset_kernel_needed ? &_vector_sum_col : nullptr, + b_offset_kernel_needed ? &_vector_sum_row : nullptr, + a_to_use->dimension(0), _a_offset, _b_offset, dequantize_scale); } } // Configure activation @@ -305,11 +314,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure( } // Request memory for LHS and RHS reshape matrix - _aux_mem[VectorSumCol] = - MemoryInfo(offset_int_vec(VectorSumCol), - !_fused_assembly_path && _a_offset != 0 && _reshape_b_only_on_first_run ? MemoryLifetime::Persistent - : MemoryLifetime::Temporary, - _vector_sum_col.total_size()); + _aux_mem[VectorSumCol] = MemoryInfo(offset_int_vec(VectorSumCol), + !_fused_assembly_path && a_offset_kernel_needed && _reshape_b_only_on_first_run + ? MemoryLifetime::Persistent + : MemoryLifetime::Temporary, + _vector_sum_col.total_size()); _aux_mem[VectorSumRow] = MemoryInfo(offset_int_vec(VectorSumRow), MemoryLifetime::Temporary, _vector_sum_row.total_size()); _aux_mem[TmpA] = MemoryInfo(offset_int_vec(TmpA), MemoryLifetime::Temporary, _tmp_a.total_size()); @@ -333,8 +342,8 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32, DataType::QASYMM8, - DataType::QASYMM8_SIGNED); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(c != nullptr && + DataType::QASYMM8_SIGNED, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(c != nullptr && output->data_type() != DataType::F32 && gemm_info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::NONE, "Bias addition not supported in NEGEMMLowpMatrixMultiplyCore for output S32"); ARM_COMPUTE_RETURN_ERROR_ON_MSG( @@ -343,6 +352,16 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported"); + // When using accumulation(in place summation), for now, the only supported DataType for output is S32. + if (gemm_info.accumulate()) + { +#ifdef __arm__ + ARM_COMPUTE_RETURN_ERROR_MSG("Accumulation is not supported for armv7"); +#endif /* __arm__ */ + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE, + "Accumulation is not supported for output QASYMM8/QASYMM8_SIGNED"); + } + GEMMInfo info = gemm_info; const ITensorInfo *matrix_a_info = a; const ITensorInfo *matrix_b_info = b; @@ -356,6 +375,10 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, int32_t a_offset = a->quantization_info().uniform().offset; int32_t b_offset = b->quantization_info().uniform().offset; + // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic). + bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic(); + bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic(); + bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE; if (fuse_output_stage) { @@ -478,7 +501,7 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const GEMMLowpReductionKernelInfo reduction_info(a_to_use->dimension(0), false, 0, false); // Validate matrix B reduction kernel only if _a_offset is not equal to 0 - if (a_offset != 0) + if (a_offset_kernel_needed) { info_vector_sum_col = TensorInfo(compute_reductionA_shape(*b), 1, DataType::S32); @@ -488,7 +511,7 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, } // Validate Matrix A reduction kernel only if _b_offset is not equal to 0 - if (b_offset != 0) + if (b_offset_kernel_needed) { info_vector_sum_row = TensorInfo(compute_reductionB_shape(*a), 1, DataType::S32); @@ -514,9 +537,9 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, // Validate offset contribution kernel ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuGemmLowpOffsetContributionOutputStageKernel::validate( - &mm_result_s32_info, a_offset == 0 ? nullptr : &info_vector_sum_col, - b_offset == 0 ? nullptr : &info_vector_sum_row, c, flip_signedness ? &signed_output : output, a_offset, - b_offset, info.gemmlowp_output_stage())); + &mm_result_s32_info, a_offset_kernel_needed ? &info_vector_sum_col : nullptr, + b_offset_kernel_needed ? &info_vector_sum_row : nullptr, c, flip_signedness ? &signed_output : output, + a_offset, b_offset, info.gemmlowp_output_stage())); } else { @@ -534,8 +557,8 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, } // Validate offset contribution kernel ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuGemmLowpOffsetContributionKernel::validate( - output, a_offset == 0 ? nullptr : &info_vector_sum_col, b_offset == 0 ? nullptr : &info_vector_sum_row, - a_offset, b_offset)); + output, a_offset_kernel_needed ? &info_vector_sum_col : nullptr, + b_offset_kernel_needed ? &info_vector_sum_row : nullptr, a_offset, b_offset)); } } @@ -569,6 +592,14 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors) CpuAuxTensorHandler signed_a(offset_int_vec(SignedA), _signed_a, tensors, false); CpuAuxTensorHandler signed_output(offset_int_vec(SignedOutput), _signed_output, tensors, false); + const QuantizationInfo a_qinfo = a->info()->quantization_info(); + const QuantizationInfo b_qinfo = b->info()->quantization_info(); + + if (a_qinfo.is_dynamic()) + _a_offset = a_qinfo.uniform().offset; + if (b_qinfo.is_dynamic()) + _b_offset = b_qinfo.uniform().offset; + // Convert QASYMM8->QASYMM8_SIGNED if (_flip_signedness) { @@ -651,6 +682,11 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors) if (_fuse_output_stage) { + if (a_qinfo.is_dynamic()) + _offset_contribution_output_stage_kernel->set_a_offset(_a_offset); + if (b_qinfo.is_dynamic()) + _offset_contribution_output_stage_kernel->set_b_offset(_b_offset); + ITensorPack pack; pack.add_tensor(TensorType::ACL_SRC_0, mm_result_s32.get()); pack.add_tensor(TensorType::ACL_SRC_1, _a_offset == 0 ? nullptr : vector_sum_col.get()); @@ -664,6 +700,16 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors) } else { + if (a_qinfo.is_dynamic()) + _offset_contribution_kernel->set_a_offset(_a_offset); + if (b_qinfo.is_dynamic()) + _offset_contribution_kernel->set_b_offset(_b_offset); + if (a_qinfo.is_dynamic() || b_qinfo.is_dynamic()) + { + const float dequantize_scale = a_qinfo.uniform().scale * b_qinfo.uniform().scale; + _offset_contribution_kernel->set_scale(dequantize_scale); + } + ITensorPack pack; pack.add_tensor(TensorType::ACL_SRC_0, _a_offset == 0 ? nullptr : vector_sum_col.get()); pack.add_tensor(TensorType::ACL_SRC_1, _b_offset == 0 ? nullptr : vector_sum_row.get()); diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h index 78065a8953..38121c9bb4 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023 Arm Limited. + * Copyright (c) 2021, 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -92,6 +92,7 @@ public: * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |S32 | * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |S32 | * |QASYMM8_SIGNED |QSYMM8 |S32 |S32 | + * |QASYMM8_SIGNED |QASYMM8_SIGNED |F32 |F32 | * * @note GEMM_LOWP: low precision GEMM kernel * This kernel performs the following computations: @@ -100,12 +101,12 @@ public: * -# Convert b values from QASYMM8 to int32 add b_offset to each of them. * -# Compute the matrix product of the resulting a * b in int32. * - * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED otherwise + * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED/F32 otherwise * * @param[in] a First input tensor info (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED. * @param[in] b Second input tensor info (Matrix B). Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM8/QSYMM8_PER_CHANNEL. - * @param[in] c Third input tensor info (Matrix C). It can be a nullptr. Data type supported: S32 - * @param[out] dst Output tensor info. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED + * @param[in] c Third input tensor info (Matrix C). It can be a nullptr. Data type supported: S32/F32 + * @param[out] dst Output tensor info. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED/F32 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should be executed only for the first run */ diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index efe2a7a67e..a4c856bb8f 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -540,6 +540,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo * { configure_indirect(a, b, d, gemm_info); } + + if (std::is_same<OutputStage, arm_gemm::DequantizeFloat>::value) + { + // Output dequantization is just the two src scales multiplied together + _gemm_kernel_asm->set_dequantize_scale(a->quantization_info().uniform().scale * + b->quantization_info().uniform().scale); + } } template <typename TypeInput, typename TypeOutput, class OutputStage> @@ -630,6 +637,15 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) auto d = tensors.get_tensor(TensorType::ACL_DST); ARM_COMPUTE_ERROR_ON_NULLPTR(a, d); + // Only update at runtime if the src quantization is dynamic + if (std::is_same<OutputStage, arm_gemm::DequantizeFloat>::value && + (a->info()->quantization_info().is_dynamic() || b->info()->quantization_info().is_dynamic())) + { + // Output dequantization is just the two src scales multiplied together + _gemm_kernel_asm->set_dequantize_scale(a->info()->quantization_info().uniform().scale * + b->info()->quantization_info().uniform().scale); + } + int lda = a->info()->strides_in_bytes().y() / a->info()->element_size(); int ldb = 0; const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size(); @@ -775,7 +791,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge arm_gemm::GemmConfig cfg; cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, - info.fixed_format, info.fast_mode, &cfg); + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); @@ -784,6 +800,39 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge } template <typename TypeInput, typename TypeOutput> +void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, + const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + ITensorInfo *d, + arm_gemm::Activation activation, + const AsmGemmInfo &info) +{ + ARM_COMPUTE_UNUSED(activation); + + Params p = extract_parameters(a, b, d, info); + const CPUInfo &ci = NEScheduler::get().cpu_info(); + const unsigned int num_threads = NEScheduler::get().num_threads(); + + arm_gemm::GemmConfig cfg; + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, + info.fixed_format, info.fast_mode, info.accumulate, &cfg); + + // Create arm_gemm fallback + auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>(); + + // Configure requantization info + const GEMMLowpOutputStageInfo os_info = info.output_stage; + + arm_gemm::DequantizeFloat gemm_dequant_info{}; + gemm_dequant_info = arm_gemm::DequantizeFloat(d->quantization_info().uniform().scale); + + fallback->configure(a, b, c, d, args, info, gemm_dequant_info); + arm_gemm = std::move(fallback); +} + +template <typename TypeInput, typename TypeOutput> void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, @@ -800,7 +849,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & arm_gemm::GemmConfig cfg; cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, - info.fixed_format, info.fast_mode, &cfg); + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); @@ -855,8 +904,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, - info.fixed_format, info.fast_mode, &cfg); - + info.fixed_format, info.fast_mode, info.accumulate, &cfg); // TODO: Incorporate info.transpose_b COMPMID-6595 switch (a->data_type()) { @@ -897,6 +945,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected } break; #endif /* __aarch64__ */ + #if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: { @@ -915,13 +964,14 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected break; } #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(ENABLE_FP16_KERNELS) case DataType::F16: ARM_COMPUTE_RETURN_ERROR_ON_MSG( !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F16 input and F16 output"); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +#endif /* ENABLE_FP16_KERNELS */ default: ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); break; @@ -1032,6 +1082,10 @@ void CpuGemmAssemblyDispatch::configure( { create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info); } + else if (d->data_type() == DataType::F32) + { + create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info); + } else { create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info); @@ -1050,11 +1104,11 @@ void CpuGemmAssemblyDispatch::configure( } break; #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ENABLE_FP16_KERNELS case DataType::F16: create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +#endif /* ENABLE_FP16_KERNELS */ default: break; } diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 671a222fed..44c5c189a5 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023 Arm Limited. + * Copyright (c) 2018-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -57,6 +57,7 @@ struct AsmGemmInfo bool fixed_format{false}; arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED}; bool reshape_b_only_on_first_run{true}; + bool accumulate{false}; /** Whether we want to perform an additional transpose of b before passing it to gemm or pretranspose_B_array * @note This transpose b operation is also considered a form of "reshape" or "transform", so should be counted for * by the reshape_b_only_on_first_run flag diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp index 4544a66e39..c4117b8a1a 100644 --- a/src/gpu/cl/ClKernelLibrary.cpp +++ b/src/gpu/cl/ClKernelLibrary.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2023 Arm Limited. + * Copyright (c) 2016-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -441,6 +441,8 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map = {"reorg_layer_nhwc", "nhwc/reorg_layer.cl"}, {"scale_nearest_neighbour_nhwc", "nhwc/scale.cl"}, {"scale_bilinear_nhwc", "nhwc/scale.cl"}, + {"scatter_mp1d_2d_mpnd", "common/scatter.cl"}, + {"scatter1D", "common/scatter.cl"}, {"space_to_batch_nhwc", "nhwc/space_to_batch.cl"}, {"space_to_batch_static_nhwc", "nhwc/space_to_batch.cl"}, {"space_to_depth_nhwc", "nhwc/space_to_depth.cl"}, @@ -591,6 +593,10 @@ const std::map<std::string, std::string> ClKernelLibrary::_program_source_map = #include "./cl_kernels/common/gather.clembed" }, { + "common/scatter.cl", +#include "./cl_kernels/common/scatter.clembed" + }, + { "common/gemm.cl", #include "./cl_kernels/common/gemm.clembed" }, diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp index 720164366e..19adc1ef34 100644 --- a/src/gpu/cl/kernels/ClScatterKernel.cpp +++ b/src/gpu/cl/kernels/ClScatterKernel.cpp @@ -26,6 +26,15 @@ #include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/ITensorPack.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/DataTypeUtils.h" +#include "arm_compute/core/utils/helpers/AdjustVecSize.h" + +#include "src/common/utils/Log.h" +#include "src/core/helpers/WindowHelpers.h" +#include "support/Cast.h" + +#include <cstdint> namespace arm_compute { @@ -33,44 +42,207 @@ namespace opencl { namespace kernels { + +namespace +{ +constexpr int max_index_length = 5; +} // namespace + ClScatterKernel::ClScatterKernel() { } -Status ClScatterKernel::validate(const ITensorInfo *src, - const ITensorInfo *updates, +Status ClScatterKernel::validate(const ITensorInfo *updates, const ITensorInfo *indices, const ITensorInfo *dst, const ScatterInfo &info) { - ARM_COMPUTE_UNUSED(src); - ARM_COMPUTE_UNUSED(updates); - ARM_COMPUTE_UNUSED(indices); - ARM_COMPUTE_UNUSED(dst); ARM_COMPUTE_UNUSED(info); + const TensorShape &ind_shape = indices->tensor_shape(); + const TensorShape &upt_shape = updates->tensor_shape(); + const TensorShape &dst_shape = dst->tensor_shape(); + + const int32_t upt_dims = upt_shape.num_dimensions(); + const int32_t dst_dims = dst_shape.num_dimensions(); + const int32_t ind_dims = ind_shape.num_dimensions(); + const int32_t data_dim = upt_dims - (ind_dims - 1); // Number of batch dims is the number of indices dims - 1 + + const int32_t index_len = ind_shape[0]; + bool unsupported_padding_config = + (dst_dims == index_len) && index_len > 1 && (dst->has_padding() || updates->has_padding()); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(unsupported_padding_config, "Padding is not supported with these shapes."); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32, DataType::F16, DataType::S32, DataType::S16, + DataType::S8, DataType::U32, DataType::U16, DataType::U8); + + // Check data dims in update tensor and output tensor are equal + for (int32_t i = 0; i < data_dim; i++) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_shape[i] != dst_shape[i], + "Data dims should be same size in both updates and ouput tensor."); + } + + // Check if batch dims in indices and updates tensor are equal. + for (int32_t i = 0; i < ind_dims - 1; i++) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_shape[data_dim + i] != ind_shape[i + 1], + "Batch dimensions should be the same in updates and indices tensor."); + } + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_shape[1] != upt_shape[data_dim], + "Height of indices tensor should match size of highest dimension in updates tensor " + "(Excluding batch dimension)"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + data_dim >= dst_dims, "Update tensor cannot have more dims than output tensor. (Excluding batch dimensions)"); + ARM_COMPUTE_RETURN_ERROR_ON(index_len != dst_dims - data_dim); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((ind_dims < 2), "Shape of Indices tensor must be at least 2D"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > dst_dims && dst_dims != 1, + "Index length should be smaller than or equal to number of output dims"); + return Status{}; } + void ClScatterKernel::configure(const ClCompileContext &compile_context, - const ITensorInfo *src, const ITensorInfo *updates, const ITensorInfo *indices, ITensorInfo *dst, const ScatterInfo &info) { - ARM_COMPUTE_UNUSED(compile_context); - ARM_COMPUTE_UNUSED(src); - ARM_COMPUTE_UNUSED(updates); - ARM_COMPUTE_UNUSED(indices); - ARM_COMPUTE_UNUSED(dst); - ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_ERROR_ON_NULLPTR(updates, dst, indices); + ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info); + + const TensorShape &dst_shape = dst->tensor_shape(); + const int index_len = indices->dimension(0); + + // Check for single element data block + const bool is_scalar_block = (dst->num_dimensions() == static_cast<uint32_t>(index_len)); + + const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0)); + const int partial_n0 = updates->dimension(0) % n0; + + // The GWS will be 2D [x, y] + // x-dimension refers to the x coordinate of the dst tensor + // y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor + Window win; + + if (!is_scalar_block) + { + win = calculate_max_window(dst_shape, Steps(n0)); + + // Collapse the dimensions corresponding to indices in the execution window + for (int i = 0; i < index_len; ++i) + { + win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1)); + } + + win = win.collapse(win, 1); + } + + // Set build options + CLBuildOptions build_opts; + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type())); + build_opts.add_option_if(is_data_type_float(dst->data_type()), "-DIS_FLOAT"); + + const int num_dims = dst->num_dimensions(); + TensorShape ind_collapsed = indices->tensor_shape().collapsed_from(1); + build_opts.add_option("-DNUM_INDICES=" + support::cpp11::to_string(ind_collapsed[1])); + build_opts.add_option("-DINDEX_LENGTH=" + support::cpp11::to_string(index_len)); + + // We provide 5 variables to use in a constant array + for (int i = 1; i <= max_index_length; i++) + { + build_opts.add_option("-DOUT_SHAPE_N_MINUS_" + support::cpp11::to_string(i) + "=" + + support::cpp11::to_string(dst_shape[std::max(num_dims - i, 0)])); + } + + build_opts.add_option("-DN0=" + support::cpp11::to_string(n0)); + build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_n0)); + + switch (info.func) + { + case ScatterFunction::Update: + build_opts.add_option("-DSCATTER_FUNCTION=UPDATE_OP"); + build_opts.add_option("-DSKIP_OUTPUT_READ"); + break; + case ScatterFunction::Add: + build_opts.add_option("-DSCATTER_FUNCTION=ADD_OP"); + break; + case ScatterFunction::Sub: + build_opts.add_option("-DSCATTER_FUNCTION=SUB_OP"); + break; + case ScatterFunction::Max: + build_opts.add_option("-DSCATTER_FUNCTION=MAX_OP"); + break; + case ScatterFunction::Min: + build_opts.add_option("-DSCATTER_FUNCTION=MIN_OP"); + break; + default: + ARM_COMPUTE_ERROR("Not implemented"); + } + + // Create kernel + std::string kernel_name = "scatter_mp1d_2d_mpnd"; + build_opts.add_option("-D" + upper_string(kernel_name)); + + ICLKernel::configure_internal(win); + _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); + + // Set config_id for enabling LWS tuning + _config_id = kernel_name; + _config_id += "_"; + _config_id += lower_string(string_from_data_type(updates->data_type())); + _config_id += "_"; + _config_id += support::cpp11::to_string(dst->dimension(1)); + _config_id += "_"; + _config_id += support::cpp11::to_string(dst->dimension(0)); + _config_id += "_"; + _config_id += support::cpp11::to_string(dst->dimension(2)); + _config_id += "_"; } void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) { - ARM_COMPUTE_UNUSED(tensors); - ARM_COMPUTE_UNUSED(window); - ARM_COMPUTE_UNUSED(queue); + const auto updates = + utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0)); + const auto indices = + utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1)); + auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST)); + + const ITensorInfo *dst_info = dst->info(); + const ITensorInfo *upd_info = updates->info(); + const int num_dims = dst_info->num_dimensions(); + const int ind_dims = indices->info()->num_dimensions(); + const int index_len = indices->info()->dimension(0); + + bool unsupported_padding_config = + num_dims == index_len && index_len > 1 && (dst_info->has_padding() || upd_info->has_padding()); + if (unsupported_padding_config) + { + ARM_COMPUTE_ERROR("Unsupported Configuration! Padding not supported with these shapes."); + } + + // calculate m-dimensional data block strides in updates and destination tensors + const int upt_block_stride = + updates->info()->strides_in_bytes()[updates->info()->num_dimensions() - (ind_dims - 1)]; + + const int out_block_stride = dst_info->strides_in_bytes()[num_dims - index_len]; + + unsigned int idx = 0; + + add_2D_tensor_argument(idx, updates, window); + add_2D_tensor_argument(idx, indices, window); + add_2D_tensor_argument(idx, dst, window); + + _kernel.setArg<cl_int>(idx++, upt_block_stride); + _kernel.setArg<cl_int>(idx++, out_block_stride); + + enqueue(queue, *this, window, lws_hint()); } } // namespace kernels diff --git a/src/gpu/cl/kernels/ClScatterKernel.h b/src/gpu/cl/kernels/ClScatterKernel.h index dda614ff3e..e1b469c88e 100644 --- a/src/gpu/cl/kernels/ClScatterKernel.h +++ b/src/gpu/cl/kernels/ClScatterKernel.h @@ -37,6 +37,7 @@ namespace opencl { namespace kernels { + class ClScatterKernel : public IClKernel { public: @@ -44,15 +45,15 @@ public: ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClScatterKernel); /** Initialise the kernel's input and output. * + * @note Negative indices are treated as out of bounds. + * * @param[in] compile_context The compile context to be used. - * @param[in] src Input tensor info for the source matrix. * @param[in] updates Input tensor info for the Update matrix. Data type supported: same as @p src - * @param[in] indices Input tensor info for the Indices matrix. Data type supported: U32. + * @param[in] indices Input tensor info for the Indices matrix. Data type supported: S32. * @param[out] dst Output tensor info. Data type supported: same as @p src * @param[in] info Attributes for Scatter Kernel */ void configure(const ClCompileContext &compile_context, - const ITensorInfo *src, const ITensorInfo *updates, const ITensorInfo *indices, ITensorInfo *dst, @@ -63,11 +64,8 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *src, - const ITensorInfo *updates, - const ITensorInfo *indices, - const ITensorInfo *dst, - const ScatterInfo &info); + static Status + validate(const ITensorInfo *updates, const ITensorInfo *indices, const ITensorInfo *dst, const ScatterInfo &info); // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp index af5fbb86f3..a11ecd7e6a 100644 --- a/src/gpu/cl/operators/ClScatter.cpp +++ b/src/gpu/cl/operators/ClScatter.cpp @@ -27,6 +27,7 @@ #include "arm_compute/runtime/CL/CLScheduler.h" #include "src/common/utils/Log.h" +#include "src/gpu/cl/kernels/ClCopyKernel.h" #include "src/gpu/cl/kernels/ClFillKernel.h" #include "src/gpu/cl/kernels/ClScatterKernel.h" @@ -47,9 +48,19 @@ Status ClScatter::validate(const ITensorInfo *src, const ScatterInfo &info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32); + if (src != nullptr) + { + // Check dst/src are same shape and datatype. + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(src->tensor_shape(), dst->tensor_shape()); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, updates, dst); + ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClCopyKernel::validate(src, dst)); // Validate Copy kernel + } + if (src != dst) + { + ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClFillKernel::validate(dst, PixelValue(0.0f))); // Validate Fill kernel. + } - return kernels::ClScatterKernel::validate(src, updates, indices, dst, info); + return kernels::ClScatterKernel::validate(updates, indices, dst, info); } void ClScatter::configure(const CLCompileContext &compile_context, @@ -61,11 +72,6 @@ void ClScatter::configure(const CLCompileContext &compile_context, { ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, dst); ARM_COMPUTE_LOG_PARAMS(src, indices, dst, info); - ARM_COMPUTE_UNUSED(src); - ARM_COMPUTE_UNUSED(updates); - ARM_COMPUTE_UNUSED(indices); - ARM_COMPUTE_UNUSED(dst); - ARM_COMPUTE_UNUSED(info); // Perform validation step ARM_COMPUTE_ERROR_THROW_ON(validate(src, updates, indices, dst, info)); @@ -74,19 +80,50 @@ void ClScatter::configure(const CLCompileContext &compile_context, // If necessary, create fill kernel to fill dst tensor. if (_fill_zero) { - _fill_kernel = std::make_unique<kernels::ClFillKernel>(); + auto f = std::make_unique<kernels::ClFillKernel>(); + f->configure(compile_context, dst, PixelValue(0.0f)); + _fill_kernel = std::move(f); + } + else if (src != dst) // Check whether copying is necessary + { + // Fill dst with src copy here. + auto j = std::make_unique<kernels::ClCopyKernel>(); + j->configure(compile_context, src, dst); + _copy_kernel = std::move(j); + _run_copy = true; } // Configure ClScatterKernel auto k = std::make_unique<kernels::ClScatterKernel>(); k->set_target(CLScheduler::get().target()); - k->configure(compile_context, src, updates, indices, dst, info); + k->configure(compile_context, updates, indices, dst, info); _scatter_kernel = std::move(k); } void ClScatter::run(ITensorPack &tensors) { - ARM_COMPUTE_UNUSED(tensors); + // Get tensors. + auto src = tensors.get_const_tensor(ACL_SRC_0); + auto updates = tensors.get_const_tensor(ACL_SRC_1); + auto indices = tensors.get_const_tensor(ACL_SRC_2); + auto dst = tensors.get_tensor(ACL_DST); + + if (_fill_zero) + { + // Fill destination tensor with 0 values if zero init. + ITensorPack fill_pack{{ACL_SRC, dst}}; + CLScheduler::get().enqueue_op(*_fill_kernel, fill_pack, false); + } + + if (_run_copy) + { + // copy src to dst before scatter op. + ITensorPack copy_pack{{ACL_SRC, src}, {ACL_DST, dst}}; + CLScheduler::get().enqueue_op(*_copy_kernel, copy_pack, false); + } + + ITensorPack scatter_pack{{ACL_SRC_0, updates}, {ACL_SRC_1, indices}, {ACL_DST, dst}}; + CLScheduler::get().enqueue_op(*_scatter_kernel, scatter_pack, false); } } // namespace opencl diff --git a/src/gpu/cl/operators/ClScatter.h b/src/gpu/cl/operators/ClScatter.h index 433f7ca3a4..a1b32fed45 100644 --- a/src/gpu/cl/operators/ClScatter.h +++ b/src/gpu/cl/operators/ClScatter.h @@ -39,6 +39,7 @@ namespace opencl // Forward declaration class ClFillKernel; class ClScatterKernel; +class ClCopyKernel; /** Basic operator to execute Scatter on OpenCL. This operator calls the following OpenCL kernels: * @@ -56,13 +57,14 @@ public: * Valid data layouts: * - All * - * @note indices must always be U32 + * @note indices must always be S32. + * @note Negative indices are treated as out of bounds. * @note src, updates and dst tensors must be same datatype. * * @param[in] compile_context The compile context to be used. * @param[in] src Source input tensor info. Can be nullptr when using "Add" Scatter Function with zero initialization. * @param[in] updates Tensor info for tensor storing update values to use for scatter function. Data types supported: same as @p src. - * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: U32 only. + * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: S32 only. * @param[out] dst Output tensor to store the result of the Scatter Function. Data types supported: same as @p src and @p updates. * @param[in] Scatter_info Contains Scatter operation information described in @ref ScatterInfo. */ @@ -89,7 +91,9 @@ public: private: std::unique_ptr<opencl::IClKernel> _scatter_kernel{nullptr}; std::unique_ptr<opencl::IClKernel> _fill_kernel{nullptr}; + std::unique_ptr<opencl::IClKernel> _copy_kernel{nullptr}; bool _fill_zero{false}; + bool _run_copy{false}; }; } // namespace opencl } // namespace arm_compute diff --git a/src/runtime/OMP/OMPScheduler.cpp b/src/runtime/OMP/OMPScheduler.cpp index d4d6193fce..baffa8cbb2 100644 --- a/src/runtime/OMP/OMPScheduler.cpp +++ b/src/runtime/OMP/OMPScheduler.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2023 Arm Limited. + * Copyright (c) 2017-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -32,10 +32,21 @@ namespace arm_compute { +#if !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + (defined(__arm__) || defined(__aarch64__)) && defined(__ANDROID__) OMPScheduler::OMPScheduler() // NOLINT - : _num_threads(omp_get_max_threads()) + : _num_threads(cpu_info().get_cpu_num_excluding_little()), + _nonlittle_num_cpus(cpu_info().get_cpu_num_excluding_little()) { } +#else /* !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + (defined(__arm__) || defined(__aarch64__)) && defined(__ANDROID__)*/ +OMPScheduler::OMPScheduler() // NOLINT + : _num_threads(omp_get_max_threads()), _nonlittle_num_cpus(cpu_info().get_cpu_num_excluding_little()) +{ +} +#endif /* !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + (defined(__arm__) || defined(__aarch64__)) && defined(__ANDROID__)*/ unsigned int OMPScheduler::num_threads() const { @@ -45,7 +56,15 @@ unsigned int OMPScheduler::num_threads() const void OMPScheduler::set_num_threads(unsigned int num_threads) { const unsigned int num_cores = omp_get_max_threads(); - _num_threads = (num_threads == 0) ? num_cores : num_threads; +#if !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + (defined(__arm__) || defined(__aarch64__)) && defined(__ANDROID__) + const unsigned int adjusted_num_threads = std::min(_nonlittle_num_cpus, num_threads); + _num_threads = (num_threads == 0) ? num_cores : adjusted_num_threads; +#else /* !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + (defined(__arm__) || defined(__aarch64__)) && defined(__ANDROID__)*/ + _num_threads = (num_threads == 0) ? num_cores : num_threads; +#endif /* !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + (defined(__arm__) || defined(__aarch64__)) && defined(__ANDROID__)*/ } void OMPScheduler::schedule(ICPPKernel *kernel, const Hints &hints) @@ -99,9 +118,15 @@ void OMPScheduler::run_workloads(std::vector<arm_compute::IScheduler::Workload> } ThreadInfo info; - info.cpu_info = &cpu_info(); + info.cpu_info = &cpu_info(); + +#if !defined(__ANDROID__) + info.num_threads = _num_threads; +#else /* !__ANDROID__ */ info.num_threads = num_threads_to_use; -#pragma omp parallel for firstprivate(info) num_threads(num_threads_to_use) default(shared) proc_bind(close) \ +#endif /* __ANDROID__ */ + +#pragma omp parallel for firstprivate(info) num_threads(info.num_threads) default(shared) proc_bind(close) \ schedule(static, 1) for (unsigned int wid = 0; wid < amount_of_work; ++wid) { |