aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Android.bp6
-rw-r--r--CMakeLists.txt4
-rw-r--r--README.md24
-rw-r--r--SConscript5
-rw-r--r--arm_compute/core/QuantizationInfo.h72
-rw-r--r--arm_compute/function_info/GEMMInfo.h31
-rw-r--r--arm_compute/runtime/CL/functions/CLScatter.h7
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h15
-rw-r--r--docs/user_guide/operator_list.dox3
-rw-r--r--docs/user_guide/release_version_and_change_log.dox11
-rw-r--r--examples/CMakeLists.txt8
-rw-r--r--examples/neon_gemm_s8_f32.cpp239
-rw-r--r--filelist.json9
-rwxr-xr-xscripts/generate_android_bp.py4
-rw-r--r--src/BUILD.bazel7
-rw-r--r--src/CMakeLists.txt7
-rw-r--r--src/core/CL/cl_kernels/common/scatter.cl166
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp31
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp117
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp142
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp111
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp3240
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp93
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp417
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp93
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp448
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp93
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp513
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.cpp60
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.hpp6
-rw-r--r--src/core/common/Registrars.h16
-rw-r--r--src/core/utils/helpers/tensor_transform.cpp7
-rw-r--r--src/core/utils/quantization/AsymmHelpers.cpp16
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp324
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h41
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp21
-rw-r--r--src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h26
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h3
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.cpp10
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp24
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp6
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/fp16.cpp774
-rw-r--r--src/cpu/kernels/softmax/generic/sme2/fp32.cpp578
-rw-r--r--src/cpu/kernels/softmax/list.h10
-rw-r--r--src/cpu/operators/CpuGemm.cpp13
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp13
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp94
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h9
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp60
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h3
-rw-r--r--src/gpu/cl/ClKernelLibrary.cpp8
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.cpp162
-rw-r--r--src/gpu/cl/kernels/ClScatterKernel.h14
-rw-r--r--src/gpu/cl/operators/ClScatter.cpp57
-rw-r--r--src/gpu/cl/operators/ClScatter.h8
-rw-r--r--tests/datasets/LargeGEMMDataset.h21
-rw-r--r--tests/datasets/ScatterDataset.h67
-rw-r--r--tests/datasets/SmallGEMMDataset.h19
-rw-r--r--tests/validation/CL/GEMMLowp.cpp13
-rw-r--r--tests/validation/CL/ScatterLayer.cpp154
-rw-r--r--tests/validation/CPP/DFT.cpp4
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp77
-rw-r--r--tests/validation/NEON/GEMM.cpp145
-rw-r--r--tests/validation/NEON/GEMMLowp.cpp125
-rw-r--r--tests/validation/NEON/SoftmaxLayer.cpp37
-rw-r--r--tests/validation/UNIT/CPPScheduler.cpp8
-rw-r--r--tests/validation/fixtures/GEMMFixture.h60
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h226
-rw-r--r--tests/validation/fixtures/ScatterLayerFixture.h117
-rw-r--r--tests/validation/reference/DequantizationLayer.cpp9
-rw-r--r--tests/validation/reference/GEMM.cpp30
-rw-r--r--tests/validation/reference/GEMM.h11
-rw-r--r--tests/validation/reference/GEMMLowp.cpp12
-rw-r--r--tests/validation/reference/GEMMLowp.h11
-rw-r--r--tests/validation/reference/QuantizationLayer.cpp2
-rw-r--r--tests/validation/reference/ScatterLayer.cpp74
-rw-r--r--tests/validation/reference/ScatterLayer.h4
85 files changed, 9055 insertions, 496 deletions
diff --git a/Android.bp b/Android.bp
index bb0486403b..ab554a8ca2 100644
--- a/Android.bp
+++ b/Android.bp
@@ -65,6 +65,7 @@ opencl_srcs = [
"src/core/CL/cl_kernels/common/roi_align_layer.cl",
"src/core/CL/cl_kernels/common/roi_align_layer_quantized.cl",
"src/core/CL/cl_kernels/common/roi_pooling_layer.cl",
+ "src/core/CL/cl_kernels/common/scatter.cl",
"src/core/CL/cl_kernels/common/select.cl",
"src/core/CL/cl_kernels/common/slice_ops.cl",
"src/core/CL/cl_kernels/common/softmax_layer.cl",
@@ -331,6 +332,7 @@ cc_library_static {
"src/core/NEON/kernels/arm_gemm/gemm_int8.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp",
+ "src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp",
"src/core/NEON/kernels/arm_gemm/interleave-8way.cpp",
@@ -1214,6 +1216,7 @@ cc_library_static {
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp",
@@ -1301,6 +1304,9 @@ cc_library_static {
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp",
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0462c2f085..c67479ce41 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -28,7 +28,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
list(APPEND CMAKE_MESSAGE_CONTEXT ArmCompute)
project(
ArmCompute
- VERSION 35.0.1
+ VERSION 36.0.0
DESCRIPTION
"The Arm Compute Library is a collection of low-level machine learning functions optimized for Arm® Cortex®-A CPU and Arm® Mali™ GPU architectures"
LANGUAGES C CXX ASM)
@@ -57,7 +57,7 @@ endif()
# ---------------------------------------------------------------------
# Configuration
-set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -gdwarf-2 -DARM_COMPUTE_ASSERTS_ENABLED")
+set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -gdwarf-2 -DARM_COMPUTE_ASSERTS_ENABLED -DARM_COMPUTE_DEBUG_ENABLED")
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
# Default to Release Build
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
diff --git a/README.md b/README.md
index 51d4dfbe65..112f40225d 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
<img src="https://raw.githubusercontent.com/ARM-software/ComputeLibrary/gh-pages/ACL_logo.png"/><br><br>
</div>
-# Compute Library ![](https://img.shields.io/badge/latest_release-24.02.1-green)
+# Compute Library ![](https://img.shields.io/badge/latest_release-24.04-green)
The Compute Library is a collection of low-level machine learning functions optimized for Arm® Cortex®-A, Arm® Neoverse® and Arm® Mali™ GPUs architectures.<br>
@@ -37,7 +37,7 @@ Key Features:
<br>
## Documentation
-[![Documentation](https://img.shields.io/badge/documentation-24.02.1-green)](https://arm-software.github.io/ComputeLibrary/latest)
+[![Documentation](https://img.shields.io/badge/documentation-24.04-green)](https://arm-software.github.io/ComputeLibrary/latest)
> Note: The documentation includes the reference API, changelogs, build guide, contribution guide, errata, etc.
@@ -50,24 +50,24 @@ All the binaries can be downloaded from [here](https://github.com/ARM-software/C
| Platform | Operating System | Release archive (Download) |
| -------------- | ---------------- | -------------------------- |
-| Raspberry Pi 4 | Linux® 32bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-armv7a-neon.tar.gz) |
-| Raspberry Pi 4 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon.tar.gz) |
-| Odroid N2 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon-cl.tar.gz) |
-| HiKey960 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon-cl.tar.gz) |
+| Raspberry Pi 4 | Linux® 32bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-armv7a-neon.tar.gz) |
+| Raspberry Pi 4 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon.tar.gz) |
+| Odroid N2 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon-cl.tar.gz) |
+| HiKey960 | Linux® 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon-cl.tar.gz) |
<br>
| Architecture | Operating System | Release archive (Download) |
| ------------ | ---------------- | -------------------------- |
-| armv7 | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-armv7a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-armv7a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-armv7a-neon-cl.tar.gz) |
-| arm64-v8a | Android™ | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8a-neon-cl.tar.gz) |
-| arm64-v8a | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8a-neon-cl.tar.gz) |
-| arm64-v8.2-a | Android™ | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-android-arm64-v8.2-a-neon-cl.tar.gz) |
-| arm64-v8.2-a | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.02.1/arm_compute-v24.02.1-bin-linux-arm64-v8.2-a-neon-cl.tar.gz) |
+| armv7 | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-armv7a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-armv7a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-armv7a-neon-cl.tar.gz) |
+| arm64-v8a | Android™ | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8a-neon-cl.tar.gz) |
+| arm64-v8a | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8a-neon-cl.tar.gz) |
+| arm64-v8.2-a | Android™ | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-android-arm64-v8.2-a-neon-cl.tar.gz) |
+| arm64-v8.2-a | Linux® | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v24.04/arm_compute-v24.04-bin-linux-arm64-v8.2-a-neon-cl.tar.gz) |
<br>
-Please refer to the following link for more pre-built binaries: [![](https://img.shields.io/badge/v24.02.1-bins-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/tag/v24.02.1)
+Please refer to the following link for more pre-built binaries: [![](https://img.shields.io/badge/v24.04-bins-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/tag/v24.04)
Pre-build binaries are generated with the following security / good coding practices related flags:
> -Wall, -Wextra, -Wformat=2, -Winit-self, -Wstrict-overflow=2, -Wswitch-default, -Woverloaded-virtual, -Wformat-security, -Wctor-dtor-privacy, -Wsign-promo, -Weffc++, -pedantic, -fstack-protector-strong
diff --git a/SConscript b/SConscript
index f1b81ec2a5..80aa87cae8 100644
--- a/SConscript
+++ b/SConscript
@@ -33,9 +33,9 @@ import codecs
import platform
VERSION = "v0.0-unreleased"
-LIBRARY_VERSION_MAJOR = 35
+LIBRARY_VERSION_MAJOR = 36
LIBRARY_VERSION_MINOR = 0
-LIBRARY_VERSION_PATCH = 1
+LIBRARY_VERSION_PATCH = 0
SONAME_VERSION = str(LIBRARY_VERSION_MAJOR) + "." + str(LIBRARY_VERSION_MINOR) + "." + str(LIBRARY_VERSION_PATCH)
Import('env')
@@ -429,6 +429,7 @@ if env['opencl'] and env['embed_kernels']:
'src/core/CL/cl_kernels/common/fill_border.cl',
'src/core/CL/cl_kernels/common/floor.cl',
'src/core/CL/cl_kernels/common/gather.cl',
+ 'src/core/CL/cl_kernels/common/scatter.cl',
'src/core/CL/cl_kernels/common/gemm.cl',
'src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl',
'src/core/CL/cl_kernels/common/gemm_utils.cl',
diff --git a/arm_compute/core/QuantizationInfo.h b/arm_compute/core/QuantizationInfo.h
index 471b8c57ab..aecba3712e 100644
--- a/arm_compute/core/QuantizationInfo.h
+++ b/arm_compute/core/QuantizationInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2023 Arm Limited.
+ * Copyright (c) 2019-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_QUANTIZATION_INFO_H
-#define ARM_COMPUTE_QUANTIZATION_INFO_H
+#ifndef ACL_ARM_COMPUTE_CORE_QUANTIZATIONINFO_H
+#define ACL_ARM_COMPUTE_CORE_QUANTIZATIONINFO_H
#include "arm_compute/core/Rounding.h"
#include "arm_compute/core/utils/misc/Utility.h"
@@ -84,10 +84,12 @@ public:
*
* @note Used for asymmetric quantization
*
- * @param[in] scale Scale.
- * @param[in] offset Offset.
+ * @param[in] scale Scale.
+ * @param[in] offset Offset.
+ * @param[in] is_dynamic Whether this QuantizationInfo is dynamic, i.e. the scale and offset may change.
*/
- QuantizationInfo(float scale, int offset) : _scale(1, scale), _offset(1, offset)
+ QuantizationInfo(float scale, int offset, bool is_dynamic = false)
+ : _scale(1, scale), _offset(1, offset), _is_dynamic(is_dynamic)
{
}
/** Construct quantization info.
@@ -103,10 +105,12 @@ public:
*
* @note Used for asymmetric per channel quantization
*
- * @param[in] scale Scale.
- * @param[in] offset Offset.
+ * @param[in] scale Scale.
+ * @param[in] offset Offset.
+ * @param[in] is_dynamic Whether this QuantizationInfo is dynamic, i.e. the scale and offset may change.
*/
- QuantizationInfo(std::vector<float> scale, std::vector<int32_t> offset) : _scale(scale), _offset(offset)
+ QuantizationInfo(std::vector<float> scale, std::vector<int32_t> offset, bool is_dynamic = false)
+ : _scale(scale), _offset(offset), _is_dynamic(is_dynamic)
{
}
/** Scale vector accessor
@@ -125,6 +129,14 @@ public:
{
return _offset;
}
+ /** is_dynamic accessor
+ *
+ * @return If true, the scale and offset may change, so operators will need to read on every run
+ */
+ bool is_dynamic() const
+ {
+ return _is_dynamic;
+ }
/** Indicates whether this QuantizationInfo has valid settings or not
*
* @return True if the this has invalid settings.
@@ -149,6 +161,8 @@ public:
private:
std::vector<float> _scale; /**< Vector containing scaling factors */
std::vector<int32_t> _offset; /**< Vector containing zero offsets */
+ bool _is_dynamic =
+ false; /**< If true, the scale and offset may change, so operators will need to read on every run */
};
/** Check whether two quantization info are equal.
@@ -430,6 +444,19 @@ inline float dequantize(uint16_t value, float scale, int32_t offset)
return (static_cast<int>(value) - offset) * scale;
}
+/** Dequantize a value given a 32-bit asymmetric quantization scheme
+ *
+ * @param[in] value Value to dequantize
+ * @param[in] scale Scale to use for dequantization
+ * @param[in] offset Zero-offset to use for dequantization
+ *
+ * @return Dequantized value
+ */
+inline float dequantize(int32_t value, float scale, int32_t offset)
+{
+ return (static_cast<int>(value) - offset) * scale;
+}
+
/** Quantize a value given a 16-bit symmetric quantization scheme
*
* @param[in] value Value to quantize
@@ -536,6 +563,31 @@ inline float dequantize_qasymm16(uint16_t value, const QuantizationInfo &qinfo)
return dequantize_qasymm16(value, qinfo.uniform());
}
+/** Dequantize a value given a 32-bit asymmetric quantization scheme
+ *
+ * @param[in] value Value to dequantize
+ * @param[in] qinfo Quantization information to use for dequantizing
+ *
+ * @return Dequantized value
+ */
+inline float dequantize_s32(int32_t value, const UniformQuantizationInfo &qinfo)
+{
+ return (static_cast<int>(value) - qinfo.offset) * qinfo.scale;
+}
+
+/** Dequantize a value given a 32-bit asymmetric quantization scheme
+ *
+ * @param[in] value Value to dequantize
+ * @param[in] qinfo Quantization information to use for dequantizing
+ *
+ * @return Dequantized value
+ */
+
+inline float dequantize_s32(int32_t value, const QuantizationInfo &qinfo)
+{
+ return dequantize_s32(value, qinfo.uniform());
+}
+
/*
* In case of requantization of a quantized input tensor to an output tensor with another quantization
* instead of applying dequantization and then a quantization functions, we just compute new scale and
@@ -581,4 +633,4 @@ inline UniformQuantizationInfo compute_requantization_scale_offset(const Uniform
}
} // namespace arm_compute
-#endif /* ARM_COMPUTE_QUANTIZATION_INFO_H */
+#endif // ACL_ARM_COMPUTE_CORE_QUANTIZATIONINFO_H
diff --git a/arm_compute/function_info/GEMMInfo.h b/arm_compute/function_info/GEMMInfo.h
index a827c79fda..74fe30454e 100644
--- a/arm_compute/function_info/GEMMInfo.h
+++ b/arm_compute/function_info/GEMMInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -85,7 +85,8 @@ public:
_pretranspose_B(false),
_activation_info(),
_fixed_format(false),
- _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
+ _weight_format(arm_compute::WeightFormat::UNSPECIFIED),
+ _accumulate(false)
{
}
/** Constructor
@@ -106,6 +107,7 @@ public:
* @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
* @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
* @param[in] pretranspose_B (Optional) Pretranspose matrix B (transposition of its lowest 2 dimensions), in addition to and before, any further transformations of B
+ * @param[in] accumulate (Optional) Whether to accumulate in destination or not
*/
GEMMInfo(bool is_a_reshaped,
bool is_b_reshaped,
@@ -120,7 +122,8 @@ public:
const ActivationLayerInfo &activation_info = ActivationLayerInfo(),
bool fixed_format = false,
arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED,
- bool pretranspose_B = false) noexcept
+ bool pretranspose_B = false,
+ bool accumulate = false) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -135,7 +138,8 @@ public:
_pretranspose_B(pretranspose_B),
_activation_info(activation_info),
_fixed_format(fixed_format),
- _weight_format(weight_format)
+ _weight_format(weight_format),
+ _accumulate(accumulate)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -294,7 +298,14 @@ public:
{
return _fixed_format;
}
-
+ /** Flag which specifies if GEMM should accumulate the result in destination or not.
+ *
+ * @return True if GEMM is accumulating the result.
+ */
+ bool accumulate() const
+ {
+ return _accumulate;
+ }
/** Set fixed-format flag
*
* @param[in] fixed_format sets whether or not to use fixed-format kernels
@@ -303,12 +314,19 @@ public:
{
_fixed_format = fixed_format;
}
+ /** Set accumulate flag
+ *
+ * @param[in] accumulate sets whether or not to use accumulation
+ */
+ void set_accumulate(bool accumulate)
+ {
+ _accumulate = accumulate;
+ }
arm_compute::WeightFormat weight_format() const
{
return _weight_format;
}
-
/** Set weight format to be used
*
* @param[in] weight_format arm_compute::WeightFormat enumeration
@@ -334,6 +352,7 @@ private:
ActivationLayerInfo _activation_info;
bool _fixed_format;
arm_compute::WeightFormat _weight_format;
+ bool _accumulate;
};
} //namespace arm_compute
#endif // ACL_ARM_COMPUTE_FUNCTION_INFO_GEMMINFO_H
diff --git a/arm_compute/runtime/CL/functions/CLScatter.h b/arm_compute/runtime/CL/functions/CLScatter.h
index 1c90d208bd..973953624e 100644
--- a/arm_compute/runtime/CL/functions/CLScatter.h
+++ b/arm_compute/runtime/CL/functions/CLScatter.h
@@ -55,14 +55,15 @@ public:
~CLScatter();
/** Initialise the kernel's inputs and outputs
*
+ * @note Negative indices are treated as out of bounds.
+ *
* Valid data layouts:
* - All
*
- *
* @param[in] compile_context The compile context to be used.
* @param[in] src Source tensor. Values used to fill output. Can be nullptr when zero initialization is true.
* @param[in] updates Tensor containing values used to update output tensor. Data types supported: same as @p src
- * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32
+ * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : S32
* @param[out] output Destination tensor. Data types supported: same as @p src.
* @param[in] info Scatter info object.
*/
@@ -85,7 +86,7 @@ public:
*
* @param[in] src Source tensor.
* @param[in] updates Tensor containing values used for updating the output Tensor. Data types supported : same as @p src
- * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : U32
+ * @param[in] indices Tensor containing Indices to change in the output Tensor. Data types supported : S32
* @param[in] output Destination tensor. Data types supported: same as @p src.
* @param[in] info Scatter info containing type of scatter.
*
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
index 824c4443ad..6d07675d3d 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021, 2023 Arm Limited.
+ * Copyright (c) 2017-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H
-#define ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H
+#ifndef ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEGEMMLOWPMATRIXMULTIPLYCORE_H
+#define ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEGEMMLOWPMATRIXMULTIPLYCORE_H
#include "arm_compute/core/Types.h"
#include "arm_compute/function_info/GEMMInfo.h"
@@ -80,6 +80,7 @@ public:
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |S32 |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |S32 |
* |QASYMM8_SIGNED |QSYMM8 |S32 |S32 |
+ * |QASYMM8_SIGNED |QASYMM8_SIGNED |F32 |F32 |
*
* @note GEMM_LOWP: low precision GEMM kernel
* This kernel performs the following computations:
@@ -88,12 +89,12 @@ public:
* -# Convert b values from QASYMM8 to int32 add b_offset to each of them.
* -# Compute the matrix product of the resulting a * b in int32.
*
- * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED otherwise
+ * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED/F32 otherwise
*
* @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED.
* @param[in] b Second input tensor (Matrix B). Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM8/QSYMM8_PER_CHANNEL.
- * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32
- * @param[out] output Output tensor. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED
+ * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32/F32
+ * @param[out] output Output tensor. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED/F32
* @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
* if the reshape of matrix B should be executed only for the first run
*/
@@ -120,4 +121,4 @@ private:
std::unique_ptr<Impl> _impl;
};
} // namespace arm_compute
-#endif /*ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H */
+#endif // ACL_ARM_COMPUTE_RUNTIME_NEON_FUNCTIONS_NEGEMMLOWPMATRIXMULTIPLYCORE_H
diff --git a/docs/user_guide/operator_list.dox b/docs/user_guide/operator_list.dox
index 36275e68bf..e7f1823f8b 100644
--- a/docs/user_guide/operator_list.dox
+++ b/docs/user_guide/operator_list.dox
@@ -1,5 +1,5 @@
///
-/// Copyright (c) 2021-2023,2024 Arm Limited.
+/// Copyright (c) 2021-2024 Arm Limited.
///
/// SPDX-License-Identifier: MIT
///
@@ -1773,6 +1773,7 @@ where N = batches, C = channels, H = height, W = width, D = depth
<tr><td>QASYMM8_SIGNED<td>QASYMM8_SIGNED<td>S32<td>S32
<tr><td>QASYMM8_SIGNED<td>QSYMM8_PER_CHANNEL<td>S32<td>S32
<tr><td>QASYMM8_SIGNED<td>QSYMM8<td>S32<td>S32
+ <tr><td>QASYMM8_SIGNED<td>QASYMM8_SIGNED<td>F32<td>F32
</table>
<tr>
<td>CLGEMMLowpMatrixMultiplyCore
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index 31b756070d..952753effb 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -41,11 +41,18 @@ If there is more than one release in a month then an extra sequential number is
@section S2_2_changelog Changelog
+v24.05 Public major release
+ - Add @ref CLScatter operator for FP32 data type
+
v24.04 Public major release
- Add Bfloat16 data type support for @ref NEMatMul.
- - Optimize start-up time of @ref NEConvolutionLayer for some input configurations where GeMM is selected as the convolution algorithm
- - Optimize @ref NEConvolutionLayer for input tensor size > 1e7 bytes and weight tensor height > 7
+ - Add support for SoftMax in SME2 for FP32 and FP16.
+ - Add support for in place accumulation to CPU GEMM kernels.
+ - Add low-precision Int8 * Int8 -> FP32 CPU GEMM which dequantizes after multiplication
+ - Add is_dynamic flag to QuantizationInfo to signal to operators that it may change after configuration
- Performance optimizations:
+ - Optimize start-up time of @ref NEConvolutionLayer for some input configurations where GeMM is selected as the convolution algorithm
+ - Optimize @ref NEConvolutionLayer for input tensor size > 1e7 bytes and weight tensor height > 7
- Optimize @ref NESoftmaxLayer for axis != 0 by natively supporting higher axes up to axis 3.
v24.02.1 Public patch release
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 139b968e4e..6b7fbded5d 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2023-2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -48,6 +48,10 @@ set(EXAMPLE_GRAPH_NAMES
PARENT_SCOPE)
set(EXAMPLE_NEON_NAMES
- neon_cnn neon_copy_objects neon_gemm_qasymm8 neon_permute neon_scale
+ neon_cnn neon_copy_objects
+ neon_gemm_qasymm8
+ neon_gemm_s8_f32
+ neon_permute
+ neon_scale
neon_sgemm
PARENT_SCOPE)
diff --git a/examples/neon_gemm_s8_f32.cpp b/examples/neon_gemm_s8_f32.cpp
new file mode 100644
index 0000000000..7c1497ec41
--- /dev/null
+++ b/examples/neon_gemm_s8_f32.cpp
@@ -0,0 +1,239 @@
+/*
+ * Copyright (c) 2020-2021, 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+#include "arm_compute/core/WindowIterator.h"
+#include "arm_compute/runtime/NEON/NEFunctions.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+
+#include "support/ToolchainSupport.h"
+#include "utils/Utils.h"
+
+#include <cstdlib>
+
+using namespace arm_compute;
+using namespace utils;
+
+QuantizationInfo dynamic_qinfo(QuantizationInfo qinfo)
+{
+ return QuantizationInfo(qinfo.scale(), qinfo.offset(), true);
+}
+void set_qinfo_dynamic(Tensor &t)
+{
+ t.info()->set_quantization_info(dynamic_qinfo(t.info()->quantization_info()));
+}
+
+void quantize(Tensor &qt, const Tensor &t, float min, float max)
+{
+ DataType dt = DataType::QASYMM8_SIGNED;
+
+ // Determine the scale
+ const float scale = (max - min) / 256.0f;
+
+ // Determine the zero-point; using affine equation val = (qval-zerop) * scale
+ const float zero_point = -128.0f - min / scale;
+
+ QuantizationInfo qinfo(scale, (int32_t)round(zero_point), true);
+
+ // We now have the quantisation info and can configure the quantised tensor
+ qt.allocator()->init(TensorInfo(t.info()->tensor_shape(), 1, dt, qinfo));
+ qt.allocator()->allocate();
+ NEQuantizationLayer quantization;
+ quantization.configure(&t, &qt);
+ quantization.run();
+}
+
+void invert_qinfo_offset(Tensor &t)
+{
+ QuantizationInfo qinfo = t.info()->quantization_info();
+ t.info()->set_quantization_info(QuantizationInfo(qinfo.scale()[0], -qinfo.offset()[0], qinfo.is_dynamic()));
+}
+
+void print_quantization_info(const Tensor &t, const std::string &name_prefix)
+{
+ QuantizationInfo qinfo = t.info()->quantization_info();
+ std::cout << name_prefix << "_qinfo="
+ << "QuantizationInfo(" << qinfo.scale()[0] << ", " << qinfo.offset()[0] << ")\n";
+}
+
+int main(int argc, char **argv)
+{
+ size_t M = 4;
+ size_t N = 4;
+ size_t K = 4;
+
+ // Parse args
+ if (argc < 3) /* case default matrix sizes */
+ {
+ // Print help
+ std::cout << "Usage: ./build/neon_gemm_qasymm8 M N K\n";
+ std::cout << "Too few or no inputs provided. Using default M=4, N=4, K=4\n\n";
+ }
+ else /* case M N K arguments provided */
+ {
+ M = strtol(argv[1], nullptr, 10);
+ N = strtol(argv[2], nullptr, 10);
+ K = strtol(argv[3], nullptr, 10);
+ }
+
+ /*** Floating point matrix multiplication ***/
+
+ // Initialise input matrices
+ NEGEMM fgemm{};
+
+ Tensor src1;
+ Tensor src2;
+ Tensor dst;
+ src1.allocator()->init(TensorInfo(TensorShape(K, M), 1, DataType::F32));
+ src2.allocator()->init(TensorInfo(TensorShape(N, K), 1, DataType::F32));
+ dst.allocator()->init(TensorInfo(TensorShape(N, M), 1, DataType::F32));
+ fgemm.configure(&src1, &src2, nullptr, &dst, 1, 0);
+
+ // Allocate matrices
+ src1.allocator()->allocate();
+ src2.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ float min1 = 0.0f;
+ float max1 = 1.0f;
+ fill_random_tensor(src1, 0, min1, max1);
+
+ float min2 = -1.0f;
+ float max2 = 2.0f;
+ fill_random_tensor(src2, 1, min2, max2);
+
+ // Run single precision gemm and print result
+ fgemm.run();
+
+#if ARM_COMPUTE_DEBUG_ENABLED
+ std::cout << "# F32 GEMM result:\n";
+ std::cout << "src1=[ \n";
+ src1.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "src2=[ \n";
+ src2.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "dst=[ \n";
+ dst.print(std::cout);
+ std::cout << "] \n";
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+
+ Tensor q_src1;
+ quantize(q_src1, src1, min1, max1);
+ print_quantization_info(q_src1, "src1");
+ q_src1.info()->set_are_values_constant(false);
+
+ // NEGEMMLowpMatrixMultiplyCore adopts the opposite convention for the offset
+ // compared to NEQuantizeLayer
+ invert_qinfo_offset(q_src1);
+
+ Tensor q_src2;
+ quantize(q_src2, src2, min2, max2);
+ print_quantization_info(q_src2, "src2");
+ q_src2.info()->set_are_values_constant(false);
+
+ // NEGEMMLowpMatrixMultiplyCore adopts the opposite convention for the offset
+ // compared to NEQuantizeLayer
+ invert_qinfo_offset(q_src2);
+
+ // q_dst will be Dequantized to F32 so it doesn't need a QuantizationInfo
+ Tensor q_dst;
+ q_dst.allocator()->init(TensorInfo(TensorShape(N, M), 1, DataType::F32));
+
+ // Configure low precision gemm and initialise result tensor (pre-output)
+ NEGEMMLowpMatrixMultiplyCore qgemm;
+ qgemm.configure(&q_src1, &q_src2, nullptr, &q_dst);
+
+ q_dst.allocator()->allocate();
+
+ // Run low precision matrix multiply kernel
+ qgemm.run();
+
+#if ARM_COMPUTE_DEBUG_ENABLED
+ // Print quantized source matrices
+ std::cout << "q_src1=[ \n";
+ q_src1.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "q_src2=[ \n";
+ q_src2.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "# Lowp GEMM output (FP32):\n";
+ std::cout << "q_dst=[ \n";
+ q_dst.print(std::cout);
+ std::cout << "] \n";
+
+ // Expected result
+ std::cout << "# Expected result:\n";
+ std::cout << "dst=[ \n";
+ dst.print(std::cout);
+ std::cout << "] \n";
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+
+ // Rerun to test the ability to modify the Tensor contents and QuantizationInfo (dynamic quantization)
+ min1 = -1.0f;
+ max1 = 1.0f;
+ fill_random_tensor(src1, 2, min1, max1);
+
+#if ARM_COMPUTE_DEBUG_ENABLED
+ std::cout << "# Refilled src1\n";
+ std::cout << "src1=[ \n";
+ src1.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "src2=[ \n";
+ src2.print(std::cout);
+ std::cout << "] \n";
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+
+ fgemm.run();
+
+ quantize(q_src1, src1, min1, max1);
+ set_qinfo_dynamic(q_src1);
+ print_quantization_info(q_src1, "src1");
+
+ // NEGEMMLowpMatrixMultiplyCore adopts the opposite convention for the offset
+ // compared to NEQuantizeLayer
+ invert_qinfo_offset(q_src1);
+
+ qgemm.run();
+
+#if ARM_COMPUTE_DEBUG_ENABLED
+ // Print quantized source matrices
+ std::cout << "q_src1=[ \n";
+ q_src1.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "q_src2=[ \n";
+ q_src2.print(std::cout);
+ std::cout << "] \n";
+ std::cout << "# Lowp GEMM output (FP32):\n";
+ std::cout << "q_dst=[ \n";
+ q_dst.print(std::cout);
+ std::cout << "] \n";
+
+ // Expected result
+ std::cout << "# Expected result:\n";
+ std::cout << "dst=[ \n";
+ dst.print(std::cout);
+ std::cout << "] \n";
+#endif // ARM_COMPUTE_DEBUG_ENABLED
+}
diff --git a/filelist.json b/filelist.json
index ab7f16bc90..2c3621cd8b 100644
--- a/filelist.json
+++ b/filelist.json
@@ -1598,6 +1598,7 @@
"src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_int16.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_int8.cpp",
+ "src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp",
"src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp",
@@ -1695,6 +1696,7 @@
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp",
@@ -1722,6 +1724,9 @@
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp",
+ "src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp",
"src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp",
@@ -2236,7 +2241,9 @@
"common": [ "src/cpu/kernels/softmax/generic/sve/impl.cpp" ]
},
"sve2":{
- "common" :["src/cpu/kernels/softmax/generic/sve2/impl.cpp"]
+ "common" :["src/cpu/kernels/softmax/generic/sve2/impl.cpp"],
+ "fp32" :["src/cpu/kernels/softmax/generic/sme2/fp32.cpp"],
+ "fp16" :["src/cpu/kernels/softmax/generic/sme2/fp16.cpp"]
}
}
},
diff --git a/scripts/generate_android_bp.py b/scripts/generate_android_bp.py
index 6efd072acd..d5b268f522 100755
--- a/scripts/generate_android_bp.py
+++ b/scripts/generate_android_bp.py
@@ -45,7 +45,9 @@ excluded_paths = ["build",
"/sve/",
"/SVE/",
"/sve2/",
- "/SVE2/"
+ "/SVE2/",
+ "/sme/",
+ "/sme2/",
]
excluded_files = ["TracePoint.cpp"]
diff --git a/src/BUILD.bazel b/src/BUILD.bazel
index d4a3b61836..e3cac07de1 100644
--- a/src/BUILD.bazel
+++ b/src/BUILD.bazel
@@ -117,6 +117,8 @@ filegroup(
"cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp",
"cpu/kernels/elementwise_unary/generic/sve2/q8.cpp",
"cpu/kernels/lut/generic/sve2/u8.cpp",
+ "cpu/kernels/softmax/generic/sme2/fp16.cpp",
+ "cpu/kernels/softmax/generic/sme2/fp32.cpp",
"cpu/kernels/softmax/generic/sve2/impl.cpp"] +
glob(["**/*.h",
"**/*.hpp",
@@ -261,6 +263,9 @@ filegroup(
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp",
+ "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp",
+ "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp",
+ "core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp",
@@ -516,6 +521,7 @@ filegroup(
"core/NEON/kernels/arm_gemm/gemm_int8.cpp",
"core/NEON/kernels/arm_gemm/gemm_qint8.cpp",
"core/NEON/kernels/arm_gemm/gemm_quint8.cpp",
+ "core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp",
"core/NEON/kernels/arm_gemm/gemm_uint16.cpp",
"core/NEON/kernels/arm_gemm/gemm_uint8.cpp",
"core/NEON/kernels/arm_gemm/interleave-8way.cpp",
@@ -524,6 +530,7 @@ filegroup(
"core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp",
+ "core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp",
"core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp",
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index c6410714d2..984db79c18 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -238,6 +238,9 @@ target_sources(
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_1VLx4VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_2VLx2VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8q_mopa_4VLx1VL/generic.cpp
+ core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp
+ core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp
+ core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_2VLx2VL/generic.cpp
core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8s32_mopa_4VLx1VL/generic.cpp
@@ -335,6 +338,8 @@ target_sources(
cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp
cpu/kernels/elementwise_unary/generic/sve2/q8.cpp
cpu/kernels/lut/generic/sve2/u8.cpp
+ cpu/kernels/softmax/generic/sme2/fp16.cpp
+ cpu/kernels/softmax/generic/sme2/fp32.cpp
cpu/kernels/softmax/generic/sve2/impl.cpp
)
@@ -507,6 +512,7 @@ target_sources(
core/NEON/kernels/arm_gemm/gemm_int8.cpp
core/NEON/kernels/arm_gemm/gemm_qint8.cpp
core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+ core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp
core/NEON/kernels/arm_gemm/gemm_uint16.cpp
core/NEON/kernels/arm_gemm/gemm_uint8.cpp
core/NEON/kernels/arm_gemm/interleave-8way.cpp
@@ -515,6 +521,7 @@ target_sources(
core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp
+ core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp
core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp
diff --git a/src/core/CL/cl_kernels/common/scatter.cl b/src/core/CL/cl_kernels/common/scatter.cl
new file mode 100644
index 0000000000..ac9f828df2
--- /dev/null
+++ b/src/core/CL/cl_kernels/common/scatter.cl
@@ -0,0 +1,166 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+#include "tile_helpers.h"
+
+// The below defines the various reduce operations for our purposes.
+// Where a corresponds to the existing value, and b the new value.
+#define ADD_OP(a, b) ((a) + (b))
+#define SUB_OP(a, b) ((a) - (b))
+#define MAX_OP(a, b) fmax(a, b)
+#define MIN_OP(a, b) fmin(a, b)
+#define UPDATE_OP(a, b) (b)
+
+#ifdef SCATTER_MP1D_2D_MPND
+
+/** This kernel performs scatter operation
+ *
+ * @note Datatype should be given as a compile-time argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
+ * @note Number of indices should be given as a compile-time argument using -DNUM_INDICES, e.g. -DNUM_INDICES=3
+ * @note Index length should be given as a compile-time argument using -DINDEX_LENGTH, e.g. -DINDEX_LENGTH=2
+ * @note Outermost output shapes should be given as a compile-time argument using -DOUT_SHAPE_N_MINUS_X, where
+ * X must be 1,2,3,4,5, e.g. -DOUT_SHAPE_N_MINUS_1=3, ...
+ * @note Number of elements to copy in a row should be given as a compile-time argument using -DN0, e.g. -DN0=4
+ * @note Number of partial elements at the edge to copy in a row should be given as a compile-time argument using
+ * -DPARTIAL_N0, e.g. -DPARTIAL_N0=2
+ * @note Scatter function should be given as a compile-time argument using -DSCATTER_FUNCTION, e.g. -DSCATTER_FUNCTION=ADD
+ * @note If the kernel should skip reading the output tensor, -DSKIP_OUTPUT_READ option should be provided.
+ * @note Kernel name in uppercase letters should be provided as a compile-time argument, e.g. -DSCATTER_MP1D_2D_MPND
+ *
+ * @param[in] updates_ptr Pointer to the updates tensor. Data Types: F32
+ * @param[in] updates_stride_x Stride of the updates tensor in X dimension (in bytes)
+ * @param[in] updates_step_x updates_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] updates_stride_y Stride of the updates tensor in Y dimension (in bytes)
+ * @param[in] updates_step_y updates_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] updates_offset_first_element_in_bytes The offset of the first element in the updates tensor
+ * @param[in] indices_ptr Pointer to the indices tensor. Data Types: S32
+ * @param[in] indices_stride_x Stride of the indices tensor in X dimension (in bytes)
+ * @param[in] indices_step_x indices_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] indices_stride_y Stride of the indices tensor in Y dimension (in bytes)
+ * @param[in] indices_step_y indices_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] indices_offset_first_element_in_bytes The offset of the first element in the indices tensor
+ * @param[out] output_ptr Pointer to the destination tensor. Same as @p upt_ptr
+ * @param[in] output_stride_x Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] output_step_x output_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] output_stride_y Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] output_step_y output_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] output_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] upt_block_stride Update tensor data block stride in bytes
+ * @param[in] out_block_stride Output tensor data block stride in bytes
+ */
+__kernel void scatter_mp1d_2d_mpnd(
+ IMAGE_DECLARATION(updates),
+ IMAGE_DECLARATION(indices),
+ IMAGE_DECLARATION(output),
+ int upt_block_stride,
+ int out_block_stride
+ )
+{
+ const int out_shape[5] = {OUT_SHAPE_N_MINUS_1, OUT_SHAPE_N_MINUS_2, OUT_SHAPE_N_MINUS_3,
+ OUT_SHAPE_N_MINUS_4, OUT_SHAPE_N_MINUS_5};
+
+ const int x = GET_SPATIAL_IDX(0, N0, PARTIAL_N0); // x-coordinate in the tensor
+ const int y = get_global_id(1); // collapsed y-coordinate (ignoring the outermost dimensions)
+
+ const bool x_cond = (PARTIAL_N0 != 0 && get_global_id(0) == 0);
+
+ uchar *ind_ptr_raw = indices_ptr + indices_offset_first_element_in_bytes;
+ const uchar *out_ptr_raw = output_ptr + output_offset_first_element_in_bytes
+ + x * sizeof(DATA_TYPE) + y * output_stride_y;
+
+ const uchar *upt_ptr_raw = updates_ptr + updates_offset_first_element_in_bytes
+ + x * sizeof(DATA_TYPE) + y * updates_stride_y;
+
+ for(int index_element = 0; index_element < NUM_INDICES; ++index_element)
+ {
+ const int *ind_ptr = (const int *) (ind_ptr_raw);
+
+ // Out of bounds check
+ bool out_of_bounds = false;
+ LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH,
+ {
+ if(ind_ptr[i] >= out_shape[i] || ind_ptr[i] < 0)
+ {
+ out_of_bounds = true;
+ }
+ });
+
+ ind_ptr_raw += indices_stride_y;
+
+ if(out_of_bounds)
+ {
+ continue;
+ }
+
+ // Index calculation
+ int index = 0;
+ LOOP_UNROLLING(int, i, 0, 1, INDEX_LENGTH,
+ {
+ index = index * out_shape[i] + ind_ptr[i];
+ });
+
+ DATA_TYPE *out_ptr = (DATA_TYPE *) (out_ptr_raw + index * out_block_stride);
+
+ const DATA_TYPE *upt_ptr = (const DATA_TYPE *) (upt_ptr_raw + index_element * upt_block_stride);
+
+ VEC_DATA_TYPE(DATA_TYPE, N0) data_in0 = VLOAD(N0)(0, (__global DATA_TYPE *) upt_ptr);
+
+#ifdef SKIP_OUTPUT_READ
+ STORE_VECTOR_SELECT(data_in, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond);
+#else // ifdef SKIP_OUTPUT_READ
+ VEC_DATA_TYPE(DATA_TYPE, N0) data_out0 = VLOAD(N0)(0, (__global DATA_TYPE *) out_ptr);
+ data_out0 = SCATTER_FUNCTION(data_out0, data_in0);
+
+ STORE_VECTOR_SELECT(data_out, DATA_TYPE, (__global DATA_TYPE *) out_ptr, N0, PARTIAL_N0, x_cond);
+#endif // ifdef SKIP_OUTPUT_READ
+ }
+}
+
+#endif // SCATTER_MP1D_2D_MPND
+
+#ifdef SCATTER1D_PARALLEL
+
+// NOTE : This code is non-deterministic and can only be excecuted with the "update" ScatterFunction
+// This code is currently unusued as it requires changes to the existing test suite.
+/** Performs the Scatter1D operation with multiple threads.
+ * Similar to @ref scatter1D()
+ */
+__kernel void scatter1D_parallel(
+ TENSOR4D_DECLARATION(updates),
+ TENSOR4D_DECLARATION(indices),
+ TENSOR4D_DECLARATION(output))
+{
+ // Currently 1D - only iterate through x dimension of indices.
+ const int px = get_global_id(0);
+ const int index_value = *(uchar*)(indices_ptr + indices_offset_first_element_in_bytes + (sizeof(int) * px));
+
+ if(index_value < OUT_SHAPE_X)
+ {
+ const DATA_TYPE update = *(DATA_TYPE *)(updates_ptr + updates_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * px));
+ __global uchar *out_addr = output_ptr + indices_offset_first_element_in_bytes + (sizeof(DATA_TYPE) * index_value);
+ *(__global DATA_TYPE *)(out_addr) = update;
+ }
+}
+
+#endif // SCATTER1D_PARALLEL
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 44a7bb894a..af0d38ec37 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@
#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
#include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp"
#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp"
+#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp"
#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
#include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp"
#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
@@ -123,14 +124,14 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"sme2_gemv_fp32bf16fp32_dot_16VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"sme2_gemv_fp32_mla_16VL",
- [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(args); }
},
@@ -138,7 +139,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; },
[](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); }
@@ -147,7 +148,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_fp32_mopa_1VLx4VL",
- [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; },
[](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); }
@@ -156,7 +157,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; },
[](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(args); }
@@ -165,7 +166,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_fp32_mopa_4VLx1VL",
- [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; },
[](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(args); }
@@ -174,7 +175,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL",
- [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(args); }
},
@@ -182,7 +183,7 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_INTERLEAVED,
"sme2_interleaved_nomerge_fp32_mopa_2VLx2VL",
- [](const GemmArgs &args) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(args); }
},
@@ -292,14 +293,14 @@ GemmImplementation<float, float>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_fp32_mla_8x4",
- [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_fp32_mla_6x4",
- [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; },
+ [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; },
nullptr,
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(args); }
},
@@ -350,6 +351,14 @@ GemmImplementation<float, float>::with_estimate(
[](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); },
[](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>(args); }
),
+GemmImplementation<float, float>::with_estimate(
+ GemmMethod::GEMM_HYBRID,
+ "a64_ffhybrid_fp32bf16fp32_mmla_6x16",
+ KernelWeightFormat::VL256_BL64_BF16,
+ [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
+ [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16, float, float>::estimate_cycles<float>(args); },
+ [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16, float, float>(args); }
+),
#endif // BF16
GemmImplementation<float, float>::with_estimate(
GemmMethod::GEMM_INTERLEAVED,
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index 89c2d5a23e..0cc4d4f3d9 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -530,7 +530,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
} else if (_convolver) {
@@ -563,7 +563,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
} else {
@@ -579,7 +579,7 @@ public:
(m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
(this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
last_pass ? _args._act : Activation(),
- !first_pass,
+ !first_pass || _args._accumulate,
// Quantization parameters
_os, _col_bias+(multi * _args._Nsize), n0);
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index fd20e53f60..0dc0d55b27 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -128,14 +128,14 @@ GemmImplementation<int8_t, int32_t>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_s8s32_dot_8x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_s8s32_dot_6x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(args); }
},
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 4f732f7d94..ae344f09b5 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -29,7 +29,6 @@
#include "arm_gemm.hpp"
#include "bfloat.hpp"
#include "convolver.hpp"
-#include "kernel_weight_format.hpp"
#include "kernel_traits.hpp"
#include "kernel_weight_format.hpp"
#include "mergeresults.hpp"
@@ -247,6 +246,84 @@ void kernel_and_merge<true, false, Requantize32>::run(
}
}
+// Run a kernel with integrated merge, dequantizing to FP32
+template<>
+template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+void kernel_and_merge<false, false, DequantizeFloat>::run(
+#ifdef CYCLE_PROFILING
+ profiler &prof,
+#endif
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *,
+ Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max,
+ unsigned int n_0, unsigned int n_max, const Tr *bias,
+ const Activation &act, bool accumulate, const DequantizeFloat &dq, const int32_t *col_bias,
+ Tab *acc_buff)
+{
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k);
+#endif
+
+ const int32_t *offset_col_bias = nullptr;
+ const Tr *offset_bias = nullptr;
+
+ if (col_bias) {
+ offset_col_bias = col_bias + n_0;
+ }
+
+ if (bias) {
+ offset_bias = bias + n_0;
+ }
+
+ strat.kernel(// A and B pointers are just the packed panels.
+ a_ptr, b_panel,
+ // Provide relevant part of output array and row stride.
+ c_ptr ? (c_ptr + m_0 * ldc + n_0) : nullptr, ldc,
+ // M, N, K sizes
+ m_max-m_0, n_max - n_0, kern_k,
+ // Bias, activation, accumulation. Need to offset the bias as needed.
+ offset_col_bias, dq, offset_bias, act, accumulate, acc_buff);
+}
+
+template<>
+template<typename strategy, typename To, typename Tr, typename Tri, typename Tab>
+void kernel_and_merge<true, false, DequantizeFloat>::run(
+#ifdef CYCLE_PROFILING
+ profiler &prof,
+#endif
+ strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel,
+ Tr *c_ptr, int ldc, int kern_k, unsigned int m_0,
+ unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *bias,
+ const Activation &act, bool accumulate, const DequantizeFloat &qp, const int32_t *,
+ Tab *)
+{
+ const int bblocks = iceildiv(n_max - n_0, strategy::out_width());
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
+#endif
+
+ strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
+ }
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p=prof.ScopedProfiler(PROFILE_QUANTIZE, ((m_max-m_0) * bblocks * strategy::out_width() * sizeof(Tr)));
+#endif
+ auto out_area = strategy::out_width() * strategy::out_height();
+ for (int i=0; i<bblocks; i++) {
+ const unsigned int n_start = n_0 + (strategy::out_width() * i);
+ const unsigned int n_end = std::min(n_start + strategy::out_width(), n_max);
+
+ dequantize_block_32(qp, (n_end - n_start), (m_max - m_0),
+ c_panel + (i * out_area), strategy::out_width(),
+ c_ptr + m_0 * ldc + n_start, ldc,
+ bias != nullptr ? bias + n_start : nullptr, accumulate, act);
+
+ }
+ }
+}
+
// Integer GEMMs can be used in two contexts - "normal" where the full 32-bit output is required, or in
// "requantizing" context where the output will be requantized.
//
@@ -280,6 +357,12 @@ public:
typedef int32_t type;
};
+template<typename strategy>
+class accumulate_buffer_type<strategy, DequantizeFloat, false> {
+public:
+ typedef int32_t type;
+};
+
template<typename strategy, typename OutputStage>
class accumulate_buffer_type<strategy, OutputStage, true> {
public:
@@ -350,6 +433,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
const bool _thread_columns;
const Activation _act;
+ const bool _accumulate;
const int _maxthreads;
int _nthreads;
@@ -680,7 +764,7 @@ public:
_Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
_rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
_nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
_os(os) { }
@@ -690,7 +774,7 @@ public:
_Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
_rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
_nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
- _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+ _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
_k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
_os() { }
@@ -763,6 +847,9 @@ public:
const bool first_pass = (k0==0);
const bool last_pass = (kmax==_Ktotal);
+ // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass.
+ const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass;
+
// Figure out how many "K" the kernel will actually process.
unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll());
@@ -821,9 +908,9 @@ public:
// K size, and M/N ranges
kern_k, start_row, end_row, start_x, end_x,
// Only do bias on the first pass
- ((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
+ ((bias_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
// Only do activation on the last pass, and accumulation on any non-first pass.
- (last_pass ? _act : Activation()), !first_pass,
+ (last_pass ? _act : Activation()), (!first_pass || _accumulate),
// Pass in quantization parameters for requantizing kernels (others will ignore)
_os, col_bias + (multi * _Nsize),
// Accumulation buffer
@@ -948,6 +1035,9 @@ public:
const bool first_pass = (current.k0() == 0);
const bool last_pass = (current.kmax() == _Ktotal);
+ // Bias is passed for the first pass only, except for dequantizefloat nomerge cases where it's the last pass.
+ const bool bias_pass = (std::is_same<OutputStage, DequantizeFloat>::value && !MergeStep) ? last_pass : first_pass;
+
// Pointer to appropriate part of result array.
Tr *result_ptr = this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride);
@@ -969,9 +1059,9 @@ public:
// K size, and M/N ranges
kern_k, y, ymax, current.x0(), current.xmax(),
// Only do bias on the first pass
- ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
+ ((bias_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
// Only do activation on the last pass, and accumulation on any non-first pass.
- (last_pass ? _act : Activation()), !first_pass,
+ (last_pass ? _act : Activation()), (!first_pass || _accumulate),
// Pass in quantization parameters for requantizing kernels (others will ignore)
_os, col_bias + (current.multi() * _Nsize),
// Accumulation buffer
@@ -1184,6 +1274,13 @@ public:
}
}
+ void set_dequantize_scale(const float scale) override {
+ if(std::is_same<OutputStage, DequantizeFloat>::value) {
+ DequantizeFloat* df = reinterpret_cast<DequantizeFloat *>(&_os);
+ df->scale = scale;
+ }
+ }
+
void set_indirect_parameters(size_t string_len, const To * const * const *ptr) override {
assert(string_len == _Ksize);
_indirect_buf = ptr;
@@ -1248,4 +1345,10 @@ using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved<strat
template<typename strategy, typename To, typename Tr>
using GemmInterleavedQuantized = GemmInterleaved<strategy, To, Tr, Requantize32>;
+template<typename strategy, typename To, typename Tr>
+using GemmInterleavedNoMergeDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat, false>;
+
+template<typename strategy, typename To, typename Tr>
+using GemmInterleavedDequantized = GemmInterleaved<strategy, To, Tr, DequantizeFloat>;
+
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp
new file mode 100644
index 0000000000..782399df8c
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp
@@ -0,0 +1,142 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include "arm_gemm.hpp"
+
+#include "kernels/a64_gemm_s16_8x12.hpp"
+#include "kernels/a64_gemm_s8_8x12.hpp"
+#include "kernels/a64_gemm_s8_4x4.hpp"
+#include "kernels/a64_interleaved_s8s32_mmla_8x12.hpp"
+
+#ifdef ARM_COMPUTE_ENABLE_SVE
+#ifdef ARM_COMPUTE_ENABLE_SME2
+#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp"
+#include "kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_SME2
+#include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp"
+#include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp"
+#endif // ARM_COMPUTE_ENABLE_SVE
+
+#include "gemm_implementation.hpp"
+#include "gemm_interleaved.hpp"
+#include "utils.hpp"
+
+#include <cstdint>
+#include <vector>
+namespace arm_gemm {
+
+static const GemmImplementation<int8_t, float, DequantizeFloat> gemm_s8fp32_methods[] =
+{
+#ifdef ARM_COMPUTE_ENABLE_SVE
+#ifdef ARM_COMPUTE_ENABLE_SME2
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>();
+ return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+ [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL, int8_t, float>(args, dq); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8qfp32_mopa_4Vx1VL.hpp",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>();
+ return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
+ [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL, int8_t, float>(args, dq); }
+},
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "sme2_interleaved_nomerge_s8qfp32_mopa_2Vx2VL.hpp",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); },
+ nullptr,
+ [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL, int8_t, float>(args, dq); }
+},
+#endif // ARM_COMPUTE_ENABLE_SME2
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "sve_interleaved_s8s32_mmla_8x3VL",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_svei8mm(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, float>(args, qp); }
+),
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "sve_interleaved_s8s32_dot_8x3VL",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sve(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, float>(args, qp); }
+),
+#endif // ARM_COMPUTE_ENABLE_SVE
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_interleaved_s8s32_mmla_8x12",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_i8mm(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, float>(args, qp); }
+),
+{
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s16_8x12",
+ nullptr,
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->get_cpu_model() == CPUModel::A53 && ((args._Msize > 28) || ((args._Msize % 8) > 4)); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s16_8x12, int8_t, float>(args, qp); }
+},
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s8_8x12",
+ [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_dotprod(); },
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_gemm_s8_8x12, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s8_8x12, int8_t, float>(args, qp); }
+),
+GemmImplementation<int8_t, float, DequantizeFloat>::with_estimate(
+ GemmMethod::GEMM_INTERLEAVED,
+ "a64_gemm_s8_4x4",
+ nullptr,
+ [](const GemmArgs &args, const DequantizeFloat &) { return GemmInterleavedDequantized<cls_a64_gemm_s8_4x4, int8_t, float>::estimate_cycles<int8_t>(args); },
+ [](const GemmArgs &args, const DequantizeFloat &qp) { return new GemmInterleavedDequantized<cls_a64_gemm_s8_4x4, int8_t, float>(args, qp); }
+),
+{
+ GemmMethod::DEFAULT,
+ "",
+ nullptr,
+ nullptr,
+ nullptr
+}
+};
+
+template<>
+const GemmImplementation<int8_t, float, DequantizeFloat> *gemm_implementation_list<int8_t, float, DequantizeFloat>() {
+ return gemm_s8fp32_methods;
+}
+
+template UniqueGemmCommon<int8_t, float> gemm<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os);
+template KernelDescription get_gemm_method<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os);
+template std::vector<KernelDescription> get_compatible_kernels<int8_t, float, DequantizeFloat>(const GemmArgs &args, const DequantizeFloat &os);
+
+} // namespace arm_gemm
+
+#endif // __aarch64__ \ No newline at end of file
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index af5cfbbf2b..dfacb687a8 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -94,14 +94,14 @@ GemmImplementation<uint8_t, uint32_t>::with_estimate(
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_u8u32_dot_8x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); }
},
{
GemmMethod::GEMM_HYBRID,
"a64_smallK_hybrid_u8u32_dot_6x4",
- [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+ [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; },
[](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
[](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); }
},
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 92c884ce18..dbada36052 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -180,7 +180,7 @@ public:
this->_Cptr + (multi * this->_C_multi_stride) + n,
(nmax - n), (kmax-k0),
this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr,
- _args._act, (k0 != 0),
+ _args._act, (k0 != 0) || _args._accumulate,
_os, col_bias, n + (_args._Nsize * multi));
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp
index 923d008bb1..ac3cbf943f 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,10 @@ public:
{
if (std::is_same<T, float>::value) {
switch (ci->get_cpu_model()) {
+ case CPUModel::V1:
+ return { 23.64 };
default:
- return { 28.48 };
+ return { 16.89 };
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp
new file mode 100644
index 0000000000..98f7fc9403
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16.hpp
@@ -0,0 +1,111 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+#ifdef __aarch64__
+
+#include "../std_transforms_fixed.hpp"
+#include "../bfloat.hpp"
+#include "../kernel_weight_format.hpp"
+#include "../performance_parameters.hpp"
+
+#define ARGLIST \
+ unsigned int, const unsigned int *, \
+ IndirectInputArg<float>, \
+ size_t, size_t, \
+ const bfloat16 *, \
+ size_t, \
+ IndirectOutputArg<float>, \
+ const float *, Activation, bool
+
+namespace arm_gemm
+{
+// Actual kernel implementations
+void a64_ffhybrid_fp32bf16fp32_mmla_6x16( ARGLIST );
+
+class cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16
+{
+public:
+ typedef float lhs_operand_type;
+ typedef bfloat16 rhs_operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)( ARGLIST );
+
+ /* Kernel blocking parameters */
+ static constexpr unsigned int out_height()
+ {
+ return 6;
+ }
+ static unsigned int stripe_width()
+ {
+ return 4;
+ }
+
+ static KernelWeightFormat kernel_weight_format()
+ {
+ return KernelWeightFormat::VL256_BL64_BF16;
+ }
+
+ static unsigned int out_width()
+ {
+ return 16;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 4;
+ }
+
+ static constexpr bool supports_accumulate()
+ {
+ return true;
+ }
+
+ StdTransformsFixed<rhs_operand_type, result_type, 6, 16, 4> transforms = {};
+ template<typename T>
+ static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci)
+ {
+ if (std::is_same<T, float>::value) {
+ switch (ci->get_cpu_model()) {
+ case CPUModel::V1:
+ return { 21.05 };
+ default:
+ return { 15.27 };
+ }
+ }
+
+ return { 1.0 };
+ }
+
+ // Default to the generic kernel
+ kern_type kernel=a64_ffhybrid_fp32bf16fp32_mmla_6x16;
+ cls_a64_ffhybrid_fp32bf16fp32_mmla_6x16(const CPUInfo *)
+ {
+ }
+};
+
+} // namespace arm_gemm
+
+#undef ARGLIST
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp
new file mode 100644
index 0000000000..9ab4aa98f9
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_6x16/generic.cpp
@@ -0,0 +1,3240 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef __aarch64__
+
+#include "arm_gemm.hpp"
+#include "../../utils.hpp"
+#include "../../bfloat.hpp"
+
+#include <cassert>
+#include <limits>
+
+namespace arm_gemm {
+
+void a64_ffhybrid_fp32bf16fp32_mmla_6x16 (
+ unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg<float> A_arg,
+ size_t M, size_t N, const bfloat16 *B_ptr, size_t B_stride, IndirectOutputArg<float> output_arg,
+ const float *bias, Activation act, bool accumulate
+)
+{
+ struct KernelArgs {
+ float maxval = static_cast<float>(std::numeric_limits<float>::infinity());
+ float minval = - static_cast<float>(std::numeric_limits<float>::infinity());
+ unsigned int num_strings = {};
+ const unsigned int *string_lengths = {};
+ size_t N = {};
+ const bfloat16 *B_ptr = {};
+ const bfloat16 *cur_B_ptr = {};
+ size_t B_stride = {};
+ size_t output_offset = {};
+ size_t input_initial_col = {};
+ size_t input_offset = {};
+ void *output_ptr = nullptr;
+ const float *bias = nullptr;
+ } ka;
+
+ unsigned long flags=0;
+ void *input_ptr;
+
+ if (output_arg.is_indirect) {
+ ka.output_ptr=(void *)(output_arg.indirect.ptr);
+ ka.output_offset=output_arg.indirect.offset;
+ flags |= 0x4;
+ } else {
+ ka.output_ptr=(void *)(output_arg.direct.base);
+ ka.output_offset=output_arg.direct.stride;
+ }
+
+ if (A_arg.is_indirect) {
+ input_ptr=(void *)(A_arg.indirect.ptr);
+ ka.input_offset=A_arg.indirect.start_row;
+ ka.input_initial_col=A_arg.indirect.start_col;
+ flags |= 0x8;
+ } else {
+ assert(num_strings==1);
+ input_ptr=(void *)(A_arg.direct.base);
+ ka.input_offset=A_arg.direct.stride;
+ }
+ if (accumulate) {
+ flags |= 0x1;
+ }
+ ka.num_strings = num_strings;
+ ka.string_lengths = string_lengths;
+ ka.N = N;
+ ka.B_ptr = B_ptr;
+ ka.bias = bias;
+ ka.B_stride = B_stride;
+ switch(act.type) {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ ka.maxval = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ ka.minval = 0;
+ flags |= 0x2;
+ break;
+ }
+ __asm__ __volatile__(
+ "1:" // Row loop
+ "cmp %x[M], #0x6\n"
+ "bge 181f\n"
+ "cmp %x[M], #0x4\n"
+ "bgt 145f\n"
+ "beq 109f\n"
+ "cmp %x[M], #0x2\n"
+ "bgt 73f\n"
+ "beq 37f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "2:" // Height 1: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 3f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 3f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 3f\n"
+ "mov x11, x12\n"
+ "3:" // Height 1: B setup done
+ "cbz x15, 4f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "b 16f\n"
+ "4:" // Height 1: no bias
+ "tbz %x[flags], #0, 15f\n"
+ "cmp x14, #0x10\n"
+ "bge 13f\n"
+ "tbz x14, #3, 8f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "tbz x14, #2, 6f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "tbz x14, #1, 5f\n"
+ "ldr d16, [x13], #0x8\n"
+ "mov x20, #0x38\n"
+ "tbz x14, #0, 12f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "b 12f\n"
+ "5:" // Height 1: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 12f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "b 12f\n"
+ "6:" // Height 1: Partial accumulate: partial_2_8
+ "tbz x14, #1, 7f\n"
+ "ldr d11, [x13], #0x8\n"
+ "mov x20, #0x28\n"
+ "tbz x14, #0, 12f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "b 12f\n"
+ "7:" // Height 1: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 12f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "b 12f\n"
+ "8:" // Height 1: Partial accumulate: partial_4_0
+ "tbz x14, #2, 10f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "tbz x14, #1, 9f\n"
+ "ldr d10, [x13], #0x8\n"
+ "mov x20, #0x18\n"
+ "tbz x14, #0, 12f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "b 12f\n"
+ "9:" // Height 1: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 12f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "b 12f\n"
+ "10:" // Height 1: Partial accumulate: partial_2_0
+ "tbz x14, #1, 11f\n"
+ "ldr d9, [x13], #0x8\n"
+ "mov x20, #0x8\n"
+ "tbz x14, #0, 12f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "b 12f\n"
+ "11:" // Height 1: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "mov x20, #0x0\n"
+ "12:" // Height 1: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 14f\n"
+ "13:" // Height 1: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "14:" // Height 1: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "b 16f\n"
+ "15:" // Height 1: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "16:" // Height 1: setup done
+ "mov x28, #0x0\n"
+ "17:" // Height 1: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 18f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "cbnz x28, 19f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "b 19f\n"
+ "18:" // Height 1: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "19:" // Height 1: input setup done
+ "cmp x27, #0x4\n"
+ "blt 22f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "cmp x27, #0x8\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 21f\n"
+ "20:" // Height 1: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 20b\n"
+ "21:" // Height 1: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "22:" // Height 1: Multiply loop: Main loop skip
+ "cbz x27, 25f\n"
+ "cbz x27, 25f\n"
+ "tbz x27, #1, 23f\n"
+ "ldr d0, [x26], #0x8\n"
+ "tbz x27, #0, 24f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "b 24f\n"
+ "23:" // Height 1: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "24:" // Height 1: Multiply loop: Ragged operand read: Done
+ "ldr q18, [x12, #0x0]\n"
+ "ldr q17, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x6e52ec08 // bfmmla v8.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e51ec0c // bfmmla v12.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "25:" // Height 1: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 17b\n"
+ "uzp1 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v10.2d, v10.2d, v14.2d\n"
+ "uzp1 v11.2d, v11.2d, v15.2d\n"
+ "tbz %x[flags], #1, 26f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v18.4s }, [x21]\n"
+ "ld1r { v17.4s }, [x20]\n"
+ "fmin v8.4s, v8.4s, v18.4s\n"
+ "fmin v9.4s, v9.4s, v18.4s\n"
+ "fmin v10.4s, v10.4s, v18.4s\n"
+ "fmin v11.4s, v11.4s, v18.4s\n"
+ "fmax v8.4s, v8.4s, v17.4s\n"
+ "fmax v9.4s, v9.4s, v17.4s\n"
+ "fmax v10.4s, v10.4s, v17.4s\n"
+ "fmax v11.4s, v11.4s, v17.4s\n"
+ "26:" // Height 1: No activation
+ "cmp x14, #0x10\n"
+ "bge 35f\n"
+ "tbz x14, #3, 30f\n"
+ "st1 { v8.4s }, [x13], #0x10\n"
+ "st1 { v9.4s }, [x13], #0x10\n"
+ "tbz x14, #2, 28f\n"
+ "st1 { v10.4s }, [x13], #0x10\n"
+ "tbz x14, #1, 27f\n"
+ "str d11, [x13], #0x8\n"
+ "tbz x14, #0, 34f\n"
+ "st1 { v11.s }[2], [x13]\n"
+ "b 34f\n"
+ "27:" // Height 1: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 34f\n"
+ "str s11, [x13, #0x0]\n"
+ "b 34f\n"
+ "28:" // Height 1: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 29f\n"
+ "str d10, [x13], #0x8\n"
+ "tbz x14, #0, 34f\n"
+ "st1 { v10.s }[2], [x13]\n"
+ "b 34f\n"
+ "29:" // Height 1: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 34f\n"
+ "str s10, [x13, #0x0]\n"
+ "b 34f\n"
+ "30:" // Height 1: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 32f\n"
+ "st1 { v8.4s }, [x13], #0x10\n"
+ "tbz x14, #1, 31f\n"
+ "str d9, [x13], #0x8\n"
+ "tbz x14, #0, 34f\n"
+ "st1 { v9.s }[2], [x13]\n"
+ "b 34f\n"
+ "31:" // Height 1: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 34f\n"
+ "str s9, [x13, #0x0]\n"
+ "b 34f\n"
+ "32:" // Height 1: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 33f\n"
+ "str d8, [x13], #0x8\n"
+ "tbz x14, #0, 34f\n"
+ "st1 { v8.s }[2], [x13]\n"
+ "b 34f\n"
+ "33:" // Height 1: Partial direct writeback: partial_1_0
+ "str s8, [x13, #0x0]\n"
+ "34:" // Height 1: Partial direct writeback: Done
+ "b 36f\n"
+ "35:" // Height 1: Full writeback
+ "str q8, [x13, #0x0]\n"
+ "str q9, [x13, #0x10]\n"
+ "str q10, [x13, #0x20]\n"
+ "str q11, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "36:" // Height 1: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 2b\n"
+ "b 218f\n"
+ "37:" // Height 2
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "38:" // Height 2: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 39f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 39f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 39f\n"
+ "mov x11, x12\n"
+ "39:" // Height 2: B setup done
+ "cbz x15, 40f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "b 52f\n"
+ "40:" // Height 2: no bias
+ "tbz %x[flags], #0, 51f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "bge 49f\n"
+ "tbz x14, #3, 44f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "tbz x14, #2, 42f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "tbz x14, #1, 41f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "tbz x14, #0, 48f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "b 48f\n"
+ "41:" // Height 2: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 48f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "b 48f\n"
+ "42:" // Height 2: Partial accumulate: partial_2_8
+ "tbz x14, #1, 43f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "tbz x14, #0, 48f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "b 48f\n"
+ "43:" // Height 2: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 48f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "b 48f\n"
+ "44:" // Height 2: Partial accumulate: partial_4_0
+ "tbz x14, #2, 46f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "tbz x14, #1, 45f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "tbz x14, #0, 48f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "b 48f\n"
+ "45:" // Height 2: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 48f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "b 48f\n"
+ "46:" // Height 2: Partial accumulate: partial_2_0
+ "tbz x14, #1, 47f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "tbz x14, #0, 48f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "b 48f\n"
+ "47:" // Height 2: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "48:" // Height 2: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 50f\n"
+ "49:" // Height 2: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "50:" // Height 2: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "b 52f\n"
+ "51:" // Height 2: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "52:" // Height 2: setup done
+ "mov x28, #0x0\n"
+ "53:" // Height 2: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 54f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "cbnz x28, 55f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "b 55f\n"
+ "54:" // Height 2: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "55:" // Height 2: input setup done
+ "cmp x27, #0x4\n"
+ "blt 58f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 57f\n"
+ "56:" // Height 2: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 56b\n"
+ "57:" // Height 2: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "58:" // Height 2: Multiply loop: Main loop skip
+ "cbz x27, 61f\n"
+ "cbz x27, 61f\n"
+ "tbz x27, #1, 59f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "tbz x27, #0, 60f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "b 60f\n"
+ "59:" // Height 2: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "60:" // Height 2: Multiply loop: Ragged operand read: Done
+ "ldr q18, [x12, #0x0]\n"
+ "ldr q17, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x6e52ec08 // bfmmla v8.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x11, #0x0]\n"
+ ".inst 0x6e51ec0c // bfmmla v12.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e52ec09 // bfmmla v9.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x10, #0x0]\n"
+ ".inst 0x6e51ec0d // bfmmla v13.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e52ec0a // bfmmla v10.4s, v0.8h, v18.8h\n"
+ "ldr q18, [x9, #0x0]\n"
+ ".inst 0x6e51ec0e // bfmmla v14.4s, v0.8h, v17.8h\n"
+ "ldr q17, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e52ec0b // bfmmla v11.4s, v0.8h, v18.8h\n"
+ ".inst 0x6e51ec0f // bfmmla v15.4s, v0.8h, v17.8h\n"
+ "61:" // Height 2: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 53b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "tbz %x[flags], #1, 62f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v18.4s }, [x21]\n"
+ "ld1r { v17.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v18.4s\n"
+ "fmin v12.4s, v12.4s, v18.4s\n"
+ "fmin v13.4s, v13.4s, v18.4s\n"
+ "fmin v14.4s, v14.4s, v18.4s\n"
+ "fmin v8.4s, v8.4s, v18.4s\n"
+ "fmin v9.4s, v9.4s, v18.4s\n"
+ "fmin v10.4s, v10.4s, v18.4s\n"
+ "fmin v11.4s, v11.4s, v18.4s\n"
+ "fmax v6.4s, v6.4s, v17.4s\n"
+ "fmax v12.4s, v12.4s, v17.4s\n"
+ "fmax v13.4s, v13.4s, v17.4s\n"
+ "fmax v14.4s, v14.4s, v17.4s\n"
+ "fmax v8.4s, v8.4s, v17.4s\n"
+ "fmax v9.4s, v9.4s, v17.4s\n"
+ "fmax v10.4s, v10.4s, v17.4s\n"
+ "fmax v11.4s, v11.4s, v17.4s\n"
+ "62:" // Height 2: No activation
+ "cmp x14, #0x10\n"
+ "bge 71f\n"
+ "tbz x14, #3, 66f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "tbz x14, #2, 64f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "tbz x14, #1, 63f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "tbz x14, #0, 70f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "b 70f\n"
+ "63:" // Height 2: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 70f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "b 70f\n"
+ "64:" // Height 2: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 65f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "tbz x14, #0, 70f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "b 70f\n"
+ "65:" // Height 2: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 70f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "b 70f\n"
+ "66:" // Height 2: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 68f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "tbz x14, #1, 67f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "tbz x14, #0, 70f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "b 70f\n"
+ "67:" // Height 2: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 70f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "b 70f\n"
+ "68:" // Height 2: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 69f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "tbz x14, #0, 70f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "b 70f\n"
+ "69:" // Height 2: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "70:" // Height 2: Partial direct writeback: Done
+ "b 72f\n"
+ "71:" // Height 2: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "72:" // Height 2: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 38b\n"
+ "b 218f\n"
+ "73:" // Height 3
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "74:" // Height 3: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 75f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 75f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 75f\n"
+ "mov x11, x12\n"
+ "75:" // Height 3: B setup done
+ "cbz x15, 76f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "mov v16.16b, v8.16b\n"
+ "mov v20.16b, v12.16b\n"
+ "mov v17.16b, v9.16b\n"
+ "mov v21.16b, v13.16b\n"
+ "mov v18.16b, v10.16b\n"
+ "mov v22.16b, v14.16b\n"
+ "mov v19.16b, v11.16b\n"
+ "mov v23.16b, v15.16b\n"
+ "b 88f\n"
+ "76:" // Height 3: no bias
+ "tbz %x[flags], #0, 87f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "bge 85f\n"
+ "tbz x14, #3, 80f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "ld1 { v18.4s }, [x25], #0x10\n"
+ "tbz x14, #2, 78f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "ld1 { v19.4s }, [x25], #0x10\n"
+ "tbz x14, #1, 77f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "ldr d24, [x25], #0x8\n"
+ "tbz x14, #0, 84f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "ld1 { v24.s }[2], [x25]\n"
+ "b 84f\n"
+ "77:" // Height 3: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 84f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "ldr s24, [x25, #0x0]\n"
+ "b 84f\n"
+ "78:" // Height 3: Partial accumulate: partial_2_8
+ "tbz x14, #1, 79f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "ldr d19, [x25], #0x8\n"
+ "tbz x14, #0, 84f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "ld1 { v19.s }[2], [x25]\n"
+ "b 84f\n"
+ "79:" // Height 3: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 84f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "ldr s19, [x25, #0x0]\n"
+ "b 84f\n"
+ "80:" // Height 3: Partial accumulate: partial_4_0
+ "tbz x14, #2, 82f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "tbz x14, #1, 81f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "ldr d18, [x25], #0x8\n"
+ "tbz x14, #0, 84f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "ld1 { v18.s }[2], [x25]\n"
+ "b 84f\n"
+ "81:" // Height 3: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 84f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "ldr s18, [x25, #0x0]\n"
+ "b 84f\n"
+ "82:" // Height 3: Partial accumulate: partial_2_0
+ "tbz x14, #1, 83f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "ldr d17, [x25], #0x8\n"
+ "tbz x14, #0, 84f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "ld1 { v17.s }[2], [x25]\n"
+ "b 84f\n"
+ "83:" // Height 3: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "ldr s17, [x25, #0x0]\n"
+ "84:" // Height 3: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 86f\n"
+ "85:" // Height 3: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "ldr q17, [x25, #0x0]\n"
+ "ldr q18, [x25, #0x10]\n"
+ "ldr q19, [x25, #0x20]\n"
+ "ldr q24, [x25, #0x30]\n"
+ "86:" // Height 3: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "zip1 v16.2d, v17.2d, v20.2d\n"
+ "zip2 v20.2d, v17.2d, v20.2d\n"
+ "zip1 v17.2d, v18.2d, v21.2d\n"
+ "zip2 v21.2d, v18.2d, v21.2d\n"
+ "zip1 v18.2d, v19.2d, v22.2d\n"
+ "zip2 v22.2d, v19.2d, v22.2d\n"
+ "zip1 v19.2d, v24.2d, v23.2d\n"
+ "zip2 v23.2d, v24.2d, v23.2d\n"
+ "b 88f\n"
+ "87:" // Height 3: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v17.16b, #0x0\n"
+ "movi v18.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "88:" // Height 3: setup done
+ "mov x28, #0x0\n"
+ "89:" // Height 3: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 90f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "ldr x24, [x20, #0x10]\n"
+ "cbnz x28, 91f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "add x24, x24, x20, LSL #2\n"
+ "b 91f\n"
+ "90:" // Height 3: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "add x24, x25, x21, LSL #2\n"
+ "91:" // Height 3: input setup done
+ "cmp x27, #0x4\n"
+ "blt 94f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 93f\n"
+ "92:" // Height 3: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 92b\n"
+ "93:" // Height 3: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "94:" // Height 3: Multiply loop: Main loop skip
+ "cbz x27, 97f\n"
+ "cbz x27, 97f\n"
+ "tbz x27, #1, 95f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "ldr d2, [x24], #0x8\n"
+ "tbz x27, #0, 96f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "ld1 { v2.s }[2], [x24]\n"
+ "b 96f\n"
+ "95:" // Height 3: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "ldr s2, [x24, #0x0]\n"
+ "96:" // Height 3: Multiply loop: Ragged operand read: Done
+ "ldr q26, [x12, #0x0]\n"
+ "ldr q25, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x6e5aec50 // bfmmla v16.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec54 // bfmmla v20.4s, v2.8h, v25.8h\n"
+ ".inst 0x6e5aec08 // bfmmla v8.4s, v0.8h, v26.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e59ec0c // bfmmla v12.4s, v0.8h, v25.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "97:" // Height 3: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 89b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "add x25, x26, x20, LSL #2\n"
+ "uzp1 v16.2d, v16.2d, v20.2d\n"
+ "uzp1 v17.2d, v17.2d, v21.2d\n"
+ "uzp1 v18.2d, v18.2d, v22.2d\n"
+ "uzp1 v19.2d, v19.2d, v23.2d\n"
+ "tbz %x[flags], #1, 98f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v26.4s }, [x21]\n"
+ "ld1r { v25.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v26.4s\n"
+ "fmin v12.4s, v12.4s, v26.4s\n"
+ "fmin v13.4s, v13.4s, v26.4s\n"
+ "fmin v14.4s, v14.4s, v26.4s\n"
+ "fmin v8.4s, v8.4s, v26.4s\n"
+ "fmin v9.4s, v9.4s, v26.4s\n"
+ "fmin v10.4s, v10.4s, v26.4s\n"
+ "fmin v11.4s, v11.4s, v26.4s\n"
+ "fmin v16.4s, v16.4s, v26.4s\n"
+ "fmin v17.4s, v17.4s, v26.4s\n"
+ "fmin v18.4s, v18.4s, v26.4s\n"
+ "fmin v19.4s, v19.4s, v26.4s\n"
+ "fmax v6.4s, v6.4s, v25.4s\n"
+ "fmax v12.4s, v12.4s, v25.4s\n"
+ "fmax v13.4s, v13.4s, v25.4s\n"
+ "fmax v14.4s, v14.4s, v25.4s\n"
+ "fmax v8.4s, v8.4s, v25.4s\n"
+ "fmax v9.4s, v9.4s, v25.4s\n"
+ "fmax v10.4s, v10.4s, v25.4s\n"
+ "fmax v11.4s, v11.4s, v25.4s\n"
+ "fmax v16.4s, v16.4s, v25.4s\n"
+ "fmax v17.4s, v17.4s, v25.4s\n"
+ "fmax v18.4s, v18.4s, v25.4s\n"
+ "fmax v19.4s, v19.4s, v25.4s\n"
+ "98:" // Height 3: No activation
+ "cmp x14, #0x10\n"
+ "bge 107f\n"
+ "tbz x14, #3, 102f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "st1 { v16.4s }, [x25], #0x10\n"
+ "st1 { v17.4s }, [x25], #0x10\n"
+ "tbz x14, #2, 100f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "st1 { v18.4s }, [x25], #0x10\n"
+ "tbz x14, #1, 99f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "str d19, [x25], #0x8\n"
+ "tbz x14, #0, 106f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "st1 { v19.s }[2], [x25]\n"
+ "b 106f\n"
+ "99:" // Height 3: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 106f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "str s19, [x25, #0x0]\n"
+ "b 106f\n"
+ "100:" // Height 3: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 101f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "str d18, [x25], #0x8\n"
+ "tbz x14, #0, 106f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "st1 { v18.s }[2], [x25]\n"
+ "b 106f\n"
+ "101:" // Height 3: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 106f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "str s18, [x25, #0x0]\n"
+ "b 106f\n"
+ "102:" // Height 3: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 104f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v16.4s }, [x25], #0x10\n"
+ "tbz x14, #1, 103f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "str d17, [x25], #0x8\n"
+ "tbz x14, #0, 106f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "st1 { v17.s }[2], [x25]\n"
+ "b 106f\n"
+ "103:" // Height 3: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 106f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "str s17, [x25, #0x0]\n"
+ "b 106f\n"
+ "104:" // Height 3: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 105f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "str d16, [x25], #0x8\n"
+ "tbz x14, #0, 106f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "st1 { v16.s }[2], [x25]\n"
+ "b 106f\n"
+ "105:" // Height 3: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "str s16, [x25, #0x0]\n"
+ "106:" // Height 3: Partial direct writeback: Done
+ "b 108f\n"
+ "107:" // Height 3: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "str q16, [x25, #0x0]\n"
+ "str q17, [x25, #0x10]\n"
+ "str q18, [x25, #0x20]\n"
+ "str q19, [x25, #0x30]\n"
+ "108:" // Height 3: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 74b\n"
+ "b 218f\n"
+ "109:" // Height 4
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "110:" // Height 4: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 111f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 111f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 111f\n"
+ "mov x11, x12\n"
+ "111:" // Height 4: B setup done
+ "cbz x15, 112f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "mov v16.16b, v8.16b\n"
+ "mov v20.16b, v12.16b\n"
+ "mov v17.16b, v9.16b\n"
+ "mov v21.16b, v13.16b\n"
+ "mov v18.16b, v10.16b\n"
+ "mov v22.16b, v14.16b\n"
+ "mov v19.16b, v11.16b\n"
+ "mov v23.16b, v15.16b\n"
+ "b 124f\n"
+ "112:" // Height 4: no bias
+ "tbz %x[flags], #0, 123f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "bge 121f\n"
+ "tbz x14, #3, 116f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "ld1 { v18.4s }, [x25], #0x10\n"
+ "ld1 { v21.4s }, [x24], #0x10\n"
+ "tbz x14, #2, 114f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "ld1 { v19.4s }, [x25], #0x10\n"
+ "ld1 { v22.4s }, [x24], #0x10\n"
+ "tbz x14, #1, 113f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "ldr d24, [x25], #0x8\n"
+ "ldr d23, [x24], #0x8\n"
+ "tbz x14, #0, 120f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "ld1 { v24.s }[2], [x25]\n"
+ "ld1 { v23.s }[2], [x24]\n"
+ "b 120f\n"
+ "113:" // Height 4: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 120f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "ldr s24, [x25, #0x0]\n"
+ "ldr s23, [x24, #0x0]\n"
+ "b 120f\n"
+ "114:" // Height 4: Partial accumulate: partial_2_8
+ "tbz x14, #1, 115f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "ldr d19, [x25], #0x8\n"
+ "ldr d22, [x24], #0x8\n"
+ "tbz x14, #0, 120f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "ld1 { v19.s }[2], [x25]\n"
+ "ld1 { v22.s }[2], [x24]\n"
+ "b 120f\n"
+ "115:" // Height 4: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 120f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "ldr s19, [x25, #0x0]\n"
+ "ldr s22, [x24, #0x0]\n"
+ "b 120f\n"
+ "116:" // Height 4: Partial accumulate: partial_4_0
+ "tbz x14, #2, 118f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "tbz x14, #1, 117f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "ldr d18, [x25], #0x8\n"
+ "ldr d21, [x24], #0x8\n"
+ "tbz x14, #0, 120f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "ld1 { v18.s }[2], [x25]\n"
+ "ld1 { v21.s }[2], [x24]\n"
+ "b 120f\n"
+ "117:" // Height 4: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 120f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "ldr s18, [x25, #0x0]\n"
+ "ldr s21, [x24, #0x0]\n"
+ "b 120f\n"
+ "118:" // Height 4: Partial accumulate: partial_2_0
+ "tbz x14, #1, 119f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "ldr d17, [x25], #0x8\n"
+ "ldr d20, [x24], #0x8\n"
+ "tbz x14, #0, 120f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "ld1 { v17.s }[2], [x25]\n"
+ "ld1 { v20.s }[2], [x24]\n"
+ "b 120f\n"
+ "119:" // Height 4: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "ldr s17, [x25, #0x0]\n"
+ "ldr s20, [x24, #0x0]\n"
+ "120:" // Height 4: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 122f\n"
+ "121:" // Height 4: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "ldr q17, [x25, #0x0]\n"
+ "ldr q18, [x25, #0x10]\n"
+ "ldr q19, [x25, #0x20]\n"
+ "ldr q24, [x25, #0x30]\n"
+ "ldr q20, [x24, #0x0]\n"
+ "ldr q21, [x24, #0x10]\n"
+ "ldr q22, [x24, #0x20]\n"
+ "ldr q23, [x24, #0x30]\n"
+ "122:" // Height 4: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "zip1 v16.2d, v17.2d, v20.2d\n"
+ "zip2 v20.2d, v17.2d, v20.2d\n"
+ "zip1 v17.2d, v18.2d, v21.2d\n"
+ "zip2 v21.2d, v18.2d, v21.2d\n"
+ "zip1 v18.2d, v19.2d, v22.2d\n"
+ "zip2 v22.2d, v19.2d, v22.2d\n"
+ "zip1 v19.2d, v24.2d, v23.2d\n"
+ "zip2 v23.2d, v24.2d, v23.2d\n"
+ "b 124f\n"
+ "123:" // Height 4: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v17.16b, #0x0\n"
+ "movi v18.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "124:" // Height 4: setup done
+ "mov x28, #0x0\n"
+ "125:" // Height 4: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 126f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "ldr x24, [x20, #0x10]\n"
+ "ldr x23, [x20, #0x18]\n"
+ "cbnz x28, 127f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "add x24, x24, x20, LSL #2\n"
+ "add x23, x23, x20, LSL #2\n"
+ "b 127f\n"
+ "126:" // Height 4: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "add x24, x25, x21, LSL #2\n"
+ "add x23, x24, x21, LSL #2\n"
+ "127:" // Height 4: input setup done
+ "cmp x27, #0x4\n"
+ "blt 130f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 129f\n"
+ "128:" // Height 4: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 128b\n"
+ "129:" // Height 4: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "130:" // Height 4: Multiply loop: Main loop skip
+ "cbz x27, 133f\n"
+ "cbz x27, 133f\n"
+ "tbz x27, #1, 131f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "ldr d2, [x24], #0x8\n"
+ "ldr d3, [x23], #0x8\n"
+ "tbz x27, #0, 132f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "ld1 { v2.s }[2], [x24]\n"
+ "ld1 { v3.s }[2], [x23]\n"
+ "b 132f\n"
+ "131:" // Height 4: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "ldr s2, [x24, #0x0]\n"
+ "ldr s3, [x23, #0x0]\n"
+ "132:" // Height 4: Multiply loop: Ragged operand read: Done
+ "ldr q26, [x12, #0x0]\n"
+ "ldr q25, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x6e5aec08 // bfmmla v8.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec50 // bfmmla v16.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x11, #0x0]\n"
+ ".inst 0x6e59ec0c // bfmmla v12.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec54 // bfmmla v20.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e5aec09 // bfmmla v9.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec51 // bfmmla v17.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x10, #0x0]\n"
+ ".inst 0x6e59ec0d // bfmmla v13.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec55 // bfmmla v21.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e5aec0a // bfmmla v10.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec52 // bfmmla v18.4s, v2.8h, v26.8h\n"
+ "ldr q26, [x9, #0x0]\n"
+ ".inst 0x6e59ec0e // bfmmla v14.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec56 // bfmmla v22.4s, v2.8h, v25.8h\n"
+ "ldr q25, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e5aec0b // bfmmla v11.4s, v0.8h, v26.8h\n"
+ ".inst 0x6e5aec53 // bfmmla v19.4s, v2.8h, v26.8h\n"
+ ".inst 0x6e59ec0f // bfmmla v15.4s, v0.8h, v25.8h\n"
+ ".inst 0x6e59ec57 // bfmmla v23.4s, v2.8h, v25.8h\n"
+ "133:" // Height 4: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 125b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "add x24, x25, x20, LSL #2\n"
+ "uzp1 v15.2d, v16.2d, v20.2d\n"
+ "uzp2 v16.2d, v16.2d, v20.2d\n"
+ "uzp1 v20.2d, v17.2d, v21.2d\n"
+ "uzp2 v17.2d, v17.2d, v21.2d\n"
+ "uzp1 v21.2d, v18.2d, v22.2d\n"
+ "uzp2 v18.2d, v18.2d, v22.2d\n"
+ "uzp1 v22.2d, v19.2d, v23.2d\n"
+ "uzp2 v19.2d, v19.2d, v23.2d\n"
+ "tbz %x[flags], #1, 134f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v26.4s }, [x21]\n"
+ "ld1r { v25.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v26.4s\n"
+ "fmin v12.4s, v12.4s, v26.4s\n"
+ "fmin v13.4s, v13.4s, v26.4s\n"
+ "fmin v14.4s, v14.4s, v26.4s\n"
+ "fmin v8.4s, v8.4s, v26.4s\n"
+ "fmin v9.4s, v9.4s, v26.4s\n"
+ "fmin v10.4s, v10.4s, v26.4s\n"
+ "fmin v11.4s, v11.4s, v26.4s\n"
+ "fmin v15.4s, v15.4s, v26.4s\n"
+ "fmin v20.4s, v20.4s, v26.4s\n"
+ "fmin v21.4s, v21.4s, v26.4s\n"
+ "fmin v22.4s, v22.4s, v26.4s\n"
+ "fmin v16.4s, v16.4s, v26.4s\n"
+ "fmin v17.4s, v17.4s, v26.4s\n"
+ "fmin v18.4s, v18.4s, v26.4s\n"
+ "fmin v19.4s, v19.4s, v26.4s\n"
+ "fmax v6.4s, v6.4s, v25.4s\n"
+ "fmax v12.4s, v12.4s, v25.4s\n"
+ "fmax v13.4s, v13.4s, v25.4s\n"
+ "fmax v14.4s, v14.4s, v25.4s\n"
+ "fmax v8.4s, v8.4s, v25.4s\n"
+ "fmax v9.4s, v9.4s, v25.4s\n"
+ "fmax v10.4s, v10.4s, v25.4s\n"
+ "fmax v11.4s, v11.4s, v25.4s\n"
+ "fmax v15.4s, v15.4s, v25.4s\n"
+ "fmax v20.4s, v20.4s, v25.4s\n"
+ "fmax v21.4s, v21.4s, v25.4s\n"
+ "fmax v22.4s, v22.4s, v25.4s\n"
+ "fmax v16.4s, v16.4s, v25.4s\n"
+ "fmax v17.4s, v17.4s, v25.4s\n"
+ "fmax v18.4s, v18.4s, v25.4s\n"
+ "fmax v19.4s, v19.4s, v25.4s\n"
+ "134:" // Height 4: No activation
+ "cmp x14, #0x10\n"
+ "bge 143f\n"
+ "tbz x14, #3, 138f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v20.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v17.4s }, [x24], #0x10\n"
+ "tbz x14, #2, 136f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "st1 { v21.4s }, [x25], #0x10\n"
+ "st1 { v18.4s }, [x24], #0x10\n"
+ "tbz x14, #1, 135f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "str d22, [x25], #0x8\n"
+ "str d19, [x24], #0x8\n"
+ "tbz x14, #0, 142f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "st1 { v22.s }[2], [x25]\n"
+ "st1 { v19.s }[2], [x24]\n"
+ "b 142f\n"
+ "135:" // Height 4: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 142f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "str s22, [x25, #0x0]\n"
+ "str s19, [x24, #0x0]\n"
+ "b 142f\n"
+ "136:" // Height 4: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 137f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "str d21, [x25], #0x8\n"
+ "str d18, [x24], #0x8\n"
+ "tbz x14, #0, 142f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "st1 { v21.s }[2], [x25]\n"
+ "st1 { v18.s }[2], [x24]\n"
+ "b 142f\n"
+ "137:" // Height 4: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 142f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "str s21, [x25, #0x0]\n"
+ "str s18, [x24, #0x0]\n"
+ "b 142f\n"
+ "138:" // Height 4: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 140f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "tbz x14, #1, 139f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "str d20, [x25], #0x8\n"
+ "str d17, [x24], #0x8\n"
+ "tbz x14, #0, 142f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "st1 { v20.s }[2], [x25]\n"
+ "st1 { v17.s }[2], [x24]\n"
+ "b 142f\n"
+ "139:" // Height 4: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 142f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "str s20, [x25, #0x0]\n"
+ "str s17, [x24, #0x0]\n"
+ "b 142f\n"
+ "140:" // Height 4: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 141f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "str d15, [x25], #0x8\n"
+ "str d16, [x24], #0x8\n"
+ "tbz x14, #0, 142f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "st1 { v15.s }[2], [x25]\n"
+ "st1 { v16.s }[2], [x24]\n"
+ "b 142f\n"
+ "141:" // Height 4: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "str s15, [x25, #0x0]\n"
+ "str s16, [x24, #0x0]\n"
+ "142:" // Height 4: Partial direct writeback: Done
+ "b 144f\n"
+ "143:" // Height 4: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "str q15, [x25, #0x0]\n"
+ "str q20, [x25, #0x10]\n"
+ "str q21, [x25, #0x20]\n"
+ "str q22, [x25, #0x30]\n"
+ "str q16, [x24, #0x0]\n"
+ "str q17, [x24, #0x10]\n"
+ "str q18, [x24, #0x20]\n"
+ "str q19, [x24, #0x30]\n"
+ "144:" // Height 4: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 110b\n"
+ "b 218f\n"
+ "145:" // Height 5
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "146:" // Height 5: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 147f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 147f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 147f\n"
+ "mov x11, x12\n"
+ "147:" // Height 5: B setup done
+ "cbz x15, 148f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "mov v16.16b, v8.16b\n"
+ "mov v20.16b, v12.16b\n"
+ "mov v17.16b, v9.16b\n"
+ "mov v21.16b, v13.16b\n"
+ "mov v18.16b, v10.16b\n"
+ "mov v22.16b, v14.16b\n"
+ "mov v19.16b, v11.16b\n"
+ "mov v23.16b, v15.16b\n"
+ "mov v24.16b, v8.16b\n"
+ "mov v28.16b, v12.16b\n"
+ "mov v25.16b, v9.16b\n"
+ "mov v29.16b, v13.16b\n"
+ "mov v26.16b, v10.16b\n"
+ "mov v30.16b, v14.16b\n"
+ "mov v27.16b, v11.16b\n"
+ "mov v31.16b, v15.16b\n"
+ "b 160f\n"
+ "148:" // Height 5: no bias
+ "tbz %x[flags], #0, 159f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "add x23, x24, x20, LSL #2\n"
+ "bge 157f\n"
+ "tbz x14, #3, 152f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v25.4s }, [x23], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "ld1 { v18.4s }, [x25], #0x10\n"
+ "ld1 { v21.4s }, [x24], #0x10\n"
+ "ld1 { v26.4s }, [x23], #0x10\n"
+ "tbz x14, #2, 150f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "ld1 { v19.4s }, [x25], #0x10\n"
+ "ld1 { v22.4s }, [x24], #0x10\n"
+ "ld1 { v27.4s }, [x23], #0x10\n"
+ "tbz x14, #1, 149f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "ldr d24, [x25], #0x8\n"
+ "ldr d23, [x24], #0x8\n"
+ "ldr d6, [x23], #0x8\n"
+ "tbz x14, #0, 156f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "ld1 { v24.s }[2], [x25]\n"
+ "ld1 { v23.s }[2], [x24]\n"
+ "ld1 { v6.s }[2], [x23]\n"
+ "b 156f\n"
+ "149:" // Height 5: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 156f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "ldr s24, [x25, #0x0]\n"
+ "ldr s23, [x24, #0x0]\n"
+ "ldr s6, [x23, #0x0]\n"
+ "b 156f\n"
+ "150:" // Height 5: Partial accumulate: partial_2_8
+ "tbz x14, #1, 151f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "ldr d19, [x25], #0x8\n"
+ "ldr d22, [x24], #0x8\n"
+ "ldr d27, [x23], #0x8\n"
+ "tbz x14, #0, 156f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "ld1 { v19.s }[2], [x25]\n"
+ "ld1 { v22.s }[2], [x24]\n"
+ "ld1 { v27.s }[2], [x23]\n"
+ "b 156f\n"
+ "151:" // Height 5: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 156f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "ldr s19, [x25, #0x0]\n"
+ "ldr s22, [x24, #0x0]\n"
+ "ldr s27, [x23, #0x0]\n"
+ "b 156f\n"
+ "152:" // Height 5: Partial accumulate: partial_4_0
+ "tbz x14, #2, 154f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v25.4s }, [x23], #0x10\n"
+ "tbz x14, #1, 153f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "ldr d18, [x25], #0x8\n"
+ "ldr d21, [x24], #0x8\n"
+ "ldr d26, [x23], #0x8\n"
+ "tbz x14, #0, 156f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "ld1 { v18.s }[2], [x25]\n"
+ "ld1 { v21.s }[2], [x24]\n"
+ "ld1 { v26.s }[2], [x23]\n"
+ "b 156f\n"
+ "153:" // Height 5: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 156f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "ldr s18, [x25, #0x0]\n"
+ "ldr s21, [x24, #0x0]\n"
+ "ldr s26, [x23, #0x0]\n"
+ "b 156f\n"
+ "154:" // Height 5: Partial accumulate: partial_2_0
+ "tbz x14, #1, 155f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "ldr d17, [x25], #0x8\n"
+ "ldr d20, [x24], #0x8\n"
+ "ldr d25, [x23], #0x8\n"
+ "tbz x14, #0, 156f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "ld1 { v17.s }[2], [x25]\n"
+ "ld1 { v20.s }[2], [x24]\n"
+ "ld1 { v25.s }[2], [x23]\n"
+ "b 156f\n"
+ "155:" // Height 5: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "ldr s17, [x25, #0x0]\n"
+ "ldr s20, [x24, #0x0]\n"
+ "ldr s25, [x23, #0x0]\n"
+ "156:" // Height 5: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 158f\n"
+ "157:" // Height 5: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "ldr q17, [x25, #0x0]\n"
+ "ldr q18, [x25, #0x10]\n"
+ "ldr q19, [x25, #0x20]\n"
+ "ldr q24, [x25, #0x30]\n"
+ "ldr q20, [x24, #0x0]\n"
+ "ldr q21, [x24, #0x10]\n"
+ "ldr q22, [x24, #0x20]\n"
+ "ldr q23, [x24, #0x30]\n"
+ "ldr q25, [x23, #0x0]\n"
+ "ldr q26, [x23, #0x10]\n"
+ "ldr q27, [x23, #0x20]\n"
+ "ldr q6, [x23, #0x30]\n"
+ "158:" // Height 5: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "zip1 v16.2d, v17.2d, v20.2d\n"
+ "zip2 v20.2d, v17.2d, v20.2d\n"
+ "zip1 v17.2d, v18.2d, v21.2d\n"
+ "zip2 v21.2d, v18.2d, v21.2d\n"
+ "zip1 v18.2d, v19.2d, v22.2d\n"
+ "zip2 v22.2d, v19.2d, v22.2d\n"
+ "zip1 v19.2d, v24.2d, v23.2d\n"
+ "zip2 v23.2d, v24.2d, v23.2d\n"
+ "zip1 v24.2d, v25.2d, v28.2d\n"
+ "zip2 v28.2d, v25.2d, v28.2d\n"
+ "zip1 v25.2d, v26.2d, v29.2d\n"
+ "zip2 v29.2d, v26.2d, v29.2d\n"
+ "zip1 v26.2d, v27.2d, v30.2d\n"
+ "zip2 v30.2d, v27.2d, v30.2d\n"
+ "zip1 v27.2d, v6.2d, v31.2d\n"
+ "zip2 v31.2d, v6.2d, v31.2d\n"
+ "b 160f\n"
+ "159:" // Height 5: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v17.16b, #0x0\n"
+ "movi v18.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "movi v24.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v26.16b, #0x0\n"
+ "movi v27.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "movi v29.16b, #0x0\n"
+ "movi v30.16b, #0x0\n"
+ "movi v31.16b, #0x0\n"
+ "160:" // Height 5: setup done
+ "mov x28, #0x0\n"
+ "161:" // Height 5: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 162f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "ldr x24, [x20, #0x10]\n"
+ "ldr x23, [x20, #0x18]\n"
+ "ldr x22, [x20, #0x20]\n"
+ "cbnz x28, 163f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "add x24, x24, x20, LSL #2\n"
+ "add x23, x23, x20, LSL #2\n"
+ "add x22, x22, x20, LSL #2\n"
+ "b 163f\n"
+ "162:" // Height 5: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "add x24, x25, x21, LSL #2\n"
+ "add x23, x24, x21, LSL #2\n"
+ "add x22, x23, x21, LSL #2\n"
+ "163:" // Height 5: input setup done
+ "cmp x27, #0x4\n"
+ "blt 166f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ "ld1 { v4.4s }, [x22], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 165f\n"
+ "164:" // Height 5: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q6, [x11, #0x0]\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "ldr q5, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x10, #0x0]\n"
+ ".inst 0x6e45ec0d // bfmmla v13.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e45ec55 // bfmmla v21.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e45ec9d // bfmmla v29.4s, v4.8h, v5.8h\n"
+ "ldr q5, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x9, #0x0]\n"
+ ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e45ec56 // bfmmla v22.4s, v2.8h, v5.8h\n"
+ ".inst 0x6e45ec9e // bfmmla v30.4s, v4.8h, v5.8h\n"
+ "ldr q5, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e45ec0f // bfmmla v15.4s, v0.8h, v5.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ ".inst 0x6e45ec57 // bfmmla v23.4s, v2.8h, v5.8h\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ ".inst 0x6e45ec9f // bfmmla v31.4s, v4.8h, v5.8h\n"
+ "ld1 { v4.4s }, [x22], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 164b\n"
+ "165:" // Height 5: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q3, [x11, #0x0]\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ "ldr q1, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x10, #0x0]\n"
+ ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x9, #0x0]\n"
+ ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n"
+ ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n"
+ "166:" // Height 5: Multiply loop: Main loop skip
+ "cbz x27, 169f\n"
+ "cbz x27, 169f\n"
+ "tbz x27, #1, 167f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "ldr d2, [x24], #0x8\n"
+ "ldr d3, [x23], #0x8\n"
+ "ldr d4, [x22], #0x8\n"
+ "tbz x27, #0, 168f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "ld1 { v2.s }[2], [x24]\n"
+ "ld1 { v3.s }[2], [x23]\n"
+ "ld1 { v4.s }[2], [x22]\n"
+ "b 168f\n"
+ "167:" // Height 5: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "ldr s2, [x24, #0x0]\n"
+ "ldr s3, [x23, #0x0]\n"
+ "ldr s4, [x22, #0x0]\n"
+ "168:" // Height 5: Multiply loop: Ragged operand read: Done
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q5, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ ".inst 0x6e45ec0c // bfmmla v12.4s, v0.8h, v5.8h\n"
+ ".inst 0x6e45ec9c // bfmmla v28.4s, v4.8h, v5.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ "ldr q3, [x11, #0x0]\n"
+ ".inst 0x6e45ec54 // bfmmla v20.4s, v2.8h, v5.8h\n"
+ "ldr q1, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x10, #0x0]\n"
+ ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x9, #0x0]\n"
+ ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n"
+ ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n"
+ "169:" // Height 5: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 161b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "uzp1 v15.2d, v16.2d, v20.2d\n"
+ "uzp2 v16.2d, v16.2d, v20.2d\n"
+ "add x23, x24, x20, LSL #2\n"
+ "uzp1 v20.2d, v17.2d, v21.2d\n"
+ "uzp2 v17.2d, v17.2d, v21.2d\n"
+ "uzp1 v21.2d, v18.2d, v22.2d\n"
+ "uzp2 v18.2d, v18.2d, v22.2d\n"
+ "uzp1 v22.2d, v19.2d, v23.2d\n"
+ "uzp2 v19.2d, v19.2d, v23.2d\n"
+ "uzp1 v24.2d, v24.2d, v28.2d\n"
+ "uzp1 v25.2d, v25.2d, v29.2d\n"
+ "uzp1 v26.2d, v26.2d, v30.2d\n"
+ "uzp1 v27.2d, v27.2d, v31.2d\n"
+ "tbz %x[flags], #1, 170f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v1.4s }, [x21]\n"
+ "ld1r { v0.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v1.4s\n"
+ "fmin v12.4s, v12.4s, v1.4s\n"
+ "fmin v13.4s, v13.4s, v1.4s\n"
+ "fmin v14.4s, v14.4s, v1.4s\n"
+ "fmin v8.4s, v8.4s, v1.4s\n"
+ "fmin v9.4s, v9.4s, v1.4s\n"
+ "fmin v10.4s, v10.4s, v1.4s\n"
+ "fmin v11.4s, v11.4s, v1.4s\n"
+ "fmin v15.4s, v15.4s, v1.4s\n"
+ "fmin v20.4s, v20.4s, v1.4s\n"
+ "fmin v21.4s, v21.4s, v1.4s\n"
+ "fmin v22.4s, v22.4s, v1.4s\n"
+ "fmin v16.4s, v16.4s, v1.4s\n"
+ "fmin v17.4s, v17.4s, v1.4s\n"
+ "fmin v18.4s, v18.4s, v1.4s\n"
+ "fmin v19.4s, v19.4s, v1.4s\n"
+ "fmin v24.4s, v24.4s, v1.4s\n"
+ "fmin v25.4s, v25.4s, v1.4s\n"
+ "fmin v26.4s, v26.4s, v1.4s\n"
+ "fmin v27.4s, v27.4s, v1.4s\n"
+ "fmax v6.4s, v6.4s, v0.4s\n"
+ "fmax v12.4s, v12.4s, v0.4s\n"
+ "fmax v13.4s, v13.4s, v0.4s\n"
+ "fmax v14.4s, v14.4s, v0.4s\n"
+ "fmax v8.4s, v8.4s, v0.4s\n"
+ "fmax v9.4s, v9.4s, v0.4s\n"
+ "fmax v10.4s, v10.4s, v0.4s\n"
+ "fmax v11.4s, v11.4s, v0.4s\n"
+ "fmax v15.4s, v15.4s, v0.4s\n"
+ "fmax v20.4s, v20.4s, v0.4s\n"
+ "fmax v21.4s, v21.4s, v0.4s\n"
+ "fmax v22.4s, v22.4s, v0.4s\n"
+ "fmax v16.4s, v16.4s, v0.4s\n"
+ "fmax v17.4s, v17.4s, v0.4s\n"
+ "fmax v18.4s, v18.4s, v0.4s\n"
+ "fmax v19.4s, v19.4s, v0.4s\n"
+ "fmax v24.4s, v24.4s, v0.4s\n"
+ "fmax v25.4s, v25.4s, v0.4s\n"
+ "fmax v26.4s, v26.4s, v0.4s\n"
+ "fmax v27.4s, v27.4s, v0.4s\n"
+ "170:" // Height 5: No activation
+ "cmp x14, #0x10\n"
+ "bge 179f\n"
+ "tbz x14, #3, 174f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v20.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v17.4s }, [x24], #0x10\n"
+ "st1 { v24.4s }, [x23], #0x10\n"
+ "st1 { v25.4s }, [x23], #0x10\n"
+ "tbz x14, #2, 172f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "st1 { v21.4s }, [x25], #0x10\n"
+ "st1 { v18.4s }, [x24], #0x10\n"
+ "st1 { v26.4s }, [x23], #0x10\n"
+ "tbz x14, #1, 171f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "str d22, [x25], #0x8\n"
+ "str d19, [x24], #0x8\n"
+ "str d27, [x23], #0x8\n"
+ "tbz x14, #0, 178f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "st1 { v22.s }[2], [x25]\n"
+ "st1 { v19.s }[2], [x24]\n"
+ "st1 { v27.s }[2], [x23]\n"
+ "b 178f\n"
+ "171:" // Height 5: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 178f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "str s22, [x25, #0x0]\n"
+ "str s19, [x24, #0x0]\n"
+ "str s27, [x23, #0x0]\n"
+ "b 178f\n"
+ "172:" // Height 5: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 173f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "str d21, [x25], #0x8\n"
+ "str d18, [x24], #0x8\n"
+ "str d26, [x23], #0x8\n"
+ "tbz x14, #0, 178f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "st1 { v21.s }[2], [x25]\n"
+ "st1 { v18.s }[2], [x24]\n"
+ "st1 { v26.s }[2], [x23]\n"
+ "b 178f\n"
+ "173:" // Height 5: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 178f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "str s21, [x25, #0x0]\n"
+ "str s18, [x24, #0x0]\n"
+ "str s26, [x23, #0x0]\n"
+ "b 178f\n"
+ "174:" // Height 5: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 176f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v24.4s }, [x23], #0x10\n"
+ "tbz x14, #1, 175f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "str d20, [x25], #0x8\n"
+ "str d17, [x24], #0x8\n"
+ "str d25, [x23], #0x8\n"
+ "tbz x14, #0, 178f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "st1 { v20.s }[2], [x25]\n"
+ "st1 { v17.s }[2], [x24]\n"
+ "st1 { v25.s }[2], [x23]\n"
+ "b 178f\n"
+ "175:" // Height 5: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 178f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "str s20, [x25, #0x0]\n"
+ "str s17, [x24, #0x0]\n"
+ "str s25, [x23, #0x0]\n"
+ "b 178f\n"
+ "176:" // Height 5: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 177f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "str d15, [x25], #0x8\n"
+ "str d16, [x24], #0x8\n"
+ "str d24, [x23], #0x8\n"
+ "tbz x14, #0, 178f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "st1 { v15.s }[2], [x25]\n"
+ "st1 { v16.s }[2], [x24]\n"
+ "st1 { v24.s }[2], [x23]\n"
+ "b 178f\n"
+ "177:" // Height 5: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "str s15, [x25, #0x0]\n"
+ "str s16, [x24, #0x0]\n"
+ "str s24, [x23, #0x0]\n"
+ "178:" // Height 5: Partial direct writeback: Done
+ "b 180f\n"
+ "179:" // Height 5: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "str q15, [x25, #0x0]\n"
+ "str q20, [x25, #0x10]\n"
+ "str q21, [x25, #0x20]\n"
+ "str q22, [x25, #0x30]\n"
+ "str q16, [x24, #0x0]\n"
+ "str q17, [x24, #0x10]\n"
+ "str q18, [x24, #0x20]\n"
+ "str q19, [x24, #0x30]\n"
+ "str q24, [x23, #0x0]\n"
+ "str q25, [x23, #0x10]\n"
+ "str q26, [x23, #0x20]\n"
+ "str q27, [x23, #0x30]\n"
+ "180:" // Height 5: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 146b\n"
+ "b 218f\n"
+ "181:" // Height 6
+ "ldr x20, [%x[args_ptr], %[offsetof_B_ptr]]\n"
+ "ldr x15, [%x[args_ptr], %[offsetof_bias]]\n"
+ "mov x21, #0x18\n"
+ "ldr x14, [%x[args_ptr], %[offsetof_N]]\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "ldr x13, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "madd x21, x20, x21, x13\n"
+ "str x21, [%x[args_ptr], %[offsetof_output_ptr]]\n"
+ "182:" // Height 6: Column loop
+ "ldr x12, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_B_stride]]\n"
+ "cmp x14, #0xc\n"
+ "add x11, x12, x20, LSL #1\n"
+ "add x10, x11, x20, LSL #1\n"
+ "add x9, x10, x20, LSL #1\n"
+ "add x20, x9, x20, LSL #1\n"
+ "str x20, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n"
+ "bgt 183f\n"
+ "cmp x14, #0x8\n"
+ "mov x9, x12\n"
+ "bgt 183f\n"
+ "cmp x14, #0x4\n"
+ "mov x10, x12\n"
+ "bgt 183f\n"
+ "mov x11, x12\n"
+ "183:" // Height 6: B setup done
+ "cbz x15, 184f\n"
+ "ldr q8, [x15, #0x0]\n"
+ "ldr q9, [x15, #0x10]\n"
+ "ldr q10, [x15, #0x20]\n"
+ "ldr q11, [x15, #0x30]\n"
+ "add x15, x15, #0x40\n"
+ "zip2 v12.2d, v8.2d, v8.2d\n"
+ "zip1 v8.2d, v8.2d, v8.2d\n"
+ "zip2 v13.2d, v9.2d, v9.2d\n"
+ "zip1 v9.2d, v9.2d, v9.2d\n"
+ "zip2 v14.2d, v10.2d, v10.2d\n"
+ "zip1 v10.2d, v10.2d, v10.2d\n"
+ "zip2 v15.2d, v11.2d, v11.2d\n"
+ "zip1 v11.2d, v11.2d, v11.2d\n"
+ "mov v16.16b, v8.16b\n"
+ "mov v20.16b, v12.16b\n"
+ "mov v17.16b, v9.16b\n"
+ "mov v21.16b, v13.16b\n"
+ "mov v18.16b, v10.16b\n"
+ "mov v22.16b, v14.16b\n"
+ "mov v19.16b, v11.16b\n"
+ "mov v23.16b, v15.16b\n"
+ "mov v24.16b, v8.16b\n"
+ "mov v28.16b, v12.16b\n"
+ "mov v25.16b, v9.16b\n"
+ "mov v29.16b, v13.16b\n"
+ "mov v26.16b, v10.16b\n"
+ "mov v30.16b, v14.16b\n"
+ "mov v27.16b, v11.16b\n"
+ "mov v31.16b, v15.16b\n"
+ "b 196f\n"
+ "184:" // Height 6: no bias
+ "tbz %x[flags], #0, 195f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "cmp x14, #0x10\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "add x23, x24, x20, LSL #2\n"
+ "add x22, x23, x20, LSL #2\n"
+ "bge 193f\n"
+ "tbz x14, #3, 188f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v25.4s }, [x23], #0x10\n"
+ "ld1 { v28.4s }, [x22], #0x10\n"
+ "ld1 { v10.4s }, [x13], #0x10\n"
+ "ld1 { v13.4s }, [x26], #0x10\n"
+ "ld1 { v18.4s }, [x25], #0x10\n"
+ "ld1 { v21.4s }, [x24], #0x10\n"
+ "ld1 { v26.4s }, [x23], #0x10\n"
+ "ld1 { v29.4s }, [x22], #0x10\n"
+ "tbz x14, #2, 186f\n"
+ "ld1 { v11.4s }, [x13], #0x10\n"
+ "ld1 { v14.4s }, [x26], #0x10\n"
+ "ld1 { v19.4s }, [x25], #0x10\n"
+ "ld1 { v22.4s }, [x24], #0x10\n"
+ "ld1 { v27.4s }, [x23], #0x10\n"
+ "ld1 { v30.4s }, [x22], #0x10\n"
+ "tbz x14, #1, 185f\n"
+ "ldr d16, [x13], #0x8\n"
+ "ldr d15, [x26], #0x8\n"
+ "mov x20, #0x38\n"
+ "ldr d24, [x25], #0x8\n"
+ "ldr d23, [x24], #0x8\n"
+ "ldr d6, [x23], #0x8\n"
+ "ldr d31, [x22], #0x8\n"
+ "tbz x14, #0, 192f\n"
+ "ld1 { v16.s }[2], [x13]\n"
+ "ld1 { v15.s }[2], [x26]\n"
+ "ld1 { v24.s }[2], [x25]\n"
+ "ld1 { v23.s }[2], [x24]\n"
+ "ld1 { v6.s }[2], [x23]\n"
+ "ld1 { v31.s }[2], [x22]\n"
+ "b 192f\n"
+ "185:" // Height 6: Partial accumulate: partial_1_12
+ "mov x20, #0x30\n"
+ "tbz x14, #0, 192f\n"
+ "ldr s16, [x13, #0x0]\n"
+ "ldr s15, [x26, #0x0]\n"
+ "ldr s24, [x25, #0x0]\n"
+ "ldr s23, [x24, #0x0]\n"
+ "ldr s6, [x23, #0x0]\n"
+ "ldr s31, [x22, #0x0]\n"
+ "b 192f\n"
+ "186:" // Height 6: Partial accumulate: partial_2_8
+ "tbz x14, #1, 187f\n"
+ "ldr d11, [x13], #0x8\n"
+ "ldr d14, [x26], #0x8\n"
+ "mov x20, #0x28\n"
+ "ldr d19, [x25], #0x8\n"
+ "ldr d22, [x24], #0x8\n"
+ "ldr d27, [x23], #0x8\n"
+ "ldr d30, [x22], #0x8\n"
+ "tbz x14, #0, 192f\n"
+ "ld1 { v11.s }[2], [x13]\n"
+ "ld1 { v14.s }[2], [x26]\n"
+ "ld1 { v19.s }[2], [x25]\n"
+ "ld1 { v22.s }[2], [x24]\n"
+ "ld1 { v27.s }[2], [x23]\n"
+ "ld1 { v30.s }[2], [x22]\n"
+ "b 192f\n"
+ "187:" // Height 6: Partial accumulate: partial_1_8
+ "mov x20, #0x20\n"
+ "tbz x14, #0, 192f\n"
+ "ldr s11, [x13, #0x0]\n"
+ "ldr s14, [x26, #0x0]\n"
+ "ldr s19, [x25, #0x0]\n"
+ "ldr s22, [x24, #0x0]\n"
+ "ldr s27, [x23, #0x0]\n"
+ "ldr s30, [x22, #0x0]\n"
+ "b 192f\n"
+ "188:" // Height 6: Partial accumulate: partial_4_0
+ "tbz x14, #2, 190f\n"
+ "ld1 { v9.4s }, [x13], #0x10\n"
+ "ld1 { v12.4s }, [x26], #0x10\n"
+ "ld1 { v17.4s }, [x25], #0x10\n"
+ "ld1 { v20.4s }, [x24], #0x10\n"
+ "ld1 { v25.4s }, [x23], #0x10\n"
+ "ld1 { v28.4s }, [x22], #0x10\n"
+ "tbz x14, #1, 189f\n"
+ "ldr d10, [x13], #0x8\n"
+ "ldr d13, [x26], #0x8\n"
+ "mov x20, #0x18\n"
+ "ldr d18, [x25], #0x8\n"
+ "ldr d21, [x24], #0x8\n"
+ "ldr d26, [x23], #0x8\n"
+ "ldr d29, [x22], #0x8\n"
+ "tbz x14, #0, 192f\n"
+ "ld1 { v10.s }[2], [x13]\n"
+ "ld1 { v13.s }[2], [x26]\n"
+ "ld1 { v18.s }[2], [x25]\n"
+ "ld1 { v21.s }[2], [x24]\n"
+ "ld1 { v26.s }[2], [x23]\n"
+ "ld1 { v29.s }[2], [x22]\n"
+ "b 192f\n"
+ "189:" // Height 6: Partial accumulate: partial_1_4
+ "mov x20, #0x10\n"
+ "tbz x14, #0, 192f\n"
+ "ldr s10, [x13, #0x0]\n"
+ "ldr s13, [x26, #0x0]\n"
+ "ldr s18, [x25, #0x0]\n"
+ "ldr s21, [x24, #0x0]\n"
+ "ldr s26, [x23, #0x0]\n"
+ "ldr s29, [x22, #0x0]\n"
+ "b 192f\n"
+ "190:" // Height 6: Partial accumulate: partial_2_0
+ "tbz x14, #1, 191f\n"
+ "ldr d9, [x13], #0x8\n"
+ "ldr d12, [x26], #0x8\n"
+ "mov x20, #0x8\n"
+ "ldr d17, [x25], #0x8\n"
+ "ldr d20, [x24], #0x8\n"
+ "ldr d25, [x23], #0x8\n"
+ "ldr d28, [x22], #0x8\n"
+ "tbz x14, #0, 192f\n"
+ "ld1 { v9.s }[2], [x13]\n"
+ "ld1 { v12.s }[2], [x26]\n"
+ "ld1 { v17.s }[2], [x25]\n"
+ "ld1 { v20.s }[2], [x24]\n"
+ "ld1 { v25.s }[2], [x23]\n"
+ "ld1 { v28.s }[2], [x22]\n"
+ "b 192f\n"
+ "191:" // Height 6: Partial accumulate: partial_1_0
+ "ldr s9, [x13, #0x0]\n"
+ "ldr s12, [x26, #0x0]\n"
+ "mov x20, #0x0\n"
+ "ldr s17, [x25, #0x0]\n"
+ "ldr s20, [x24, #0x0]\n"
+ "ldr s25, [x23, #0x0]\n"
+ "ldr s28, [x22, #0x0]\n"
+ "192:" // Height 6: Partial accumulate: Done
+ "sub x13, x13, x20\n"
+ "b 194f\n"
+ "193:" // Height 6: full accumulate
+ "ldr q9, [x13, #0x0]\n"
+ "ldr q10, [x13, #0x10]\n"
+ "ldr q11, [x13, #0x20]\n"
+ "ldr q16, [x13, #0x30]\n"
+ "ldr q12, [x26, #0x0]\n"
+ "ldr q13, [x26, #0x10]\n"
+ "ldr q14, [x26, #0x20]\n"
+ "ldr q15, [x26, #0x30]\n"
+ "ldr q17, [x25, #0x0]\n"
+ "ldr q18, [x25, #0x10]\n"
+ "ldr q19, [x25, #0x20]\n"
+ "ldr q24, [x25, #0x30]\n"
+ "ldr q20, [x24, #0x0]\n"
+ "ldr q21, [x24, #0x10]\n"
+ "ldr q22, [x24, #0x20]\n"
+ "ldr q23, [x24, #0x30]\n"
+ "ldr q25, [x23, #0x0]\n"
+ "ldr q26, [x23, #0x10]\n"
+ "ldr q27, [x23, #0x20]\n"
+ "ldr q6, [x23, #0x30]\n"
+ "ldr q28, [x22, #0x0]\n"
+ "ldr q29, [x22, #0x10]\n"
+ "ldr q30, [x22, #0x20]\n"
+ "ldr q31, [x22, #0x30]\n"
+ "194:" // Height 6: MMLA fixup
+ "zip1 v8.2d, v9.2d, v12.2d\n"
+ "zip2 v12.2d, v9.2d, v12.2d\n"
+ "zip1 v9.2d, v10.2d, v13.2d\n"
+ "zip2 v13.2d, v10.2d, v13.2d\n"
+ "zip1 v10.2d, v11.2d, v14.2d\n"
+ "zip2 v14.2d, v11.2d, v14.2d\n"
+ "zip1 v11.2d, v16.2d, v15.2d\n"
+ "zip2 v15.2d, v16.2d, v15.2d\n"
+ "zip1 v16.2d, v17.2d, v20.2d\n"
+ "zip2 v20.2d, v17.2d, v20.2d\n"
+ "zip1 v17.2d, v18.2d, v21.2d\n"
+ "zip2 v21.2d, v18.2d, v21.2d\n"
+ "zip1 v18.2d, v19.2d, v22.2d\n"
+ "zip2 v22.2d, v19.2d, v22.2d\n"
+ "zip1 v19.2d, v24.2d, v23.2d\n"
+ "zip2 v23.2d, v24.2d, v23.2d\n"
+ "zip1 v24.2d, v25.2d, v28.2d\n"
+ "zip2 v28.2d, v25.2d, v28.2d\n"
+ "zip1 v25.2d, v26.2d, v29.2d\n"
+ "zip2 v29.2d, v26.2d, v29.2d\n"
+ "zip1 v26.2d, v27.2d, v30.2d\n"
+ "zip2 v30.2d, v27.2d, v30.2d\n"
+ "zip1 v27.2d, v6.2d, v31.2d\n"
+ "zip2 v31.2d, v6.2d, v31.2d\n"
+ "b 196f\n"
+ "195:" // Height 6: no accumulate
+ "movi v8.16b, #0x0\n"
+ "movi v9.16b, #0x0\n"
+ "movi v10.16b, #0x0\n"
+ "movi v11.16b, #0x0\n"
+ "movi v12.16b, #0x0\n"
+ "movi v13.16b, #0x0\n"
+ "movi v14.16b, #0x0\n"
+ "movi v15.16b, #0x0\n"
+ "movi v16.16b, #0x0\n"
+ "movi v17.16b, #0x0\n"
+ "movi v18.16b, #0x0\n"
+ "movi v19.16b, #0x0\n"
+ "movi v20.16b, #0x0\n"
+ "movi v21.16b, #0x0\n"
+ "movi v22.16b, #0x0\n"
+ "movi v23.16b, #0x0\n"
+ "movi v24.16b, #0x0\n"
+ "movi v25.16b, #0x0\n"
+ "movi v26.16b, #0x0\n"
+ "movi v27.16b, #0x0\n"
+ "movi v28.16b, #0x0\n"
+ "movi v29.16b, #0x0\n"
+ "movi v30.16b, #0x0\n"
+ "movi v31.16b, #0x0\n"
+ "196:" // Height 6: setup done
+ "mov x28, #0x0\n"
+ "197:" // Height 6: String loop
+ "ldr x20, [%x[args_ptr], %[offsetof_string_lengths]]\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "ldr w27, [x20, x28, LSL #0x2]\n"
+ "tbz %x[flags], #3, 198f\n"
+ "ldr x20, [%x[input_ptr], x28, LSL #0x3]\n"
+ "add x20, x20, x21, LSL #3\n"
+ "ldr x26, [x20, #0x0]\n"
+ "ldr x25, [x20, #0x8]\n"
+ "ldr x24, [x20, #0x10]\n"
+ "ldr x23, [x20, #0x18]\n"
+ "ldr x22, [x20, #0x20]\n"
+ "ldr x21, [x20, #0x28]\n"
+ "cbnz x28, 199f\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_input_initial_col]]\n"
+ "add x26, x26, x20, LSL #2\n"
+ "add x25, x25, x20, LSL #2\n"
+ "add x24, x24, x20, LSL #2\n"
+ "add x23, x23, x20, LSL #2\n"
+ "add x22, x22, x20, LSL #2\n"
+ "add x21, x21, x20, LSL #2\n"
+ "b 199f\n"
+ "198:" // Height 6: setup direct input
+ "mov x26, %x[input_ptr]\n"
+ "add x25, x26, x21, LSL #2\n"
+ "add x24, x25, x21, LSL #2\n"
+ "add x23, x24, x21, LSL #2\n"
+ "add x22, x23, x21, LSL #2\n"
+ "add x21, x22, x21, LSL #2\n"
+ "199:" // Height 6: input setup done
+ "cmp x27, #0x4\n"
+ "blt 202f\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ "cmp x27, #0x8\n"
+ "ld1 { v4.4s }, [x22], #0x10\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ "ld1 { v5.4s }, [x21], #0x10\n"
+ "ldr q6, [x12, #0x0]\n"
+ "ldr q7, [x12, #0x10]\n"
+ "blt 201f\n"
+ "200:" // Height 6: Multiply loop: Main loop head
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ "cmp x27, #0x8\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ "ld1 { v1.4s }, [x25], #0x10\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ "ld1 { v3.4s }, [x23], #0x10\n"
+ ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n"
+ "ld1 { v5.4s }, [x21], #0x10\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x11, #0x0]\n"
+ ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n"
+ "ldr q7, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x10, #0x0]\n"
+ ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e47ec9d // bfmmla v29.4s, v4.8h, v7.8h\n"
+ "ldr q7, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x9, #0x0]\n"
+ ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e47ec9e // bfmmla v30.4s, v4.8h, v7.8h\n"
+ "ldr q7, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n"
+ "ldr q6, [x12, #0x0]\n"
+ ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n"
+ "ld1 { v0.4s }, [x26], #0x10\n"
+ ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n"
+ "ld1 { v2.4s }, [x24], #0x10\n"
+ ".inst 0x6e47ec9f // bfmmla v31.4s, v4.8h, v7.8h\n"
+ "ld1 { v4.4s }, [x22], #0x10\n"
+ "ldr q7, [x12, #0x10]\n"
+ "bge 200b\n"
+ "201:" // Height 6: Multiply loop: Single iteration only
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ "sub x27, x27, #0x4\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n"
+ ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n"
+ "ldr q3, [x11, #0x0]\n"
+ ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n"
+ "ldr q1, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x10, #0x0]\n"
+ ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x9, #0x0]\n"
+ ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n"
+ ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n"
+ "202:" // Height 6: Multiply loop: Main loop skip
+ "cbz x27, 205f\n"
+ "cbz x27, 205f\n"
+ "tbz x27, #1, 203f\n"
+ "ldr d0, [x26], #0x8\n"
+ "ldr d1, [x25], #0x8\n"
+ "ldr d2, [x24], #0x8\n"
+ "ldr d3, [x23], #0x8\n"
+ "ldr d4, [x22], #0x8\n"
+ "ldr d5, [x21], #0x8\n"
+ "tbz x27, #0, 204f\n"
+ "ld1 { v0.s }[2], [x26]\n"
+ "ld1 { v1.s }[2], [x25]\n"
+ "ld1 { v2.s }[2], [x24]\n"
+ "ld1 { v3.s }[2], [x23]\n"
+ "ld1 { v4.s }[2], [x22]\n"
+ "ld1 { v5.s }[2], [x21]\n"
+ "b 204f\n"
+ "203:" // Height 6: Multiply loop: Ragged operand read: partial_1_0
+ "ldr s0, [x26, #0x0]\n"
+ "ldr s1, [x25, #0x0]\n"
+ "ldr s2, [x24, #0x0]\n"
+ "ldr s3, [x23, #0x0]\n"
+ "ldr s4, [x22, #0x0]\n"
+ "ldr s5, [x21, #0x0]\n"
+ "204:" // Height 6: Multiply loop: Ragged operand read: Done
+ "ldr q7, [x12, #0x0]\n"
+ "ldr q6, [x12, #0x10]\n"
+ ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n"
+ ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n"
+ ".inst 0x0ea16884 // bfcvtn v4.4h, v4.4s\n"
+ "add x12, x12, #0x20\n"
+ ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n"
+ ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n"
+ ".inst 0x4ea168a4 // bfcvtn2 v4.8h, v5.4s\n"
+ ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n"
+ ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n"
+ ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n"
+ ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n"
+ "ldr q3, [x11, #0x0]\n"
+ ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n"
+ ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n"
+ "ldr q1, [x11, #0x10]\n"
+ "add x11, x11, #0x20\n"
+ ".inst 0x6e43ec09 // bfmmla v9.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec51 // bfmmla v17.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x10, #0x0]\n"
+ ".inst 0x6e41ec0d // bfmmla v13.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9d // bfmmla v29.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x10, #0x10]\n"
+ "add x10, x10, #0x20\n"
+ ".inst 0x6e43ec0a // bfmmla v10.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec52 // bfmmla v18.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9a // bfmmla v26.4s, v4.8h, v3.8h\n"
+ "ldr q3, [x9, #0x0]\n"
+ ".inst 0x6e41ec0e // bfmmla v14.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec56 // bfmmla v22.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9e // bfmmla v30.4s, v4.8h, v1.8h\n"
+ "ldr q1, [x9, #0x10]\n"
+ "add x9, x9, #0x20\n"
+ ".inst 0x6e43ec0b // bfmmla v11.4s, v0.8h, v3.8h\n"
+ ".inst 0x6e43ec53 // bfmmla v19.4s, v2.8h, v3.8h\n"
+ ".inst 0x6e43ec9b // bfmmla v27.4s, v4.8h, v3.8h\n"
+ ".inst 0x6e41ec0f // bfmmla v15.4s, v0.8h, v1.8h\n"
+ ".inst 0x6e41ec57 // bfmmla v23.4s, v2.8h, v1.8h\n"
+ ".inst 0x6e41ec9f // bfmmla v31.4s, v4.8h, v1.8h\n"
+ "205:" // Height 6: Multiply loop: No odd multiplies
+ "ldr w20, [%x[args_ptr], %[offsetof_num_strings]]\n"
+ "add x28, x28, #0x1\n"
+ "cmp x28, x20\n"
+ "bne 197b\n"
+ "ldr x20, [%x[args_ptr], %[offsetof_output_offset]]\n"
+ "uzp1 v6.2d, v8.2d, v12.2d\n"
+ "uzp2 v8.2d, v8.2d, v12.2d\n"
+ "uzp1 v12.2d, v9.2d, v13.2d\n"
+ "uzp2 v9.2d, v9.2d, v13.2d\n"
+ "uzp1 v13.2d, v10.2d, v14.2d\n"
+ "uzp2 v10.2d, v10.2d, v14.2d\n"
+ "add x26, x13, x20, LSL #2\n"
+ "add x25, x26, x20, LSL #2\n"
+ "add x24, x25, x20, LSL #2\n"
+ "uzp1 v14.2d, v11.2d, v15.2d\n"
+ "uzp2 v11.2d, v11.2d, v15.2d\n"
+ "add x23, x24, x20, LSL #2\n"
+ "uzp1 v15.2d, v16.2d, v20.2d\n"
+ "uzp2 v16.2d, v16.2d, v20.2d\n"
+ "add x22, x23, x20, LSL #2\n"
+ "uzp1 v20.2d, v17.2d, v21.2d\n"
+ "uzp2 v17.2d, v17.2d, v21.2d\n"
+ "uzp1 v21.2d, v18.2d, v22.2d\n"
+ "uzp2 v18.2d, v18.2d, v22.2d\n"
+ "uzp1 v22.2d, v19.2d, v23.2d\n"
+ "uzp2 v19.2d, v19.2d, v23.2d\n"
+ "uzp1 v23.2d, v24.2d, v28.2d\n"
+ "uzp2 v24.2d, v24.2d, v28.2d\n"
+ "uzp1 v28.2d, v25.2d, v29.2d\n"
+ "uzp2 v25.2d, v25.2d, v29.2d\n"
+ "uzp1 v29.2d, v26.2d, v30.2d\n"
+ "uzp2 v26.2d, v26.2d, v30.2d\n"
+ "uzp1 v30.2d, v27.2d, v31.2d\n"
+ "uzp2 v27.2d, v27.2d, v31.2d\n"
+ "tbz %x[flags], #1, 206f\n"
+ "add x21, %x[args_ptr], %[offset_max]\n"
+ "add x20, %x[args_ptr], %[offset_min]\n"
+ "ld1r { v1.4s }, [x21]\n"
+ "ld1r { v0.4s }, [x20]\n"
+ "fmin v6.4s, v6.4s, v1.4s\n"
+ "fmin v12.4s, v12.4s, v1.4s\n"
+ "fmin v13.4s, v13.4s, v1.4s\n"
+ "fmin v14.4s, v14.4s, v1.4s\n"
+ "fmin v8.4s, v8.4s, v1.4s\n"
+ "fmin v9.4s, v9.4s, v1.4s\n"
+ "fmin v10.4s, v10.4s, v1.4s\n"
+ "fmin v11.4s, v11.4s, v1.4s\n"
+ "fmin v15.4s, v15.4s, v1.4s\n"
+ "fmin v20.4s, v20.4s, v1.4s\n"
+ "fmin v21.4s, v21.4s, v1.4s\n"
+ "fmin v22.4s, v22.4s, v1.4s\n"
+ "fmin v16.4s, v16.4s, v1.4s\n"
+ "fmin v17.4s, v17.4s, v1.4s\n"
+ "fmin v18.4s, v18.4s, v1.4s\n"
+ "fmin v19.4s, v19.4s, v1.4s\n"
+ "fmin v23.4s, v23.4s, v1.4s\n"
+ "fmin v28.4s, v28.4s, v1.4s\n"
+ "fmin v29.4s, v29.4s, v1.4s\n"
+ "fmin v30.4s, v30.4s, v1.4s\n"
+ "fmin v24.4s, v24.4s, v1.4s\n"
+ "fmin v25.4s, v25.4s, v1.4s\n"
+ "fmin v26.4s, v26.4s, v1.4s\n"
+ "fmin v27.4s, v27.4s, v1.4s\n"
+ "fmax v6.4s, v6.4s, v0.4s\n"
+ "fmax v12.4s, v12.4s, v0.4s\n"
+ "fmax v13.4s, v13.4s, v0.4s\n"
+ "fmax v14.4s, v14.4s, v0.4s\n"
+ "fmax v8.4s, v8.4s, v0.4s\n"
+ "fmax v9.4s, v9.4s, v0.4s\n"
+ "fmax v10.4s, v10.4s, v0.4s\n"
+ "fmax v11.4s, v11.4s, v0.4s\n"
+ "fmax v15.4s, v15.4s, v0.4s\n"
+ "fmax v20.4s, v20.4s, v0.4s\n"
+ "fmax v21.4s, v21.4s, v0.4s\n"
+ "fmax v22.4s, v22.4s, v0.4s\n"
+ "fmax v16.4s, v16.4s, v0.4s\n"
+ "fmax v17.4s, v17.4s, v0.4s\n"
+ "fmax v18.4s, v18.4s, v0.4s\n"
+ "fmax v19.4s, v19.4s, v0.4s\n"
+ "fmax v23.4s, v23.4s, v0.4s\n"
+ "fmax v28.4s, v28.4s, v0.4s\n"
+ "fmax v29.4s, v29.4s, v0.4s\n"
+ "fmax v30.4s, v30.4s, v0.4s\n"
+ "fmax v24.4s, v24.4s, v0.4s\n"
+ "fmax v25.4s, v25.4s, v0.4s\n"
+ "fmax v26.4s, v26.4s, v0.4s\n"
+ "fmax v27.4s, v27.4s, v0.4s\n"
+ "206:" // Height 6: No activation
+ "cmp x14, #0x10\n"
+ "bge 215f\n"
+ "tbz x14, #3, 210f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v12.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v9.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v20.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v17.4s }, [x24], #0x10\n"
+ "st1 { v23.4s }, [x23], #0x10\n"
+ "st1 { v28.4s }, [x23], #0x10\n"
+ "st1 { v24.4s }, [x22], #0x10\n"
+ "st1 { v25.4s }, [x22], #0x10\n"
+ "tbz x14, #2, 208f\n"
+ "st1 { v13.4s }, [x13], #0x10\n"
+ "st1 { v10.4s }, [x26], #0x10\n"
+ "st1 { v21.4s }, [x25], #0x10\n"
+ "st1 { v18.4s }, [x24], #0x10\n"
+ "st1 { v29.4s }, [x23], #0x10\n"
+ "st1 { v26.4s }, [x22], #0x10\n"
+ "tbz x14, #1, 207f\n"
+ "str d14, [x13], #0x8\n"
+ "str d11, [x26], #0x8\n"
+ "str d22, [x25], #0x8\n"
+ "str d19, [x24], #0x8\n"
+ "str d30, [x23], #0x8\n"
+ "str d27, [x22], #0x8\n"
+ "tbz x14, #0, 214f\n"
+ "st1 { v14.s }[2], [x13]\n"
+ "st1 { v11.s }[2], [x26]\n"
+ "st1 { v22.s }[2], [x25]\n"
+ "st1 { v19.s }[2], [x24]\n"
+ "st1 { v30.s }[2], [x23]\n"
+ "st1 { v27.s }[2], [x22]\n"
+ "b 214f\n"
+ "207:" // Height 6: Partial direct writeback: partial_1_12
+ "tbz x14, #0, 214f\n"
+ "str s14, [x13, #0x0]\n"
+ "str s11, [x26, #0x0]\n"
+ "str s22, [x25, #0x0]\n"
+ "str s19, [x24, #0x0]\n"
+ "str s30, [x23, #0x0]\n"
+ "str s27, [x22, #0x0]\n"
+ "b 214f\n"
+ "208:" // Height 6: Partial direct writeback: partial_2_8
+ "tbz x14, #1, 209f\n"
+ "str d13, [x13], #0x8\n"
+ "str d10, [x26], #0x8\n"
+ "str d21, [x25], #0x8\n"
+ "str d18, [x24], #0x8\n"
+ "str d29, [x23], #0x8\n"
+ "str d26, [x22], #0x8\n"
+ "tbz x14, #0, 214f\n"
+ "st1 { v13.s }[2], [x13]\n"
+ "st1 { v10.s }[2], [x26]\n"
+ "st1 { v21.s }[2], [x25]\n"
+ "st1 { v18.s }[2], [x24]\n"
+ "st1 { v29.s }[2], [x23]\n"
+ "st1 { v26.s }[2], [x22]\n"
+ "b 214f\n"
+ "209:" // Height 6: Partial direct writeback: partial_1_8
+ "tbz x14, #0, 214f\n"
+ "str s13, [x13, #0x0]\n"
+ "str s10, [x26, #0x0]\n"
+ "str s21, [x25, #0x0]\n"
+ "str s18, [x24, #0x0]\n"
+ "str s29, [x23, #0x0]\n"
+ "str s26, [x22, #0x0]\n"
+ "b 214f\n"
+ "210:" // Height 6: Partial direct writeback: partial_4_0
+ "tbz x14, #2, 212f\n"
+ "st1 { v6.4s }, [x13], #0x10\n"
+ "st1 { v8.4s }, [x26], #0x10\n"
+ "st1 { v15.4s }, [x25], #0x10\n"
+ "st1 { v16.4s }, [x24], #0x10\n"
+ "st1 { v23.4s }, [x23], #0x10\n"
+ "st1 { v24.4s }, [x22], #0x10\n"
+ "tbz x14, #1, 211f\n"
+ "str d12, [x13], #0x8\n"
+ "str d9, [x26], #0x8\n"
+ "str d20, [x25], #0x8\n"
+ "str d17, [x24], #0x8\n"
+ "str d28, [x23], #0x8\n"
+ "str d25, [x22], #0x8\n"
+ "tbz x14, #0, 214f\n"
+ "st1 { v12.s }[2], [x13]\n"
+ "st1 { v9.s }[2], [x26]\n"
+ "st1 { v20.s }[2], [x25]\n"
+ "st1 { v17.s }[2], [x24]\n"
+ "st1 { v28.s }[2], [x23]\n"
+ "st1 { v25.s }[2], [x22]\n"
+ "b 214f\n"
+ "211:" // Height 6: Partial direct writeback: partial_1_4
+ "tbz x14, #0, 214f\n"
+ "str s12, [x13, #0x0]\n"
+ "str s9, [x26, #0x0]\n"
+ "str s20, [x25, #0x0]\n"
+ "str s17, [x24, #0x0]\n"
+ "str s28, [x23, #0x0]\n"
+ "str s25, [x22, #0x0]\n"
+ "b 214f\n"
+ "212:" // Height 6: Partial direct writeback: partial_2_0
+ "tbz x14, #1, 213f\n"
+ "str d6, [x13], #0x8\n"
+ "str d8, [x26], #0x8\n"
+ "str d15, [x25], #0x8\n"
+ "str d16, [x24], #0x8\n"
+ "str d23, [x23], #0x8\n"
+ "str d24, [x22], #0x8\n"
+ "tbz x14, #0, 214f\n"
+ "st1 { v6.s }[2], [x13]\n"
+ "st1 { v8.s }[2], [x26]\n"
+ "st1 { v15.s }[2], [x25]\n"
+ "st1 { v16.s }[2], [x24]\n"
+ "st1 { v23.s }[2], [x23]\n"
+ "st1 { v24.s }[2], [x22]\n"
+ "b 214f\n"
+ "213:" // Height 6: Partial direct writeback: partial_1_0
+ "str s6, [x13, #0x0]\n"
+ "str s8, [x26, #0x0]\n"
+ "str s15, [x25, #0x0]\n"
+ "str s16, [x24, #0x0]\n"
+ "str s23, [x23, #0x0]\n"
+ "str s24, [x22, #0x0]\n"
+ "214:" // Height 6: Partial direct writeback: Done
+ "b 216f\n"
+ "215:" // Height 6: Full writeback
+ "str q6, [x13, #0x0]\n"
+ "str q12, [x13, #0x10]\n"
+ "str q13, [x13, #0x20]\n"
+ "str q14, [x13, #0x30]\n"
+ "add x13, x13, #0x40\n"
+ "str q8, [x26, #0x0]\n"
+ "str q9, [x26, #0x10]\n"
+ "str q10, [x26, #0x20]\n"
+ "str q11, [x26, #0x30]\n"
+ "str q15, [x25, #0x0]\n"
+ "str q20, [x25, #0x10]\n"
+ "str q21, [x25, #0x20]\n"
+ "str q22, [x25, #0x30]\n"
+ "str q16, [x24, #0x0]\n"
+ "str q17, [x24, #0x10]\n"
+ "str q18, [x24, #0x20]\n"
+ "str q19, [x24, #0x30]\n"
+ "str q23, [x23, #0x0]\n"
+ "str q28, [x23, #0x10]\n"
+ "str q29, [x23, #0x20]\n"
+ "str q30, [x23, #0x30]\n"
+ "str q24, [x22, #0x0]\n"
+ "str q25, [x22, #0x10]\n"
+ "str q26, [x22, #0x20]\n"
+ "str q27, [x22, #0x30]\n"
+ "216:" // Height 6: Writeback done
+ "subs x14, x14, #0x10\n"
+ "bgt 182b\n"
+ "subs %x[M], %x[M], #0x6\n"
+ "beq 218f\n"
+ "ldr x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "tbz %x[flags], #3, 217f\n"
+ "add x21, x21, #0x6\n"
+ "str x21, [%x[args_ptr], %[offsetof_input_offset]]\n"
+ "b 1b\n"
+ "217:" // Update direct input
+ "mov x20, #0x18\n"
+ "madd %x[input_ptr], x20, x21, %x[input_ptr]\n"
+ "b 1b\n"
+ "218:" // Exit
+ : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr)
+ : [args_ptr] "r" (&ka), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_output_ptr] "I" (offsetof(KernelArgs, output_ptr)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths))
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+ );
+}
+
+} // namespace arm_gemm
+#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp
index cf4d74266a..1a8b0fd630 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,10 @@ public:
if (std::is_same<T, float>::value) {
switch (ci->get_cpu_model()) {
+ case CPUModel::V1:
+ return { 45.25, 4.29, 4.80 };
default:
- return { 38.10, 5.23, 3.15 };
+ return { 29.85, 2.60, 5.49 };
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp
new file mode 100644
index 0000000000..7792192856
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include <cstdint>
+#include "../std_transforms_sme.hpp"
+
+namespace arm_gemm
+{
+
+// Implementations
+void sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+class cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL
+{
+public:
+ typedef int8_t operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+ /* Kernel blocking parameters */
+ static unsigned int out_height()
+ {
+ return sme::get_vector_length<int32_t>() * 1;
+ }
+
+ static unsigned int out_width()
+ {
+ return sme::get_vector_length<int32_t>() * 4;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 4;
+ }
+
+ static constexpr bool supports_accumulate()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_bias()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_activation()
+ {
+ return true;
+ }
+
+ static constexpr bool is_sme()
+ {
+ return true;
+ }
+
+ // Default to the generic kernel
+ kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL;
+
+ StdTransformsSME<operand_type, result_type, 1, 4, 4> transforms = {};
+
+ cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const CPUInfo *)
+ {
+ }
+};
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp
new file mode 100644
index 0000000000..4b26a6578c
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL/generic.cpp
@@ -0,0 +1,417 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_gemm.hpp"
+
+#include <cstdint>
+#include "../../asmlib.hpp"
+#include "../../utils.hpp"
+
+namespace arm_gemm {
+
+void sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer)
+{
+ struct KernelArgs
+ {
+ KernelArgs(
+ const int8_t *const A,
+ const int8_t *const B,
+ float *const C, const int ldc,
+ const int M, const int N, const int K,
+ const int32_t *const bias, const float *const late_bias, const Activation act,
+ bool accumulate,
+ int32_t *const accumulator_buffer
+ ) : A(A),
+ B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)),
+ C(C), ldcb(ldc * sizeof(float)),
+ M(M), N(N), K(K),
+ min(-std::numeric_limits<float>::infinity()),
+ max(std::numeric_limits<float>::infinity()),
+ bias(bias), late_bias(late_bias),
+ accumulator_buffer(accumulator_buffer),
+ flags(0x0)
+ {
+ if (accumulate)
+ {
+ flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER
+ }
+ if (C == nullptr)
+ {
+ flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER
+ }
+
+ // Initialise the activation values
+ switch (act.type)
+ {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ this->max = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ this->min = static_cast<float>(0);
+ break;
+ }
+ }
+
+ const int8_t *const A;
+ const int8_t *const B;
+ const long kstride_bytes;
+ float *const C;
+ const long ldcb;
+ const long M, N, K;
+ float min = -std::numeric_limits<float>::infinity();
+ float max = std::numeric_limits<float>::infinity();
+
+ const int32_t *const bias;
+ const float *const late_bias;
+
+ int32_t *const accumulator_buffer;
+ uint64_t flags;
+ };
+
+ // Construct arguments for this kernel
+ KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer);
+
+ __asm__ __volatile__(
+ "ldr x13, [%x[args], %[offsetof_flags]]\n"
+ ".inst 0xd503477f // SMSTART ZA\n"
+ "ptrue p0.b\n"
+ ".inst 0x25207811 // ptrue pn9.b\n"
+ "ldr x11, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "ldr x10, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "tbz x13, #0, 2f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "1:" // Initial accumulator load from buffer: Loop
+ ".inst 0xa040c57c // ld1w { z28.s-z31.s }, pn9.b/Z, [x11]\n"
+ ".inst 0xa041c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n"
+ ".inst 0xa042c578 // ld1w { z24.s-z27.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n"
+ ".inst 0xa043c56c // ld1w { z12.s-z15.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n"
+ ".inst 0xc0840780 // mova za0h.s[x12], { z28.s-z31.s }\n"
+ "addvl x11, x11, #16\n"
+ ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n"
+ ".inst 0xc0840702 // mova za2h.s[x12], { z24.s-z27.s }\n"
+ ".inst 0xc0840583 // mova za3h.s[x12], { z12.s-z15.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 1b\n"
+ "2:" // Initial accumulator load from buffer: End
+ "ldr w9, [%x[args], %[offsetof_M]]\n"
+ "mov x28, #0x0\n"
+ "mov x27, #0x0\n"
+ "ldr w26, [%x[args], %[offsetof_N]]\n"
+ "ldr x25, [%x[args], %[offsetof_A]]\n"
+ "3:" // M and N loop
+ "mov x24, x25\n"
+ ".inst 0x25ba6770 // whilelt pn8.s, x27, x26, VLx4\n"
+ "tbnz x13, #0, 4f\n"
+ "ldr x20, [%x[args], %[offsetof_bias]]\n"
+ ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
+ "cbz x20, 5f\n"
+ ".inst 0xa01bc288 // ld1w { z8.s-z11.s }, p8/Z, [x20, x27, LSL #2]\n"
+ ".inst 0xc0900100 // addha za0.s, p0/M, p0/M, z8.s\n"
+ ".inst 0xc0900121 // addha za1.s, p0/M, p0/M, z9.s\n"
+ ".inst 0xc0900142 // addha za2.s, p0/M, p0/M, z10.s\n"
+ ".inst 0xc0900163 // addha za3.s, p0/M, p0/M, z11.s\n"
+ "4:" // Prepare accumulators: Test for last block
+ "mov x20, x27\n"
+ "mov x21, x28\n"
+ "incw x20, ALL, MUL #4\n"
+ "incw x21\n"
+ "cmp x20, x26\n"
+ "mov x20, x13\n"
+ "csel x21, x28, x21, LT\n"
+ "bfm x13, XZR, #0x0, #0x0 // bfc x13, #0x0, #0x1\n"
+ "cmp x21, x9\n"
+ "csel x13, x20, x13, LT\n"
+ "5:" // Prepare accumulators: End
+ "ldr x20, [%x[args], %[offsetof_K]]\n"
+ "ldr x23, [%x[args], %[offsetof_B]]\n"
+ "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n"
+ "add x20, x20, #0x3\n"
+ "lsr x20, x20, #0x2\n"
+ "lsr x21, x20, #0x2\n"
+ "madd x23, x27, x22, x23\n" // bptr = B + n * kstride_bytes
+ "and x20, x20, #0x3\n"
+ "cbz x21, 8f\n"
+ "subs x21, x21, #0x1\n"
+ "ld1b { z31.b }, p0/Z, [x24]\n"
+ ".inst 0xa04086e8 // ld1b { z8.b-z11.b }, pn9.b/Z, [x23]\n"
+ "ld1b { z1.b }, p0/Z, [x24, #1, MUL VL]\n"
+ ".inst 0xa04186e4 // ld1b { z4.b-z7.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
+ "ld1b { z0.b }, p0/Z, [x24, #2, MUL VL]\n"
+ ".inst 0xa04286ec // ld1b { z12.b-z15.b }, pn9.b/Z, [x23, #0x8, MUL VL]\n"
+ "ld1b { z3.b }, p0/Z, [x24, #3, MUL VL]\n"
+ "addvl x24, x24, #4\n"
+ ".inst 0xa04386f0 // ld1b { z16.b-z19.b }, pn9.b/Z, [x23, #0xc, MUL VL]\n"
+ "addvl x23, x23, #16\n"
+ "ble 7f\n"
+ "6:" // K loop
+ ".inst 0xa08803e0 // smopa za0.s, p0/M, p0/M, z31.b, z8.b\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa08903e1 // smopa za1.s, p0/M, p0/M, z31.b, z9.b\n"
+ ".inst 0xa08a03e2 // smopa za2.s, p0/M, p0/M, z31.b, z10.b\n"
+ ".inst 0xa08b03e3 // smopa za3.s, p0/M, p0/M, z31.b, z11.b\n"
+ "ld1b { z31.b }, p0/Z, [x24]\n"
+ ".inst 0xa0840020 // smopa za0.s, p0/M, p0/M, z1.b, z4.b\n"
+ ".inst 0xa04086e8 // ld1b { z8.b-z11.b }, pn9.b/Z, [x23]\n"
+ ".inst 0xa0850021 // smopa za1.s, p0/M, p0/M, z1.b, z5.b\n"
+ ".inst 0xa0860022 // smopa za2.s, p0/M, p0/M, z1.b, z6.b\n"
+ ".inst 0xa0870023 // smopa za3.s, p0/M, p0/M, z1.b, z7.b\n"
+ "ld1b { z1.b }, p0/Z, [x24, #1, MUL VL]\n"
+ ".inst 0xa08c0000 // smopa za0.s, p0/M, p0/M, z0.b, z12.b\n"
+ ".inst 0xa04186e4 // ld1b { z4.b-z7.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
+ ".inst 0xa08d0001 // smopa za1.s, p0/M, p0/M, z0.b, z13.b\n"
+ ".inst 0xa08e0002 // smopa za2.s, p0/M, p0/M, z0.b, z14.b\n"
+ ".inst 0xa08f0003 // smopa za3.s, p0/M, p0/M, z0.b, z15.b\n"
+ "ld1b { z0.b }, p0/Z, [x24, #2, MUL VL]\n"
+ ".inst 0xa04286ec // ld1b { z12.b-z15.b }, pn9.b/Z, [x23, #0x8, MUL VL]\n"
+ ".inst 0xa0900060 // smopa za0.s, p0/M, p0/M, z3.b, z16.b\n"
+ ".inst 0xa0910061 // smopa za1.s, p0/M, p0/M, z3.b, z17.b\n"
+ ".inst 0xa0920062 // smopa za2.s, p0/M, p0/M, z3.b, z18.b\n"
+ ".inst 0xa0930063 // smopa za3.s, p0/M, p0/M, z3.b, z19.b\n"
+ "ld1b { z3.b }, p0/Z, [x24, #3, MUL VL]\n"
+ "addvl x24, x24, #4\n"
+ ".inst 0xa04386f0 // ld1b { z16.b-z19.b }, pn9.b/Z, [x23, #0xc, MUL VL]\n"
+ "addvl x23, x23, #16\n"
+ "bgt 6b\n"
+ "7:" // K loop tail
+ ".inst 0xa08803e0 // smopa za0.s, p0/M, p0/M, z31.b, z8.b\n"
+ ".inst 0xa08903e1 // smopa za1.s, p0/M, p0/M, z31.b, z9.b\n"
+ ".inst 0xa08a03e2 // smopa za2.s, p0/M, p0/M, z31.b, z10.b\n"
+ ".inst 0xa08b03e3 // smopa za3.s, p0/M, p0/M, z31.b, z11.b\n"
+ ".inst 0xa0840020 // smopa za0.s, p0/M, p0/M, z1.b, z4.b\n"
+ ".inst 0xa0850021 // smopa za1.s, p0/M, p0/M, z1.b, z5.b\n"
+ ".inst 0xa0860022 // smopa za2.s, p0/M, p0/M, z1.b, z6.b\n"
+ ".inst 0xa0870023 // smopa za3.s, p0/M, p0/M, z1.b, z7.b\n"
+ ".inst 0xa08c0000 // smopa za0.s, p0/M, p0/M, z0.b, z12.b\n"
+ ".inst 0xa08d0001 // smopa za1.s, p0/M, p0/M, z0.b, z13.b\n"
+ ".inst 0xa08e0002 // smopa za2.s, p0/M, p0/M, z0.b, z14.b\n"
+ ".inst 0xa08f0003 // smopa za3.s, p0/M, p0/M, z0.b, z15.b\n"
+ ".inst 0xa0900060 // smopa za0.s, p0/M, p0/M, z3.b, z16.b\n"
+ ".inst 0xa0910061 // smopa za1.s, p0/M, p0/M, z3.b, z17.b\n"
+ ".inst 0xa0920062 // smopa za2.s, p0/M, p0/M, z3.b, z18.b\n"
+ ".inst 0xa0930063 // smopa za3.s, p0/M, p0/M, z3.b, z19.b\n"
+ "8:" // K oddments
+ "cbz x20, 10f\n"
+ "9:" // K oddments: Loop
+ "ld1b { z18.b }, p0/Z, [x24]\n"
+ "subs x20, x20, #0x1\n"
+ "addvl x24, x24, #1\n"
+ ".inst 0xa04086fc // ld1b { z28.b-z31.b }, pn9.b/Z, [x23]\n"
+ "addvl x23, x23, #4\n"
+ ".inst 0xa09c0240 // smopa za0.s, p0/M, p0/M, z18.b, z28.b\n"
+ ".inst 0xa09d0241 // smopa za1.s, p0/M, p0/M, z18.b, z29.b\n"
+ ".inst 0xa09e0242 // smopa za2.s, p0/M, p0/M, z18.b, z30.b\n"
+ ".inst 0xa09f0243 // smopa za3.s, p0/M, p0/M, z18.b, z31.b\n"
+ "bgt 9b\n"
+ "10:" // K oddments: End
+ "tbz x13, #1, 14f\n"
+ "tbz x13, #0, 12f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "11:" // Store to partial result buffer: Store and refill: Loop
+ ".inst 0xa040c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11]\n"
+ ".inst 0xc0860408 // mova { z8.s-z11.s }, za0h.s[x12]\n"
+ ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n"
+ ".inst 0xa041c57c // ld1w { z28.s-z31.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n"
+ ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n"
+ ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n"
+ ".inst 0xa042c578 // ld1w { z24.s-z27.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n"
+ ".inst 0xa043c574 // ld1w { z20.s-z23.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n"
+ ".inst 0xc0840400 // mova za0h.s[x12], { z0.s-z3.s }\n"
+ "addvl x11, x11, #16\n"
+ ".inst 0xc0840781 // mova za1h.s[x12], { z28.s-z31.s }\n"
+ ".inst 0xa060c548 // st1w { z8.s-z11.s }, pn9.b, [x10]\n"
+ ".inst 0xc0840702 // mova za2h.s[x12], { z24.s-z27.s }\n"
+ ".inst 0xa061c54c // st1w { z12.s-z15.s }, pn9.b, [x10, #0x4, MUL VL]\n"
+ ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa062c544 // st1w { z4.s-z7.s }, pn9.b, [x10, #0x8, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa063c550 // st1w { z16.s-z19.s }, pn9.b, [x10, #0xc, MUL VL]\n"
+ "addvl x10, x10, #16\n"
+ "blt 11b\n"
+ "b 21f\n"
+ "12:" // Store to partial result buffer: Store only
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "13:" // Store to partial result buffer: Store only: Loop
+ ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n"
+ ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n"
+ ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n"
+ ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n"
+ ".inst 0xa060c544 // st1w { z4.s-z7.s }, pn9.b, [x10]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa061c550 // st1w { z16.s-z19.s }, pn9.b, [x10, #0x4, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa062c548 // st1w { z8.s-z11.s }, pn9.b, [x10, #0x8, MUL VL]\n"
+ ".inst 0xa063c54c // st1w { z12.s-z15.s }, pn9.b, [x10, #0xc, MUL VL]\n"
+ "addvl x10, x10, #16\n"
+ "blt 13b\n"
+ "b 21f\n"
+ "14:" // Store to output array
+ "ldr x23, [%x[args], %[offsetof_C]]\n"
+ "sub x21, x9, x28\n"
+ "ld1rw { z18.s }, p0/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n"
+ "fmov z20.s, #0x0\n"
+ "ldr x22, [%x[args], %[offsetof_ldcb]]\n"
+ "fmov z21.s, #0x0\n"
+ "fmov z22.s, #0x0\n"
+ "ldr x20, [%x[args], %[offsetof_late_bias]]\n"
+ "fmov z23.s, #0x0\n"
+ "add x23, x23, x27, LSL #2\n" // C += n
+ "madd x23, x28, x22, x23\n" // C += m * ldc
+ "cbz x20, 15f\n"
+ "add x20, x20, x27, LSL #2\n"
+ ".inst 0xa040c294 // ld1w { z20.s-z23.s }, p8/Z, [x20]\n"
+ "15:" // Store to output array: no late bias
+ "cntw x20\n"
+ "ld1rw { z17.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n"
+ "mov x12, #0x0\n"
+ "cmp x21, x20\n"
+ "ld1rw { z16.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n"
+ "csel x20, x21, x20, LT\n"
+ "lsr x21, x20, #0x2\n"
+ "and x20, x20, #0x3\n"
+ "cbz x21, 17f\n"
+ "16:" // Store to output array: Accumulator row 0 loop
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n"
+ ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n"
+ ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n"
+ ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n"
+ ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n"
+ ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "fmad z0.s, p0/M, z18.s, z20.s\n"
+ "fmad z1.s, p0/M, z18.s, z20.s\n"
+ "fmad z2.s, p0/M, z18.s, z20.s\n"
+ "fmad z3.s, p0/M, z18.s, z20.s\n"
+ "add x12, x12, #0x4\n"
+ "fmad z4.s, p0/M, z18.s, z21.s\n"
+ "fmad z5.s, p0/M, z18.s, z21.s\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z6.s, p0/M, z18.s, z21.s\n"
+ "fmad z7.s, p0/M, z18.s, z21.s\n"
+ "fmad z8.s, p0/M, z18.s, z22.s\n"
+ "fmad z9.s, p0/M, z18.s, z22.s\n"
+ "fmad z10.s, p0/M, z18.s, z22.s\n"
+ "fmad z11.s, p0/M, z18.s, z22.s\n"
+ "fmad z12.s, p0/M, z18.s, z23.s\n"
+ "fmad z13.s, p0/M, z18.s, z23.s\n"
+ "fmad z14.s, p0/M, z18.s, z23.s\n"
+ "fmad z15.s, p0/M, z18.s, z23.s\n"
+ ".inst 0xc1b0ca20 // fclamp { z0.s-z3.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca24 // fclamp { z4.s-z7.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca28 // fclamp { z8.s-z11.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca2c // fclamp { z12.s-z15.s }, z17.s, z16.s\n"
+ ".inst 0xa160c2e0 // st1w { z0.s, z4.s, z8.s, z12.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ ".inst 0xa160c2e1 // st1w { z1.s, z5.s, z9.s, z13.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ ".inst 0xa160c2e2 // st1w { z2.s, z6.s, z10.s, z14.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ ".inst 0xa160c2e3 // st1w { z3.s, z7.s, z11.s, z15.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ "blt 16b\n"
+ "17:" // Store to output array: Accumulator row 0 oddments
+ "cbz x20, 18f\n"
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n"
+ ".inst 0xc0860448 // mova { z8.s-z11.s }, za2h.s[x12]\n"
+ ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n"
+ ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n"
+ ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n"
+ ".inst 0xc132e108 // scvtf { z8.s-z11.s }, { z8.s-z11.s }\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "fmad z0.s, p0/M, z18.s, z20.s\n"
+ "fmad z1.s, p0/M, z18.s, z20.s\n"
+ "fmad z2.s, p0/M, z18.s, z20.s\n"
+ "fmad z3.s, p0/M, z18.s, z20.s\n"
+ "subs x20, x20, #0x1\n"
+ "fmad z4.s, p0/M, z18.s, z21.s\n"
+ "fmad z5.s, p0/M, z18.s, z21.s\n"
+ "fmad z6.s, p0/M, z18.s, z21.s\n"
+ "fmad z7.s, p0/M, z18.s, z21.s\n"
+ "fmad z8.s, p0/M, z18.s, z22.s\n"
+ "fmad z9.s, p0/M, z18.s, z22.s\n"
+ "fmad z10.s, p0/M, z18.s, z22.s\n"
+ "fmad z11.s, p0/M, z18.s, z22.s\n"
+ "fmad z12.s, p0/M, z18.s, z23.s\n"
+ "fmad z13.s, p0/M, z18.s, z23.s\n"
+ "fmad z14.s, p0/M, z18.s, z23.s\n"
+ "fmad z15.s, p0/M, z18.s, z23.s\n"
+ ".inst 0xc1b0ca20 // fclamp { z0.s-z3.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca24 // fclamp { z4.s-z7.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca28 // fclamp { z8.s-z11.s }, z17.s, z16.s\n"
+ ".inst 0xc1b0ca2c // fclamp { z12.s-z15.s }, z17.s, z16.s\n"
+ ".inst 0xa160c2e0 // st1w { z0.s, z4.s, z8.s, z12.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ "beq 18f\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xa160c2e1 // st1w { z1.s, z5.s, z9.s, z13.s }, p8, [x23]\n"
+ "add x23, x23, x22\n"
+ "beq 18f\n"
+ ".inst 0xa160c2e2 // st1w { z2.s, z6.s, z10.s, z14.s }, p8, [x23]\n"
+ "18:" // Store to output array: Accumulator row 0 oddments: End
+ "19:" // Store to output array: End
+ "tbz x13, #0, 21f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "20:" // Store to output array: Refill accumulators: Loop
+ ".inst 0xa040c574 // ld1w { z20.s-z23.s }, pn9.b/Z, [x11]\n"
+ ".inst 0xa041c56c // ld1w { z12.s-z15.s }, pn9.b/Z, [x11, #0x4, MUL VL]\n"
+ ".inst 0xa042c560 // ld1w { z0.s-z3.s }, pn9.b/Z, [x11, #0x8, MUL VL]\n"
+ ".inst 0xa043c568 // ld1w { z8.s-z11.s }, pn9.b/Z, [x11, #0xc, MUL VL]\n"
+ ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n"
+ "addvl x11, x11, #16\n"
+ ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n"
+ ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n"
+ ".inst 0xc0840503 // mova za3h.s[x12], { z8.s-z11.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 20b\n"
+ "21:" // End block
+ "incw x27, ALL, MUL #4\n"
+ "cmp x27, x26\n"
+ "blt 3b\n"
+ "incw x28\n"
+ "mov x27, #0x0\n"
+ "cmp x28, x9\n"
+ "mov x25, x24\n"
+ "blt 3b\n"
+ ".inst 0xd503467f // SMSTOP\n"
+ :
+ : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb))
+ : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+}
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp
new file mode 100644
index 0000000000..df2c9c0ca3
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL.hpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include <cstdint>
+#include "../std_transforms_sme.hpp"
+
+namespace arm_gemm
+{
+
+// Implementations
+void sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+class cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL
+{
+public:
+ typedef int8_t operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+ /* Kernel blocking parameters */
+ static unsigned int out_height()
+ {
+ return sme::get_vector_length<int32_t>() * 2;
+ }
+
+ static unsigned int out_width()
+ {
+ return sme::get_vector_length<int32_t>() * 2;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 4;
+ }
+
+ static constexpr bool supports_accumulate()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_bias()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_activation()
+ {
+ return true;
+ }
+
+ static constexpr bool is_sme()
+ {
+ return true;
+ }
+
+ // Default to the generic kernel
+ kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL;
+
+ StdTransformsSME<operand_type, result_type, 2, 2, 4> transforms = {};
+
+ cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const CPUInfo *)
+ {
+ }
+};
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp
new file mode 100644
index 0000000000..1631fae8e9
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL/generic.cpp
@@ -0,0 +1,448 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_gemm.hpp"
+
+#include <cstdint>
+#include "../../asmlib.hpp"
+#include "../../utils.hpp"
+
+namespace arm_gemm {
+
+void sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer)
+{
+ struct KernelArgs
+ {
+ KernelArgs(
+ const int8_t *const A,
+ const int8_t *const B,
+ float *const C, const int ldc,
+ const int M, const int N, const int K,
+ const int32_t *const bias, const float *const late_bias, const Activation act,
+ bool accumulate,
+ int32_t *const accumulator_buffer
+ ) : A(A),
+ B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)),
+ C(C), ldcb(ldc * sizeof(float)),
+ M(M), N(N), K(K),
+ min(-std::numeric_limits<float>::infinity()),
+ max(std::numeric_limits<float>::infinity()),
+ bias(bias), late_bias(late_bias),
+ accumulator_buffer(accumulator_buffer),
+ flags(0x0)
+ {
+ if (accumulate)
+ {
+ flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER
+ }
+ if (C == nullptr)
+ {
+ flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER
+ }
+
+ // Initialise the activation values
+ switch (act.type)
+ {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ this->max = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ this->min = static_cast<float>(0);
+ break;
+ }
+ }
+
+ const int8_t *const A;
+ const int8_t *const B;
+ const long kstride_bytes;
+ float *const C;
+ const long ldcb;
+ const long M, N, K;
+ float min = -std::numeric_limits<float>::infinity();
+ float max = std::numeric_limits<float>::infinity();
+
+ const int32_t *const bias;
+ const float *const late_bias;
+
+ int32_t *const accumulator_buffer;
+ uint64_t flags;
+ };
+
+ // Construct arguments for this kernel
+ KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer);
+
+ __asm__ __volatile__(
+ "ldr x16, [%x[args], %[offsetof_flags]]\n"
+ ".inst 0xd503477f // SMSTART ZA\n"
+ "ptrue p0.b\n"
+ ".inst 0x25207811 // ptrue pn9.b\n"
+ "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "ldr x14, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "tbz x16, #0, 2f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "1:" // Initial accumulator load from buffer: Loop
+ ".inst 0xa040c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15]\n"
+ ".inst 0xa041c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xa042c5e0 // ld1w { z0.s-z3.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c5f8 // ld1w { z24.s-z27.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840580 // mova za0h.s[x12], { z12.s-z15.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840681 // mova za1h.s[x12], { z20.s-z23.s }\n"
+ ".inst 0xc0840402 // mova za2h.s[x12], { z0.s-z3.s }\n"
+ ".inst 0xc0840703 // mova za3h.s[x12], { z24.s-z27.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 1b\n"
+ "2:" // Initial accumulator load from buffer: End
+ "ldr w13, [%x[args], %[offsetof_M]]\n"
+ "mov x11, #0x0\n"
+ "mov x10, #0x0\n"
+ "ldr w9, [%x[args], %[offsetof_N]]\n"
+ "ldr x28, [%x[args], %[offsetof_A]]\n"
+ "3:" // M and N loop
+ "mov x27, x28\n"
+ ".inst 0x25a94550 // whilelt pn8.s, x10, x9, VLx2\n"
+ "tbnz x16, #0, 4f\n"
+ "ldr x20, [%x[args], %[offsetof_bias]]\n"
+ ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
+ "cbz x20, 5f\n"
+ ".inst 0xa10a4286 // ld1w { z6.s, z14.s }, p8/Z, [x20, x10, LSL #2]\n"
+ ".inst 0xc09000c0 // addha za0.s, p0/M, p0/M, z6.s\n"
+ ".inst 0xc09001c1 // addha za1.s, p0/M, p0/M, z14.s\n"
+ ".inst 0xc09000c2 // addha za2.s, p0/M, p0/M, z6.s\n"
+ ".inst 0xc09001c3 // addha za3.s, p0/M, p0/M, z14.s\n"
+ "4:" // Prepare accumulators: Test for last block
+ "mov x20, x10\n"
+ "mov x21, x11\n"
+ "incw x20, ALL, MUL #2\n"
+ "incw x21, ALL, MUL #2\n"
+ "cmp x20, x9\n"
+ "mov x20, x16\n"
+ "csel x21, x11, x21, LT\n"
+ "bfm x16, XZR, #0x0, #0x0 // bfc x16, #0x0, #0x1\n"
+ "cmp x21, x13\n"
+ "csel x16, x20, x16, LT\n"
+ "5:" // Prepare accumulators: End
+ "ldr x20, [%x[args], %[offsetof_K]]\n"
+ "ldr x23, [%x[args], %[offsetof_B]]\n"
+ "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n"
+ "add x20, x20, #0x3\n"
+ "lsr x20, x20, #0x2\n"
+ "lsr x21, x20, #0x2\n"
+ "madd x23, x10, x22, x23\n" // bptr = B + n * kstride_bytes
+ "and x20, x20, #0x3\n"
+ "cbz x21, 8f\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa1400775 // ld1b { z21.b, z29.b }, pn9.b/Z, [x27]\n"
+ ".inst 0xa04006f2 // ld1b { z18.b-z19.b }, pn9.b/Z, [x23]\n"
+ ".inst 0xa041076a // ld1b { z10.b-z11.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n"
+ ".inst 0xa14106e5 // ld1b { z5.b, z13.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n"
+ ".inst 0xa1420767 // ld1b { z7.b, z15.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa14206f0 // ld1b { z16.b, z24.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
+ ".inst 0xa1430774 // ld1b { z20.b, z28.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n"
+ "addvl x27, x27, #8\n"
+ ".inst 0xa14306f7 // ld1b { z23.b, z31.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n"
+ "addvl x23, x23, #8\n"
+ "ble 7f\n"
+ "6:" // K loop
+ ".inst 0xa09202a0 // smopa za0.s, p0/M, p0/M, z21.b, z18.b\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa09302a1 // smopa za1.s, p0/M, p0/M, z21.b, z19.b\n"
+ ".inst 0xa09203a2 // smopa za2.s, p0/M, p0/M, z29.b, z18.b\n"
+ ".inst 0xa09303a3 // smopa za3.s, p0/M, p0/M, z29.b, z19.b\n"
+ ".inst 0xa1400775 // ld1b { z21.b, z29.b }, pn9.b/Z, [x27]\n"
+ ".inst 0xa0850140 // smopa za0.s, p0/M, p0/M, z10.b, z5.b\n"
+ ".inst 0xa04006f2 // ld1b { z18.b-z19.b }, pn9.b/Z, [x23]\n"
+ ".inst 0xa08d0141 // smopa za1.s, p0/M, p0/M, z10.b, z13.b\n"
+ ".inst 0xa0850162 // smopa za2.s, p0/M, p0/M, z11.b, z5.b\n"
+ ".inst 0xa08d0163 // smopa za3.s, p0/M, p0/M, z11.b, z13.b\n"
+ ".inst 0xa041076a // ld1b { z10.b-z11.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n"
+ ".inst 0xa09000e0 // smopa za0.s, p0/M, p0/M, z7.b, z16.b\n"
+ ".inst 0xa14106e5 // ld1b { z5.b, z13.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n"
+ ".inst 0xa09800e1 // smopa za1.s, p0/M, p0/M, z7.b, z24.b\n"
+ ".inst 0xa09001e2 // smopa za2.s, p0/M, p0/M, z15.b, z16.b\n"
+ ".inst 0xa09801e3 // smopa za3.s, p0/M, p0/M, z15.b, z24.b\n"
+ ".inst 0xa1420767 // ld1b { z7.b, z15.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa14206f0 // ld1b { z16.b, z24.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
+ ".inst 0xa0970280 // smopa za0.s, p0/M, p0/M, z20.b, z23.b\n"
+ ".inst 0xa09f0281 // smopa za1.s, p0/M, p0/M, z20.b, z31.b\n"
+ ".inst 0xa0970382 // smopa za2.s, p0/M, p0/M, z28.b, z23.b\n"
+ ".inst 0xa09f0383 // smopa za3.s, p0/M, p0/M, z28.b, z31.b\n"
+ ".inst 0xa1430774 // ld1b { z20.b, z28.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n"
+ "addvl x27, x27, #8\n"
+ ".inst 0xa14306f7 // ld1b { z23.b, z31.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n"
+ "addvl x23, x23, #8\n"
+ "bgt 6b\n"
+ "7:" // K loop tail
+ ".inst 0xa09202a0 // smopa za0.s, p0/M, p0/M, z21.b, z18.b\n"
+ ".inst 0xa09302a1 // smopa za1.s, p0/M, p0/M, z21.b, z19.b\n"
+ ".inst 0xa09203a2 // smopa za2.s, p0/M, p0/M, z29.b, z18.b\n"
+ ".inst 0xa09303a3 // smopa za3.s, p0/M, p0/M, z29.b, z19.b\n"
+ ".inst 0xa0850140 // smopa za0.s, p0/M, p0/M, z10.b, z5.b\n"
+ ".inst 0xa08d0141 // smopa za1.s, p0/M, p0/M, z10.b, z13.b\n"
+ ".inst 0xa0850162 // smopa za2.s, p0/M, p0/M, z11.b, z5.b\n"
+ ".inst 0xa08d0163 // smopa za3.s, p0/M, p0/M, z11.b, z13.b\n"
+ ".inst 0xa09000e0 // smopa za0.s, p0/M, p0/M, z7.b, z16.b\n"
+ ".inst 0xa09800e1 // smopa za1.s, p0/M, p0/M, z7.b, z24.b\n"
+ ".inst 0xa09001e2 // smopa za2.s, p0/M, p0/M, z15.b, z16.b\n"
+ ".inst 0xa09801e3 // smopa za3.s, p0/M, p0/M, z15.b, z24.b\n"
+ ".inst 0xa0970280 // smopa za0.s, p0/M, p0/M, z20.b, z23.b\n"
+ ".inst 0xa09f0281 // smopa za1.s, p0/M, p0/M, z20.b, z31.b\n"
+ ".inst 0xa0970382 // smopa za2.s, p0/M, p0/M, z28.b, z23.b\n"
+ ".inst 0xa09f0383 // smopa za3.s, p0/M, p0/M, z28.b, z31.b\n"
+ "8:" // K oddments
+ "cbz x20, 10f\n"
+ "9:" // K oddments: Loop
+ ".inst 0xa040077e // ld1b { z30.b-z31.b }, pn9.b/Z, [x27]\n"
+ "subs x20, x20, #0x1\n"
+ "addvl x27, x27, #2\n"
+ ".inst 0xa14006e7 // ld1b { z7.b, z15.b }, pn9.b/Z, [x23]\n"
+ "addvl x23, x23, #2\n"
+ ".inst 0xa08703c0 // smopa za0.s, p0/M, p0/M, z30.b, z7.b\n"
+ ".inst 0xa08f03c1 // smopa za1.s, p0/M, p0/M, z30.b, z15.b\n"
+ ".inst 0xa08703e2 // smopa za2.s, p0/M, p0/M, z31.b, z7.b\n"
+ ".inst 0xa08f03e3 // smopa za3.s, p0/M, p0/M, z31.b, z15.b\n"
+ "bgt 9b\n"
+ "10:" // K oddments: End
+ "tbz x16, #1, 14f\n"
+ "tbz x16, #0, 12f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "11:" // Store to partial result buffer: Store and refill: Loop
+ ".inst 0xa040c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15]\n"
+ ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n"
+ ".inst 0xc0860428 // mova { z8.s-z11.s }, za1h.s[x12]\n"
+ ".inst 0xa041c5f0 // ld1w { z16.s-z19.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xc0860440 // mova { z0.s-z3.s }, za2h.s[x12]\n"
+ ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n"
+ ".inst 0xa042c5fc // ld1w { z28.s-z31.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840580 // mova za0h.s[x12], { z12.s-z15.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840601 // mova za1h.s[x12], { z16.s-z19.s }\n"
+ ".inst 0xa060c5c4 // st1w { z4.s-z7.s }, pn9.b, [x14]\n"
+ ".inst 0xc0840782 // mova za2h.s[x12], { z28.s-z31.s }\n"
+ ".inst 0xa061c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14, #0x4, MUL VL]\n"
+ ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa062c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14, #0x8, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa063c5d8 // st1w { z24.s-z27.s }, pn9.b, [x14, #0xc, MUL VL]\n"
+ "addvl x14, x14, #16\n"
+ "blt 11b\n"
+ "b 24f\n"
+ "12:" // Store to partial result buffer: Store only
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "13:" // Store to partial result buffer: Store only: Loop
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n"
+ ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n"
+ ".inst 0xc0860468 // mova { z8.s-z11.s }, za3h.s[x12]\n"
+ ".inst 0xa060c5c0 // st1w { z0.s-z3.s }, pn9.b, [x14]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa061c5cc // st1w { z12.s-z15.s }, pn9.b, [x14, #0x4, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa062c5d0 // st1w { z16.s-z19.s }, pn9.b, [x14, #0x8, MUL VL]\n"
+ ".inst 0xa063c5c8 // st1w { z8.s-z11.s }, pn9.b, [x14, #0xc, MUL VL]\n"
+ "addvl x14, x14, #16\n"
+ "blt 13b\n"
+ "b 24f\n"
+ "14:" // Store to output array
+ "ldr x26, [%x[args], %[offsetof_C]]\n"
+ "sub x25, x13, x11\n"
+ "ld1rw { z3.s }, p0/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n"
+ "fmov z2.s, #0x0\n"
+ "ldr x24, [%x[args], %[offsetof_ldcb]]\n"
+ "fmov z10.s, #0x0\n"
+ "ldr x20, [%x[args], %[offsetof_late_bias]]\n"
+ "add x26, x26, x10, LSL #2\n" // C += n
+ "madd x26, x11, x24, x26\n" // C += m * ldc
+ "cbz x20, 15f\n"
+ "add x20, x20, x10, LSL #2\n"
+ ".inst 0xa1404282 // ld1w { z2.s, z10.s }, p8/Z, [x20]\n"
+ "15:" // Store to output array: no late bias
+ "cntw x23\n"
+ "ld1rw { z1.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_min]]\n"
+ "mov x12, #0x0\n"
+ "cmp x25, x23\n"
+ "ld1rw { z0.s }, p0/Z, [%x[args], %[offsetof_KernelArgs_max]]\n"
+ "csel x22, x25, x23, LT\n"
+ "lsr x21, x22, #0x2\n"
+ "and x20, x22, #0x3\n"
+ "cbz x21, 17f\n"
+ "16:" // Store to output array: Accumulator row 0 loop
+ ".inst 0xc0860404 // mova { z4.s-z7.s }, za0h.s[x12]\n"
+ ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n"
+ ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "fmad z4.s, p0/M, z3.s, z2.s\n"
+ "fmad z5.s, p0/M, z3.s, z2.s\n"
+ "add x12, x12, #0x4\n"
+ "fmad z6.s, p0/M, z3.s, z2.s\n"
+ "fmad z7.s, p0/M, z3.s, z2.s\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z12.s, p0/M, z3.s, z10.s\n"
+ "fmad z13.s, p0/M, z3.s, z10.s\n"
+ "fmad z14.s, p0/M, z3.s, z10.s\n"
+ "fmad z15.s, p0/M, z3.s, z10.s\n"
+ ".inst 0xc1a0c824 // fclamp { z4.s-z7.s }, z1.s, z0.s\n"
+ ".inst 0xc1a0c82c // fclamp { z12.s-z15.s }, z1.s, z0.s\n"
+ ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604347 // st1w { z7.s, z15.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 16b\n"
+ "17:" // Store to output array: Accumulator row 0 oddments
+ "cbz x20, 18f\n"
+ ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n"
+ ".inst 0xc0860438 // mova { z24.s-z27.s }, za1h.s[x12]\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n"
+ "fmad z16.s, p0/M, z3.s, z2.s\n"
+ "fmad z17.s, p0/M, z3.s, z2.s\n"
+ "subs x20, x20, #0x1\n"
+ "fmad z18.s, p0/M, z3.s, z2.s\n"
+ "fmad z19.s, p0/M, z3.s, z2.s\n"
+ "fmad z24.s, p0/M, z3.s, z10.s\n"
+ "fmad z25.s, p0/M, z3.s, z10.s\n"
+ "fmad z26.s, p0/M, z3.s, z10.s\n"
+ "fmad z27.s, p0/M, z3.s, z10.s\n"
+ ".inst 0xc1a0c830 // fclamp { z16.s-z19.s }, z1.s, z0.s\n"
+ ".inst 0xc1a0c838 // fclamp { z24.s-z27.s }, z1.s, z0.s\n"
+ ".inst 0xa1604350 // st1w { z16.s, z24.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 18f\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xa1604351 // st1w { z17.s, z25.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 18f\n"
+ ".inst 0xa1604352 // st1w { z18.s, z26.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "18:" // Store to output array: Accumulator row 0 oddments: End
+ "subs x25, x25, x22\n"
+ "beq 22f\n"
+ "cmp x25, x23\n"
+ "mov x12, #0x0\n"
+ "csel x20, x25, x23, LT\n"
+ "lsr x21, x20, #0x2\n"
+ "and x20, x20, #0x3\n"
+ "cbz x21, 20f\n"
+ "19:" // Store to output array: Accumulator row 1 loop
+ ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n"
+ ".inst 0xc086047c // mova { z28.s-z31.s }, za3h.s[x12]\n"
+ ".inst 0xc132e294 // scvtf { z20.s-z23.s }, { z20.s-z23.s }\n"
+ ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n"
+ "fmad z20.s, p0/M, z3.s, z2.s\n"
+ "fmad z21.s, p0/M, z3.s, z2.s\n"
+ "add x12, x12, #0x4\n"
+ "fmad z22.s, p0/M, z3.s, z2.s\n"
+ "fmad z23.s, p0/M, z3.s, z2.s\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z28.s, p0/M, z3.s, z10.s\n"
+ "fmad z29.s, p0/M, z3.s, z10.s\n"
+ "fmad z30.s, p0/M, z3.s, z10.s\n"
+ "fmad z31.s, p0/M, z3.s, z10.s\n"
+ ".inst 0xc1a0c834 // fclamp { z20.s-z23.s }, z1.s, z0.s\n"
+ ".inst 0xc1a0c83c // fclamp { z28.s-z31.s }, z1.s, z0.s\n"
+ ".inst 0xa1604354 // st1w { z20.s, z28.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604355 // st1w { z21.s, z29.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604356 // st1w { z22.s, z30.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ ".inst 0xa1604357 // st1w { z23.s, z31.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 19b\n"
+ "20:" // Store to output array: Accumulator row 1 oddments
+ "cbz x20, 21f\n"
+ ".inst 0xc0860444 // mova { z4.s-z7.s }, za2h.s[x12]\n"
+ ".inst 0xc086046c // mova { z12.s-z15.s }, za3h.s[x12]\n"
+ ".inst 0xc132e084 // scvtf { z4.s-z7.s }, { z4.s-z7.s }\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "fmad z4.s, p0/M, z3.s, z2.s\n"
+ "fmad z5.s, p0/M, z3.s, z2.s\n"
+ "subs x20, x20, #0x1\n"
+ "fmad z6.s, p0/M, z3.s, z2.s\n"
+ "fmad z7.s, p0/M, z3.s, z2.s\n"
+ "fmad z12.s, p0/M, z3.s, z10.s\n"
+ "fmad z13.s, p0/M, z3.s, z10.s\n"
+ "fmad z14.s, p0/M, z3.s, z10.s\n"
+ "fmad z15.s, p0/M, z3.s, z10.s\n"
+ ".inst 0xc1a0c824 // fclamp { z4.s-z7.s }, z1.s, z0.s\n"
+ ".inst 0xc1a0c82c // fclamp { z12.s-z15.s }, z1.s, z0.s\n"
+ ".inst 0xa1604344 // st1w { z4.s, z12.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 21f\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xa1604345 // st1w { z5.s, z13.s }, p8, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 21f\n"
+ ".inst 0xa1604346 // st1w { z6.s, z14.s }, p8, [x26]\n"
+ "21:" // Store to output array: Accumulator row 1 oddments: End
+ "22:" // Store to output array: End
+ "tbz x16, #0, 24f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "23:" // Store to output array: Refill accumulators: Loop
+ ".inst 0xa040c5f4 // ld1w { z20.s-z23.s }, pn9.b/Z, [x15]\n"
+ ".inst 0xa041c5ec // ld1w { z12.s-z15.s }, pn9.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xa042c5e4 // ld1w { z4.s-z7.s }, pn9.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c5e8 // ld1w { z8.s-z11.s }, pn9.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n"
+ ".inst 0xc0840482 // mova za2h.s[x12], { z4.s-z7.s }\n"
+ ".inst 0xc0840503 // mova za3h.s[x12], { z8.s-z11.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 23b\n"
+ "24:" // End block
+ "incw x10, ALL, MUL #2\n"
+ "cmp x10, x9\n"
+ "blt 3b\n"
+ "incw x11, ALL, MUL #2\n"
+ "mov x10, #0x0\n"
+ "cmp x11, x13\n"
+ "mov x28, x27\n"
+ "blt 3b\n"
+ ".inst 0xd503467f // SMSTOP\n"
+ :
+ : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb))
+ : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+}
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp
new file mode 100644
index 0000000000..70952f4f03
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL.hpp
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include <cstdint>
+#include "../std_transforms_sme.hpp"
+
+namespace arm_gemm
+{
+
+// Implementations
+void sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+class cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL
+{
+public:
+ typedef int8_t operand_type;
+ typedef float result_type;
+
+ typedef void (*kern_type)(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer);
+
+ /* Kernel blocking parameters */
+ static unsigned int out_height()
+ {
+ return sme::get_vector_length<int32_t>() * 4;
+ }
+
+ static unsigned int out_width()
+ {
+ return sme::get_vector_length<int32_t>() * 1;
+ }
+
+ static constexpr unsigned int k_unroll()
+ {
+ return 4;
+ }
+
+ static constexpr bool supports_accumulate()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_bias()
+ {
+ return true;
+ }
+
+ static constexpr bool supports_activation()
+ {
+ return true;
+ }
+
+ static constexpr bool is_sme()
+ {
+ return true;
+ }
+
+ // Default to the generic kernel
+ kern_type kernel = sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL;
+
+ StdTransformsSME<operand_type, result_type, 4, 1, 4> transforms = {};
+
+ cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const CPUInfo *)
+ {
+ }
+};
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp
new file mode 100644
index 0000000000..bafb16bca8
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL/generic.cpp
@@ -0,0 +1,513 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_gemm.hpp"
+
+#include <cstdint>
+#include "../../asmlib.hpp"
+#include "../../utils.hpp"
+
+namespace arm_gemm {
+
+void sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL(const int8_t *const A, const int8_t *const B, float *const C, int ldc, const int M, const int N, const int K, const int32_t *const bias, const DequantizeFloat &dq, const float *const late_bias, const Activation act, bool accumulate, int32_t *const accumulator_buffer)
+{
+ struct KernelArgs
+ {
+ KernelArgs(
+ const int8_t *const A,
+ const int8_t *const B,
+ float *const C, const int ldc,
+ const int M, const int N, const int K,
+ const int32_t *const bias, const float *const late_bias, const Activation act,
+ bool accumulate,
+ int32_t *const accumulator_buffer
+ ) : A(A),
+ B(B), kstride_bytes(roundup(K, 4) * sizeof(int8_t)),
+ C(C), ldcb(ldc * sizeof(float)),
+ M(M), N(N), K(K),
+ min(-std::numeric_limits<float>::infinity()),
+ max(std::numeric_limits<float>::infinity()),
+ bias(bias), late_bias(late_bias),
+ accumulator_buffer(accumulator_buffer),
+ flags(0x0)
+ {
+ if (accumulate)
+ {
+ flags |= 1 << 0; // FILL_ACCUMULATORS_FROM_BUFFER
+ }
+ if (C == nullptr)
+ {
+ flags |= 1 << 1; // STORE_ACCUMULATORS_TO_BUFFER
+ }
+
+ // Initialise the activation values
+ switch (act.type)
+ {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ this->max = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ this->min = static_cast<float>(0);
+ break;
+ }
+ }
+
+ const int8_t *const A;
+ const int8_t *const B;
+ const long kstride_bytes;
+ float *const C;
+ const long ldcb;
+ const long M, N, K;
+ float min = -std::numeric_limits<float>::infinity();
+ float max = std::numeric_limits<float>::infinity();
+
+ const int32_t *const bias;
+ const float *const late_bias;
+
+ int32_t *const accumulator_buffer;
+ uint64_t flags;
+ };
+
+ // Construct arguments for this kernel
+ KernelArgs args(A, B, C, ldc, M, N, K, bias, late_bias, act, accumulate, accumulator_buffer);
+
+ __asm__ __volatile__(
+ "ldr x16, [%x[args], %[offsetof_flags]]\n"
+ ".inst 0xd503477f // SMSTART ZA\n"
+ "ptrue p1.b\n"
+ ".inst 0x25207810 // ptrue pn8.b\n"
+ "ldr x15, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "ldr x14, [%x[args], %[offsetof_accumulator_buffer]]\n"
+ "tbz x16, #0, 2f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "1:" // Initial accumulator load from buffer: Loop
+ ".inst 0xa040c1f4 // ld1w { z20.s-z23.s }, pn8.b/Z, [x15]\n"
+ ".inst 0xa041c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xa042c1e8 // ld1w { z8.s-z11.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c1f0 // ld1w { z16.s-z19.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840680 // mova za0h.s[x12], { z20.s-z23.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840781 // mova za1h.s[x12], { z28.s-z31.s }\n"
+ ".inst 0xc0840502 // mova za2h.s[x12], { z8.s-z11.s }\n"
+ ".inst 0xc0840603 // mova za3h.s[x12], { z16.s-z19.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 1b\n"
+ "2:" // Initial accumulator load from buffer: End
+ "ldr w13, [%x[args], %[offsetof_M]]\n"
+ "mov x11, #0x0\n"
+ "mov x10, #0x0\n"
+ "ldr w9, [%x[args], %[offsetof_N]]\n"
+ "ldr x28, [%x[args], %[offsetof_A]]\n"
+ "3:" // M and N loop
+ "mov x27, x28\n"
+ "whilelt p0.s, x10, x9\n"
+ "tbnz x16, #0, 4f\n"
+ "ldr x20, [%x[args], %[offsetof_bias]]\n"
+ ".inst 0xc00800ff // zero { zad0, zad1, zad2, zad3, zad4, zad5, zad6, zad7 }\n"
+ "cbz x20, 5f\n"
+ "ld1w { z23.s }, p0/Z, [x20, x10, LSL #2]\n"
+ ".inst 0xc09026e0 // addha za0.s, p1/M, p1/M, z23.s\n"
+ ".inst 0xc09026e1 // addha za1.s, p1/M, p1/M, z23.s\n"
+ ".inst 0xc09026e2 // addha za2.s, p1/M, p1/M, z23.s\n"
+ ".inst 0xc09026e3 // addha za3.s, p1/M, p1/M, z23.s\n"
+ "4:" // Prepare accumulators: Test for last block
+ "mov x20, x10\n"
+ "mov x21, x11\n"
+ "incw x20\n"
+ "incw x21, ALL, MUL #4\n"
+ "cmp x20, x9\n"
+ "mov x20, x16\n"
+ "csel x21, x11, x21, LT\n"
+ "bfm x16, XZR, #0x0, #0x0 // bfc x16, #0x0, #0x1\n"
+ "cmp x21, x13\n"
+ "csel x16, x20, x16, LT\n"
+ "5:" // Prepare accumulators: End
+ "ldr x20, [%x[args], %[offsetof_K]]\n"
+ "ldr x23, [%x[args], %[offsetof_B]]\n"
+ "ldr x22, [%x[args], %[offsetof_kstride_bytes]]\n"
+ "add x20, x20, #0x3\n"
+ "lsr x20, x20, #0x2\n"
+ "lsr x21, x20, #0x2\n"
+ "madd x23, x10, x22, x23\n" // bptr = B + n * kstride_bytes
+ "and x20, x20, #0x3\n"
+ "cbz x21, 8f\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa0408378 // ld1b { z24.b-z27.b }, pn8.b/Z, [x27]\n"
+ "ld1b { z4.b }, p1/Z, [x23]\n"
+ ".inst 0xa0418374 // ld1b { z20.b-z23.b }, pn8.b/Z, [x27, #0x4, MUL VL]\n"
+ "ld1b { z2.b }, p1/Z, [x23, #1, MUL VL]\n"
+ ".inst 0xa042836c // ld1b { z12.b-z15.b }, pn8.b/Z, [x27, #0x8, MUL VL]\n"
+ "ld1b { z11.b }, p1/Z, [x23, #2, MUL VL]\n"
+ ".inst 0xa0438370 // ld1b { z16.b-z19.b }, pn8.b/Z, [x27, #0xc, MUL VL]\n"
+ "addvl x27, x27, #16\n"
+ "ld1b { z28.b }, p1/Z, [x23, #3, MUL VL]\n"
+ "addvl x23, x23, #4\n"
+ "ble 7f\n"
+ "6:" // K loop
+ ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n"
+ "subs x21, x21, #0x1\n"
+ ".inst 0xa0842721 // smopa za1.s, p1/M, p1/M, z25.b, z4.b\n"
+ ".inst 0xa0842742 // smopa za2.s, p1/M, p1/M, z26.b, z4.b\n"
+ ".inst 0xa0842763 // smopa za3.s, p1/M, p1/M, z27.b, z4.b\n"
+ ".inst 0xa0408378 // ld1b { z24.b-z27.b }, pn8.b/Z, [x27]\n"
+ ".inst 0xa0822680 // smopa za0.s, p1/M, p1/M, z20.b, z2.b\n"
+ "ld1b { z4.b }, p1/Z, [x23]\n"
+ ".inst 0xa08226a1 // smopa za1.s, p1/M, p1/M, z21.b, z2.b\n"
+ ".inst 0xa08226c2 // smopa za2.s, p1/M, p1/M, z22.b, z2.b\n"
+ ".inst 0xa08226e3 // smopa za3.s, p1/M, p1/M, z23.b, z2.b\n"
+ ".inst 0xa0418374 // ld1b { z20.b-z23.b }, pn8.b/Z, [x27, #0x4, MUL VL]\n"
+ ".inst 0xa08b2580 // smopa za0.s, p1/M, p1/M, z12.b, z11.b\n"
+ "ld1b { z2.b }, p1/Z, [x23, #1, MUL VL]\n"
+ ".inst 0xa08b25a1 // smopa za1.s, p1/M, p1/M, z13.b, z11.b\n"
+ ".inst 0xa08b25c2 // smopa za2.s, p1/M, p1/M, z14.b, z11.b\n"
+ ".inst 0xa08b25e3 // smopa za3.s, p1/M, p1/M, z15.b, z11.b\n"
+ ".inst 0xa042836c // ld1b { z12.b-z15.b }, pn8.b/Z, [x27, #0x8, MUL VL]\n"
+ "ld1b { z11.b }, p1/Z, [x23, #2, MUL VL]\n"
+ ".inst 0xa09c2600 // smopa za0.s, p1/M, p1/M, z16.b, z28.b\n"
+ ".inst 0xa09c2621 // smopa za1.s, p1/M, p1/M, z17.b, z28.b\n"
+ ".inst 0xa09c2642 // smopa za2.s, p1/M, p1/M, z18.b, z28.b\n"
+ ".inst 0xa09c2663 // smopa za3.s, p1/M, p1/M, z19.b, z28.b\n"
+ ".inst 0xa0438370 // ld1b { z16.b-z19.b }, pn8.b/Z, [x27, #0xc, MUL VL]\n"
+ "addvl x27, x27, #16\n"
+ "ld1b { z28.b }, p1/Z, [x23, #3, MUL VL]\n"
+ "addvl x23, x23, #4\n"
+ "bgt 6b\n"
+ "7:" // K loop tail
+ ".inst 0xa0842700 // smopa za0.s, p1/M, p1/M, z24.b, z4.b\n"
+ ".inst 0xa0842721 // smopa za1.s, p1/M, p1/M, z25.b, z4.b\n"
+ ".inst 0xa0842742 // smopa za2.s, p1/M, p1/M, z26.b, z4.b\n"
+ ".inst 0xa0842763 // smopa za3.s, p1/M, p1/M, z27.b, z4.b\n"
+ ".inst 0xa0822680 // smopa za0.s, p1/M, p1/M, z20.b, z2.b\n"
+ ".inst 0xa08226a1 // smopa za1.s, p1/M, p1/M, z21.b, z2.b\n"
+ ".inst 0xa08226c2 // smopa za2.s, p1/M, p1/M, z22.b, z2.b\n"
+ ".inst 0xa08226e3 // smopa za3.s, p1/M, p1/M, z23.b, z2.b\n"
+ ".inst 0xa08b2580 // smopa za0.s, p1/M, p1/M, z12.b, z11.b\n"
+ ".inst 0xa08b25a1 // smopa za1.s, p1/M, p1/M, z13.b, z11.b\n"
+ ".inst 0xa08b25c2 // smopa za2.s, p1/M, p1/M, z14.b, z11.b\n"
+ ".inst 0xa08b25e3 // smopa za3.s, p1/M, p1/M, z15.b, z11.b\n"
+ ".inst 0xa09c2600 // smopa za0.s, p1/M, p1/M, z16.b, z28.b\n"
+ ".inst 0xa09c2621 // smopa za1.s, p1/M, p1/M, z17.b, z28.b\n"
+ ".inst 0xa09c2642 // smopa za2.s, p1/M, p1/M, z18.b, z28.b\n"
+ ".inst 0xa09c2663 // smopa za3.s, p1/M, p1/M, z19.b, z28.b\n"
+ "8:" // K oddments
+ "cbz x20, 10f\n"
+ "9:" // K oddments: Loop
+ ".inst 0xa1408373 // ld1b { z19.b, z23.b, z27.b, z31.b }, pn8.b/Z, [x27]\n"
+ "subs x20, x20, #0x1\n"
+ "addvl x27, x27, #4\n"
+ "ld1b { z16.b }, p1/Z, [x23]\n"
+ "addvl x23, x23, #1\n"
+ ".inst 0xa0902660 // smopa za0.s, p1/M, p1/M, z19.b, z16.b\n"
+ ".inst 0xa09026e1 // smopa za1.s, p1/M, p1/M, z23.b, z16.b\n"
+ ".inst 0xa0902762 // smopa za2.s, p1/M, p1/M, z27.b, z16.b\n"
+ ".inst 0xa09027e3 // smopa za3.s, p1/M, p1/M, z31.b, z16.b\n"
+ "bgt 9b\n"
+ "10:" // K oddments: End
+ "tbz x16, #1, 14f\n"
+ "tbz x16, #0, 12f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "11:" // Store to partial result buffer: Store and refill: Loop
+ ".inst 0xa040c1e8 // ld1w { z8.s-z11.s }, pn8.b/Z, [x15]\n"
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ ".inst 0xc0860424 // mova { z4.s-z7.s }, za1h.s[x12]\n"
+ ".inst 0xa041c1ec // ld1w { z12.s-z15.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xc0860458 // mova { z24.s-z27.s }, za2h.s[x12]\n"
+ ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n"
+ ".inst 0xa042c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c1f4 // ld1w { z20.s-z23.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840500 // mova za0h.s[x12], { z8.s-z11.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840581 // mova za1h.s[x12], { z12.s-z15.s }\n"
+ ".inst 0xa060c1c0 // st1w { z0.s-z3.s }, pn8.b, [x14]\n"
+ ".inst 0xc0840782 // mova za2h.s[x12], { z28.s-z31.s }\n"
+ ".inst 0xa061c1c4 // st1w { z4.s-z7.s }, pn8.b, [x14, #0x4, MUL VL]\n"
+ ".inst 0xc0840683 // mova za3h.s[x12], { z20.s-z23.s }\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa062c1d8 // st1w { z24.s-z27.s }, pn8.b, [x14, #0x8, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, pn8.b, [x14, #0xc, MUL VL]\n"
+ "addvl x14, x14, #16\n"
+ "blt 11b\n"
+ "b 30f\n"
+ "12:" // Store to partial result buffer: Store only
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "13:" // Store to partial result buffer: Store only: Loop
+ ".inst 0xc0860408 // mova { z8.s-z11.s }, za0h.s[x12]\n"
+ ".inst 0xc086042c // mova { z12.s-z15.s }, za1h.s[x12]\n"
+ ".inst 0xc0860454 // mova { z20.s-z23.s }, za2h.s[x12]\n"
+ ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n"
+ ".inst 0xa060c1c8 // st1w { z8.s-z11.s }, pn8.b, [x14]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xa061c1cc // st1w { z12.s-z15.s }, pn8.b, [x14, #0x4, MUL VL]\n"
+ "cmp x12, x20\n"
+ ".inst 0xa062c1d4 // st1w { z20.s-z23.s }, pn8.b, [x14, #0x8, MUL VL]\n"
+ ".inst 0xa063c1d0 // st1w { z16.s-z19.s }, pn8.b, [x14, #0xc, MUL VL]\n"
+ "addvl x14, x14, #16\n"
+ "blt 13b\n"
+ "b 30f\n"
+ "14:" // Store to output array
+ "ldr x26, [%x[args], %[offsetof_C]]\n"
+ "sub x25, x13, x11\n"
+ "ld1rw { z23.s }, p1/Z, [%x[dq], %[offset_DequantizeFloat_scale]]\n"
+ "fmov z22.s, #0x0\n"
+ "ldr x24, [%x[args], %[offsetof_ldcb]]\n"
+ "ldr x20, [%x[args], %[offsetof_late_bias]]\n"
+ "add x26, x26, x10, LSL #2\n" // C += n
+ "madd x26, x11, x24, x26\n" // C += m * ldc
+ "cbz x20, 15f\n"
+ "add x20, x20, x10, LSL #2\n"
+ "ld1w { z22.s }, p0/Z, [x20]\n"
+ "15:" // Store to output array: no late bias
+ "cntw x23\n"
+ "ld1rw { z21.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_min]]\n"
+ "mov x12, #0x0\n"
+ "cmp x25, x23\n"
+ "ld1rw { z20.s }, p1/Z, [%x[args], %[offsetof_KernelArgs_max]]\n"
+ "csel x22, x25, x23, LT\n"
+ "lsr x21, x22, #0x2\n"
+ "and x20, x22, #0x3\n"
+ "cbz x21, 17f\n"
+ "16:" // Store to output array: Accumulator row 0 loop
+ ".inst 0xc0860400 // mova { z0.s-z3.s }, za0h.s[x12]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xc132e000 // scvtf { z0.s-z3.s }, { z0.s-z3.s }\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z0.s, p1/M, z23.s, z22.s\n"
+ "fmad z1.s, p1/M, z23.s, z22.s\n"
+ "fmad z2.s, p1/M, z23.s, z22.s\n"
+ "fmad z3.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4caa0 // fclamp { z0.s-z3.s }, z21.s, z20.s\n"
+ "st1w { z0.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z1.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z2.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z3.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 16b\n"
+ "17:" // Store to output array: Accumulator row 0 oddments
+ "cbz x20, 18f\n"
+ ".inst 0xc0860410 // mova { z16.s-z19.s }, za0h.s[x12]\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ "fmad z16.s, p1/M, z23.s, z22.s\n"
+ "fmad z17.s, p1/M, z23.s, z22.s\n"
+ "fmad z18.s, p1/M, z23.s, z22.s\n"
+ "fmad z19.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n"
+ "st1w { z16.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 18f\n"
+ "subs x20, x20, #0x1\n"
+ "st1w { z17.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 18f\n"
+ "st1w { z18.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "18:" // Store to output array: Accumulator row 0 oddments: End
+ "subs x25, x25, x22\n"
+ "beq 28f\n"
+ "cmp x25, x23\n"
+ "mov x12, #0x0\n"
+ "csel x22, x25, x23, LT\n"
+ "lsr x21, x22, #0x2\n"
+ "and x20, x22, #0x3\n"
+ "cbz x21, 20f\n"
+ "19:" // Store to output array: Accumulator row 1 loop
+ ".inst 0xc0860430 // mova { z16.s-z19.s }, za1h.s[x12]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z16.s, p1/M, z23.s, z22.s\n"
+ "fmad z17.s, p1/M, z23.s, z22.s\n"
+ "fmad z18.s, p1/M, z23.s, z22.s\n"
+ "fmad z19.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n"
+ "st1w { z16.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z17.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z18.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z19.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 19b\n"
+ "20:" // Store to output array: Accumulator row 1 oddments
+ "cbz x20, 21f\n"
+ ".inst 0xc086043c // mova { z28.s-z31.s }, za1h.s[x12]\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xc132e39c // scvtf { z28.s-z31.s }, { z28.s-z31.s }\n"
+ "fmad z28.s, p1/M, z23.s, z22.s\n"
+ "fmad z29.s, p1/M, z23.s, z22.s\n"
+ "fmad z30.s, p1/M, z23.s, z22.s\n"
+ "fmad z31.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cabc // fclamp { z28.s-z31.s }, z21.s, z20.s\n"
+ "st1w { z28.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 21f\n"
+ "subs x20, x20, #0x1\n"
+ "st1w { z29.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 21f\n"
+ "st1w { z30.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "21:" // Store to output array: Accumulator row 1 oddments: End
+ "subs x25, x25, x22\n"
+ "beq 28f\n"
+ "cmp x25, x23\n"
+ "mov x12, #0x0\n"
+ "csel x22, x25, x23, LT\n"
+ "lsr x21, x22, #0x2\n"
+ "and x20, x22, #0x3\n"
+ "cbz x21, 23f\n"
+ "22:" // Store to output array: Accumulator row 2 loop
+ ".inst 0xc086044c // mova { z12.s-z15.s }, za2h.s[x12]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xc132e18c // scvtf { z12.s-z15.s }, { z12.s-z15.s }\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z12.s, p1/M, z23.s, z22.s\n"
+ "fmad z13.s, p1/M, z23.s, z22.s\n"
+ "fmad z14.s, p1/M, z23.s, z22.s\n"
+ "fmad z15.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4caac // fclamp { z12.s-z15.s }, z21.s, z20.s\n"
+ "st1w { z12.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z13.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z14.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z15.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 22b\n"
+ "23:" // Store to output array: Accumulator row 2 oddments
+ "cbz x20, 24f\n"
+ ".inst 0xc0860450 // mova { z16.s-z19.s }, za2h.s[x12]\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ "fmad z16.s, p1/M, z23.s, z22.s\n"
+ "fmad z17.s, p1/M, z23.s, z22.s\n"
+ "fmad z18.s, p1/M, z23.s, z22.s\n"
+ "fmad z19.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n"
+ "st1w { z16.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 24f\n"
+ "subs x20, x20, #0x1\n"
+ "st1w { z17.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 24f\n"
+ "st1w { z18.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "24:" // Store to output array: Accumulator row 2 oddments: End
+ "subs x25, x25, x22\n"
+ "beq 28f\n"
+ "cmp x25, x23\n"
+ "mov x12, #0x0\n"
+ "csel x20, x25, x23, LT\n"
+ "lsr x21, x20, #0x2\n"
+ "and x20, x20, #0x3\n"
+ "cbz x21, 26f\n"
+ "25:" // Store to output array: Accumulator row 3 loop
+ ".inst 0xc0860478 // mova { z24.s-z27.s }, za3h.s[x12]\n"
+ "add x12, x12, #0x4\n"
+ ".inst 0xc132e318 // scvtf { z24.s-z27.s }, { z24.s-z27.s }\n"
+ "cmp x12, x21, LSL #2\n"
+ "fmad z24.s, p1/M, z23.s, z22.s\n"
+ "fmad z25.s, p1/M, z23.s, z22.s\n"
+ "fmad z26.s, p1/M, z23.s, z22.s\n"
+ "fmad z27.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab8 // fclamp { z24.s-z27.s }, z21.s, z20.s\n"
+ "st1w { z24.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z25.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z26.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "st1w { z27.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "blt 25b\n"
+ "26:" // Store to output array: Accumulator row 3 oddments
+ "cbz x20, 27f\n"
+ ".inst 0xc0860470 // mova { z16.s-z19.s }, za3h.s[x12]\n"
+ "subs x20, x20, #0x1\n"
+ ".inst 0xc132e210 // scvtf { z16.s-z19.s }, { z16.s-z19.s }\n"
+ "fmad z16.s, p1/M, z23.s, z22.s\n"
+ "fmad z17.s, p1/M, z23.s, z22.s\n"
+ "fmad z18.s, p1/M, z23.s, z22.s\n"
+ "fmad z19.s, p1/M, z23.s, z22.s\n"
+ ".inst 0xc1b4cab0 // fclamp { z16.s-z19.s }, z21.s, z20.s\n"
+ "st1w { z16.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 27f\n"
+ "subs x20, x20, #0x1\n"
+ "st1w { z17.s }, p0, [x26]\n"
+ "add x26, x26, x24\n"
+ "beq 27f\n"
+ "st1w { z18.s }, p0, [x26]\n"
+ "27:" // Store to output array: Accumulator row 3 oddments: End
+ "28:" // Store to output array: End
+ "tbz x16, #0, 30f\n"
+ "mov x12, #0x0\n"
+ "cntw x20\n"
+ "29:" // Store to output array: Refill accumulators: Loop
+ ".inst 0xa040c1fc // ld1w { z28.s-z31.s }, pn8.b/Z, [x15]\n"
+ ".inst 0xa041c1e0 // ld1w { z0.s-z3.s }, pn8.b/Z, [x15, #0x4, MUL VL]\n"
+ ".inst 0xa042c1ec // ld1w { z12.s-z15.s }, pn8.b/Z, [x15, #0x8, MUL VL]\n"
+ ".inst 0xa043c1e4 // ld1w { z4.s-z7.s }, pn8.b/Z, [x15, #0xc, MUL VL]\n"
+ ".inst 0xc0840780 // mova za0h.s[x12], { z28.s-z31.s }\n"
+ "addvl x15, x15, #16\n"
+ ".inst 0xc0840401 // mova za1h.s[x12], { z0.s-z3.s }\n"
+ ".inst 0xc0840582 // mova za2h.s[x12], { z12.s-z15.s }\n"
+ ".inst 0xc0840483 // mova za3h.s[x12], { z4.s-z7.s }\n"
+ "add x12, x12, #0x4\n"
+ "cmp x12, x20\n"
+ "blt 29b\n"
+ "30:" // End block
+ "incw x10\n"
+ "cmp x10, x9\n"
+ "blt 3b\n"
+ "incw x11, ALL, MUL #4\n"
+ "mov x10, #0x0\n"
+ "cmp x11, x13\n"
+ "mov x28, x27\n"
+ "blt 3b\n"
+ ".inst 0xd503467f // SMSTOP\n"
+ :
+ : [args] "r" (&args), [dq] "r" (&dq), [offset_DequantizeFloat_scale] "I" (offsetof(DequantizeFloat, scale)), [offsetof_A] "I" (offsetof(KernelArgs, A)), [offsetof_B] "I" (offsetof(KernelArgs, B)), [offsetof_C] "I" (offsetof(KernelArgs, C)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_KernelArgs_max] "I" (offsetof(KernelArgs, max)), [offsetof_KernelArgs_min] "I" (offsetof(KernelArgs, min)), [offsetof_M] "I" (offsetof(KernelArgs, M)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_accumulator_buffer] "I" (offsetof(KernelArgs, accumulator_buffer)), [offsetof_bias] "I" (offsetof(KernelArgs, bias)), [offsetof_flags] "I" (offsetof(KernelArgs, flags)), [offsetof_kstride_bytes] "I" (offsetof(KernelArgs, kstride_bytes)), [offsetof_late_bias] "I" (offsetof(KernelArgs, late_bias)), [offsetof_ldcb] "I" (offsetof(KernelArgs, ldcb))
+ : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+ );
+}
+
+} // namespace arm_gemm
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp
index 887d78e1de..23f686a902 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,10 @@ public:
{
if (std::is_same<T, float>::value) {
switch (ci->get_cpu_model()) {
+ case CPUModel::V1:
+ return { 28.74 };
default:
- return { 32.35 };
+ return { 15.27 };
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp
index d0ef531c33..1fe5f48da6 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -88,8 +88,10 @@ public:
if (std::is_same<T, float>::value) {
switch (ci->get_cpu_model()) {
- default:
- return { 39.66, 5.18, 4.37 };
+ case CPUModel::V1:
+ return { 53.48, 4.23, 6.53 };
+ default:
+ return { 29.07, 2.76, 5.39 };
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp
index 111d01ed3a..6da9f4be0e 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.cpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -1142,6 +1142,64 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h
template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
+void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height,
+ const int32_t* in_ptr, unsigned int in_stride, float *out_ptr, unsigned int out_stride,
+ const float* bias_ptr, bool accumulate, const Activation &act)
+{
+ const float32x4_t vscale = vdupq_n_f32(qp.scale);
+ float maxval = std::numeric_limits<float>::infinity();
+ float minval = -std::numeric_limits<float>::infinity();
+
+ switch(act.type) {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ maxval = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ minval = 0;
+ break;
+ }
+
+ const float32x4_t vmin = vdupq_n_f32(minval);
+ const float32x4_t vmax = vdupq_n_f32(maxval);
+
+ for(unsigned int row=0; row<height; row++) {
+ auto row_in_ptr = in_ptr + (row * in_stride);
+ auto row_out_ptr = out_ptr + (row * out_stride);
+ unsigned int col=0;
+ if (width >= 4) {
+ for(; col <= (width - 4); col+= 4) {
+ const int32x4_t vin = vld1q_s32(row_in_ptr + col);
+ float32x4_t vdeq = vmulq_f32(vcvtq_f32_s32(vin), vscale);
+ if(bias_ptr) {
+ const float32x4_t bin = vld1q_f32(bias_ptr + col);
+ vdeq = vaddq_f32(vdeq, bin);
+ }
+ if(accumulate) {
+ vdeq = vaddq_f32(vdeq, vld1q_f32(row_out_ptr + col));
+ }
+ vdeq = vminq_f32(vmaxq_f32(vdeq, vmin), vmax);
+ vst1q_f32(reinterpret_cast<float *>(row_out_ptr + col), vdeq);
+ }
+ }
+ // left-over elements
+ for(; col < width; ++col) {
+ const int32_t val = *(row_in_ptr + col);
+ float res = static_cast<float>(val * qp.scale);
+ if(bias_ptr) {
+ res += static_cast<float>(*(bias_ptr + col));
+ }
+ if(accumulate) {
+ res += *(row_out_ptr + col);
+ }
+ res = std::min(std::max(res, minval), maxval);
+ *(row_out_ptr + col) = res;
+ }
+ }
+}
+
} // namespace arm_gemm
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.hpp b/src/core/NEON/kernels/arm_gemm/quantized.hpp
index 31dd65b397..bc64fd967b 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.hpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019, 2023 Arm Limited.
+ * Copyright (c) 2019, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,4 +45,8 @@ template<typename T>
void row_sums_indirect(size_t num_strings, const unsigned int *string_lengths, IndirectInputArg<T> A_arg,
size_t M, int32_t *output_ptr, const Requantize32 *qp);
+void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height,
+ const int32_t* input, unsigned int in_stride, float *output, unsigned int out_stride,
+ const float *row_bias, bool not_first_pass, const Activation &act);
+
} // namespace arm_gemm
diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h
index 50b3fc1284..a74316b486 100644
--- a/src/core/common/Registrars.h
+++ b/src/core/common/Registrars.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020-2023 Arm Limited.
+ * Copyright (c) 2020-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,6 +38,12 @@
#define REGISTER_FP16_SVE2(func_name) nullptr
#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
+#if defined(ARM_COMPUTE_ENABLE_SME2)
+#define REGISTER_FP16_SME2(func_name) &(func_name)
+#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */
+#define REGISTER_FP16_SME2(func_name) nullptr
+#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */
+
#if defined(ARM_COMPUTE_ENABLE_NEON)
#define REGISTER_FP16_NEON(func_name) &(func_name)
#else /* !defined(ARM_COMPUTE_ENABLE_NEON) */
@@ -48,6 +54,7 @@
#define REGISTER_FP16_NEON(func_name) nullptr
#define REGISTER_FP16_SVE(func_name) nullptr
#define REGISTER_FP16_SVE2(func_name) nullptr
+#define REGISTER_FP16_SME2(func_name) nullptr
#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
#if defined(ENABLE_FP32_KERNELS)
@@ -64,6 +71,12 @@
#define REGISTER_FP32_SVE2(func_name) nullptr
#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
+#if defined(ARM_COMPUTE_ENABLE_SME2)
+#define REGISTER_FP32_SME2(func_name) &(func_name)
+#else /* !defined(ARM_COMPUTE_ENABLE_SME2) */
+#define REGISTER_FP32_SME2(func_name) nullptr
+#endif /* defined(ARM_COMPUTE_ENABLE_SME2) */
+
#if defined(ARM_COMPUTE_ENABLE_NEON)
#define REGISTER_FP32_NEON(func_name) &(func_name)
#else /* !defined(ARM_COMPUTE_ENABLE_NEON) */
@@ -74,6 +87,7 @@
#define REGISTER_FP32_NEON(func_name) nullptr
#define REGISTER_FP32_SVE(func_name) nullptr
#define REGISTER_FP32_SVE2(func_name) nullptr
+#define REGISTER_FP32_SME2(func_name) nullptr
#endif /* defined(ENABLE_FP32_KERNELS) */
#if defined(ENABLE_QASYMM8_SIGNED_KERNELS)
diff --git a/src/core/utils/helpers/tensor_transform.cpp b/src/core/utils/helpers/tensor_transform.cpp
index 19d0badd74..212cfdabaa 100644
--- a/src/core/utils/helpers/tensor_transform.cpp
+++ b/src/core/utils/helpers/tensor_transform.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -117,7 +117,10 @@ int calculate_end_on_index(TensorShape input_shape,
}
// Final clamp
- stop = (stride > 0) ? utility::clamp(stop, 0, dim_size) : utility::clamp(stop, -1, dim_size - 1);
+ if (stride > 0)
+ stop = utility::clamp(stop, 0, dim_size);
+ else
+ stop = utility::clamp(stop, -1, dim_size - 1);
return stop;
}
diff --git a/src/core/utils/quantization/AsymmHelpers.cpp b/src/core/utils/quantization/AsymmHelpers.cpp
index f66d3e7064..f8b74a985d 100644
--- a/src/core/utils/quantization/AsymmHelpers.cpp
+++ b/src/core/utils/quantization/AsymmHelpers.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -122,13 +122,13 @@ arm_compute::Status calculate_quantized_multipliers(const QuantizationInfo &iq_
ARM_COMPUTE_RETURN_ERROR_ON(iq_info.scale().empty());
ARM_COMPUTE_RETURN_ERROR_ON(wq_info.scale().empty());
ARM_COMPUTE_RETURN_ERROR_ON(oq_info.scale().empty());
-
- const unsigned int size = wq_info.scale().size();
-
- auto &quant_multipliers = stage_info.gemmlowp_multipliers;
- auto &quant_shifts = stage_info.gemmlowp_shifts;
- quant_multipliers.resize(size);
- quant_shifts.resize(size);
+ constexpr unsigned int padding_elems = 32; // assembly kernels assume the shifts and multipliers buffers are padded
+ const unsigned int size = wq_info.scale().size();
+ const size_t padded_size = (size == 1) ? 1 : size + padding_elems;
+ auto &quant_multipliers = stage_info.gemmlowp_multipliers;
+ auto &quant_shifts = stage_info.gemmlowp_shifts;
+ quant_multipliers.resize(padded_size);
+ quant_shifts.resize(padded_size);
const auto &w_scales = wq_info.scale();
const float i_scale = iq_info.scale().at(0);
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
index e290783021..2a76a5958d 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,17 +51,19 @@ Status validate_arguments(const ITensorInfo *mm_result,
int32_t a_offset,
int32_t b_offset)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32, DataType::F32);
- // If a_offset == 0, vector_sum_col can be a nullptr
- if (a_offset != 0)
+ // We run if the offset is nonzero or a sum col has been provided, we need
+ // the second option in case the QuantizationInfo is dynamic
+ if (a_offset != 0 || vector_sum_col != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != mm_result->dimension(0));
}
- // If b_offset == 0, vector_sum_row can be a nullptr
- if (b_offset != 0)
+ // We run if the offset is nonzero or a sum row has been provided, we need
+ // the second option in case the QuantizationInfo is dynamic
+ if (b_offset != 0 || vector_sum_row != nullptr)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
@@ -86,7 +88,7 @@ Status validate_arguments(const ITensorInfo *mm_result,
ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[output_batch_idx],
"mm_result tensor must have the same number of batches of output tensor");
- if (a_offset != 0)
+ if (vector_sum_col != nullptr)
{
TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape();
vector_sum_col_shape.collapse_from(1);
@@ -102,6 +104,275 @@ Status validate_arguments(const ITensorInfo *mm_result,
return Status{};
}
+void run_offset_contribution_float(const Window &window,
+ ITensor *mm_result,
+ const ITensor *vector_sum_col,
+ const ITensor *vector_sum_row,
+ int32_t a_offset,
+ int32_t b_offset,
+ int32_t k_offset,
+ float scale,
+ bool slide_vector_sum_col,
+ bool is_gemm3d)
+{
+ Window collapsed_window = window.collapse_if_possible(window, Window::DimZ);
+ collapsed_window.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ const int height_input = is_gemm3d ? mm_result->info()->dimension(1) : 0;
+ const int depth_input = is_gemm3d ? mm_result->info()->dimension(2) : 1;
+
+ const int window_start_x = window.x().start();
+ const int window_end_x = window.x().end();
+ const int window_step_x = 16;
+
+ // if vector_sum_col is nullptr then stride_y is 0, else get stride_y
+ const size_t sum_col_stride_y = (vector_sum_col != nullptr) ? (vector_sum_col->info()->strides_in_bytes().y()) : 0;
+ Iterator mm_result_it(mm_result, collapsed_window);
+
+ if ((a_offset != 0) && (b_offset != 0) && (vector_sum_col != nullptr) && (vector_sum_row != nullptr)) // true, true
+ {
+ // Set window for vector_sum_col
+ Window win_vector_sum_col(collapsed_window);
+ win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ // Set window for vector_sum_row
+ Window win_vector_sum_row(collapsed_window);
+ win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
+ Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
+
+ const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
+
+ // Offset in case vector_sum_col is batched
+ const int vector_sum_col_batch_offset =
+ slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ const size_t batch_offset_col = batch_id * (sum_col_stride_y);
+ auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col +
+ batch_id * vector_sum_col_batch_offset);
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ // Compute the leftover term due to b_offset.
+ int32_t b_offset_term_s32 =
+ *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
+ id.y() + (id.z() % depth_input) * height_input);
+ b_offset_term_s32 *= b_offset;
+
+ const int32x4_t b_offset_term_s32_vec = vdupq_n_s32(b_offset_term_s32);
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32x4x4_t a_offset_term_s32 = {
+ {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4),
+ vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}};
+
+ a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
+ a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
+ a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
+ a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
+
+ // Add a_offset_term_s32 and b_offset_term_s32
+ int32x4x4_t offset_term_s32 = {
+ {vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset), vdupq_n_s32(k_offset)}};
+
+ offset_term_s32.val[0] =
+ vaddq_s32(offset_term_s32.val[0], vaddq_s32(a_offset_term_s32.val[0], b_offset_term_s32_vec));
+ offset_term_s32.val[1] =
+ vaddq_s32(offset_term_s32.val[1], vaddq_s32(a_offset_term_s32.val[1], b_offset_term_s32_vec));
+ offset_term_s32.val[2] =
+ vaddq_s32(offset_term_s32.val[2], vaddq_s32(a_offset_term_s32.val[2], b_offset_term_s32_vec));
+ offset_term_s32.val[3] =
+ vaddq_s32(offset_term_s32.val[3], vaddq_s32(a_offset_term_s32.val[3], b_offset_term_s32_vec));
+
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Convert and scale the S32 offsets to match the already scaled GEMM results
+ float32x4x4_t offset_terms_scaled = {{
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[0]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[1]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[2]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(offset_term_s32.val[3]), scale),
+ }};
+
+ // Add the offset terms to the GEMM result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], offset_terms_scaled.val[0]);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], offset_terms_scaled.val[1]);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], offset_terms_scaled.val[2]);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], offset_terms_scaled.val[3]);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
+
+ a_offset_term_s32 *= a_offset;
+
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += (k_offset + a_offset_term_s32 + b_offset_term_s32) * scale;
+ }
+ },
+ vector_sum_col_it, vector_sum_row_it, mm_result_it);
+ }
+ else if ((a_offset == 0) && (b_offset != 0) && (vector_sum_row != nullptr)) // false, true
+ {
+ ARM_COMPUTE_ERROR_ON_NULLPTR(vector_sum_row);
+
+ // Set window for vector_sum_row
+ Window win_vector_sum_row(collapsed_window);
+ win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_row.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_row_it(vector_sum_row, win_vector_sum_row);
+
+ const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ // Compute the leftover term due to b_offset.
+ int32_t row_sum =
+ *(reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y) +
+ id.y() + (id.z() % depth_input) * height_input);
+ float scaled_b_offset_term_f32 = row_sum * b_offset * scale;
+
+ const float32x4_t b_offset_term_f32_vec = vdupq_n_f32(scaled_b_offset_term_f32);
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Add the offset terms to GEMM's result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], b_offset_term_f32_vec);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], b_offset_term_f32_vec);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], b_offset_term_f32_vec);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], b_offset_term_f32_vec);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += scaled_b_offset_term_f32;
+ }
+ },
+ vector_sum_row_it, mm_result_it);
+ }
+ else if ((a_offset != 0) && (b_offset == 0) && (vector_sum_col != nullptr)) // true, false
+ {
+ // Set window for vector_sum_col
+ Window win_vector_sum_col(collapsed_window);
+ win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+ win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator vector_sum_col_it(vector_sum_col, win_vector_sum_col);
+
+ // Offset in case vector_sum_col is batched
+ const int vector_sum_col_batch_offset =
+ slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ const int batch_id = id.z() / depth_input;
+ const size_t batch_offset_col =
+ batch_id *
+ (sum_col_stride_y); // Value to offset vector_sum_col_ptr to allow for iteration of y values in tensor
+ auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_offset_col +
+ batch_id * vector_sum_col_batch_offset);
+ auto mm_result_ptr = reinterpret_cast<float *>(mm_result_it.ptr());
+
+ int x = window_start_x;
+ for (; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ // Compute the leftover term due to a_offset.
+ int32x4x4_t a_offset_term_s32 = {
+ {vld1q_s32(vector_sum_col_ptr + x + 0), vld1q_s32(vector_sum_col_ptr + x + 4),
+ vld1q_s32(vector_sum_col_ptr + x + 8), vld1q_s32(vector_sum_col_ptr + x + 12)}};
+
+ a_offset_term_s32.val[0] = vmulq_n_s32(a_offset_term_s32.val[0], a_offset);
+ a_offset_term_s32.val[1] = vmulq_n_s32(a_offset_term_s32.val[1], a_offset);
+ a_offset_term_s32.val[2] = vmulq_n_s32(a_offset_term_s32.val[2], a_offset);
+ a_offset_term_s32.val[3] = vmulq_n_s32(a_offset_term_s32.val[3], a_offset);
+
+ float32x4x4_t a_offset_term_scaled = {{
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[0]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[1]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[2]), scale),
+ vmulq_n_f32(vcvtq_f32_s32(a_offset_term_s32.val[3]), scale),
+ }};
+
+ float32x4x4_t in_f32 = {{vld1q_f32(mm_result_ptr + x + 0), vld1q_f32(mm_result_ptr + x + 4),
+ vld1q_f32(mm_result_ptr + x + 8), vld1q_f32(mm_result_ptr + x + 12)}};
+
+ // Add the offset terms to GEMM's result
+ in_f32.val[0] = vaddq_f32(in_f32.val[0], a_offset_term_scaled.val[0]);
+ in_f32.val[1] = vaddq_f32(in_f32.val[1], a_offset_term_scaled.val[1]);
+ in_f32.val[2] = vaddq_f32(in_f32.val[2], a_offset_term_scaled.val[2]);
+ in_f32.val[3] = vaddq_f32(in_f32.val[3], a_offset_term_scaled.val[3]);
+
+ // Store the result with the offset contribution
+ vst1q_f32(mm_result_ptr + x + 0, in_f32.val[0]);
+ vst1q_f32(mm_result_ptr + x + 4, in_f32.val[1]);
+ vst1q_f32(mm_result_ptr + x + 8, in_f32.val[2]);
+ vst1q_f32(mm_result_ptr + x + 12, in_f32.val[3]);
+ }
+
+ // Left-overs loop
+ for (; x < window_end_x; ++x)
+ {
+ // Compute the leftover term due to a_offset.
+ const int32_t a_offset_term_s32 = *(vector_sum_col_ptr + x);
+
+ // Add the offset terms to GEMM's result
+ // Store the result with the offset contribution
+ mm_result_ptr[x] += a_offset_term_s32 * a_offset * scale;
+ }
+ },
+ vector_sum_col_it, mm_result_it);
+ }
+ else // false, false
+ {
+ // No offset contribution from matrix A and matrix B
+ return;
+ }
+}
+
void run_offset_contribution(const Window &window,
ITensor *mm_result,
const ITensor *vector_sum_col,
@@ -361,7 +632,8 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
ITensorInfo *vector_sum_row,
int32_t k,
int32_t a_offset,
- int32_t b_offset)
+ int32_t b_offset,
+ float scale)
{
// Perform validate step
ARM_COMPUTE_UNUSED(vector_sum_row);
@@ -370,10 +642,11 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
_a_offset = a_offset;
_b_offset = b_offset;
- _k_offset = a_offset * b_offset * k;
+ _k = k;
- // If a_offset == 0, vector_sum_col can be a nullptr
- if (a_offset != 0)
+ _scale = scale;
+
+ if (vector_sum_col != nullptr)
{
// Check if vector_sum_col_shape should be slidden or not
// Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1
@@ -386,6 +659,21 @@ void CpuGemmLowpOffsetContributionKernel::configure(ITensorInfo *mm_result,
ICpuKernel::configure(win);
}
+void CpuGemmLowpOffsetContributionKernel::set_a_offset(int32_t a_offset)
+{
+ _a_offset = a_offset;
+}
+
+void CpuGemmLowpOffsetContributionKernel::set_b_offset(int32_t b_offset)
+{
+ _b_offset = b_offset;
+}
+
+void CpuGemmLowpOffsetContributionKernel::set_scale(float scale)
+{
+ _scale = scale;
+}
+
Status CpuGemmLowpOffsetContributionKernel::validate(const ITensorInfo *mm_result,
const ITensorInfo *vector_sum_col,
const ITensorInfo *vector_sum_row,
@@ -410,8 +698,18 @@ void CpuGemmLowpOffsetContributionKernel::run_op(ITensorPack &tensors, const Win
const bool reinterpret_as_3d = vector_sum_row != nullptr && mm_result->info()->num_dimensions() > 1 &&
mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
- run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, _k_offset,
- _slide_vector_sum_col, reinterpret_as_3d);
+ // check to see what is the output type of result
+ auto k_offset = _a_offset * _b_offset * _k;
+ if (mm_result->info()->data_type() == DataType::F32)
+ {
+ run_offset_contribution_float(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, k_offset,
+ _scale, _slide_vector_sum_col, reinterpret_as_3d);
+ }
+ else
+ {
+ run_offset_contribution(window, mm_result, vector_sum_col, vector_sum_row, _a_offset, _b_offset, k_offset,
+ _slide_vector_sum_col, reinterpret_as_3d);
+ }
}
const char *CpuGemmLowpOffsetContributionKernel::name() const
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
index 08b2d47529..ecbfb0c282 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2022,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,12 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H
-#define ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
+#include <cstdint>
+
namespace arm_compute
{
namespace cpu
@@ -62,13 +64,16 @@ public:
* @param[in] k Number of matrix A columns or Matrix B rows
* @param[in] a_offset Offset to be added to each element of the matrix A.
* @param[in] b_offset Offset to be added to each element of the matrix B.
+ * @param[in] scale (Optional) multiplies the contribution to make it the same scale as the dst in the case where mm_result is float
+ * (and so has already been scaled). Default is 1.0
*/
void configure(ITensorInfo *mm_result,
ITensorInfo *vector_sum_col,
ITensorInfo *vector_sum_row,
int32_t k,
int32_t a_offset,
- int32_t b_offset);
+ int32_t b_offset,
+ float scale = 1.0f);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuGemmLowpOffsetContributionKernel::configure()
@@ -81,6 +86,29 @@ public:
int32_t a_offset,
int32_t b_offset);
+ /** Set the a offset
+ * Warning: if a_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] a_offset Offset to be added to each element of the matrix A.
+ */
+ void set_a_offset(int32_t a_offset);
+
+ /** Set the b offset
+ * Warning: if b_offset is non-zero then vector_sum_row must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] b_offset Offset to be added to each element of the matrix B.
+ */
+ void set_b_offset(int32_t b_offset);
+
+ /** Set the dequantize scale
+ *
+ * @param[in] scale Multiplies the contribution to make it the same scale as the dst in the case where
+ * mm_result is float (and so has already been scaled).
+ */
+ void set_scale(float scale);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
@@ -88,10 +116,11 @@ public:
private:
int32_t _a_offset{0};
int32_t _b_offset{0};
- int32_t _k_offset{0};
+ int32_t _k{0}; // Number of columns of A or rows of B, used in last offset term
+ float _scale{1.0};
bool _slide_vector_sum_col{true};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONKERNEL_H
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
index d008842398..3c113f2828 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021, 2023 Arm Limited.
+ * Copyright (c) 2019-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -919,7 +919,7 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::configure(const ITensorInfo
_a_offset = a_offset;
_b_offset = b_offset;
- _k_offset = a_offset * b_offset * k;
+ _k = k;
_output_stage = output_stage;
// If a_offset == 0, vector_sum_col can be a nullptr
@@ -958,6 +958,16 @@ Status CpuGemmLowpOffsetContributionOutputStageKernel::validate(const ITensorInf
return Status{};
}
+void CpuGemmLowpOffsetContributionOutputStageKernel::set_a_offset(int32_t a_offset)
+{
+ _a_offset = a_offset;
+}
+
+void CpuGemmLowpOffsetContributionOutputStageKernel::set_b_offset(int32_t b_offset)
+{
+ _b_offset = b_offset;
+}
+
void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &tensors,
const Window &window,
const ThreadInfo &info)
@@ -993,10 +1003,11 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &te
// Check if symmetric per-channel execution
const bool is_symm = _output_stage.is_quantized_per_channel;
+ auto k_offset = _a_offset * _b_offset * _k;
if (is_symm)
{
run_offset_contribution_output_stage_symm(window, mm_result, vector_sum_col, vector_sum_row, bias, dst,
- _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched,
+ _a_offset, _b_offset, k_offset, _is_vector_sum_col_batched,
_output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
else
@@ -1004,13 +1015,13 @@ void CpuGemmLowpOffsetContributionOutputStageKernel::run_op(ITensorPack &te
if (is_signed)
{
run_offset_contribution_output_stage<int8_t>(
- window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset,
+ window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, k_offset,
_is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
else
{
run_offset_contribution_output_stage<uint8_t>(
- window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset,
+ window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, k_offset,
_is_vector_sum_col_batched, _output_stage, reinterpret_as_3d, is_bounded_relu, is_fixed_point);
}
}
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
index af477d4756..ff706ff3dc 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H
-#define ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
+#define ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
#include "arm_compute/core/KernelDescriptors.h"
@@ -110,6 +110,22 @@ public:
int32_t b_offset,
GEMMLowpOutputStageInfo output_stage);
+ /** Set the a offset
+ * Warning: if a_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] a_offset Offset to be added to each element of the matrix A.
+ */
+ void set_a_offset(int32_t a_offset);
+
+ /** Set the b offset
+ * Warning: if b_offset is non-zero then vector_sum_col must be set in run_op.
+ * Run configure or validate again if you aren't sure
+ *
+ * @param[in] b_offset Offset to be added to each element of the matrix B.
+ */
+ void set_b_offset(int32_t b_offset);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
@@ -118,11 +134,11 @@ private:
/** Function to use for the particular tensors passed to configure() */
int32_t _a_offset{0};
int32_t _b_offset{0};
- int32_t _k_offset{0};
+ int32_t _k{0}; // Number of columns of A or rows of B, used in last offset term
bool _is_vector_sum_col_batched{true};
GEMMLowpOutputStageInfo _output_stage{GEMMLowpOutputStageInfo()};
};
} // namespace kernels
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_GEMMLOWP_OFFSETCONTRIBUTION_OUTPUTSTAGE_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_CPUGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 45ebeec394..d71789cc39 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -104,6 +104,7 @@ struct SoftmaxKernelDataTypeISASelectorData
DataType dt;
cpuinfo::CpuIsaInfo isa;
bool is_log;
+ int axis;
};
// Selector pointer types
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 54ff858eeb..5cf81f815c 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -50,9 +50,17 @@ namespace
{
/* Softmax */
static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = {
+ {"sme2_fp32_softmax",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2 && data.axis == 0); },
+ REGISTER_FP32_SME2(sme2_fp32_softmax)},
{"neon_fp32_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); },
REGISTER_FP32_NEON(neon_fp32_softmax<false>)},
+ {"sme2_fp16_softmax",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ { return (!data.is_log && data.dt == DataType::F16 && data.isa.sme2 && data.axis == 0); },
+ REGISTER_FP16_SME2(sme2_fp16_softmax)},
{"neon_fp16_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data)
{ return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; },
@@ -150,7 +158,7 @@ void CpuSoftmaxKernel::configure(
}
const auto *uk = CpuSoftmaxKernel::get_implementation(
- SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log});
+ SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log, axis});
ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel");
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 9a913c5c58..941fed0ba8 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+
#pragma once
#include "arm_gemm_local.hpp"
@@ -151,6 +155,7 @@ public:
int _maxthreads;
bool _fixed_format;
bool _fast_mode;
+ bool _accumulate;
const GemmConfig *_cfg;
GemmArgs(const CPUInfo *ci,
@@ -165,6 +170,7 @@ public:
const int maxthreads,
bool fixed_format = false,
bool fast_mode = false,
+ bool accumulate = false,
const GemmConfig *cfg = nullptr)
: _ci(ci),
_Msize(M),
@@ -178,6 +184,7 @@ public:
_maxthreads(maxthreads),
_fixed_format(fixed_format),
_fast_mode(fast_mode),
+ _accumulate(accumulate),
_cfg(cfg)
{
}
@@ -253,6 +260,19 @@ public:
}
};
+struct DequantizeFloat
+{
+public:
+ float scale = 0;
+
+ DequantizeFloat() = default;
+
+ // Constructor
+ DequantizeFloat(const float scale) : scale(scale)
+ {
+ }
+};
+
struct Nothing
{
};
@@ -278,3 +298,5 @@ template <typename Top, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
+
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 4825814e31..45d1e43274 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -166,6 +166,12 @@ public:
{
}
+ /*** Dequanize scale interface (optional) ***/
+ /* Set the dequantize scale for GEMMs when converting from int to float (float out = scale * float(int out) ) */
+ virtual void set_dequantize_scale(const float)
+ {
+ }
+
/*** Introspection interface ***/
/* Get the configuration of this GEMM */
virtual GemmConfig get_config() = 0;
diff --git a/src/cpu/kernels/softmax/generic/sme2/fp16.cpp b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp
new file mode 100644
index 0000000000..bcd34d1ca2
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/fp16.cpp
@@ -0,0 +1,774 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_f16_softmax_kernel( //
+ const float16_t *src,
+ float16_t *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4])
+{
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // Registers
+ //
+ // * x9: temporary, index
+ // * x10: temporary, -inf
+ // * x11: temporary, 0
+ // * x12: temporary, 1.0f
+ // * x13: temporary, body_length
+ //
+ // * x20: index_3
+ // * x21: src_3
+ // * x22: dst_3
+ // * x23: index_2
+ // * x24: src_2
+ // * x25: dst_2
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ //
+ // * z0: c1
+ // * z1: c2
+ // * z2: c3
+ // * z3: c4
+ // * z4: c5
+ // * z5: shift
+ // * z6: inv_ln2
+ // * z7: neg_ln2_hi
+ // * z8: neg_ln2_lo
+ // * z9: min_input
+ // * z10: 23, 0
+ // * z11: max_value
+ // * z12-z15: x, x_fp32_lower_halves, r_hi, r, r2
+ // * z16-z19: max_value, shift, z, scale, poly
+ // * z20-z21: n, p1, p12345
+ // * z22-z23: n, p23, p2345
+ // * z24-z25: p45
+ // * z26: beta
+ // * z28-z31: sum_value, x_fp32_upper_halves
+ //
+ // * za0-za3: sum_value
+ //
+ // * p0: all-true
+ // * p1: left-over predicate for find-max & normalize loops
+ // * p2-p4: left-over predicates for regularize loop
+ // * p4-p7: underflow in vector loop
+ // * p5-p6: underflow in leftover loop
+ // *
+ // * pn9: all-true
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25207811 // ptrue pn9.b
+
+ mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ dup z0.s, w9 // c1.
+ dup z1.s, w10 // c2.
+ dup z2.s, w11 // c3.
+ dup z3.s, w12 // c4.
+ dup z4.s, w13 // c5.
+
+ mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ dup z5.s, w9 // shift
+ dup z6.s, w10 // inv_ln2
+ dup z7.s, w11 // neg_ln2_hi
+ dup z8.s, w12 // neg_ln2_lo
+ dup z9.s, w13 // min_input
+
+ dup z26.s, %w[beta] // beta
+ fcvt h26, s26
+ dup z26.h, z26.h[0]
+
+ mov w10, #0xfc00 // -inf: 0xfc00 for fp16
+
+ mov w11, #0 // 0
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cnth x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+
+ // ---------------------------------------------------------------- z16-z19: max_value = -inf
+ dup z16.h, w10
+ dup z17.h, w10
+ dup z18.h, w10
+ dup z19.h, w10
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+ dup z11.h, w10 // z11: max_value = -inf
+
+find_max_body_start%=:
+ cmp x9, x13
+ b.eq find_max_body_end%=
+
+ .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x
+ .inst 0xc16cb910 // fmax {z16.h-z19.h}, {z16.h-z19.h}, {z12.h-z15.h} // z16-z19: max_value = max(max_value, x)
+
+ inch x9, ALL, MUL #4
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.h, x9, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1h z12.h, p1/z, [x27, x9, LSL #1] // z12: x
+ fmax z16.h, p1/m, z16.h, z12.h // z16: max_value = max(max_value, x)
+
+ inch x9
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+
+ // ---------------------------------------------------------------- z16: max_value
+ .inst 0xc172b110 // fmax {z16.h-z17.h}, {z16.h-z17.h}, {z18.s-z19.h}
+ fmax z16.h, p0/m, z16.h, z17.h
+ fmaxv h16, p0, z16.h
+
+ // ---------------------------------------------------------------- z11: max_value
+ dup z11.h, z16.h[0]
+
+ // ==================================================
+ // Step 2: Regularize, i.e. Calculate exp(x - max(x)
+ // ==================================================
+
+ .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value (in fp32)
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // ---------------------------------------------------- x9: index
+
+regularize_body_start%=:
+ cmp x9, x13
+ b.eq regularize_body_end%=
+
+ // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
+ .inst 0xa009a76c // ld1h {z12.h-z15.h}, pn9/z, [x27, x9, LSL #1] // z12-z15: x
+
+ // ---------------------------------------------------------------- z12-z15: x = input_data - max_value
+ fsub z12.h, z12.h, z11.h
+ fsub z13.h, z13.h, z11.h
+ fsub z14.h, z14.h, z11.h
+ fsub z15.h, z15.h, z11.h
+
+ // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta
+ fmul z12.h, z12.h, z26.h
+ fmul z13.h, z13.h, z26.h
+ fmul z14.h, z14.h, z26.h
+ fmul z15.h, z15.h, z26.h
+
+ // ----------------------------------------------------------------
+ // Convert fp16 values to fp32. This results in four more registers.
+ // z12 --> z12, z28
+ fcvtlt z28.s, p0/m, z12.h
+ fcvt z12.s, p0/m, z12.h
+
+ // z13 --> z13, z29
+ fcvtlt z29.s, p0/m, z13.h
+ fcvt z13.s, p0/m, z13.h
+
+ // z14 --> z14, z30
+ fcvtlt z30.s, p0/m, z14.h
+ fcvt z14.s, p0/m, z14.h
+
+ // z15 --> z15, z31
+ fcvtlt z31.s, p0/m, z15.h
+ fcvt z15.s, p0/m, z15.h
+
+ // ----------------------------------------------------------------
+ // Process z12-z15
+ // ----------------------------------------------------------------
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z12.s, z9.s
+ fcmlt p5.s, p0/z, z13.s, z9.s
+ fcmlt p6.s, p0/z, z14.s, z9.s
+ fcmlt p7.s, p0/z, z15.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z12.s, z6.s
+ fmla z17.s, p0/m, z13.s, z6.s
+ fmla z18.s, p0/m, z14.s, z6.s
+ fmla z19.s, p0/m, z15.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p0/m, z20.s, z7.s
+ fmla z13.s, p0/m, z21.s, z7.s
+ fmla z14.s, p0/m, z22.s, z7.s
+ fmla z15.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
+ fmla z12.s, p0/m, z20.s, z8.s
+ fmla z13.s, p0/m, z21.s, z8.s
+ fmla z14.s, p0/m, z22.s, z8.s
+ fmla z15.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors. (z12-z13)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z12.s, z0.s
+ fmul z21.s, z13.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z12.s, z2.s
+ fmla z23.s, p0/m, z13.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z12.s, z4.s
+ fmla z25.s, p0/m, z13.s, z4.s
+
+ // ---------------------------------------------------------------- z12-z13: r2 = r * r
+ fmul z12.s, z12.s, z12.s
+ fmul z13.s, z13.s, z13.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z12.s, z24.s
+ fmla z23.s, p0/m, z13.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z12.s, z22.s
+ fmla z21.s, p0/m, z13.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors (z14-z15)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z14.s, z0.s
+ fmul z21.s, z15.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z14.s, z2.s
+ fmla z23.s, p0/m, z15.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z14.s, z4.s
+ fmla z25.s, p0/m, z15.s, z4.s
+
+ // ---------------------------------------------------------------- z14-z15: r2 = r * r
+ fmul z14.s, z14.s, z14.s
+ fmul z15.s, z15.s, z15.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z14.s, z24.s
+ fmla z23.s, p0/m, z15.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z14.s, z22.s
+ fmla z21.s, p0/m, z15.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z12.s, p4, z10.s, z16.s
+ sel z13.s, p5, z10.s, z17.s
+ sel z14.s, p6, z10.s, z18.s
+ sel z15.s, p7, z10.s, z19.s
+
+ // ---------------------------------------------------------------- sum in fp32
+ .inst 0xc1a17d80 // fadd za.s[w11, #0, VGx4], {z12.s-z15.s} za0-za3: sum_value = sum_value + poly
+
+ // ----------------------------------------------------------------
+ // Process z28-z31
+ // ----------------------------------------------------------------
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z28.s, z9.s
+ fcmlt p5.s, p0/z, z29.s, z9.s
+ fcmlt p6.s, p0/z, z30.s, z9.s
+ fcmlt p7.s, p0/z, z31.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z28.s, z6.s
+ fmla z17.s, p0/m, z29.s, z6.s
+ fmla z18.s, p0/m, z30.s, z6.s
+ fmla z19.s, p0/m, z31.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z24-z27: r_hi = x + n * neg_ln2_hi
+ fmla z28.s, p0/m, z20.s, z7.s
+ fmla z29.s, p0/m, z21.s, z7.s
+ fmla z30.s, p0/m, z22.s, z7.s
+ fmla z31.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z27-z30: r = r_hi + n * neg_ln2_lo
+ fmla z28.s, p0/m, z20.s, z8.s
+ fmla z29.s, p0/m, z21.s, z8.s
+ fmla z30.s, p0/m, z22.s, z8.s
+ fmla z31.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors. (z28-z29)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z28.s, z0.s
+ fmul z21.s, z29.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z28.s, z2.s
+ fmla z23.s, p0/m, z29.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z25: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z28.s, z4.s
+ fmla z25.s, p0/m, z29.s, z4.s
+
+ // ---------------------------------------------------------------- z28-z29: r2 = r * r
+ fmul z28.s, z28.s, z28.s
+ fmul z29.s, z29.s, z29.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z28.s, z24.s
+ fmla z23.s, p0/m, z29.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z28.s, z22.s
+ fmla z21.s, p0/m, z29.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors (z30-z31)
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z30.s, z0.s
+ fmul z21.s, z31.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z30.s, z2.s
+ fmla z23.s, p0/m, z31.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z30.s, z4.s
+ fmla z25.s, p0/m, z31.s, z4.s
+
+ // ---------------------------------------------------------------- z30-z31: r2 = r * r
+ fmul z30.s, z30.s, z30.s
+ fmul z31.s, z31.s, z31.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z30.s, z24.s
+ fmla z23.s, p0/m, z31.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z30.s, z22.s
+ fmla z21.s, p0/m, z31.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z28.s, p4, z10.s, z16.s
+ sel z29.s, p5, z10.s, z17.s
+ sel z30.s, p6, z10.s, z18.s
+ sel z31.s, p7, z10.s, z19.s
+
+ // ---------------------------------------------------------------- sum in fp32
+ .inst 0xc1a17f80 // fadd za.s[w11, #0, VGx4], {z28.s-z31.s} za0-za3: sum_value = sum_value + poly
+
+ fcvt z12.h, p0/m, z12.s
+ fcvtnt z12.h, p0/m, z28.s
+
+ fcvt z13.h, p0/m, z13.s
+ fcvtnt z13.h, p0/m, z29.s
+
+ fcvt z14.h, p0/m, z14.s
+ fcvtnt z14.h, p0/m, z30.s
+
+ fcvt z15.h, p0/m, z15.s
+ fcvtnt z15.h, p0/m, z31.s
+
+ // Stores 4 consecutive registers to the output
+ .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1]
+
+ inch x9, ALL, MUL #4
+ b regularize_body_start%=
+regularize_body_end%=:
+
+ // ---------------------------------------------------------------- z28: sum_value
+ .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4]
+ fadd z28.s, z28.s, z29.s
+ fadd z30.s, z30.s, z31.s
+ fadd z28.s, z28.s, z30.s
+
+ // Loop for processing the leftover part.
+regularize_leftover_start%=:
+ whilelo p2.h, x9, %x[length]
+ b.none regularize_leftover_end%=
+
+ ld1h z12.h, p2/z, [x27, x9, LSL #1] // x12: input_data
+
+ fsub z12.h, z12.h, z11.h // z12: x = input_data - max_value
+ fmul z12.h, z12.h, z26.h // z12: x = (input_data - max_value) * beta
+
+ // ---------------------------------------------------------------- z12.h --> z12.s, z13.s
+ fcvtlt z13.s, p2/m, z12.h
+ fcvt z12.s, p2/m, z12.h
+
+ // ---------------------------------------------------------------- p3, p4: predicates for z12, z14
+ pfalse p1.b
+ trn1 p3.h, p2.h, p1.h // for z12
+ trn2 p4.h, p2.h, p1.h // for z13
+
+ mov z16.d, z5.d // z16: shift
+ mov z17.d, z5.d // z17: shift
+ fcmlt p5.s, p3/z, z12.s, z9.s // p5: underflow = x < min_input
+ fcmlt p6.s, p4/z, z13.s, z9.s // p6: underflow = x < min_input
+ fmla z16.s, p3/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
+ fmla z17.s, p4/m, z13.s, z6.s // z17: z = shift + x * inv_ln2
+ fsub z20.s, z16.s, z5.s // z20: n = z - shift
+ fsub z21.s, z17.s, z5.s // z21: n = z - shift
+ fmla z12.s, p3/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
+ fmla z13.s, p4/m, z21.s, z7.s // z13: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p3/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
+ fmla z13.s, p4/m, z21.s, z8.s // z13: r = r_hi + n * neg_ln2_lo
+ dup z10.s, #23 // z10: 23
+ urshl z16.s, p3/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
+ urshl z17.s, p4/m, z17.s, z10.s // z17: scale = z << 23 (2^n)
+ fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
+ fmul z21.s, z13.s, z0.s // z21: p1 = r * c1
+ mov z22.d, z1.d // z22: p23 = c2
+ mov z23.d, z1.d // z23: p23 = c2
+ fmla z22.s, p3/m, z12.s, z2.s // z22: p23 = c2 + r * c3
+ fmla z23.s, p4/m, z13.s, z2.s // z23: p23 = c2 + r * c3
+ mov z24.d, z3.d // z24: c4
+ mov z25.d, z3.d // z25: c4
+ fmla z24.s, p3/m, z12.s, z4.s // z24: p45 = c4 + r * c5
+ fmla z25.s, p4/m, z13.s, z4.s // z25: p45 = c4 + r * c5
+ fmul z12.s, z12.s, z12.s // z12: r2 = r * r
+ fmul z13.s, z13.s, z13.s // z13: r2 = r * r
+ fmla z22.s, p3/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
+ fmla z23.s, p4/m, z13.s, z25.s // z23: p2345 = p23 + r2 * p45
+ fmla z20.s, p3/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
+ fmla z21.s, p4/m, z13.s, z23.s // z21: p12345 = p1 + r2 * p2345
+ fmla z16.s, p3/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
+ fmla z17.s, p4/m, z21.s, z17.s // z17: poly = scale + p12345 * scale
+ dup z10.s, #0 // z10: 0
+ sel z16.s, p5, z10.s, z16.s // z16: poly = underflow ? 0 : poly
+ sel z17.s, p6, z10.s, z17.s // z17: poly = underflow ? 0 : poly
+ fadd z28.s, p3/m, z28.s, z16.s // z28: sum_value = sum_value + poly
+ fadd z28.s, p4/m, z28.s, z17.s // z28: sum_value = sum_value + poly
+
+ fcvt z16.h, p3/m, z16.s
+ fcvtnt z16.h, p4/m, z17.s
+ st1h z16.h, p2, [x28, x9, LSL #1]
+
+ inch x9
+ b regularize_leftover_start%=
+regularize_leftover_end%=:
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+
+ // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value
+ faddv s28, p0, z28.s
+ fmov s29, #1.0 // 1.0f
+ fdiv s28, s29, s28
+ fcvt h28, s28
+
+ dup z28.h, z28.h[0]
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+
+normalize_body_start%=:
+ cmp x9, x13
+ b.eq normalize_body_end%=
+
+ .inst 0xa009a78c // ld1h {z12.h-z15.h}, pn9/z, [x28, x9, LSL #1]
+
+ // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value
+ fmul z12.h, z12.h, z28.h
+ fmul z13.h, z13.h, z28.h
+ fmul z14.h, z14.h, z28.h
+ fmul z15.h, z15.h, z28.h
+
+ .inst 0xa029a78c // st1h {z12.h-z15.h}, pn9, [x28, x9, LSL #1]
+
+ inch x9, ALL, MUL #4
+ b normalize_body_start%=
+normalize_body_end%=:
+
+ // Loop for processing the leftover part.
+normalize_leftover_start%=:
+ whilelo p1.h, x9, %x[length]
+ b.none normalize_leftover_end%=
+
+ ld1h z12.h, p1/z, [x28, x9, LSL #1] // z12: x
+ fmul z12.h, z12.h, z28.h // z12: result = x * inv_sum_value
+
+ st1h z12.h, p1, [x28, x9, LSL #1]
+
+ inch x9
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p9", //
+ "x9", "x10", "x11", "x12", "x13", "x14", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_fp16_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const auto *k_src = reinterpret_cast<const float16_t *>(in->buffer() + k_src_offset);
+ auto *k_dst = reinterpret_cast<float16_t *>(out->buffer() + k_dst_offset);
+
+ sme2_f16_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/softmax/generic/sme2/fp32.cpp b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp
new file mode 100644
index 0000000000..159039a320
--- /dev/null
+++ b/src/cpu/kernels/softmax/generic/sme2/fp32.cpp
@@ -0,0 +1,578 @@
+/*
+ * Copyright (c) 2023-2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Window.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+
+// SoftMax
+//
+// Steps:
+// * Find max: max_value = max(src)
+// * Regularize: dst[i] = exp(src[i] - max_value)
+// sum_value = sum(dst)
+// * Normalize: dst[i] = dst[i] / sum_value
+void sme2_f32_softmax_kernel( //
+ const float *src,
+ float *dst,
+ float beta,
+ const uintptr_t shape[4],
+ const uintptr_t src_strides[4],
+ const uintptr_t dst_strides[4])
+{
+ // Precondition:
+ // * src_strides[0] == sizeof(float)
+ // * dst_strides[0] == sizeof(float)
+
+ __asm__ volatile(
+ R"(
+ .inst 0xd503477f // smstart
+
+ // Registers
+ //
+ // * x9: temporary, index
+ // * x10: temporary, -inf
+ // * x11: temporary, 0
+ // * x12: temporary, 1.0f
+ // * x13: temporary, body_length
+ //
+ // * x20: index_3
+ // * x21: src_3
+ // * x22: dst_3
+ // * x23: index_2
+ // * x24: src_2
+ // * x25: dst_2
+ // * x26: index_1
+ // * x27: src_1
+ // * x28: dst_1
+ //
+ // * z0: c1
+ // * z1: c2
+ // * z2: c3
+ // * z3: c4
+ // * z4: c5
+ // * z5: shift
+ // * z6: inv_ln2
+ // * z7: neg_ln2_hi
+ // * z8: neg_ln2_lo
+ // * z9: min_input
+ // * z10: 23, 0
+ // * z11: max_value
+ // * z12-z15: x, r_hi, r, r2
+ // * z16-z19: max_value, shift, z, scale, poly
+ // * z20-z21: n, p1, p12345
+ // * z22-z23: n, p23, p2345
+ // * z24-z25: p45
+ // * z26: beta
+ // * z28-z31: sum_value
+ //
+ // * za0-za3: sum_value
+ //
+ // * p0: all-true
+ // * p1: left-over predicate
+ // * p4-p7: underflow
+ // * pn9: all-true
+
+ // Prepares all constant values
+
+ ptrue p0.b
+ .inst 0x25207811 // ptrue pn9.b
+
+ mov w9, #0xfff6 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ mov w10, #0xfedb // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ mov w11, #0xaf33 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ mov w12, #0x9f17 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ mov w13, #0x2010 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ movk w9, #0x3f7f, LSL #16 // c1: 0x1.ffffecp-1f = 0x3f7ffff6
+ movk w10, #0x3eff, LSL #16 // c2: 0x1.fffdb6p-2f = 0x3efffedb
+ movk x11, #0x3e2a, LSL #16 // c3: 0x1.555e66p-3f = 0x3e2aaf33
+ movk w12, #0x3d2b, LSL #16 // c4: 0x1.573e2ep-5f = 0x3d2b9f17
+ movk w13, #0x3c07, LSL #16 // c5: 0x1.0e4020p-7f = 0x3c072010
+
+ dup z0.s, w9 // c1.
+ dup z1.s, w10 // c2.
+ dup z2.s, w11 // c3.
+ dup z3.s, w12 // c4.
+ dup z4.s, w13 // c5.
+
+ mov w9, #0x007f // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ mov w10, #0xaa3b // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ mov w11, #0x7200 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ mov w12, #0xbe8e // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ mov w13, #0x47ae // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ movk w9, #0x4b00, LSL #16 // shift: 2^23 + 127 = 0x1.0000fep23f = 0x4b00007f
+ movk w10, #0x3fb8, LSL #16 // inv_ln2: 1 / ln(2) = 0x1.715476p+0f = 0x3fb8aa3b
+ movk w11, #0xbf31, LSL #16 // neg_ln2_hi: -ln(2) from bits -1 to -19 = -0x1.62e400p-1f = 0xbf317200
+ movk w12, #0xb5bf, LSL #16 // neg_ln2_lo: -ln(2) from bits -20 to -42 = -0x1.7f7d1cp-20f = 0xb5bfbe8e
+ movk w13, #0xc2ad, LSL #16 // min_input (Approximately ln 2^-125): -86.64 = 0xc2ad47ae
+
+ dup z5.s, w9 // shift
+ dup z6.s, w10 // inv_ln2
+ dup z7.s, w11 // neg_ln2_hi
+ dup z8.s, w12 // neg_ln2_lo
+ dup z9.s, w13 // min_input
+
+ dup z26.s, %w[beta] // beta
+
+ mov w10, #0x0000 // -inf: 0xff800000
+ movk w10, #0xff80 // -inf: 0xff800000
+
+ mov w11, #0 // 0
+
+ // ---------------------------------------------------------------- x13: body_length = (length / vl) * vl
+ cntw x13, ALL, MUL #4
+ udiv x9, %x[length], x13
+ mul x13, x13, x9
+
+ // ==================================================
+ // 3D loop opening
+ // ==================================================
+
+ mov x20, %x[shape_3]
+ mov x21, %x[src]
+ mov x22, %x[dst]
+
+loop_3_start%=:
+ // for index_3 in shape_3 downto 1
+ cmp x20, #0
+ b.eq loop_3_end%=
+ sub x20, x20, #1
+
+ mov x23, %x[shape_2]
+ mov x24, x21
+ mov x25, x22
+
+loop_2_start%=:
+ // for index_2 in shape_2 downto 1
+ cmp x23, #0
+ b.eq loop_2_end%=
+ sub x23, x23, #1
+
+ mov x26, %x[shape_1]
+ mov x27, x24
+ mov x28, x25
+
+loop_1_start%=:
+ // for index_1 in shape_2 downto 1
+ cmp x26, #0
+ b.eq loop_1_end%=
+ sub x26, x26, #1
+
+ // ==================================================
+ // Step 1: Find max
+ // ==================================================
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+ dup z11.s, w10 // z11: max_value = -inf
+
+ // ---------------------------------------------------------------- z16-z19: max_value = -inf
+ mov z16.d, z11.d
+ mov z17.d, z11.d
+ mov z18.d, z11.d
+ mov z19.d, z11.d
+
+find_max_body_start%=:
+ cmp x9, x13
+ b.eq find_max_body_end%=
+
+ .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2] // z12-z15: x
+ .inst 0xc1acb910 // fmax {z16.s-z19.s}, {z16.s-z19.s}, {z12.s-z15.s} // z16-z19: max_value = max(max_value, x)
+
+ incw x9, ALL, MUL #4
+ b find_max_body_start%=
+find_max_body_end%=:
+
+ // Loop for processing the leftover part.
+find_max_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none find_max_leftover_end%=
+
+ ld1w z12.s, p1/z, [x27, x9, LSL #2] // z12: x
+ fmax z16.s, p1/m, z16.s, z12.s // z16: max_value = max(max_value, x)
+
+ incw x9
+ b find_max_leftover_start%=
+find_max_leftover_end%=:
+
+ // ---------------------------------------------------------------- z16: max_value
+ .inst 0xc1b2b110 // fmax {z16.s-z17.s}, {z16.s-z17.s}, {z18.s-z19.s}
+ fmax z16.s, p0/m, z16.s, z17.s
+ fmaxv s16, p0, z16.s
+
+ // ---------------------------------------------------------------- z11: max_value
+ dup z11.s, z16.s[0]
+
+ // ==================================================
+ // Step 2: Regularize
+ // ==================================================
+
+ .inst 0xc00800ff // zero {za0.s, za1.s, za2.s, za3.s} za0-za3: sum_value
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // ---------------------------------------------------- x9: index
+
+regularize_body_start%=:
+ cmp x9, x13
+ b.eq regularize_body_end%=
+
+ // Loads the input data to 4 consecutive registers ---------------- z12-z15: input_data
+ .inst 0xa009c76c // ld1w {z12.s-z15.s}, pn9/z, [x27, x9, LSL #2]
+
+ // ---------------------------------------------------------------- z12-z15: x = input_data - max_value
+ fsub z12.s, z12.s, z11.s
+ fsub z13.s, z13.s, z11.s
+ fsub z14.s, z14.s, z11.s
+ fsub z15.s, z15.s, z11.s
+
+ // ---------------------------------------------------------------- z12-z15: x = (input_data - max_value) * beta
+ fmul z12.s, z12.s, z26.s
+ fmul z13.s, z13.s, z26.s
+ fmul z14.s, z14.s, z26.s
+ fmul z15.s, z15.s, z26.s
+
+ // ---------------------------------------------------------------- z16-z19: shift
+ mov z16.d, z5.d
+ mov z17.d, z5.d
+ mov z18.d, z5.d
+ mov z19.d, z5.d
+
+ // ---------------------------------------------------------------- p4-p7: underflow = x < min_input
+ fcmlt p4.s, p0/z, z12.s, z9.s
+ fcmlt p5.s, p0/z, z13.s, z9.s
+ fcmlt p6.s, p0/z, z14.s, z9.s
+ fcmlt p7.s, p0/z, z15.s, z9.s
+
+ // ---------------------------------------------------------------- z16-z19: z = shift + x * inv_ln2
+ fmla z16.s, p0/m, z12.s, z6.s
+ fmla z17.s, p0/m, z13.s, z6.s
+ fmla z18.s, p0/m, z14.s, z6.s
+ fmla z19.s, p0/m, z15.s, z6.s
+
+ // ---------------------------------------------------------------- z20-z23: n = z - shift
+ fsub z20.s, z16.s, z5.s
+ fsub z21.s, z17.s, z5.s
+ fsub z22.s, z18.s, z5.s
+ fsub z23.s, z19.s, z5.s
+
+ // ---------------------------------------------------------------- z12-z15: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p0/m, z20.s, z7.s
+ fmla z13.s, p0/m, z21.s, z7.s
+ fmla z14.s, p0/m, z22.s, z7.s
+ fmla z15.s, p0/m, z23.s, z7.s
+
+ // ---------------------------------------------------------------- z12-z15: r = r_hi + n * neg_ln2_lo
+ fmla z12.s, p0/m, z20.s, z8.s
+ fmla z13.s, p0/m, z21.s, z8.s
+ fmla z14.s, p0/m, z22.s, z8.s
+ fmla z15.s, p0/m, z23.s, z8.s
+
+ // ---------------------------------------------------------------- z16-z19: scale = z << 23 (2^n)
+ dup z10.s, #23
+ urshl z16.s, p0/m, z16.s, z10.s
+ urshl z17.s, p0/m, z17.s, z10.s
+ urshl z18.s, p0/m, z18.s, z10.s
+ urshl z19.s, p0/m, z19.s, z10.s
+
+ // Processes the first 2 vectors.
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z12.s, z0.s
+ fmul z21.s, z13.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z12.s, z2.s
+ fmla z23.s, p0/m, z13.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z12.s, z4.s
+ fmla z25.s, p0/m, z13.s, z4.s
+
+ // ---------------------------------------------------------------- z12-z13: r2 = r * r
+ fmul z12.s, z12.s, z12.s
+ fmul z13.s, z13.s, z13.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z12.s, z24.s
+ fmla z23.s, p0/m, z13.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z12.s, z22.s
+ fmla z21.s, p0/m, z13.s, z23.s
+
+ // ---------------------------------------------------------------- z16-z17: poly = scale + p12345 * scale
+ fmla z16.s, p0/m, z20.s, z16.s
+ fmla z17.s, p0/m, z21.s, z17.s
+
+ // Processes the last 2 vectors
+
+ // ---------------------------------------------------------------- z20-z21: p1 = r * c1
+ fmul z20.s, z14.s, z0.s
+ fmul z21.s, z15.s, z0.s
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2
+ mov z22.d, z1.d
+ mov z23.d, z1.d
+
+ // ---------------------------------------------------------------- z22-z23: p23 = c2 + r * c3
+ fmla z22.s, p0/m, z14.s, z2.s
+ fmla z23.s, p0/m, z15.s, z2.s
+
+ // ---------------------------------------------------------------- z24-z35: c4
+ mov z24.d, z3.d
+ mov z25.d, z3.d
+
+ // ---------------------------------------------------------------- z24-z25: p45 = c4 + r * c5
+ fmla z24.s, p0/m, z14.s, z4.s
+ fmla z25.s, p0/m, z15.s, z4.s
+
+ // ---------------------------------------------------------------- z14-z15: r2 = r * r
+ fmul z14.s, z14.s, z14.s
+ fmul z15.s, z15.s, z15.s
+
+ // ---------------------------------------------------------------- z22-z23: p2345 = p23 + r2 * p45
+ fmla z22.s, p0/m, z14.s, z24.s
+ fmla z23.s, p0/m, z15.s, z25.s
+
+ // ---------------------------------------------------------------- z20-z21: p12345 = p1 + r2 * p2345
+ fmla z20.s, p0/m, z14.s, z22.s
+ fmla z21.s, p0/m, z15.s, z23.s
+
+ // ---------------------------------------------------------------- z18-z19: poly = scale + p12345 * scale
+ fmla z18.s, p0/m, z20.s, z18.s
+ fmla z19.s, p0/m, z21.s, z19.s
+
+ // ---------------------------------------------------------------- z16-z19: poly = underflow ? 0 : poly
+ dup z10.s, #0
+ sel z16.s, p4, z10.s, z16.s
+ sel z17.s, p5, z10.s, z17.s
+ sel z18.s, p6, z10.s, z18.s
+ sel z19.s, p7, z10.s, z19.s
+
+ // Stores 4 consecutive registers to the output
+ .inst 0xa029c790 // st1w {z16.s-z19.s}, pn9, [x28, x9, LSL #2]
+
+ .inst 0xc1a17e00 // fadd za.s[w11, #0, VGx4], {z16.s-z19.s} za0-za3: sum_value = sum_value + poly
+
+ incw x9, ALL, MUL #4
+ b regularize_body_start%=
+regularize_body_end%=:
+
+ // ---------------------------------------------------------------- z28: sum_value
+ .inst 0xc0066c1c // mova {z28.s-z31.s}, za.s[w11, #0, VGx4]
+ fadd z28.s, z28.s, z29.s
+ fadd z30.s, z30.s, z31.s
+ fadd z28.s, z28.s, z30.s
+
+ // Loop for processing the leftover part.
+regularize_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none regularize_leftover_end%=
+
+ ld1w z12.s, p1/z, [x27, x9, LSL #2] // x12: input_data
+
+ fsub z12.s, z12.s, z11.s // z12: x = input_data - max_value
+ fmul z12.s, z12.s, z26.s // z12: x = (input_data - max_value) * beta
+
+ mov z16.d, z5.d // z16: shift
+ fcmlt p4.s, p1/z, z12.s, z9.s // p4: underflow = x < min_input
+ fmla z16.s, p1/m, z12.s, z6.s // z16: z = shift + x * inv_ln2
+ fsub z20.s, z16.s, z5.s // z20: n = z - shift
+ fmla z12.s, p1/m, z20.s, z7.s // z12: r_hi = x + n * neg_ln2_hi
+ fmla z12.s, p1/m, z20.s, z8.s // z12: r = r_hi + n * neg_ln2_lo
+ dup z10.s, #23 // z10: 23
+ urshl z16.s, p1/m, z16.s, z10.s // z16: scale = z << 23 (2^n)
+ fmul z20.s, z12.s, z0.s // z20: p1 = r * c1
+ mov z22.d, z1.d // z22: p23 = c2
+ fmla z22.s, p1/m, z12.s, z2.s // z22: p23 = c2 + r * c3
+ mov z24.d, z3.d // z24: c4
+ fmla z24.s, p1/m, z12.s, z4.s // z24: p45 = c4 + r * c5
+ fmul z12.s, z12.s, z12.s // z12: r2 = r * r
+ fmla z22.s, p1/m, z12.s, z24.s // z22: p2345 = p23 + r2 * p45
+ fmla z20.s, p1/m, z12.s, z22.s // z20: p12345 = p1 + r2 * p2345
+ fmla z16.s, p1/m, z20.s, z16.s // z16: poly = scale + p12345 * scale
+ dup z10.s, #0 // z10: 0
+ sel z16.s, p4, z10.s, z16.s // z16: poly = underflow ? 0 : poly
+
+ st1w z16.s, p1, [x28, x9, LSL #2]
+
+ fadd z28.s, p1/m, z28.s, z16.s // z28: sum_value = sum_value + poly
+
+ incw x9
+ b regularize_leftover_start%=
+regularize_leftover_end%=:
+
+ // ==================================================
+ // Step 3: Normalize
+ // ==================================================
+
+ // ---------------------------------------------------------------- z28: inv_sum_value = 1 / sum_value
+ fmov s29, #1.0 // 1.0f
+ faddv s28, p0, z28.s
+ fdiv s28, s29, s28
+ dup z28.s, z28.s[0]
+
+ // Loop for processing 4 vectors per iteration.
+ mov x9, #0 // x9: index
+
+normalize_body_start%=:
+ cmp x9, x13
+ b.eq normalize_body_end%=
+
+ .inst 0xa009c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, x9, LSL #2] // z12-z15: x
+
+ // ---------------------------------------------------------------- z12-z15: result = x * inv_sum_value
+ fmul z12.s, z12.s, z28.s
+ fmul z13.s, z13.s, z28.s
+ fmul z14.s, z14.s, z28.s
+ fmul z15.s, z15.s, z28.s
+
+ .inst 0xa029c78c // st1w {z12.s-z15.s}, pn9, [x28, x9, LSL #2]
+
+ incw x9, ALL, MUL #4
+ b normalize_body_start%=
+normalize_body_end%=:
+
+ // Loop for processing the leftover part.
+normalize_leftover_start%=:
+ whilelo p1.s, x9, %x[length]
+ b.none normalize_leftover_end%=
+
+ ld1w z12.s, p1/z, [x28, x9, LSL #2] // z12: x
+ fmul z12.s, z12.s, z28.s // z12: result = x * inv_sum_value
+
+ st1w z12.s, p1, [x28, x9, LSL #2]
+
+ incw x9
+ b normalize_leftover_start%=
+normalize_leftover_end%=:
+
+ // ==================================================
+ // 3D loop closing
+ // ==================================================
+
+ add x27, x27, %x[src_stride_1]
+ add x28, x28, %x[dst_stride_1]
+ b loop_1_start%=
+loop_1_end%=:
+
+ add x24, x24, %x[src_stride_2]
+ add x25, x25, %x[dst_stride_2]
+ b loop_2_start%=
+loop_2_end%=:
+
+ add x21, x21, %x[src_stride_3]
+ add x22, x22, %x[dst_stride_3]
+ b loop_3_start%=
+loop_3_end%=:
+
+ .inst 0xd503467f // smstop
+ )"
+ :
+ : [src] "r"(src), [dst] "r"(dst), [beta] "r"(beta), //
+ [shape_1] "r"(shape[1]), [shape_2] "r"(shape[2]), [shape_3] "r"(shape[3]), //
+ [src_stride_1] "r"(src_strides[1]), [src_stride_2] "r"(src_strides[2]),
+ [src_stride_3] "r"(src_strides[3]), //
+ [dst_stride_1] "r"(dst_strides[1]), [dst_stride_2] "r"(dst_strides[2]),
+ [dst_stride_3] "r"(dst_strides[3]), //
+ [length] "r"(shape[0]) //
+ : "cc", "memory", //
+ "p0", "p4", "p5", "p6", "p7", "p9", //
+ "x9", "x10", "x11", "x12", "x13", //
+ "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", //
+ "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", //
+ "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", //
+ "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", //
+ "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" //
+ );
+}
+
+void sme2_fp32_softmax(const ITensor *in, void *const, ITensor *out, const float beta, int axis, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(axis);
+
+ const auto *src_info = in->info();
+ const auto *dst_info = out->info();
+
+ const auto &full_shape = dst_info->tensor_shape();
+ const auto &src_strides = src_info->strides_in_bytes();
+ const auto &dst_strides = dst_info->strides_in_bytes();
+
+ const uintptr_t k_shape[] = {
+ full_shape[0],
+ window.num_iterations(1),
+ window.num_iterations(2),
+ window.num_iterations(3),
+ };
+
+ const uintptr_t k_src_strides[] = {
+ src_strides[0],
+ src_strides[1],
+ src_strides[2],
+ src_strides[3],
+ };
+
+ const uintptr_t k_dst_strides[] = {
+ dst_strides[0],
+ dst_strides[1],
+ dst_strides[2],
+ dst_strides[3],
+ };
+
+ const uintptr_t k_src_offset = window[0].start() * src_strides[0] + //
+ window[1].start() * src_strides[1] + //
+ window[2].start() * src_strides[2] + //
+ window[3].start() * src_strides[3];
+
+ const uintptr_t k_dst_offset = window[0].start() * dst_strides[0] + //
+ window[1].start() * dst_strides[1] + //
+ window[2].start() * dst_strides[2] + //
+ window[3].start() * dst_strides[3];
+
+ const auto *k_src = reinterpret_cast<const float *>(in->buffer() + k_src_offset);
+ auto *k_dst = reinterpret_cast<float *>(out->buffer() + k_dst_offset);
+
+ sme2_f32_softmax_kernel(k_src, k_dst, beta, k_shape, k_src_strides, k_dst_strides);
+}
+
+} // namespace cpu
+} // namespace arm_compute
+
+#endif // ARM_COMPUTE_ENABLE_SME2
diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h
index f9295ebbcc..1bb8ed50f0 100644
--- a/src/cpu/kernels/softmax/list.h
+++ b/src/cpu/kernels/softmax/list.h
@@ -37,6 +37,16 @@ DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax);
DECLARE_SOFTMAX_KERNEL(neon_qasymm8_softmax);
DECLARE_SOFTMAX_KERNEL(neon_qasymm8_signed_softmax);
+#ifdef ARM_COMPUTE_ENABLE_SME2
+
+void sme2_fp32_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+
+void sme2_fp16_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+
+#endif // ARM_COMPUTE_ENABLE_SME2
+
#undef DECLARE_SOFTMAX_KERNEL
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index e035de0131..905e86c185 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,6 +53,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.fast_mode = info.fast_math();
asm_info.fixed_format = info.fixed_format();
asm_info.weight_format = info.weight_format();
+ asm_info.accumulate = info.accumulate();
asm_info.transpose_b =
info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method
@@ -219,6 +220,16 @@ Status CpuGemm::validate(const ITensorInfo *a,
const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
+ // When using accumulation(in place summation), for now, the only supported values for alpha and beta are 1 respectively 0.
+ // Do the appropriate checks before proceeding.
+ if (gemm_info.accumulate())
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(alpha != 1, "Accumulation is not supported when alpha is different from 1");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ (beta != 0 && c != nullptr),
+ "Accumulation is not supported when beta is different from 0 with a non-null bias matrix c");
+ }
+
const bool is_c_bias = beta == 1 && c != nullptr;
const bool run_addition = c != nullptr && beta != 0 && beta != 1;
// Check if we should use the pretransposed_b or original b
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 7460f2020c..55d950ff4a 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -809,9 +809,16 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
if (!_skip_im2col)
{
// Run input reshaping
- unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
- ITensorPack pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, im2col_output.get()}};
- NEScheduler::get().schedule_op(_im2col_kernel.get(), y_dim, _im2col_kernel->window(), pack);
+ unsigned int hint_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
+ unsigned int x_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
+ unsigned int hint_dim_iterations = _im2col_kernel->window().num_iterations(hint_dim);
+ unsigned int x_dim_iterations = _im2col_kernel->window().num_iterations(x_dim);
+ if (hint_dim_iterations < NEScheduler::get().num_threads() && x_dim_iterations > hint_dim_iterations)
+ {
+ hint_dim = x_dim;
+ }
+ ITensorPack pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, im2col_output.get()}};
+ NEScheduler::get().schedule_op(_im2col_kernel.get(), hint_dim, _im2col_kernel->window(), pack);
gemm_input_to_use = im2col_output.get();
}
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index b25505a85d..f3396fbb5c 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -65,6 +65,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.output_stage = info.gemmlowp_output_stage();
asm_info.fast_mode = info.fast_math();
+ asm_info.accumulate = info.accumulate();
return asm_info;
}
@@ -127,6 +128,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_reshape_b_only_on_first_run;
_gemm_info = gemm_info;
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ // It is not needed if the datatype is symmetric, because there is no offset
+ bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
+
_asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
const ITensorInfo *a_to_use = a;
@@ -228,8 +234,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
// Build reduction info
const GEMMLowpReductionKernelInfo reduction_info(a_to_use->dimension(0), false, 0, false);
- // Initialize matrix B reduction kernel only if _a_offset is not equal to 0
- if (_a_offset != 0)
+ if (a_offset_kernel_needed)
{
_vector_sum_col = TensorInfo(compute_reductionA_shape(*b), 1, DataType::S32);
@@ -238,8 +243,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_mtx_b_reduction_kernel->configure(b, &_vector_sum_col, reduction_info);
}
- // Initialize Matrix A reduction kernel only if _b_offset is not equal to 0
- if (_b_offset != 0)
+ if (b_offset_kernel_needed)
{
_vector_sum_row = TensorInfo(compute_reductionB_shape(*a_to_use), 1, DataType::S32);
@@ -260,8 +264,8 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_offset_contribution_output_stage_kernel =
std::make_unique<kernels::CpuGemmLowpOffsetContributionOutputStageKernel>();
_offset_contribution_output_stage_kernel->configure(
- &_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col,
- _b_offset == 0 ? nullptr : &_vector_sum_row, c, _flip_signedness ? &_signed_output : dst,
+ &_mm_result_s32, a_offset_kernel_needed ? &_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &_vector_sum_row : nullptr, c, _flip_signedness ? &_signed_output : dst,
a->dimension(0), _a_offset, _b_offset, info.gemmlowp_output_stage());
if (_flip_signedness)
@@ -272,6 +276,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
}
else
{
+ // This scale is needed for the s8_f32 kernel where the multiplication output is dequantized to F32.
+ const float dequantize_scale =
+ (dst->data_type() == DataType::F32)
+ ? a->quantization_info().uniform().scale * b->quantization_info().uniform().scale
+ : 1.0f;
// Configure matrix multiply kernel
if (!_assembly_path)
{
@@ -280,9 +289,9 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
}
// Configure offset contribution kernel
_offset_contribution_kernel = std::make_unique<kernels::CpuGemmLowpOffsetContributionKernel>();
- _offset_contribution_kernel->configure(dst, _a_offset == 0 ? nullptr : &_vector_sum_col,
- _b_offset == 0 ? nullptr : &_vector_sum_row, a_to_use->dimension(0),
- _a_offset, _b_offset);
+ _offset_contribution_kernel->configure(dst, a_offset_kernel_needed ? &_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &_vector_sum_row : nullptr,
+ a_to_use->dimension(0), _a_offset, _b_offset, dequantize_scale);
}
}
// Configure activation
@@ -305,11 +314,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
}
// Request memory for LHS and RHS reshape matrix
- _aux_mem[VectorSumCol] =
- MemoryInfo(offset_int_vec(VectorSumCol),
- !_fused_assembly_path && _a_offset != 0 && _reshape_b_only_on_first_run ? MemoryLifetime::Persistent
- : MemoryLifetime::Temporary,
- _vector_sum_col.total_size());
+ _aux_mem[VectorSumCol] = MemoryInfo(offset_int_vec(VectorSumCol),
+ !_fused_assembly_path && a_offset_kernel_needed && _reshape_b_only_on_first_run
+ ? MemoryLifetime::Persistent
+ : MemoryLifetime::Temporary,
+ _vector_sum_col.total_size());
_aux_mem[VectorSumRow] =
MemoryInfo(offset_int_vec(VectorSumRow), MemoryLifetime::Temporary, _vector_sum_row.total_size());
_aux_mem[TmpA] = MemoryInfo(offset_int_vec(TmpA), MemoryLifetime::Temporary, _tmp_a.total_size());
@@ -333,8 +342,8 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32, DataType::QASYMM8,
- DataType::QASYMM8_SIGNED);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(c != nullptr &&
+ DataType::QASYMM8_SIGNED, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(c != nullptr && output->data_type() != DataType::F32 &&
gemm_info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::NONE,
"Bias addition not supported in NEGEMMLowpMatrixMultiplyCore for output S32");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
@@ -343,6 +352,16 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
+ // When using accumulation(in place summation), for now, the only supported DataType for output is S32.
+ if (gemm_info.accumulate())
+ {
+#ifdef __arm__
+ ARM_COMPUTE_RETURN_ERROR_MSG("Accumulation is not supported for armv7");
+#endif /* __arm__ */
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE,
+ "Accumulation is not supported for output QASYMM8/QASYMM8_SIGNED");
+ }
+
GEMMInfo info = gemm_info;
const ITensorInfo *matrix_a_info = a;
const ITensorInfo *matrix_b_info = b;
@@ -356,6 +375,10 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
int32_t a_offset = a->quantization_info().uniform().offset;
int32_t b_offset = b->quantization_info().uniform().offset;
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
+
bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE;
if (fuse_output_stage)
{
@@ -478,7 +501,7 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
const GEMMLowpReductionKernelInfo reduction_info(a_to_use->dimension(0), false, 0, false);
// Validate matrix B reduction kernel only if _a_offset is not equal to 0
- if (a_offset != 0)
+ if (a_offset_kernel_needed)
{
info_vector_sum_col = TensorInfo(compute_reductionA_shape(*b), 1, DataType::S32);
@@ -488,7 +511,7 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
}
// Validate Matrix A reduction kernel only if _b_offset is not equal to 0
- if (b_offset != 0)
+ if (b_offset_kernel_needed)
{
info_vector_sum_row = TensorInfo(compute_reductionB_shape(*a), 1, DataType::S32);
@@ -514,9 +537,9 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
// Validate offset contribution kernel
ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuGemmLowpOffsetContributionOutputStageKernel::validate(
- &mm_result_s32_info, a_offset == 0 ? nullptr : &info_vector_sum_col,
- b_offset == 0 ? nullptr : &info_vector_sum_row, c, flip_signedness ? &signed_output : output, a_offset,
- b_offset, info.gemmlowp_output_stage()));
+ &mm_result_s32_info, a_offset_kernel_needed ? &info_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &info_vector_sum_row : nullptr, c, flip_signedness ? &signed_output : output,
+ a_offset, b_offset, info.gemmlowp_output_stage()));
}
else
{
@@ -534,8 +557,8 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
}
// Validate offset contribution kernel
ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuGemmLowpOffsetContributionKernel::validate(
- output, a_offset == 0 ? nullptr : &info_vector_sum_col, b_offset == 0 ? nullptr : &info_vector_sum_row,
- a_offset, b_offset));
+ output, a_offset_kernel_needed ? &info_vector_sum_col : nullptr,
+ b_offset_kernel_needed ? &info_vector_sum_row : nullptr, a_offset, b_offset));
}
}
@@ -569,6 +592,14 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors)
CpuAuxTensorHandler signed_a(offset_int_vec(SignedA), _signed_a, tensors, false);
CpuAuxTensorHandler signed_output(offset_int_vec(SignedOutput), _signed_output, tensors, false);
+ const QuantizationInfo a_qinfo = a->info()->quantization_info();
+ const QuantizationInfo b_qinfo = b->info()->quantization_info();
+
+ if (a_qinfo.is_dynamic())
+ _a_offset = a_qinfo.uniform().offset;
+ if (b_qinfo.is_dynamic())
+ _b_offset = b_qinfo.uniform().offset;
+
// Convert QASYMM8->QASYMM8_SIGNED
if (_flip_signedness)
{
@@ -651,6 +682,11 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors)
if (_fuse_output_stage)
{
+ if (a_qinfo.is_dynamic())
+ _offset_contribution_output_stage_kernel->set_a_offset(_a_offset);
+ if (b_qinfo.is_dynamic())
+ _offset_contribution_output_stage_kernel->set_b_offset(_b_offset);
+
ITensorPack pack;
pack.add_tensor(TensorType::ACL_SRC_0, mm_result_s32.get());
pack.add_tensor(TensorType::ACL_SRC_1, _a_offset == 0 ? nullptr : vector_sum_col.get());
@@ -664,6 +700,16 @@ void CpuGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors)
}
else
{
+ if (a_qinfo.is_dynamic())
+ _offset_contribution_kernel->set_a_offset(_a_offset);
+ if (b_qinfo.is_dynamic())
+ _offset_contribution_kernel->set_b_offset(_b_offset);
+ if (a_qinfo.is_dynamic() || b_qinfo.is_dynamic())
+ {
+ const float dequantize_scale = a_qinfo.uniform().scale * b_qinfo.uniform().scale;
+ _offset_contribution_kernel->set_scale(dequantize_scale);
+ }
+
ITensorPack pack;
pack.add_tensor(TensorType::ACL_SRC_0, _a_offset == 0 ? nullptr : vector_sum_col.get());
pack.add_tensor(TensorType::ACL_SRC_1, _b_offset == 0 ? nullptr : vector_sum_row.get());
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
index 78065a8953..38121c9bb4 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -92,6 +92,7 @@ public:
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |S32 |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |S32 |
* |QASYMM8_SIGNED |QSYMM8 |S32 |S32 |
+ * |QASYMM8_SIGNED |QASYMM8_SIGNED |F32 |F32 |
*
* @note GEMM_LOWP: low precision GEMM kernel
* This kernel performs the following computations:
@@ -100,12 +101,12 @@ public:
* -# Convert b values from QASYMM8 to int32 add b_offset to each of them.
* -# Compute the matrix product of the resulting a * b in int32.
*
- * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED otherwise
+ * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED/F32 otherwise
*
* @param[in] a First input tensor info (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED.
* @param[in] b Second input tensor info (Matrix B). Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM8/QSYMM8_PER_CHANNEL.
- * @param[in] c Third input tensor info (Matrix C). It can be a nullptr. Data type supported: S32
- * @param[out] dst Output tensor info. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED
+ * @param[in] c Third input tensor info (Matrix C). It can be a nullptr. Data type supported: S32/F32
+ * @param[out] dst Output tensor info. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED/F32
* @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
* if the reshape of matrix B should be executed only for the first run
*/
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index efe2a7a67e..7d85885654 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -540,6 +540,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
{
configure_indirect(a, b, d, gemm_info);
}
+
+ if (std::is_same<OutputStage, arm_gemm::DequantizeFloat>::value)
+ {
+ // Output dequantization is just the two src scales multiplied together
+ _gemm_kernel_asm->set_dequantize_scale(a->quantization_info().uniform().scale *
+ b->quantization_info().uniform().scale);
+ }
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -630,6 +637,15 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
auto d = tensors.get_tensor(TensorType::ACL_DST);
ARM_COMPUTE_ERROR_ON_NULLPTR(a, d);
+ // Only update at runtime if the src quantization is dynamic
+ if (std::is_same<OutputStage, arm_gemm::DequantizeFloat>::value &&
+ (a->info()->quantization_info().is_dynamic() || b->info()->quantization_info().is_dynamic()))
+ {
+ // Output dequantization is just the two src scales multiplied together
+ _gemm_kernel_asm->set_dequantize_scale(a->info()->quantization_info().uniform().scale *
+ b->info()->quantization_info().uniform().scale);
+ }
+
int lda = a->info()->strides_in_bytes().y() / a->info()->element_size();
int ldb = 0;
const int ldd = d->info()->strides_in_bytes().y() / d->info()->element_size();
@@ -775,7 +791,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -784,6 +800,39 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
}
template <typename TypeInput, typename TypeOutput>
+void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
+ const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::Activation activation,
+ const AsmGemmInfo &info)
+{
+ ARM_COMPUTE_UNUSED(activation);
+
+ Params p = extract_parameters(a, b, d, info);
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ const unsigned int num_threads = NEScheduler::get().num_threads();
+
+ arm_gemm::GemmConfig cfg;
+ cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
+
+ // Create arm_gemm fallback
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>();
+
+ // Configure requantization info
+ const GEMMLowpOutputStageInfo os_info = info.output_stage;
+
+ arm_gemm::DequantizeFloat gemm_dequant_info{};
+ gemm_dequant_info = arm_gemm::DequantizeFloat(d->quantization_info().uniform().scale);
+
+ fallback->configure(a, b, c, d, args, info, gemm_dequant_info);
+ arm_gemm = std::move(fallback);
+}
+
+template <typename TypeInput, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -800,7 +849,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -855,8 +904,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
- info.fixed_format, info.fast_mode, &cfg);
-
+ info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
@@ -1032,6 +1080,10 @@ void CpuGemmAssemblyDispatch::configure(
{
create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
}
+ else if (d->data_type() == DataType::F32)
+ {
+ create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ }
else
{
create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 671a222fed..44c5c189a5 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2023 Arm Limited.
+ * Copyright (c) 2018-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -57,6 +57,7 @@ struct AsmGemmInfo
bool fixed_format{false};
arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED};
bool reshape_b_only_on_first_run{true};
+ bool accumulate{false};
/** Whether we want to perform an additional transpose of b before passing it to gemm or pretranspose_B_array
* @note This transpose b operation is also considered a form of "reshape" or "transform", so should be counted for
* by the reshape_b_only_on_first_run flag
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index 4544a66e39..c4117b8a1a 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2023 Arm Limited.
+ * Copyright (c) 2016-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -441,6 +441,8 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map =
{"reorg_layer_nhwc", "nhwc/reorg_layer.cl"},
{"scale_nearest_neighbour_nhwc", "nhwc/scale.cl"},
{"scale_bilinear_nhwc", "nhwc/scale.cl"},
+ {"scatter_mp1d_2d_mpnd", "common/scatter.cl"},
+ {"scatter1D", "common/scatter.cl"},
{"space_to_batch_nhwc", "nhwc/space_to_batch.cl"},
{"space_to_batch_static_nhwc", "nhwc/space_to_batch.cl"},
{"space_to_depth_nhwc", "nhwc/space_to_depth.cl"},
@@ -591,6 +593,10 @@ const std::map<std::string, std::string> ClKernelLibrary::_program_source_map =
#include "./cl_kernels/common/gather.clembed"
},
{
+ "common/scatter.cl",
+#include "./cl_kernels/common/scatter.clembed"
+ },
+ {
"common/gemm.cl",
#include "./cl_kernels/common/gemm.clembed"
},
diff --git a/src/gpu/cl/kernels/ClScatterKernel.cpp b/src/gpu/cl/kernels/ClScatterKernel.cpp
index 720164366e..9c25b63c72 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.cpp
+++ b/src/gpu/cl/kernels/ClScatterKernel.cpp
@@ -26,6 +26,14 @@
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
+
+#include "src/common/utils/Log.h"
+#include "src/core/helpers/WindowHelpers.h"
+#include "support/Cast.h"
+
+#include <cstdint>
namespace arm_compute
{
@@ -33,44 +41,166 @@ namespace opencl
{
namespace kernels
{
+
+namespace
+{
+constexpr int max_index_length = 5;
+} // namespace
+
ClScatterKernel::ClScatterKernel()
{
}
-Status ClScatterKernel::validate(const ITensorInfo *src,
- const ITensorInfo *updates,
+Status ClScatterKernel::validate(const ITensorInfo *updates,
const ITensorInfo *indices,
const ITensorInfo *dst,
const ScatterInfo &info)
{
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
ARM_COMPUTE_UNUSED(info);
+ const TensorShape &ind_shape = indices->tensor_shape();
+ const TensorShape &upt_shape = updates->tensor_shape();
+ const TensorShape &dst_shape = dst->tensor_shape();
+
+ const int32_t upt_dims = upt_shape.num_dimensions();
+ const int32_t dst_dims = dst_shape.num_dimensions();
+ const int32_t ind_dims = ind_shape.num_dimensions();
+
+ const int32_t index_len = ind_shape[0];
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(updates, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(indices, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(dst, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(ind_dims > 2, "Only 2D indices tensors are currently supported.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ ind_shape[1] != upt_shape[upt_dims - 1],
+ "Height of indices tensor should match size of highest dimension in updates tensor.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(upt_dims > dst_dims, "Update tensor cannot have more dims than output tensor.");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(index_len > max_index_length, "Maximum supported index length is 5!");
+ ARM_COMPUTE_RETURN_ERROR_ON(index_len != dst_dims - upt_dims + 1);
+
return Status{};
}
+
void ClScatterKernel::configure(const ClCompileContext &compile_context,
- const ITensorInfo *src,
const ITensorInfo *updates,
const ITensorInfo *indices,
ITensorInfo *dst,
const ScatterInfo &info)
{
- ARM_COMPUTE_UNUSED(compile_context);
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
- ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(updates, dst, indices);
+ ARM_COMPUTE_LOG_PARAMS(updates, indices, dst, info);
+
+ const TensorShape &dst_shape = dst->tensor_shape();
+
+ const bool is_scalar_block = updates->num_dimensions() == 1;
+ const int n0 = adjust_vec_size(16 / updates->element_size(), is_scalar_block ? 1 : updates->dimension(0));
+
+ const int partial_n0 = updates->dimension(0) % n0;
+
+ // The GWS will be 2D [x, y]
+ // x-dimension refers to the x coordinate of the dst tensor
+ // y-dimension refers to the collapsed y-coordinate of the data part of the dst tensor
+ Window win = calculate_max_window(dst_shape, Steps(n0));
+ const int index_len = indices->dimension(0);
+
+ // Collapse the dimensions corresponding to indices in the execution window
+ for (int i = 0; i < index_len; ++i)
+ {
+ win.set(dst->num_dimensions() - (i + 1), Window::Dimension(0, 1, 1));
+ }
+
+ win = win.collapse(win, 1);
+
+ // Set build options
+ CLBuildOptions build_opts;
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dst->data_type()));
+
+ const int num_dims = dst->num_dimensions();
+
+ build_opts.add_option("-DNUM_INDICES=" + support::cpp11::to_string(indices->dimension(1)));
+ build_opts.add_option("-DINDEX_LENGTH=" + support::cpp11::to_string(index_len));
+
+ // We provide 5 variables to use in a constant array
+ for (int i = 1; i <= max_index_length; i++)
+ {
+ build_opts.add_option("-DOUT_SHAPE_N_MINUS_" + support::cpp11::to_string(i) + "=" +
+ support::cpp11::to_string(dst_shape[std::max(num_dims - i, 0)]));
+ }
+
+ build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
+ build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_n0));
+
+ switch (info.func)
+ {
+ case ScatterFunction::Update:
+ build_opts.add_option("-DSCATTER_FUNCTION=UPDATE_OP");
+ build_opts.add_option("-DSKIP_OUTPUT_READ");
+ break;
+ case ScatterFunction::Add:
+ build_opts.add_option("-DSCATTER_FUNCTION=ADD_OP");
+ break;
+ case ScatterFunction::Sub:
+ build_opts.add_option("-DSCATTER_FUNCTION=SUB_OP");
+ break;
+ case ScatterFunction::Max:
+ build_opts.add_option("-DSCATTER_FUNCTION=MAX_OP");
+ break;
+ case ScatterFunction::Min:
+ build_opts.add_option("-DSCATTER_FUNCTION=MIN_OP");
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+
+ // Create kernel
+ std::string kernel_name = "scatter_mp1d_2d_mpnd";
+ build_opts.add_option("-D" + upper_string(kernel_name));
+
+ ICLKernel::configure_internal(win);
+ _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
+
+ // Set config_id for enabling LWS tuning
+ _config_id = kernel_name;
+ _config_id += "_";
+ _config_id += lower_string(string_from_data_type(updates->data_type()));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(1));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(0));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(dst->dimension(2));
+ _config_id += "_";
}
void ClScatterKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
{
- ARM_COMPUTE_UNUSED(tensors);
- ARM_COMPUTE_UNUSED(window);
- ARM_COMPUTE_UNUSED(queue);
+ const auto updates =
+ utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0));
+ const auto indices =
+ utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
+ auto dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
+
+ const ITensorInfo *dst_info = dst->info();
+ const int num_dims = dst_info->num_dimensions();
+
+ const int index_len = indices->info()->dimension(0);
+
+ // calculate m-dimensional data block strides in updates and destination tensors
+ const int upt_block_stride = updates->info()->strides_in_bytes()[updates->info()->num_dimensions() - 1];
+ const int out_block_stride = dst_info->strides_in_bytes()[num_dims - index_len];
+
+ unsigned int idx = 0;
+
+ add_2D_tensor_argument(idx, updates, window);
+ add_2D_tensor_argument(idx, indices, window);
+ add_2D_tensor_argument(idx, dst, window);
+
+ _kernel.setArg<cl_int>(idx++, upt_block_stride);
+ _kernel.setArg<cl_int>(idx++, out_block_stride);
+
+ enqueue(queue, *this, window, lws_hint());
}
} // namespace kernels
diff --git a/src/gpu/cl/kernels/ClScatterKernel.h b/src/gpu/cl/kernels/ClScatterKernel.h
index dda614ff3e..e1b469c88e 100644
--- a/src/gpu/cl/kernels/ClScatterKernel.h
+++ b/src/gpu/cl/kernels/ClScatterKernel.h
@@ -37,6 +37,7 @@ namespace opencl
{
namespace kernels
{
+
class ClScatterKernel : public IClKernel
{
public:
@@ -44,15 +45,15 @@ public:
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClScatterKernel);
/** Initialise the kernel's input and output.
*
+ * @note Negative indices are treated as out of bounds.
+ *
* @param[in] compile_context The compile context to be used.
- * @param[in] src Input tensor info for the source matrix.
* @param[in] updates Input tensor info for the Update matrix. Data type supported: same as @p src
- * @param[in] indices Input tensor info for the Indices matrix. Data type supported: U32.
+ * @param[in] indices Input tensor info for the Indices matrix. Data type supported: S32.
* @param[out] dst Output tensor info. Data type supported: same as @p src
* @param[in] info Attributes for Scatter Kernel
*/
void configure(const ClCompileContext &compile_context,
- const ITensorInfo *src,
const ITensorInfo *updates,
const ITensorInfo *indices,
ITensorInfo *dst,
@@ -63,11 +64,8 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *src,
- const ITensorInfo *updates,
- const ITensorInfo *indices,
- const ITensorInfo *dst,
- const ScatterInfo &info);
+ static Status
+ validate(const ITensorInfo *updates, const ITensorInfo *indices, const ITensorInfo *dst, const ScatterInfo &info);
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
diff --git a/src/gpu/cl/operators/ClScatter.cpp b/src/gpu/cl/operators/ClScatter.cpp
index af5fbb86f3..a11ecd7e6a 100644
--- a/src/gpu/cl/operators/ClScatter.cpp
+++ b/src/gpu/cl/operators/ClScatter.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "src/common/utils/Log.h"
+#include "src/gpu/cl/kernels/ClCopyKernel.h"
#include "src/gpu/cl/kernels/ClFillKernel.h"
#include "src/gpu/cl/kernels/ClScatterKernel.h"
@@ -47,9 +48,19 @@ Status ClScatter::validate(const ITensorInfo *src,
const ScatterInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(updates, indices, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(indices, 1, DataType::U32);
+ if (src != nullptr)
+ {
+ // Check dst/src are same shape and datatype.
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(src->tensor_shape(), dst->tensor_shape());
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, updates, dst);
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClCopyKernel::validate(src, dst)); // Validate Copy kernel
+ }
+ if (src != dst)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClFillKernel::validate(dst, PixelValue(0.0f))); // Validate Fill kernel.
+ }
- return kernels::ClScatterKernel::validate(src, updates, indices, dst, info);
+ return kernels::ClScatterKernel::validate(updates, indices, dst, info);
}
void ClScatter::configure(const CLCompileContext &compile_context,
@@ -61,11 +72,6 @@ void ClScatter::configure(const CLCompileContext &compile_context,
{
ARM_COMPUTE_ERROR_ON_NULLPTR(updates, indices, dst);
ARM_COMPUTE_LOG_PARAMS(src, indices, dst, info);
- ARM_COMPUTE_UNUSED(src);
- ARM_COMPUTE_UNUSED(updates);
- ARM_COMPUTE_UNUSED(indices);
- ARM_COMPUTE_UNUSED(dst);
- ARM_COMPUTE_UNUSED(info);
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(validate(src, updates, indices, dst, info));
@@ -74,19 +80,50 @@ void ClScatter::configure(const CLCompileContext &compile_context,
// If necessary, create fill kernel to fill dst tensor.
if (_fill_zero)
{
- _fill_kernel = std::make_unique<kernels::ClFillKernel>();
+ auto f = std::make_unique<kernels::ClFillKernel>();
+ f->configure(compile_context, dst, PixelValue(0.0f));
+ _fill_kernel = std::move(f);
+ }
+ else if (src != dst) // Check whether copying is necessary
+ {
+ // Fill dst with src copy here.
+ auto j = std::make_unique<kernels::ClCopyKernel>();
+ j->configure(compile_context, src, dst);
+ _copy_kernel = std::move(j);
+ _run_copy = true;
}
// Configure ClScatterKernel
auto k = std::make_unique<kernels::ClScatterKernel>();
k->set_target(CLScheduler::get().target());
- k->configure(compile_context, src, updates, indices, dst, info);
+ k->configure(compile_context, updates, indices, dst, info);
_scatter_kernel = std::move(k);
}
void ClScatter::run(ITensorPack &tensors)
{
- ARM_COMPUTE_UNUSED(tensors);
+ // Get tensors.
+ auto src = tensors.get_const_tensor(ACL_SRC_0);
+ auto updates = tensors.get_const_tensor(ACL_SRC_1);
+ auto indices = tensors.get_const_tensor(ACL_SRC_2);
+ auto dst = tensors.get_tensor(ACL_DST);
+
+ if (_fill_zero)
+ {
+ // Fill destination tensor with 0 values if zero init.
+ ITensorPack fill_pack{{ACL_SRC, dst}};
+ CLScheduler::get().enqueue_op(*_fill_kernel, fill_pack, false);
+ }
+
+ if (_run_copy)
+ {
+ // copy src to dst before scatter op.
+ ITensorPack copy_pack{{ACL_SRC, src}, {ACL_DST, dst}};
+ CLScheduler::get().enqueue_op(*_copy_kernel, copy_pack, false);
+ }
+
+ ITensorPack scatter_pack{{ACL_SRC_0, updates}, {ACL_SRC_1, indices}, {ACL_DST, dst}};
+ CLScheduler::get().enqueue_op(*_scatter_kernel, scatter_pack, false);
}
} // namespace opencl
diff --git a/src/gpu/cl/operators/ClScatter.h b/src/gpu/cl/operators/ClScatter.h
index 433f7ca3a4..a1b32fed45 100644
--- a/src/gpu/cl/operators/ClScatter.h
+++ b/src/gpu/cl/operators/ClScatter.h
@@ -39,6 +39,7 @@ namespace opencl
// Forward declaration
class ClFillKernel;
class ClScatterKernel;
+class ClCopyKernel;
/** Basic operator to execute Scatter on OpenCL. This operator calls the following OpenCL kernels:
*
@@ -56,13 +57,14 @@ public:
* Valid data layouts:
* - All
*
- * @note indices must always be U32
+ * @note indices must always be S32.
+ * @note Negative indices are treated as out of bounds.
* @note src, updates and dst tensors must be same datatype.
*
* @param[in] compile_context The compile context to be used.
* @param[in] src Source input tensor info. Can be nullptr when using "Add" Scatter Function with zero initialization.
* @param[in] updates Tensor info for tensor storing update values to use for scatter function. Data types supported: same as @p src.
- * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: U32 only.
+ * @param[in] indices Tensor info for tensor storing indices to use for scatter function. Data types supported: S32 only.
* @param[out] dst Output tensor to store the result of the Scatter Function. Data types supported: same as @p src and @p updates.
* @param[in] Scatter_info Contains Scatter operation information described in @ref ScatterInfo.
*/
@@ -89,7 +91,9 @@ public:
private:
std::unique_ptr<opencl::IClKernel> _scatter_kernel{nullptr};
std::unique_ptr<opencl::IClKernel> _fill_kernel{nullptr};
+ std::unique_ptr<opencl::IClKernel> _copy_kernel{nullptr};
bool _fill_zero{false};
+ bool _run_copy{false};
};
} // namespace opencl
} // namespace arm_compute
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index 6cdff7f559..e45319ef57 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_LARGE_GEMM_DATASET
-#define ARM_COMPUTE_TEST_LARGE_GEMM_DATASET
+#ifndef ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
+#define ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
#include "tests/datasets/GEMMDataset.h"
@@ -79,7 +79,20 @@ public:
add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f);
}
};
+
+class LargeAccumulateGEMMDataset final : public GEMMDataset
+{
+public:
+ LargeAccumulateGEMMDataset()
+ {
+ add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 429U), TensorShape(871U, 429U), 1.0f, 0.0f);
+ add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 1.0f, 0.0f);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_LARGE_GEMM_DATASET */
+#endif // ACL_TESTS_DATASETS_LARGEGEMMDATASET_H
diff --git a/tests/datasets/ScatterDataset.h b/tests/datasets/ScatterDataset.h
index d204d17855..c0858941db 100644
--- a/tests/datasets/ScatterDataset.h
+++ b/tests/datasets/ScatterDataset.h
@@ -113,13 +113,76 @@ private:
std::vector<TensorShape> _dst_shapes{};
};
+
+// 1D dataset for simple scatter tests.
class Small1DScatterDataset final : public ScatterDataset
{
public:
Small1DScatterDataset()
{
- add_config(TensorShape(6U), TensorShape(6U), TensorShape(6U), TensorShape(6U));
- add_config(TensorShape(10U), TensorShape(2U), TensorShape(2U), TensorShape(10U));
+ add_config(TensorShape(6U), TensorShape(6U), TensorShape(1U, 6U), TensorShape(6U));
+ add_config(TensorShape(10U), TensorShape(2U), TensorShape(1U, 2U), TensorShape(10U));
+ }
+};
+
+// This dataset represents the (m+1)-D updates/dst case.
+class SmallScatterMultiDimDataset final : public ScatterDataset
+{
+public:
+ SmallScatterMultiDimDataset()
+ {
+ // NOTE: Config is src, updates, indices, output.
+ // - In this config, the dim replaced is the final number (largest tensor dimension)
+ // - Largest "updates" dim should match y-dim of indices.
+ // - src/updates/dst should all have same number of dims. Indices should be 2D.
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 2U), TensorShape(1U, 2U), TensorShape(6U, 5U));
+ add_config(TensorShape(9U, 3U, 4U), TensorShape(9U, 3U, 2U), TensorShape(1U, 2U), TensorShape(9U, 3U, 4U));
+ add_config(TensorShape(17U, 3U, 2U, 4U), TensorShape(17U, 3U, 2U, 7U), TensorShape(1U, 7U), TensorShape(17U, 3U, 2U, 4U));
+ }
+};
+
+// This dataset represents the (m+1)-D updates tensor, (m+n)-d output tensor cases
+class SmallScatterMultiIndicesDataset final : public ScatterDataset
+{
+public:
+ SmallScatterMultiIndicesDataset()
+ {
+ // NOTE: Config is src, updates, indices, output.
+ // NOTE: indices.shape.x = src.num_dimensions - updates.num_dimensions + 1
+
+ // index length is 2
+ add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 4U), TensorShape(2U, 4U), TensorShape(6U, 5U, 2U));
+ add_config(TensorShape(17U, 3U, 3U, 2U), TensorShape(17U, 3U, 2U), TensorShape(2U, 2U), TensorShape(17U, 3U, 3U, 2U));
+ add_config(TensorShape(11U, 3U, 3U, 2U, 4U), TensorShape(11U, 3U, 3U, 4U), TensorShape(2U, 4U), TensorShape(11U, 3U, 3U, 2U, 4U));
+ add_config(TensorShape(5U, 4U, 3U, 3U, 2U, 4U), TensorShape(5U, 4U, 3U, 3U, 5U), TensorShape(2U, 5U), TensorShape(5U, 4U, 3U, 3U, 2U, 4U));
+
+ // index length is 3
+ add_config(TensorShape(4U, 3U, 2U, 2U), TensorShape(4U, 2U), TensorShape(3U, 2U), TensorShape(4U, 3U, 2U, 2U));
+ add_config(TensorShape(17U, 4U, 3U, 2U, 2U), TensorShape(17U, 4U, 4U), TensorShape(3U, 4U), TensorShape(17U, 4U, 3U, 2U, 2U));
+ add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 4U, 5U, 3U), TensorShape(3U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
+
+ // index length is 4
+ add_config(TensorShape(35U, 4U, 3U, 2U, 2U), TensorShape(35U, 4U), TensorShape(4U, 4U), TensorShape(35U, 4U, 3U, 2U, 2U));
+ add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 4U, 3U), TensorShape(4U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
+
+ // index length is 5
+ add_config(TensorShape(10U, 4U, 5U, 3U, 2U, 2U), TensorShape(10U, 3U), TensorShape(5U, 3U), TensorShape(10U, 4U, 5U, 3U, 2U, 2U));
+ }
+};
+
+// This dataset represents the (m+k)-D updates tensor, (k+1)-d indices tensor and (m+n)-d output tensor cases
+class SmallScatterBatchedDataset final : public ScatterDataset
+{
+public:
+ SmallScatterBatchedDataset()
+ {
+ // NOTE: Config is src, updates, indices, output.
+ // NOTE: Updates/Indices tensors are now batched.
+ // NOTE: indices.shape.x = (updates_batched) ? (src.num_dimensions - updates.num_dimensions) + 2 : (src.num_dimensions - updates.num_dimensions) + 1
+ add_config(TensorShape(6U, 5U), TensorShape(6U, 2U, 2U), TensorShape(1U, 2U, 2U), TensorShape(6U, 5U));
+ add_config(TensorShape(6U, 5U, 2U), TensorShape(6U, 2U, 2U), TensorShape(2U, 2U, 2U), TensorShape(6U, 5U, 2U));
+ add_config(TensorShape(6U, 5U, 2U, 2U), TensorShape(3U, 2U), TensorShape(4U, 3U, 2U), TensorShape(6U, 5U, 2U, 2U));
+ add_config(TensorShape(5U, 5U, 4U, 2U, 2U), TensorShape(6U, 2U), TensorShape(5U, 6U, 2U), TensorShape(5U, 5U, 4U, 2U, 2U));
}
};
} // namespace datasets
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index c12f57b266..99c7abbf64 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_SMALL_GEMM_DATASET
-#define ARM_COMPUTE_TEST_SMALL_GEMM_DATASET
+#ifndef ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
+#define ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
#include "tests/datasets/GEMMDataset.h"
@@ -97,7 +97,18 @@ public:
}
};
+class SmallAccumulateGEMMDataset final : public GEMMDataset
+{
+public:
+ SmallAccumulateGEMMDataset()
+ {
+ add_config(TensorShape(8U, 2U), TensorShape(16U, 8U), TensorShape(16U, 2U), TensorShape(16U, 2U), 1.0f, 0.0f);
+ add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U), TensorShape(23U, 1U), 1.0f, 0.0f);
+ add_config(TensorShape(21U, 13U), TensorShape(33U, 21U), TensorShape(33U, 13U), TensorShape(33U, 13U), 1.0f, 0.0f);
+ }
+};
+
} // namespace datasets
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_SMALL_GEMM_DATASET */
+#endif // ACL_TESTS_DATASETS_SMALLGEMMDATASET_H
diff --git a/tests/validation/CL/GEMMLowp.cpp b/tests/validation/CL/GEMMLowp.cpp
index 1ae9e96626..78d794a9bb 100644
--- a/tests/validation/CL/GEMMLowp.cpp
+++ b/tests/validation/CL/GEMMLowp.cpp
@@ -71,7 +71,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyCoreFixture, framework:
}
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
TEST_SUITE(BatchedMatMul)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned, framework::DatasetMode::ALL,
@@ -84,7 +84,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
TEST_SUITE_END() // QASYMM8
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
@@ -98,7 +98,7 @@ TEST_SUITE_END() // BatchedMatMul
TEST_SUITE(FusedOffsetOutput)
TEST_SUITE(QASYMM8)
-using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore>;
+using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -110,7 +110,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputUi
TEST_SUITE(Output3D)
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputOutput3DUint8Fixture =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, true>;
+ GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputOutput3DUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputOutput3DUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -123,7 +123,7 @@ TEST_SUITE_END() // Output3D
TEST_SUITE(InputOutput3D)
using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInputOutput3DUint8Fixture =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, true, true>;
+ GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, true, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInputOutput3DUint8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputInputOutput3DUint8Dataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -148,7 +148,8 @@ using CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInt8Fixture =
GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<CLTensor, CLAccessor, CLGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t>;
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyCoreFusedOffsetOutputInt8Fixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputInt8Dataset(),
- make("DataType", { DataType::QASYMM8_SIGNED })))
+ make("DataType", { DataType::QASYMM8_SIGNED }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_quant);
diff --git a/tests/validation/CL/ScatterLayer.cpp b/tests/validation/CL/ScatterLayer.cpp
index 56338f489f..4a2462c7d2 100644
--- a/tests/validation/CL/ScatterLayer.cpp
+++ b/tests/validation/CL/ScatterLayer.cpp
@@ -38,6 +38,10 @@ namespace test
{
namespace validation
{
+namespace
+{
+RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for fp32 data type */
+} // namespace
template <typename T>
using CLScatterLayerFixture = ScatterValidationFixture<CLTensor, CLAccessor, CLScatter, T>;
@@ -46,67 +50,123 @@ using framework::dataset::make;
TEST_SUITE(CL)
TEST_SUITE(Scatter)
-DATA_TEST_CASE(Validate, framework::DatasetMode::DISABLED, zip(
+DATA_TEST_CASE(Validate, framework::DatasetMode::PRECOMMIT, zip(
make("InputInfo", { TensorInfo(TensorShape(9U), 1, DataType::F32), // Mismatching data types
- TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
- TensorInfo(TensorShape(8U), 1, DataType::F32),
- TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
- TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
- TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype.
- }),
- make("UpdatesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::F16),
- TensorInfo(TensorShape(15U), 1, DataType::F32),
- TensorInfo(TensorShape(2U), 1, DataType::F32),
- TensorInfo(TensorShape(217U), 1, DataType::F32),
- TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(2U), 1, DataType::F32),
- }),
- make("IndicesInfo",{ TensorInfo(TensorShape(3U), 1, DataType::U32),
- TensorInfo(TensorShape(15U), 1, DataType::U32),
- TensorInfo(TensorShape(2U), 1, DataType::U32),
- TensorInfo(TensorShape(271U), 1, DataType::U32),
- TensorInfo(TensorShape(271U), 1, DataType::U32),
- TensorInfo(TensorShape(2U), 1 , DataType::S32)
- }),
- make("OutputInfo",{ TensorInfo(TensorShape(9U), 1, DataType::F16),
- TensorInfo(TensorShape(15U), 1, DataType::F32),
- TensorInfo(TensorShape(8U), 1, DataType::F32),
- TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
- TensorInfo(TensorShape(271U), 1, DataType::F32),
- TensorInfo(TensorShape(12U), 1, DataType::F32)
- }),
+ TensorInfo(TensorShape(15U), 1, DataType::F32), // Valid
+ TensorInfo(TensorShape(8U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U), 1, DataType::F32), // Mismatch input/output dims.
+ TensorInfo(TensorShape(217U), 1, DataType::F32), // Updates dim higher than Input/Output dims.
+ TensorInfo(TensorShape(12U), 1, DataType::F32), // Indices wrong datatype.
+ TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32), // Number of updates != number of indices
+ TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32), // index_len != (dst_dims - upt_dims + 1)
+ TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32), // index_len > 5
+ }),
+ make("UpdatesInfo",{TensorInfo(TensorShape(3U), 1, DataType::F16),
+ TensorInfo(TensorShape(15U), 1, DataType::F32),
+ TensorInfo(TensorShape(2U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U), 1, DataType::F32),
+ TensorInfo(TensorShape(217U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(2U), 1, DataType::F32),
+ TensorInfo(TensorShape(9U, 3U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(17U, 3U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(1U), 1, DataType::F32),
+ }),
+ make("IndicesInfo",{TensorInfo(TensorShape(1U, 3U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 15U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 2U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 271U), 1, DataType::S32),
+ TensorInfo(TensorShape(1U, 2U), 1 , DataType::F32),
+ TensorInfo(TensorShape(1U, 4U), 1, DataType::S32),
+ TensorInfo(TensorShape(3U, 2U), 1, DataType::S32),
+ TensorInfo(TensorShape(6U, 2U), 1, DataType::S32),
+ }),
+ make("OutputInfo",{TensorInfo(TensorShape(9U), 1, DataType::F16),
+ TensorInfo(TensorShape(15U), 1, DataType::F32),
+ TensorInfo(TensorShape(8U), 1, DataType::F32),
+ TensorInfo(TensorShape(271U, 3U), 1, DataType::F32),
+ TensorInfo(TensorShape(271U), 1, DataType::F32),
+ TensorInfo(TensorShape(12U), 1, DataType::F32),
+ TensorInfo(TensorShape(9U, 3U, 4U), 1, DataType::F32),
+ TensorInfo(TensorShape(17U, 3U, 3U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(17U, 3U, 3U, 2U, 2U, 2U), 1, DataType::F32),
+ }),
make("ScatterInfo",{ ScatterInfo(ScatterFunction::Add, false),
- }),
- make("Expected", { false, true, true, false, false, false })),
+ ScatterInfo(ScatterFunction::Max, false),
+ ScatterInfo(ScatterFunction::Min, false),
+ ScatterInfo(ScatterFunction::Add, false),
+ ScatterInfo(ScatterFunction::Update, false),
+ ScatterInfo(ScatterFunction::Sub, false),
+ ScatterInfo(ScatterFunction::Sub, false),
+ ScatterInfo(ScatterFunction::Update, false),
+ ScatterInfo(ScatterFunction::Update, false),
+ }),
+ make("Expected", { false, true, true, false, false, false, false, false, false })),
input_info, updates_info, indices_info, output_info, scatter_info, expected)
{
- // TODO: Enable validation tests.
- ARM_COMPUTE_UNUSED(input_info);
- ARM_COMPUTE_UNUSED(updates_info);
- ARM_COMPUTE_UNUSED(indices_info);
- ARM_COMPUTE_UNUSED(output_info);
- ARM_COMPUTE_UNUSED(scatter_info);
- ARM_COMPUTE_UNUSED(expected);
+ const Status status = CLScatter::validate(&input_info, &updates_info, &indices_info, &output_info, scatter_info);
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
}
+const auto allScatterFunctions = make("ScatterFunction",
+ {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max });
+
TEST_SUITE(Float)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
- make("DataType", {DataType::F32}),
- make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add, ScatterFunction::Sub, ScatterFunction::Min, ScatterFunction::Max}),
- make("ZeroInit", {false})))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::Small1DScatterDataset(),
+ make("DataType", {DataType::F32}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
{
- // TODO: Add validate() here.
+ validate(CLAccessor(_target), _reference, tolerance_f32);
}
// With this test, src should be passed as nullptr.
-FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::Small1DScatterDataset(),
- make("DataType", {DataType::F32}),
- make("ScatterFunction", {ScatterFunction::Add}),
- make("ZeroInit", {true})))
+FIXTURE_DATA_TEST_CASE(RunSmallZeroInit, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::Small1DScatterDataset(),
+ make("DataType", {DataType::F32}),
+ make("ScatterFunction", {ScatterFunction::Add}),
+ make("ZeroInit", {true}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+// Updates/src/dst have same no. dims.
+FIXTURE_DATA_TEST_CASE(RunSmallMultiDim, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMultiDimDataset(),
+ make("DataType", {DataType::F32}),
+ allScatterFunctions,
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+// m+1-D to m+n-D cases
+FIXTURE_DATA_TEST_CASE(RunSmallMultiIndices, CLScatterLayerFixture<float>, framework::DatasetMode::PRECOMMIT,
+ combine(datasets::SmallScatterMultiIndicesDataset(),
+ make("DataType", {DataType::F32}),
+ make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
+ make("ZeroInit", {false}),
+ make("Inplace", {false, true})))
+{
+ validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+
+// m+k, k-1-D m+n-D case
+FIXTURE_DATA_TEST_CASE(RunSmallBatchedMultiIndices, CLScatterLayerFixture<float>, framework::DatasetMode::DISABLED,
+ combine(datasets::SmallScatterBatchedDataset(),
+ make("DataType", {DataType::F32}),
+ make("ScatterFunction", {ScatterFunction::Update, ScatterFunction::Add }),
+ make("ZeroInit", {false}),
+ make("Inplace", {false})))
{
- // TODO: Add validate() here
+ validate(CLAccessor(_target), _reference, tolerance_f32);
}
+
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
TEST_SUITE_END() // Scatter
diff --git a/tests/validation/CPP/DFT.cpp b/tests/validation/CPP/DFT.cpp
index e19e850589..84431399be 100644
--- a/tests/validation/CPP/DFT.cpp
+++ b/tests/validation/CPP/DFT.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -125,7 +125,7 @@ DATA_TEST_CASE(Real, framework::DatasetMode::ALL, shapes_2d_dft,
auto backward = reference::ridft_2d(forward, is_odd);
// Validate with input
- validate(SimpleTensorAccessor<float>(src), backward, RelativeTolerance<float>(0.1f));
+ validate(SimpleTensorAccessor<float>(src), backward, RelativeTolerance<float>(0.1f), 0.f, AbsoluteTolerance<float>(0.001f));
}
DATA_TEST_CASE(Complex, framework::DatasetMode::ALL, shapes_2d_dft,
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 7a9230d37a..d739d4e1a4 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -767,21 +767,33 @@ FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixtureNoFastMath
}
#if defined(ARM_COMPUTE_ENABLE_BF16)
-
+// These tests currently only works with SVE length 256
+// If other SVE length is used a kernel will fail to be found
+// This needs to be addressed in order to ensure it doesn't revert to FP32 kernels for systems with SVE length other than 256
FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 })))
{
- ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS);
+ if(Scheduler::get().cpu_info().has_bf16() && (arm_gemm::utils::get_vector_length<float>() == 8)){
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS);
+ }
+ else{
+ ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+ }
}
FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 })))
{
- ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS);
+ if(Scheduler::get().cpu_info().has_bf16() && (arm_gemm::utils::get_vector_length<float>() == 8)){
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS);
+ }
+ else{
+ ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+ }
}
#endif // ARM_COMPUTE_ENABLE_BF16
@@ -852,20 +864,36 @@ FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath<c
combine(framework::dataset::make("DataType", { DataType::F32 }),
framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
- ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ if(Scheduler::get().cpu_info().has_bf16()){
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ }
+ else{
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ }
}
FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
combine(framework::dataset::make("DataType", { DataType::F32 }),
framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY })))
{
- ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ if(Scheduler::get().cpu_info().has_bf16()){
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ }
+ else{
+ ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS);
+ }
}
#endif // ARM_COMPUTE_ENABLE_BF16
@@ -1141,7 +1169,7 @@ TEST_SUITE(Float)
TEST_SUITE(BFLOAT16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::BFLOAT16)),
+ framework::dataset::make("DataType", Scheduler::get().cpu_info().has_bf16() ? DataType::BFLOAT16 : DataType::F32)),
framework::dataset::make("DataLayout", { DataLayout::NHWC })),
ActivationFunctionsDataset))
{
@@ -1329,6 +1357,27 @@ FIXTURE_DATA_TEST_CASE(RunSmallSigned, NEGEMMConvolutionLayerQuantizedPerChannel
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
+
+FIXTURE_DATA_TEST_CASE(MemoryStressLargeChannels, NEGEMMConvolutionLayerQuantizedPerChannelFixture<int8_t>,
+ framework::DatasetMode::ALL,
+ combine(
+ make("In", TensorShape(1U)),
+ make("Weights", TensorShape(1U, 1U, 1U, 17000U)),
+ make("Biases", TensorShape(17000U)),
+ make("Out", TensorShape(1U, 1U, 17000U)),
+ make("Info", PadStrideInfo(1, 1, 0, 0)),
+ make("Dilation", Size2D(1, 1)),
+ make("ReshapeWeights", { true }),
+ make("DataType", { DataType::QASYMM8_SIGNED }),
+ make("DataLayout", { DataLayout::NHWC }),
+ make("QuantizationInfo", QuantizationInfo(0.5f, 10)),
+ make("ActivationInfo", ActivationLayerInfo()),
+ make("WeightsDataType", { DataType::QSYMM8_PER_CHANNEL })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+
TEST_SUITE_END() // QSYMM8_PER_CHANNEL
TEST_SUITE_END() // Quantized
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index f956cdfeda..5f6a402204 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,6 +51,8 @@ namespace test
{
namespace validation
{
+using framework::dataset::make;
+
namespace
{
constexpr AbsoluteTolerance<float> tolerance_f(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for FP32 data types */
@@ -60,7 +62,7 @@ const AbsoluteTolerance<float> abs_tolerance_f16(0.2f); /**< Absolute
constexpr float tolerance_num = 0.07f; /**< Tolerance number for FP16 data types */
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
/** CNN data types */
-const auto CNNDataTypes = framework::dataset::make("DataType",
+const auto CNNDataTypes = make("DataType",
{
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
DataType::F16,
@@ -68,8 +70,8 @@ const auto CNNDataTypes = framework::dataset::make("DataType",
DataType::F32,
});
-const auto data_interleave = framework::dataset::make("M", 8, 12) * framework::dataset::make("N", 8, 12);
-const auto data_transpose = framework::dataset::make("M", 8, 14) * framework::dataset::make("N", 7, 14);
+const auto data_interleave = make("M", 8, 12) * make("N", 8, 12);
+const auto data_transpose = make("M", 8, 14) * make("N", 7, 14);
/** Zero padding test */
template <typename FunctionType>
@@ -204,16 +206,16 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("LhsInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::S32), // Unsupported data type
+ make("LhsInfo", { TensorInfo(TensorShape(27U, 13U), 1, DataType::S32), // Unsupported data type
TensorInfo(TensorShape(27U, 13U), 1, DataType::F32),
}),
- framework::dataset::make("RhsInfo",{ TensorInfo(TensorShape(8U, 27U), 1, DataType::S32),
+ make("RhsInfo",{ TensorInfo(TensorShape(8U, 27U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 27U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(8U, 13U), 1, DataType::S32),
+ make("OutputInfo",{ TensorInfo(TensorShape(8U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 13U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { false, true })),
+ make("Expected", { false, true })),
lhs_info, rhs_info, output_info, expected)
{
constexpr float alpha = 1.0;
@@ -226,8 +228,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
// *INDENT-ON*
TEST_SUITE(KERNEL_SELECTION)
DATA_TEST_CASE(KernelSelection_mul_and_add, framework::DatasetMode::ALL,
- combine(framework::dataset::make("CpuExt", std::string("NEON")),
- framework::dataset::make("DataType", { DataType::F32,
+ combine(make("CpuExt", std::string("NEON")),
+ make("DataType", { DataType::F32,
DataType::F16
})),
cpu_ext, data_type)
@@ -261,8 +263,8 @@ TEST_SUITE_END() // KERNEL_SELECTION
TEST_SUITE(TRANSPOSE_1XW)
using CpuGemmTranspose1xW = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmTranspose1xWKernel>;
DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
- framework::dataset::make("N", { 1, 23, 63, 101 }),
- framework::dataset::make("K", { 1, 47, 29, 27 })),
+ make("N", { 1, 23, 63, 101 }),
+ make("K", { 1, 47, 29, 27 })),
n_value, k_value)
{
bool status = validate_zero_padding<CpuGemmTranspose1xW>(n_value, k_value);
@@ -271,7 +273,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
TEST_SUITE(U32)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint32_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U32))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -280,7 +282,7 @@ TEST_SUITE_END() // U32
TEST_SUITE(U16)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint16_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U16))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -289,7 +291,7 @@ TEST_SUITE_END() // U16
TEST_SUITE(U8)
using CpuGemmTranspose1xWFixture = GEMMTranspose1xWValidationFixture<Tensor, Accessor, CpuGemmTranspose1xW, uint8_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * framework::dataset::make("DataType", DataType::U8))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmTranspose1xWFixture, framework::DatasetMode::PRECOMMIT, data_transpose * make("DataType", DataType::U8))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -302,8 +304,8 @@ TEST_SUITE(INTERLEAVE_4X4)
using CpuGemmInterleave4x4 = NESynthetizeFunctionWithZeroConstantKernelBorder<cpu::kernels::CpuGemmInterleave4x4Kernel>;
DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
- framework::dataset::make("M", { 1, 23, 63, 101 }),
- framework::dataset::make("K", { 1, 47, 29, 27 })),
+ make("M", { 1, 23, 63, 101 }),
+ make("K", { 1, 47, 29, 27 })),
m_value, k_value)
{
bool status = validate_zero_padding<cpu::kernels::CpuGemmInterleave4x4Kernel>(m_value, k_value);
@@ -312,7 +314,7 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(
TEST_SUITE(U32)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint32_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::U32))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -321,7 +323,7 @@ TEST_SUITE_END() // U32
TEST_SUITE(U16)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint16_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::U16))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -330,7 +332,7 @@ TEST_SUITE_END() // U16
TEST_SUITE(U8)
using CpuGemmInterleave4x4Fixture = GEMMInterleave4x4ValidationFixture<Tensor, Accessor, CpuGemmInterleave4x4, uint8_t>;
-FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * framework::dataset::make("DataType", DataType::QASYMM8))
+FIXTURE_DATA_TEST_CASE(RunSmall, CpuGemmInterleave4x4Fixture, framework::DatasetMode::PRECOMMIT, data_interleave * make("DataType", DataType::QASYMM8))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -345,15 +347,18 @@ using NEGEMMFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T>;
template <typename T>
using NEBatchedMatMulFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true, false, false, false, false, true>;
+template <typename T>
+using NEGEMMAccumulateFixture = GEMMAccumulateValidationFixture<Tensor, Accessor, NEGEMM, T>;
+
TEST_SUITE(Float)
-DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::dataset::make("In0", { TensorShape(21U, 13U),
+DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(make("In0", { TensorShape(21U, 13U),
TensorShape(31U, 1U),
TensorShape(31U, 1U),
TensorShape(8U, 2U),
TensorShape(38U, 12U),
TensorShape(32U, 1U)
}),
- framework::dataset::make("In1", { TensorShape(33U, 21U),
+ make("In1", { TensorShape(33U, 21U),
TensorShape(23U, 31U),
TensorShape(23U, 31U),
TensorShape(16U, 8U),
@@ -366,75 +371,111 @@ DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip(framework::
ARM_COMPUTE_EXPECT(status, framework::LogLevel::ERRORS);
}
+DATA_TEST_CASE(ValidateAccumulate, framework::DatasetMode::ALL, combine(
+ zip(make("In0",{ TensorShape(21U, 13U) }),
+ make("In1", { TensorShape(33U, 21U) }),
+ make("Dst", { TensorShape(33U, 13U) })),
+ zip(
+ make("alpha", { 1.0, 100.0, 1.0, 1.0 }),
+ make("beta", { 0.0, 0.0, 1.0, 1.0 }),
+ make("is_c_null", { false, false, false, true }),
+ make("Expected", { true, false, false, true }))),
+ shape_a, shape_b, shape_dst, alpha, beta, is_c_null, expected)
+{
+ /* Accumulation test for GEMM kernels */
+ // Create tensors
+ TensorInfo in_a(shape_a, 1, DataType::F32);
+ TensorInfo in_b(shape_b, 1, DataType::F32);
+ TensorInfo in_c(shape_dst, 1, DataType::F32);
+ TensorInfo dst(shape_dst, 1, DataType::F32);
+
+ GEMMInfo gemm_info = GEMMInfo();
+ gemm_info.set_accumulate(true);
+
+ // Validate accumulation
+ cpu::CpuGemm gemm;
+ Status status = gemm.validate(&in_a, &in_b, (is_c_null ? nullptr : &in_c), &dst, alpha, beta, gemm_info);
+ ARM_COMPUTE_EXPECT((expected == bool(status)), framework::LogLevel::ERRORS);
+}
+
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
- framework::dataset::make("DataType", DataType::F16)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-
-TEST_SUITE(BATCHED_MATMUL)
-
-FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
- framework::dataset::make("ReshapeWeights", { false })),
- framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-TEST_SUITE_END()
-FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F16)))
+TEST_SUITE(BATCHED_MATMUL)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
+ make("ReshapeWeights", { false })),
+ make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-TEST_SUITE_END()
+TEST_SUITE_END() // BATCHED_MATMUL
+
+TEST_SUITE_END() // FP16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
- framework::dataset::make("ReshapeWeights", { true, false })),
-
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { true, false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
TEST_SUITE(BATCHED_MATMUL)
-
-TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEBatchedMatMulFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallBatchedMatMulDataset(),
- framework::dataset::make("ReshapeWeights", { false })),
- framework::dataset::make("DataType", DataType::F32)))
+ make("ReshapeWeights", { false })),
+ make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
-TEST_SUITE_END()
+TEST_SUITE_END() // BATCHED_MATMUL
-TEST_SUITE_END()
+TEST_SUITE(ACCUMULATE)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMAccumulateFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallAccumulateGEMMDataset(),
+ make("ReshapeWeights", { false }),
+ make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMAccumulateFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeAccumulateGEMMDataset(),
+ make("ReshapeWeights", { false }),
+ make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f);
+}
+TEST_SUITE_END() // ACCUMULATE
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE_END() // FP32
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // GEMM
+TEST_SUITE_END() // NEON
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp
index 9c4d1741eb..9b1da61ed7 100644
--- a/tests/validation/NEON/GEMMLowp.cpp
+++ b/tests/validation/NEON/GEMMLowp.cpp
@@ -47,12 +47,24 @@ namespace test
{
namespace validation
{
+using framework::dataset::make;
+
+namespace
+{
+ constexpr AbsoluteTolerance<float> tolerance_batched(1);
+ constexpr AbsoluteTolerance<float> tolerance_quant(1);
+} // namespace
+
+
TEST_SUITE(NEON)
TEST_SUITE(GEMMLowp)
TEST_SUITE(MatrixMultiplyCore)
using NEGEMMLowpMatrixMultiplyCoreFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
+using NEGEMMLowpMatrixMultiplyCoreAccumulateFixture = GEMMLowpMatrixMultiplyAccumulateValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
using NEGEMMLowpBatchedMatMulFixture = GEMMLowpMatrixMultiplyCoreValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, true>;
+using NEGEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture = GEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
+using NEGEMMLowpDequantizedMatrixMultiplyValidationFixture = GEMMLowpDequantizedMatrixMultiplyValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
using framework::dataset::make;
@@ -80,6 +92,46 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, framework::dataset::c
validate(b.info()->padding(), PaddingSize());
validate(c.info()->padding(), PaddingSize());
}
+// accumulation is not supported for Int8/UInt8 in aarch32
+#ifdef __aarch64__
+DATA_TEST_CASE(ValidateAccumulate, framework::DatasetMode::ALL, combine(
+ zip(
+ make("In0",{ TensorShape(21U, 1U) }),
+ make("In1", { TensorShape(1U, 21U) }),
+ make("Dst", { TensorShape(1U, 1U) }),
+ make("a_offset", { -2 }),
+ make("a_offset", { 13 })
+ ),
+ zip(
+ make("OutputDataType", { DataType::S32, DataType::QASYMM8, DataType::QASYMM8_SIGNED}),
+ make("Expected", { true, false, false })
+ )),
+ shape_a, shape_b, shape_dst, a_offset, b_offset, output_data_type, expected)
+{
+ DataType input_data_type = (output_data_type == DataType::S32 ? DataType::QASYMM8 : output_data_type);
+ // Accumulation test for GEMM kernels
+ TensorInfo a(shape_a, 1, input_data_type, QuantizationInfo(1.0f / 255, a_offset));
+ TensorInfo b(shape_b, 1, input_data_type, QuantizationInfo(1.0f / 255, b_offset));
+ TensorInfo dst(shape_dst, 1, output_data_type, QuantizationInfo());
+
+ // Create and configure function
+ GEMMInfo gemm_info = GEMMInfo();
+ gemm_info.set_accumulate(true);
+
+ if (is_data_type_quantized(output_data_type))
+ {
+ GEMMLowpOutputStageInfo gemmLowpOutputStageInfo = GEMMLowpOutputStageInfo();
+ gemmLowpOutputStageInfo.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+
+ gemm_info.set_gemmlowp_output_stage(gemmLowpOutputStageInfo);
+ }
+
+ cpu::CpuGemmLowpMatrixMultiplyCore gemmlowp_mm;
+ Status status = gemmlowp_mm.validate(&a, &b, nullptr, &dst, gemm_info);
+
+ ARM_COMPUTE_EXPECT((expected == bool(status)), framework::LogLevel::ERRORS);
+}
+#endif // __arch64__
// *INDENT-OFF*
// clang-format off
@@ -226,13 +278,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreFixture, framework:
validate(Accessor(_target), _reference);
}
-constexpr AbsoluteTolerance<float> tolerance_batched(1);
-
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
-
TEST_SUITE(BatchedMatMul)
TEST_SUITE(QASYMM8)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned =
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, uint8_t, uint8_t, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedUnsigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
make("DataType", { DataType::QASYMM8 }),
@@ -242,9 +291,9 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
}
TEST_SUITE_END() // QASYMM8
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
TEST_SUITE(QASYMM8_SIGNED)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned =
+ GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore, false, false, int8_t, int8_t, true>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixtureBatchedSigned, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedBatchedMatMulDataset(),
make("DataType", { DataType::QASYMM8_SIGNED }),
@@ -255,26 +304,76 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFi
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE_END() // BatchedMatMul
-using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
-constexpr AbsoluteTolerance<float> tolerance_quant(1);
-
TEST_SUITE(FusedOffsetOutput)
+using NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture = GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture<Tensor, Accessor, NEGEMMLowpMatrixMultiplyCore>;
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::ALL,
combine(datasets::SmallGEMMLowpFusedOffsetOutputUint8Dataset(),
- make("DataType", { DataType::QASYMM8 })))
+ make("DataType", { DataType::QASYMM8 }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_quant);
}
-
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreFusedOffsetOutputFixture, framework::DatasetMode::NIGHTLY,
combine(datasets::LargeGEMMLowpFusedOffsetOutputUint8Dataset(),
- make("DataType", { DataType::QASYMM8 })))
+ make("DataType", { DataType::QASYMM8 }),
+ make("reshape_b_only_on_first_run", { false })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_quant);
}
TEST_SUITE_END() // FusedOffsetOutput
+
+// accumulation is not supported for Int8/UInt8 in aarch32
+#ifdef __aarch64__
+TEST_SUITE(ACCUMULATION)
+TEST_SUITE(S32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreAccumulateFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreAccumulateFixture, framework::DatasetMode::NIGHTLY, datasets::LargeGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // S32
+TEST_SUITE_END() // ACCUMULATION
+#endif // __arch64__
+
+TEST_SUITE(DynamicQuantization)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture, framework::DatasetMode::NIGHTLY, datasets::LargeGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+TEST_SUITE_END() // DynamicQuantization
+
+#ifdef __aarch64__
+// Deqaunt tests involve returning F32 from the MatrixMultiplyCore kernels and is only implemented in aarch64
+TEST_SUITE(Dequant)
+constexpr AbsoluteTolerance<float> tolerance_dequantized(0.01f);
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::ALL, datasets::SmallGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_dequantized);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::NIGHTLY, datasets::LargeGEMMLowpDataset())
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_dequantized);
+}
+TEST_SUITE_END() // Dequant
+#endif // __aarch64__
+
TEST_SUITE_END() // MatrixMultiplyCore
TEST_SUITE_END() // GEMMLowp
TEST_SUITE_END() // NEON
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index 2397d81547..8da5a0d953 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -122,40 +122,35 @@ template <typename T>
using NESoftmaxLayerFixture = SoftmaxValidationFixture<Tensor, Accessor, NESoftmaxLayer, T>;
DATA_TEST_CASE(KernelSelection, framework::DatasetMode::ALL,
- concat(concat(
+ concat(
combine(
- make("CpuExt", std::string("NEON")),
+ make("CpuExt", std::string("neon")),
make("DataType", { DataType::F32,
DataType::F16,
DataType::QASYMM8,
DataType::QASYMM8_SIGNED})
),
combine(
- make("CpuExt", std::string("SVE")),
+ make("CpuExt", std::string("sme2")),
make("DataType", { DataType::F32,
DataType::F16}))
),
- combine(
- make("CpuExt", std::string("SVE2")),
- make("DataType", { DataType::QASYMM8,
- DataType::QASYMM8_SIGNED}))
- ),
cpu_ext, data_type)
{
using namespace cpu::kernels;
cpuinfo::CpuIsaInfo cpu_isa{};
- cpu_isa.neon = (cpu_ext == "NEON");
- cpu_isa.sve = (cpu_ext == "SVE");
- cpu_isa.sve2 = (cpu_ext == "SVE2");
+ cpu_isa.neon = (cpu_ext == "neon");
+ cpu_isa.sme2 = (cpu_ext == "sme2");
cpu_isa.fp16 = (data_type == DataType::F16);
const auto *selected_impl = CpuSoftmaxKernel::get_implementation(
- SoftmaxKernelDataTypeISASelectorData{ data_type, cpu_isa, false /* is_log */ }, cpu::KernelSelectionType::Preferred);
+ SoftmaxKernelDataTypeISASelectorData{ data_type, cpu_isa, false /* is_log */, 0 /* axis */},
+ cpu::KernelSelectionType::Preferred);
ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl);
- std::string expected = "neon_" + cpu_impl_dt(data_type) + "_softmax";
+ std::string expected = cpu_ext + "_" + cpu_impl_dt(data_type) + "_softmax";
std::string actual = selected_impl->name;
ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS);
@@ -164,9 +159,19 @@ DATA_TEST_CASE(KernelSelection, framework::DatasetMode::ALL,
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall2D, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
+ combine(
+ datasets::SoftmaxLayerSmallShapes(),
+ make("DataType", DataType::F16),
+ make("Beta", { 1.0f, 2.0f }),
+ make("Axis", { 0, -1 })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f16);
+}
FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT,
combine(
- datasets::Small4DShapes(),
+ datasets::SmallShapes(),
make("DataType", DataType::F16),
make("Beta", { 1.0f, 2.0f }),
make("Axis", { 0, 1 })))
@@ -178,7 +183,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall4D, NESoftmaxLayerFixture<half>, framework::Datas
combine(
datasets::Small4DShapes(),
make("DataType", DataType::F16),
- make("Beta", { 1.0f, 2.0f }),
+ make("Beta", { 1.0f }),
make("Axis", { 0, 2, -1 })))
{
// Validate output
diff --git a/tests/validation/UNIT/CPPScheduler.cpp b/tests/validation/UNIT/CPPScheduler.cpp
index 52431653b5..6a3f6819fc 100644
--- a/tests/validation/UNIT/CPPScheduler.cpp
+++ b/tests/validation/UNIT/CPPScheduler.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -68,8 +68,7 @@ public:
TEST_SUITE(UNIT)
TEST_SUITE(CPPScheduler)
-
-#if !defined(BARE_METAL)
+#if defined(ARM_COMPUTE_CPP_SCHEDULER) && !defined(BARE_METAL)
TEST_CASE(RethrowException, framework::DatasetMode::ALL)
{
CPPScheduler scheduler;
@@ -87,7 +86,6 @@ TEST_CASE(RethrowException, framework::DatasetMode::ALL)
}
ARM_COMPUTE_EXPECT_FAIL("Expected exception not caught", framework::LogLevel::ERRORS);
}
-#endif // !defined(BARE_METAL)
-
+#endif // defined(ARM_COMPUTE_CPP_SCHEDULER) && !defined(BARE_METAL)
TEST_SUITE_END()
TEST_SUITE_END()
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index afde3d8067..94bedc83e1 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,14 +46,14 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
-class GEMMValidationFixture : public framework::Fixture
+class GEMMGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type, bool accumulate=false)
{
ARM_COMPUTE_UNUSED(pretranspose);
- _target = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type);
- _reference = compute_reference(shape_a, shape_b, output_shape, alpha, beta, data_type);
+ _target = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type, accumulate);
+ _reference = compute_reference(shape_a, shape_b, output_shape, alpha, beta, data_type, accumulate);
}
protected:
@@ -80,7 +80,7 @@ protected:
}
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &output_shape, float alpha, float beta,
- DataType data_type)
+ DataType data_type, bool accumulate=false)
{
// Create tensors
TensorType a = create_tensor<TensorType>(shape_a, data_type, 1);
@@ -99,7 +99,7 @@ protected:
&dst,
alpha, beta,
GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, false, (reinterpret_input_as_3d
- || reinterpret_output_as_3d)));
+ || reinterpret_output_as_3d), arm_compute::ActivationLayerInfo(), false /* fixed_format */, arm_compute::WeightFormat::UNSPECIFIED, false /* pretranspose_B */, accumulate));
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
ARM_COMPUTE_ASSERT(c.info()->is_resizable());
@@ -121,11 +121,14 @@ protected:
// Fill tensors
fill(AccessorType(a), 0);
fill(AccessorType(b), 1);
+ if (accumulate)
+ {
+ fill(AccessorType(dst), 6);
+ }
if(!disable_c)
{
fill(AccessorType(c), 2);
}
-
// Run with variable inputs.
if(run_twice)
{
@@ -145,7 +148,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, float alpha, float beta,
- DataType data_type)
+ DataType data_type, bool accumulate=false)
{
TensorShape shape_a_to_use = shape_a;
if(reinterpret_input_as_3d)
@@ -158,6 +161,7 @@ protected:
SimpleTensor<T> a{ shape_a_to_use, data_type, 1 };
SimpleTensor<T> b{ shape_b, data_type, 1 };
SimpleTensor<T> c{ output_shape, data_type, 1 };
+ SimpleTensor<T> dst{ output_shape, data_type, 1 };
// Fill reference
fill(a, 0);
@@ -211,17 +215,51 @@ protected:
fill(c, 5);
}
+ // Do in place summation
+ if (accumulate)
+ {
+ fill(dst, 6);
+ }
+
// Setting beta to 0 will effectively disable C for the
// computation of the reference: alpha * A * B + 0 * C
// Use transposed tensors if boolean enabled else use original tensors
- auto r = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
- return r;
+ if (accumulate)
+ {
+ reference::gemm_accumulate<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta, dst);
+ return dst;
+ }
+ else
+ {
+ return reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, alpha, disable_c ? 0.f : beta);
+ }
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
+class GEMMValidationFixture : protected GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ {
+ GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type, false /*accumulate*/);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool pretranspose_a = false, bool pretranspose_b = false, bool run_twice = false>
+class GEMMAccumulateValidationFixture : protected GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
+ {
+ bool accumulate = true;
+ GEMMGenericValidationFixture<TensorType, AccessorType, FunctionType, T, disable_c, reinterpret_input_as_3d, reinterpret_output_as_3d, pretranspose_a, pretranspose_b, run_twice>::setup(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type, accumulate);
+ }
+};
+
template <typename TensorType, typename AccessorType, typename T, typename GEMMOperatorType>
class GEMMMatrixMultiplyValidationFixture : public framework::Fixture
{
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index a65a1e6bd8..6b7cbba92e 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -30,6 +30,8 @@
#include "tests/framework/Fixture.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/GEMMLowp.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
+#include "tests/validation/reference/DequantizationLayer.h"
#include <cstdint>
#include <vector>
@@ -42,20 +44,35 @@ namespace validation
{
namespace
{
-
template <typename U>
void fill(U &&tensor, int i)
{
+ library->fill_tensor_uniform(tensor, i);
+}
+
+template <typename U>
+void fill_quantized(U &&tensor, int i)
+{
ARM_COMPUTE_ASSERT(is_data_type_quantized(tensor.data_type()));
library->fill_tensor_uniform(tensor, i);
}
template <typename U>
-void fill_bias_s32(U &&tensor, int i, int32_t min, int32_t max)
+void fill(U &&tensor, int i, int32_t min, int32_t max)
{
- ARM_COMPUTE_ASSERT(tensor.data_type() == DataType::S32);
- std::uniform_int_distribution<int32_t> distribution(min, max);
- library->fill(tensor, distribution, i);
+ if (tensor.data_type() == DataType::S32) {
+ std::uniform_int_distribution<int32_t> distribution(min, max);
+ library->fill(tensor, distribution, i);
+ }
+ else if(tensor.data_type() == DataType::F32)
+ {
+ std::uniform_real_distribution<float> distribution((float)min, (float)max);
+ library->fill(tensor, distribution, i);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("NOT SUPPORTED!");
+ }
}
/** Information about how to fill tensors */
@@ -64,6 +81,11 @@ struct TensorFillInfo
// Bias fill range. Default values are arbitrary
int32_t min_bias {-20000};
int32_t max_bias {20000};
+
+ // Output fill range. Default values are arbitrary
+ int32_t min_output {-20000};
+ int32_t max_output {20000};
+
// Optional extra hash to randomize tensor filling
int32_t hash {0};
};
@@ -71,29 +93,42 @@ struct TensorFillInfo
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d, bool reinterpret_output_as_3d, typename OutputType, bool is_fused = false, bool run_twice = false>
TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo,
const QuantizationInfo& output_qinfo, DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8,
- GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), bool reshape_b_only_on_first_run = false, const TensorFillInfo& finfo = TensorFillInfo() )
+ GEMMLowpOutputStageInfo output_stage = GEMMLowpOutputStageInfo(), bool reshape_b_only_on_first_run = false, const TensorFillInfo& finfo = TensorFillInfo(),
+ bool accumulate = false, bool dynamic_qinfo = false, DataType data_type_output = DataType::UNKNOWN)
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
- // Create tensors
- const DataType data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
+ // If unknown, set to sensible defaults
+ if (data_type_output == DataType::UNKNOWN) {
+ data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
+ }
- TensorType a = create_tensor<TensorType>(shape_a, data_type_a, 1, a_qinfo);
- TensorType b = create_tensor<TensorType>(shape_b, data_type_b, 1, b_qinfo); // gemm output before output stage mismatch if i pass data_layout_output here. to be investigated
+ // Create tensors
+ TensorType a = create_tensor<TensorType>(shape_a, data_type_a, 1, dynamic_qinfo ? QuantizationInfo(1.0,0,true) : a_qinfo);
+ TensorType b = create_tensor<TensorType>(shape_b, data_type_b, 1, dynamic_qinfo ? QuantizationInfo(1.0,0,true) : b_qinfo); // gemm output before output stage mismatch if i pass data_layout_output here. to be investigated
TensorType output = create_tensor<TensorType>(shape_output, data_type_output, 1, output_qinfo /* output_qinfo will be ignored when output stage type is None */);
TensorType bias;
if(is_fused)
{
TensorShape bias_shape(shape_b[0]);
- bias = create_tensor<TensorType>(bias_shape, DataType::S32, 1);
+ bias = create_tensor<TensorType>(bias_shape,data_type_output == DataType::F32 ? DataType::F32 : DataType::S32, 1);
}
// Create and configure function
// The GEMMinfo includes the values of the depth in case of reinterpreted 3d input/output
FunctionType gemmlowp;
gemmlowp.configure(&a, &b, is_fused ? &bias : nullptr, &output, GEMMInfo(false, false, reshape_b_only_on_first_run, (reinterpret_output_as_3d ? shape_output[2] : 0), reinterpret_input_as_3d, false,
- output_stage));
+ output_stage, false /*fp_mixed_precision*/, false /*fast_math*/, false /*broadcast_bias*/,
+ arm_compute::ActivationLayerInfo(), false /* fixed_format */, arm_compute::WeightFormat::UNSPECIFIED,
+ false /* pretranspose_B */, accumulate));
+
+ // If the QuantizationInfo is dynamic, it needs to be settable after configure (note that we also force it to be dynamic)
+ if (dynamic_qinfo)
+ {
+ a.info()->set_quantization_info(QuantizationInfo(a_qinfo.scale(), a_qinfo.offset(), true));
+ b.info()->set_quantization_info(QuantizationInfo(b_qinfo.scale(), b_qinfo.offset(), true));
+ }
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
@@ -111,26 +146,32 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
// Fill tensors
- fill(AccessorType(a), 0 + finfo.hash);
- fill(AccessorType(b), 1 + finfo.hash);
+ fill_quantized(AccessorType(a), 0 + finfo.hash);
+ fill_quantized(AccessorType(b), 1 + finfo.hash);
+
+ if (accumulate)
+ {
+ ARM_COMPUTE_ASSERT(accumulate != run_twice);
+ fill(AccessorType(output), 6 + finfo.hash, finfo.min_output, finfo.max_output);
+ }
if(is_fused)
{
ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
bias.allocator()->allocate();
ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
- fill_bias_s32(AccessorType(bias), 2 + finfo.hash, finfo.min_bias, finfo.max_bias);
+ fill(AccessorType(bias), 2 + finfo.hash, finfo.min_bias, finfo.max_bias);
}
// Run with variable inputs.
if(run_twice)
{
gemmlowp.run();
- fill(AccessorType(a), 3 + finfo.hash); // Fill tensors with new seed after run
- fill(AccessorType(b), 4 + finfo.hash);
+ fill_quantized(AccessorType(a), 3 + finfo.hash); // Fill tensors with new seed after run
+ fill_quantized(AccessorType(b), 4 + finfo.hash);
if(is_fused)
{
- fill_bias_s32(AccessorType(bias), 5 + finfo.hash, finfo.min_bias, finfo.max_bias);
+ fill(AccessorType(bias), 5 + finfo.hash, finfo.min_bias, finfo.max_bias);
}
}
@@ -168,8 +209,8 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
SimpleTensor<TW> b_transposed{ shape_b_transposed, data_type_b, 1, b_qinfo };
// Fill reference
- fill(a, 0 + finfo.hash);
- fill(b, 1 + finfo.hash);
+ fill_quantized(a, 0 + finfo.hash);
+ fill_quantized(b, 1 + finfo.hash);
// Transpose reference if required
/* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M),
@@ -189,11 +230,12 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
// Run with variable inputs.
const int32_t a_offset = a_qinfo.uniform().offset;
const int32_t b_offset = b_qinfo.uniform().offset;
+
if(run_twice)
{
reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
- fill((pretranspose_A) ? a_transposed : a, 3 + finfo.hash);
- fill((pretranspose_B) ? b_transposed : b, 4 + finfo.hash);
+ fill_quantized((pretranspose_A) ? a_transposed : a, 3 + finfo.hash);
+ fill_quantized((pretranspose_B) ? b_transposed : b, 4 + finfo.hash);
}
return reference::gemmlowp_matrix_multiply_core<int32_t, TI, TW>((pretranspose_A ? a_transposed : a), (pretranspose_B ? b_transposed : b), shape_output, a_offset, b_offset);
@@ -201,35 +243,77 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
} // namespace
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
-class GEMMLowpMatrixMultiplyCoreValidationFixture : public framework::Fixture
+class GEMMLowpGenericMatrixMultiplyCoreValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, bool accumulate=false, bool dynamic_qinfo = false)
{
const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset);
const auto b_qinfo = QuantizationInfo(1.0f / 255, b_offset);
- _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
- _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
+ TensorFillInfo finfo;
+ _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate);
}
protected:
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
{
const auto output_qinfo = QuantizationInfo(); // No output stage
- return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo);
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8, DataType::QASYMM8, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo);
}
- SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo)
+ SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate)
{
- return compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo);
+ SimpleTensor<int32_t> ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
+ DataType::QASYMM8, DataType::QASYMM8, finfo);
+
+ if (accumulate)
+ {
+ SimpleTensor<int32_t> output{ shape_output, DataType::S32, 1 };
+ fill(output, 6 + finfo.hash, finfo.min_output, finfo.max_output);
+ reference::arithmetic_operation<int32_t>(reference::ArithmeticOperation::ADD, output, ref_output, output, ConvertPolicy::SATURATE);
+ return output;
+ }
+
+ return ref_output;
}
TensorType _target{};
SimpleTensor<int32_t> _reference{};
};
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpMatrixMultiplyCoreValidationFixture : protected GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>::setup(shape_a, shape_b, shape_output, a_offset, b_offset, false /* accumulate */);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpMatrixMultiplyAccumulateValidationFixture : protected GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>::setup(shape_a, shape_b, shape_output, a_offset, b_offset, true /* accumulate */);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpMatrixMultiplyCoreDynamicQuantizationFixture : protected GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, run_twice>::setup(shape_a, shape_b, shape_output, a_offset, b_offset, false /* accumulate */, true /* dynamic_qinfo */);
+ }
+};
+
template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
-class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture : public framework::Fixture
+class GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public framework::Fixture
{
public:
/** Dynamically initialize the quantization info with saturation awareness
@@ -363,16 +447,16 @@ protected:
TensorShape bias_shape(shape_b[0]);
SimpleTensor<int32_t> bias{ bias_shape, DataType::S32, 1 };
- (run_twice) ? fill_bias_s32(bias, 5 + finfo.hash, finfo.min_bias, finfo.max_bias) : fill_bias_s32(bias, 2 + finfo.hash, finfo.min_bias, finfo.max_bias); // Fill bias with same seed as last run of gemmlowp_target
+ (run_twice) ? fill(bias, 5 + finfo.hash, finfo.min_bias, finfo.max_bias) : fill(bias, 2 + finfo.hash, finfo.min_bias, finfo.max_bias); // Fill bias with same seed as last run of gemmlowp_target
switch(output_stage.type)
{
case GEMMLowpOutputStageType::QUANTIZE_DOWN:
- return reference::gemmlowp_quantize_down_scale<int32_t, TW>(output, bias,
+ return reference::gemmlowp_quantize_down_scale<int32_t, TI>(output, bias,
output_stage.gemmlowp_offset, output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
break;
case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
- return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TW>(output, bias,
+ return reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TI>(output, bias,
output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_offset, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound);
break;
default:
@@ -384,15 +468,77 @@ protected:
SimpleTensor<TI> _reference{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t>
-class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, bool run_twice = false>
+class GEMMLowpDequantizedMatrixMultiplyValidationFixture : public framework::Fixture
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset)
+ {
+ // Accumulation is supported for Int8/UInt8 only in aarch64
+ bool accumulate = true;
+ // Accumulation is not supported for Int8/UInt8 in aarch32
+#ifdef __arm__
+ accumulate = false;
+#endif //__arm__
+ bool dynamic_qinfo = false;
+ const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset);
+ const auto b_qinfo = QuantizationInfo(5.0f / 255, b_offset);
+ TensorFillInfo finfo;
+ _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
+ }
+
+protected:
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
+ {
+ const auto output_qinfo = QuantizationInfo();
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
+ }
+
+ SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo)
+ {
+ QuantizationInfo s32_ref_output_quant_info = QuantizationInfo(a_qinfo.uniform().scale * b_qinfo.uniform().scale, 0, dynamic_qinfo);
+
+ SimpleTensor<int32_t> s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
+ DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, finfo);
+ s32_ref_output.quantization_info(s32_ref_output_quant_info);
+
+ SimpleTensor<float> f32_ref_output(s32_ref_output.shape(), DataType::F32);
+ f32_ref_output = reference::dequantization_layer<float, int32_t>(s32_ref_output);
+
+ if (accumulate)
+ {
+ SimpleTensor<float> output{ shape_output, DataType::F32, 1 };
+ fill(output, 6 + finfo.hash, finfo.min_output, finfo.max_output);
+ reference::arithmetic_operation<float>(reference::ArithmeticOperation::ADD, output, f32_ref_output, output, ConvertPolicy::SATURATE);
+ return output;
+ }
+
+ return f32_ref_output;
+ }
+
+ TensorType _target{};
+ SimpleTensor<float> _reference{};
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
+class GEMMLowpMatrixMultiplyCoreFusedOffsetOutputValidationFixture : public GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>
+{
+public:
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type, bool reshape_b_only_on_first_run)
+ {
+ GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>::setup(shape_a, shape_b,
+ shape_output, output_stage_type, data_type, reshape_b_only_on_first_run);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false, typename TI = uint8_t, typename TW = uint8_t, bool run_twice = false>
+class GEMMLowpBatchedMatrixMultiplyCoreFusedOffsetOutputFixture : public GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, GEMMLowpOutputStageType output_stage_type, DataType data_type, bool reshape_b_only_on_first_run)
{
- GEMMLowpMatrixMultiplyCoreFusedOffsetOutputGenericValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW>::setup(shape_a, shape_b,
- shape_output, output_stage_type, data_type, false /* reshape_b_only_on_first_run */);
+ GEMMLowpGenericMatrixMultiplyCoreFusedOffsetOutputValidationFixture<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, TI, TW, run_twice>::setup(shape_a, shape_b, shape_output, output_stage_type, data_type, reshape_b_only_on_first_run);
}
};
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h
index bda5532a51..4fb2d7f127 100644
--- a/tests/validation/fixtures/ScatterLayerFixture.h
+++ b/tests/validation/fixtures/ScatterLayerFixture.h
@@ -27,8 +27,9 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/runtime/CL/CLTensorAllocator.h"
#include "tests/Globals.h"
-#include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT
+#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
#include "tests/validation/Validation.h"
#include "tests/validation/reference/ScatterLayer.h"
#include "tests/SimpleTensor.h"
@@ -46,15 +47,23 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScatterGenericValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
+ void setup(TensorShape src_shape, TensorShape updates_shape, TensorShape indices_shape,
+ TensorShape out_shape, DataType data_type, ScatterInfo scatter_info, bool inplace,
+ QuantizationInfo src_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
{
- _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, src_qinfo, o_qinfo);
+ // this is for improving randomness across tests
+ _hash = src_shape[0] + src_shape[1] + src_shape[2] + src_shape[3] + src_shape[4] + src_shape[5]
+ + updates_shape[0] + updates_shape[1] + updates_shape[2] + updates_shape[3]
+ + updates_shape[4] + updates_shape[5]
+ + indices_shape[0] + indices_shape[1] + indices_shape[2] + indices_shape[3];
+
+ _target = compute_target(src_shape, updates_shape, indices_shape, out_shape, data_type, scatter_info, inplace, src_qinfo, o_qinfo);
_reference = compute_reference(src_shape, updates_shape, indices_shape, out_shape, data_type,scatter_info, src_qinfo , o_qinfo);
}
protected:
template <typename U>
- void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
+ void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f)
{
switch(tensor.data_type())
{
@@ -71,37 +80,47 @@ protected:
}
}
- // This is used to fill indices tensor with U32 datatype.
+ // This is used to fill indices tensor with S32 datatype.
// Used to prevent ONLY having values that are out of bounds.
template <typename U>
void fill_indices(U &&tensor, int i, const TensorShape &shape)
{
- // Calculate max indices the shape should contain. Add an arbitrary constant to allow testing for some out of bounds values.
- const uint32_t max = std::max({shape[0] , shape[1], shape[2]}) + 5;
- library->fill_tensor_uniform(tensor, i, static_cast<uint32_t>(0), static_cast<uint32_t>(max));
+ // Calculate max indices the shape should contain. Add an arbitrary value to allow testing for some out of bounds values (In this case min dimension)
+ const int32_t max = std::max({shape[0] , shape[1], shape[2]});
+ library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(-2), static_cast<int32_t>(max));
}
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &out_shape, DataType data_type, const ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c,
+ const TensorShape &out_shape, DataType data_type, const ScatterInfo info, bool inplace,
+ QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// 1. Create relevant tensors using ScatterInfo data structure.
// ----------------------------------------------------
// In order - src, updates, indices, output.
TensorType src = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo);
TensorType updates = create_tensor<TensorType>(shape_b, data_type, 1, a_qinfo);
- TensorType indices = create_tensor<TensorType>(shape_c, DataType::U32, 1, QuantizationInfo());
+ TensorType indices = create_tensor<TensorType>(shape_c, DataType::S32, 1, QuantizationInfo());
TensorType dst = create_tensor<TensorType>(out_shape, data_type, 1, o_qinfo);
FunctionType scatter;
// Configure operator
- // When scatter_info.zero_initialization is true, pass nullptr to scatter function.
+ // When scatter_info.zero_initialization is true, pass nullptr for src
+ // because dst does not need to be initialized with src values.
if(info.zero_initialization)
{
scatter.configure(nullptr, &updates, &indices, &dst, info);
}
else
{
- scatter.configure(&src, &updates, &indices, &dst, info);
+ if(inplace)
+ {
+ scatter.configure(&src, &updates, &indices, &src, info);
+ }
+ else
+ {
+ scatter.configure(&src, &updates, &indices, &dst, info);
+ }
}
// Assertions
@@ -110,51 +129,88 @@ protected:
ARM_COMPUTE_ASSERT(indices.info()->is_resizable());
ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
+ add_padding_x({ &src, &updates, &indices});
+
+ if(!inplace)
+ {
+ add_padding_x({ &dst });
+ }
+
// Allocate tensors
src.allocator()->allocate();
updates.allocator()->allocate();
indices.allocator()->allocate();
- dst.allocator()->allocate();
+
+ if(!inplace)
+ {
+ dst.allocator()->allocate();
+ }
ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
ARM_COMPUTE_ASSERT(!updates.info()->is_resizable());
ARM_COMPUTE_ASSERT(!indices.info()->is_resizable());
- ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+ if(!inplace)
+ {
+ ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+ }
// Fill update (a) and indices (b) tensors.
- fill(AccessorType(src), 0);
- fill(AccessorType(updates), 1);
- fill_indices(AccessorType(indices), 2, out_shape);
+ fill(AccessorType(src), 0 + _hash);
+ fill(AccessorType(updates), 1+ _hash);
+ fill_indices(AccessorType(indices), 2 + _hash, out_shape);
scatter.run();
- return dst;
+ if(inplace)
+ {
+ return src;
+ }
+ else
+ {
+ return dst;
+ }
}
- SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape, const TensorShape &out_shape, DataType data_type,
- ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
+ SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &c_shape,
+ const TensorShape &out_shape, DataType data_type, ScatterInfo info, QuantizationInfo a_qinfo, QuantizationInfo o_qinfo)
{
// Output Quantization not currently in use - fixture should be extended to support this.
ARM_COMPUTE_UNUSED(o_qinfo);
+ TensorShape src_shape = a_shape;
+ TensorShape updates_shape = b_shape;
+ TensorShape indices_shape = c_shape;
+
+ // 1. Collapse batch index into a single dim if necessary for update tensor and indices tensor.
+ if(c_shape.num_dimensions() >= 3)
+ {
+ indices_shape = indices_shape.collapsed_from(1);
+ updates_shape = updates_shape.collapsed_from(updates_shape.num_dimensions() - 2); // Collapses from last 2 dims
+ }
+
+ // 2. Collapse data dims into a single dim.
+ // Collapse all src dims into 2 dims. First one holding data, the other being the index we iterate over.
+ src_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse all data dims into single dim.
+ src_shape = src_shape.collapsed_from(1); // Collapse all index dims into a single dim
+ updates_shape.collapse(updates_shape.num_dimensions() - 1); // Collapse data dims (all except last dim which is batch dim)
// Create reference tensors
SimpleTensor<T> src{ a_shape, data_type, 1, a_qinfo };
SimpleTensor<T> updates{b_shape, data_type, 1, QuantizationInfo() };
- SimpleTensor<uint32_t> indices{ c_shape, DataType::U32, 1, QuantizationInfo() };
+ SimpleTensor<int32_t> indices{ c_shape, DataType::S32, 1, QuantizationInfo() };
// Fill reference
- fill(src, 0);
- fill(updates, 1);
- fill_indices(indices, 2, out_shape);
+ fill(src, 0 + _hash);
+ fill(updates, 1 + _hash);
+ fill_indices(indices, 2 + _hash, out_shape);
// Calculate individual reference.
- auto result = reference::scatter_layer<T>(src, updates, indices, out_shape, info);
-
- return result;
+ return reference::scatter_layer<T>(src, updates, indices, out_shape, info);
}
TensorType _target{};
SimpleTensor<T> _reference{};
+ int32_t _hash{};
};
// This fixture will use the same shape for updates as indices.
@@ -162,9 +218,12 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class ScatterValidationFixture : public ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
{
public:
- void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape, TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init)
+ void setup(TensorShape src_shape, TensorShape update_shape, TensorShape indices_shape,
+ TensorShape out_shape, DataType data_type, ScatterFunction func, bool zero_init, bool inplace)
{
- ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape, indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), QuantizationInfo(), QuantizationInfo());
+ ScatterGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, update_shape,
+ indices_shape, out_shape, data_type, ScatterInfo(func, zero_init), inplace,
+ QuantizationInfo(), QuantizationInfo());
}
};
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp
index 64a89aa6a0..67d69c2c38 100644
--- a/tests/validation/reference/DequantizationLayer.cpp
+++ b/tests/validation/reference/DequantizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -59,6 +59,12 @@ TOut dequantize(int16_t val, const UniformQuantizationInfo qinfo, DataType dt)
ARM_COMPUTE_UNUSED(dt);
return static_cast<TOut>(dequantize_qsymm16(val, qinfo));
}
+template <typename TOut>
+TOut dequantize(int32_t val, const UniformQuantizationInfo qinfo, DataType dt)
+{
+ ARM_COMPUTE_UNUSED(dt);
+ return static_cast<TOut>(dequantize_s32(val, qinfo));
+}
} // namespace
template <typename TOut, typename TIn>
SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
@@ -115,6 +121,7 @@ template SimpleTensor<half> dequantization_layer(const SimpleTensor<int8_t> &src
template SimpleTensor<float> dequantization_layer(const SimpleTensor<int8_t> &src);
template SimpleTensor<half> dequantization_layer(const SimpleTensor<int16_t> &src);
template SimpleTensor<float> dequantization_layer(const SimpleTensor<int16_t> &src);
+template SimpleTensor<float> dequantization_layer(const SimpleTensor<int32_t> &src);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp
index 20f1139a02..d513343796 100644
--- a/tests/validation/reference/GEMM.cpp
+++ b/tests/validation/reference/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021,2024 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,7 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
namespace arm_compute
{
@@ -180,17 +181,22 @@ SimpleTensor<T> gemm_mixed_precision(
return dst;
}
-template SimpleTensor<float>
-gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
-template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a,
- const SimpleTensor<bfloat16> &b,
- const SimpleTensor<bfloat16> &c,
- float alpha,
- float beta);
-template SimpleTensor<half>
-gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
-template SimpleTensor<half> gemm_mixed_precision(
- const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
+void gemm_accumulate(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta, SimpleTensor<T> &dst)
+{
+ // Compute reference
+ SimpleTensor<T> dst_gemm = gemm(a, b, c, alpha, beta);
+ reference::arithmetic_operation<T>(reference::ArithmeticOperation::ADD, dst, dst_gemm, dst, ConvertPolicy::SATURATE);
+}
+
+template SimpleTensor<bfloat16> gemm(const SimpleTensor<bfloat16> &a, const SimpleTensor<bfloat16> &b, const SimpleTensor<bfloat16> &c, float alpha, float beta);
+template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
+template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+
+template void gemm_accumulate(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta, SimpleTensor<float> &dst);
+template void gemm_accumulate(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta, SimpleTensor<half> &dst);
+
+template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h
index 5feaeda584..1b97570122 100644
--- a/tests/validation/reference/GEMM.h
+++ b/tests/validation/reference/GEMM.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_GEMM_H
-#define ARM_COMPUTE_TEST_GEMM_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
+#define ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
#include "tests/SimpleTensor.h"
#include "tests/validation/Helpers.h"
@@ -41,8 +41,11 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta);
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
+void gemm_accumulate(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta, SimpleTensor<T> &dst);
+
} // namespace reference
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_GEMM_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_GEMM_H
diff --git a/tests/validation/reference/GEMMLowp.cpp b/tests/validation/reference/GEMMLowp.cpp
index 1615b51e73..30c577d850 100644
--- a/tests/validation/reference/GEMMLowp.cpp
+++ b/tests/validation/reference/GEMMLowp.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#include "GEMMLowp.h"
#include "arm_compute/core/Types.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
#include "tests/validation/reference/UtilsQuantizedAsymm.h"
#include "support/ToolchainSupport.h"
@@ -230,6 +231,13 @@ SimpleTensor<T_out> gemmlowp_matrix_multiply_core(const SimpleTensor<T_in> &a, c
return c;
}
+template <typename T_out, typename T_in, typename T_in_1>
+void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<T_in> &a, const SimpleTensor<T_in_1> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<T_out> &dst)
+{
+ SimpleTensor<T_out> dst_gemm = gemmlowp_matrix_multiply_core<T_out, T_in, T_in_1>(a, b, shape_c, a_offset, b_offset);
+ reference::arithmetic_operation<T_out>(reference::ArithmeticOperation::ADD, dst, dst_gemm, dst, ConvertPolicy::SATURATE);
+}
+
// used to validate assembly kernels which don't know anything about offsets
template <typename T1, typename T2, typename T3>
SimpleTensor<T1> gemmlowp(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c)
@@ -336,6 +344,8 @@ template SimpleTensor<int8_t> gemmlowp_quantize_down_scale(const SimpleTensor<in
std::vector<int32_t> result_shift, int32_t min, int32_t max);
template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
template SimpleTensor<int32_t> gemmlowp_matrix_multiply_core(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
+template void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<int32_t> &dst);
+template void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<int32_t> &dst);
template SimpleTensor<int32_t> gemmlowp<int32_t, int8_t, int8_t>(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c);
template SimpleTensor<int32_t> gemmlowp<int32_t, uint8_t, uint8_t>(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, TensorShape shape_c);
template SimpleTensor<int32_t> gemmlowp<int32_t, uint8_t, int8_t>(const SimpleTensor<uint8_t> &a, const SimpleTensor<int8_t> &b, TensorShape shape_c);
diff --git a/tests/validation/reference/GEMMLowp.h b/tests/validation/reference/GEMMLowp.h
index 99015d71fb..6e471fdad1 100644
--- a/tests/validation/reference/GEMMLowp.h
+++ b/tests/validation/reference/GEMMLowp.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_GEMMLOWP_H
-#define ARM_COMPUTE_TEST_GEMMLOWP_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H
+#define ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H
#include "tests/SimpleTensor.h"
#include "tests/validation/Helpers.h"
@@ -38,6 +38,9 @@ namespace reference
template <typename T1, typename T2, typename T3>
SimpleTensor<T1> gemmlowp_matrix_multiply_core(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset);
+template <typename T1, typename T2, typename T3>
+void gemmlowp_matrix_multiply_core_accumulate(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c, int32_t a_offset, int32_t b_offset, SimpleTensor<T1> &dst_);
+
template <typename T1, typename T2, typename T3 = T2>
SimpleTensor<T1> gemmlowp(const SimpleTensor<T2> &a, const SimpleTensor<T3> &b, TensorShape shape_c);
@@ -71,4 +74,4 @@ SimpleTensor<TOut> gemmlowp_quantize_down_scale_by_float(const SimpleTensor<TIn>
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_GEMMLOWP_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_GEMMLOWP_H
diff --git a/tests/validation/reference/QuantizationLayer.cpp b/tests/validation/reference/QuantizationLayer.cpp
index 27665375c3..ad7ba7ac43 100644
--- a/tests/validation/reference/QuantizationLayer.cpp
+++ b/tests/validation/reference/QuantizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
diff --git a/tests/validation/reference/ScatterLayer.cpp b/tests/validation/reference/ScatterLayer.cpp
index 920f2b9990..283022e8e2 100644
--- a/tests/validation/reference/ScatterLayer.cpp
+++ b/tests/validation/reference/ScatterLayer.cpp
@@ -23,6 +23,7 @@
*/
#include "ScatterLayer.h"
#include "tests/validation/Helpers.h"
+#include "arm_compute/core/TensorShape.h"
namespace arm_compute
{
@@ -64,48 +65,79 @@ T reduce_op(const T &current,const T &update,const ScatterFunction func)
template float reduce_op(const float &current,const float &update,const ScatterFunction func);
}
-// Note : This function currently only supports 1D src, 1D updates, 2D indices, 1D output tensors.
+// NOTE: This function expects collapsed tensors as input.
+// Batch dims for update/indices tensors should be collapsed into a single dim.
+// Data dims should be collapsed into a single dim for both update and src tensors prior to calling this function.
template <typename T>
-SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
{
+ // 1. If zero initialization variable is false, copy src data to dst.
SimpleTensor<T> dst{ out_shape, src.data_type(), 1 };
-
- // 1. If zero initialization variable is true, fill dst with 0 values. Else copy src data to dst.
- if(info.zero_initialization)
- {
- for (int i = 0; i < src.num_elements(); ++i)
- {
- dst[i] = static_cast<T>(0);
- }
- }
- else
+ if(!info.zero_initialization)
{
std::copy_n(src.data(), src.num_elements(), dst.data());
}
- // 2. Get max index of output tensor, then iterate over index tensor.
- const auto x_bound = dst.shape().x();
+ // Number of elements between each value of the dim being iterated through
+ const unsigned int data_stride = updates.shape().total_size_lower(updates.shape().num_dimensions() - 1);
+ const unsigned int no_output_dims = out_shape.num_dimensions();
+
+ // Calculate output stride at given index for all output dims.
+ std::vector<unsigned int> out_stride_at_idx(no_output_dims);
+ for (unsigned int i = 0 ; i < no_output_dims; i++)
+ {
+ out_stride_at_idx[i] = out_shape.total_size_lower(i);
+ }
+ const unsigned int indices_x_dim = static_cast<unsigned int>(indices.shape()[0]);
+ const unsigned int indices_y_dim = static_cast<unsigned int>(indices.shape()[1]);
- for(int i = 0; i < indices.num_elements(); ++i)
+ // 2. Iterate over indices tensor y-dim and replace sections of dst tensor with relevant areas of update tensor.
+ for(unsigned int i = 0; i < indices_y_dim; i++)
{
- // 3. Check whether index is out of bounds for dst, if not then apply reduce op.
- const auto index = indices[i];
- if (index < x_bound) // Note : index is always >= 0 as datatype is unsigned.
+ // NOTE : Currently, indices.shape() == [X, Y, 1, 1], where X is the indices dim and Y is the batch dim
+ // Starting index for both the update and indices tensors.
+ const unsigned int update_dim_start = i * data_stride;
+ const unsigned int indices_dim_start = i * indices_x_dim;
+ bool out_of_bounds = false;
+ unsigned int out_offset_acc = 0;
+
+ // Iterate over each indices value for the relevant batch and accumulate the offset.
+ for(unsigned int j = 0; j < indices_x_dim; j++)
+ {
+ // Get first index value with i * indices_x_dim (iterating through y-dim/batch idx), then iterate through x dim by adding k
+ const int index_value = indices[indices_dim_start + j];
+ const unsigned int out_dim = no_output_dims - (j+1); // Calculate corresponding output dim to current index value.
+ if(index_value < static_cast<int>(out_shape[out_dim]) && index_value >= 0)
+ {
+ out_offset_acc += (index_value * out_stride_at_idx[out_dim]); // offset accumulation
+ }
+ else
+ {
+ out_of_bounds = true;
+ break;
+ }
+ }
+
+ // If not out of bounds, copy update tensor elements to output
+ if(!out_of_bounds)
{
- dst[index] = reduce_op(dst[index], updates[i], info.func);
+ for (unsigned int j = 0 ; j < data_stride; j++)
+ {
+ dst[out_offset_acc + j] = reduce_op(dst[out_offset_acc + j], updates[update_dim_start + j], info.func);
+ }
}
}
return dst;
}
template <typename T>
-SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info)
{
return scatter_layer_internal<T>(src, updates, indices, out_shape, info);
}
-template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<uint32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
+template SimpleTensor<float> scatter_layer(const SimpleTensor<float> &src, const SimpleTensor<float> &updates, const SimpleTensor<int32_t> &indices, const TensorShape &out_shape, const ScatterInfo &info);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/ScatterLayer.h b/tests/validation/reference/ScatterLayer.h
index dc441a8894..97d5e70b0d 100644
--- a/tests/validation/reference/ScatterLayer.h
+++ b/tests/validation/reference/ScatterLayer.h
@@ -37,10 +37,10 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+SimpleTensor<T> scatter_layer_internal(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<int32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
template <typename T>
-SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<uint32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
+SimpleTensor<T> scatter_layer(const SimpleTensor<T> &src, const SimpleTensor<T> &update, const SimpleTensor<int32_t> &indices, const TensorShape &shape, const ScatterInfo &info);
} // namespace reference
} // namespace validation
} // namespace test