From 53832b2bcce44c71fe31a618a81765294df55750 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Mon, 21 Jun 2021 14:45:44 +0100 Subject: Port NEGEMM to memory injecting interface (Part 2) - Port NEGEMMMatrixMultiplyKernel to the new API Partially resolves: COMPMID-4402 Signed-off-by: Michele Di Giorgio Change-Id: I52b67055dc24bb3a417d6ec5aeeee86e21b74320 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5873 Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- Android.bp | 2 +- .../runtime/NEON/functions/NEFullyConnectedLayer.h | 2 +- arm_compute/runtime/NEON/functions/NEGEMM.h | 2 +- docs/user_guide/release_version_and_change_log.dox | 2 +- filelist.json | 2 +- src/core/NEON/NEKernels.h | 1 - .../NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp | 1170 ------------------- src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h | 94 -- src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.h | 4 +- .../cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp | 1174 ++++++++++++++++++++ src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h | 92 ++ src/runtime/NEON/functions/NEGEMM.cpp | 19 +- tests/validation/NEON/GEMM.cpp | 12 +- 13 files changed, 1291 insertions(+), 1285 deletions(-) delete mode 100644 src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp delete mode 100644 src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h create mode 100644 src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp create mode 100644 src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h diff --git a/Android.bp b/Android.bp index 5b9f4e0276..6250539d02 100644 --- a/Android.bp +++ b/Android.bp @@ -166,7 +166,6 @@ cc_library_static { "src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.cpp", "src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.cpp", "src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp", - "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp", "src/core/NEON/kernels/NEGatherKernel.cpp", "src/core/NEON/kernels/NEGenerateProposalsLayerKernel.cpp", "src/core/NEON/kernels/NEIm2ColKernel.cpp", @@ -283,6 +282,7 @@ cc_library_static { "src/core/cpu/kernels/CpuGemmLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.cpp", "src/core/cpu/kernels/CpuGemmLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.cpp", "src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp", + "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp", "src/core/cpu/kernels/CpuGemmTranspose1xWKernel.cpp", "src/core/cpu/kernels/CpuMulKernel.cpp", "src/core/cpu/kernels/CpuPermuteKernel.cpp", diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h index 22ec9e0fec..e409a61ba1 100644 --- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h +++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h @@ -79,7 +79,7 @@ private: /** Basic function to compute a Fully Connected layer. This function calls the following kernels: * -# @ref NEIm2ColKernel (called when the input comes from a convolutional layer) * -# @ref NETranspose (if @p are_weights_reshaped is set to false and transpose_weights is set to true ) (called once) - * -# @ref NEGEMMMatrixMultiplyKernel or @ref NEGEMMLowpMatrixMultiplyCore (if quantized asymmetric) + * -# @ref NEGEMM or @ref NEGEMMLowpMatrixMultiplyCore (if quantized asymmetric) * -# @ref cpu::kernels::CpuGemmMatrixAdditionKernel or @ref NEGEMMLowpOutputStage (if quantized asymmetric) (if @p biases is not equal to nullptr) * * @note The fully connected layer accepts "weights" tensors only with 2 dimensions. diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h index c1ae11bcbf..5daa0406a5 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMM.h +++ b/arm_compute/runtime/NEON/functions/NEGEMM.h @@ -42,7 +42,7 @@ namespace arm_compute * Else: * -# @ref cpu::kernels::CpuGemmInterleave4x4Kernel (if the output tensor is a matrix) * -# @ref cpu::kernels::CpuGemmTranspose1xWKernel (if the output tensor is a matrix) - * -# @ref NEGEMMMatrixMultiplyKernel + * -# @ref cpu::kernels::CpuGemmMatrixMultiplyKernel * In both cases: * -# @ref cpu::kernels::CpuGemmMatrixAdditionKernel (if c != nullptr and beta != 0.0 and is not reshaped once) * Else: diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox index cccf5b9b8b..eb4c280295 100644 --- a/docs/user_guide/release_version_and_change_log.dox +++ b/docs/user_guide/release_version_and_change_log.dox @@ -249,7 +249,7 @@ v20.11 Public major release - NEConvolutionKernel - NEDepthwiseConvolutionLayerNativeKernel - @ref NEGEMMLowpMatrixMultiplyKernel - - @ref NEGEMMMatrixMultiplyKernel + - NEGEMMMatrixMultiplyKernel - NEDirectConvolutionLayerOutputStageKernel - @ref NEReductionOperationKernel - @ref NEGEMMLowpMatrixAReductionKernel diff --git a/filelist.json b/filelist.json index b8a69c5f6e..bde361f167 100644 --- a/filelist.json +++ b/filelist.json @@ -1164,7 +1164,7 @@ "files": { "kernel": [ "src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp", - "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp", + "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp", "src/core/cpu/kernels/CpuGemmTranspose1xWKernel.cpp", "src/core/cpu/kernels/CpuGemmInterleave4x4Kernel.cpp" ] diff --git a/src/core/NEON/NEKernels.h b/src/core/NEON/NEKernels.h index 0f7475c0b5..665c8c7fba 100644 --- a/src/core/NEON/NEKernels.h +++ b/src/core/NEON/NEKernels.h @@ -45,7 +45,6 @@ #include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.h" #include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h" #include "src/core/NEON/kernels/NEGEMMLowpReductionKernel.h" -#include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" #include "src/core/NEON/kernels/NEGatherKernel.h" #include "src/core/NEON/kernels/NEGenerateProposalsLayerKernel.h" #include "src/core/NEON/kernels/NEIm2ColKernel.h" diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp deleted file mode 100644 index b4a3bb5e77..0000000000 --- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp +++ /dev/null @@ -1,1170 +0,0 @@ -/* - * Copyright (c) 2017-2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" - -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/Utils.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/core/Window.h" -#include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "src/core/CPP/Validate.h" -#include "src/core/NEON/NEFixedPoint.h" -#include "src/core/helpers/AutoConfiguration.h" -#include "src/core/helpers/WindowHelpers.h" -#include "src/core/utils/helpers/float_ops.h" - -#include - -namespace arm_compute -{ -namespace -{ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha) -{ - const auto width_matrix_b = static_cast(output->info()->dimension(0)); - const auto in_b_stride = static_cast(input1->info()->strides_in_bytes()[1] / input1->info()->element_size()); - const auto num_elems_vec_a = static_cast(input0->info()->dimension(0)); - - // The implementation computes 32 elements per iteration - const int window_start_x = 32 * info.thread_id; - const int window_step_x = 32 * info.num_threads; - const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; - ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x"); - - Window win_out(window); - win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); - win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); - - Window win_a(window); - win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); - - Window win_b; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - if(input1->info()->num_dimensions() >= 3) - { - win_b = window; - } - win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); - win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); - - Iterator ina(input0, win_a); - Iterator inb(input1, win_b); - Iterator out(output, win_out); - - const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); - - const float16x8_t alpha_f16 = vdupq_n_f16(alpha); - - execute_window_loop(win_out, [&](const Coordinates &) - { - int x = window_start_x; - // Here we don't check for x lower equal than (window_end_x - window_step_x) because of - // window_end_x is computed above which may cause out-of-bound writes to the output. - for(; x < (window_end_x - window_step_x); x += window_step_x) - { - if(x > width_matrix_b) - { - return; - } - - auto matrix_b = reinterpret_cast(inb.ptr()) + x; - - float16x8_t acc0 = vdupq_n_f16(0.f); - float16x8_t acc1 = vdupq_n_f16(0.f); - float16x8_t acc2 = vdupq_n_f16(0.f); - float16x8_t acc3 = vdupq_n_f16(0.f); - - auto vec_a = reinterpret_cast(ina.ptr()); - const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4);) - { - const float16x4_t a0l = vld1_f16(vec_a); - - float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); - float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); - float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); - float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0)); - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1)); - - matrix_b += 2 * in_b_stride; - - b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); - b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); - b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); - b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2)); - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3)); - - vec_a += 4; - matrix_b += 2 * in_b_stride; - } - - for(; vec_a < vec_a_end_addr; ++vec_a) - { - const float16_t a0 = *vec_a; - const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0)); - acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0)); - acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0)); - acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0)); - - matrix_b += in_b_stride; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc0 = vmulq_f16(acc0, alpha_f16); - acc1 = vmulq_f16(acc1, alpha_f16); - acc2 = vmulq_f16(acc2, alpha_f16); - acc3 = vmulq_f16(acc3, alpha_f16); - } - - auto vec_out = reinterpret_cast(out.ptr()) + x; - - vst1q_f16(vec_out + 0, acc0); - vst1q_f16(vec_out + 8, acc1); - vst1q_f16(vec_out + 16, acc2); - vst1q_f16(vec_out + 24, acc3); - } - - for(; x < window_end_x; ++x) - { - if(x > width_matrix_b) - { - return; - } - - auto matrix_b = reinterpret_cast(inb.ptr()) + x; - - float16x4_t vacc = vdup_n_f16(0.f); - - auto vec_a = reinterpret_cast(ina.ptr()); - const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) - { - const float16x4_t a0l = vld1_f16(vec_a); - - const float16x4_t b_col = - { - *(matrix_b + 0 * in_b_stride), - *(matrix_b + 1 * in_b_stride), - *(matrix_b + 2 * in_b_stride), - *(matrix_b + 3 * in_b_stride), - }; - - vacc = vadd_f16(vacc, vmul_f16(a0l, b_col)); - - matrix_b += 4 * in_b_stride; - } - - float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3); - - for(; vec_a < vec_a_end_addr; ++vec_a) - { - const float16_t a0 = *vec_a; - const float16_t b00 = *matrix_b; - - acc += b00 * a0; - - matrix_b += in_b_stride; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc *= static_cast(alpha); - } - - auto vec_out = reinterpret_cast(out.ptr()) + x; - - *(vec_out) = acc; - } - }, - ina, inb, out); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - -void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, const ThreadInfo &info, float alpha) -{ - const auto width_matrix_b = static_cast(output->info()->dimension(0)); - const auto in_b_stride = static_cast(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type())); - const auto num_elems_vec_a = static_cast(input0->info()->dimension(0)); - - // The implementation computes 16 elements per iteration - const int window_start_x = 16 * info.thread_id; - const int window_step_x = 16 * info.num_threads; - // Make sure (window_end_x - window_start_x) is a multiple of window_step_x - const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; - - Window win_out(window); - win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); - win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); - - Window win_a(window); - win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); - - Window win_b; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - if(input1->info()->num_dimensions() >= 3) - { - win_b = window; - } - win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); - win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); - - Iterator ina(input0, win_a); - Iterator inb(input1, win_b); - Iterator out(output, win_out); - - const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); - - const float32x4_t alpha_f32 = vdupq_n_f32(alpha); - - execute_window_loop(win_out, [&](const Coordinates &) - { - int x = window_start_x; - // Here we don't check for x lower equal than (window_end_x - window_step_x) because of - // window_end_x is computed above which may cause out-of-bound writes to the output. - for(; x < (window_end_x - window_step_x); x += window_step_x) - { - if(x > width_matrix_b) - { - return; - } - - float32x4_t acc0 = vdupq_n_f32(0.f); - float32x4_t acc1 = vdupq_n_f32(0.f); - float32x4_t acc2 = vdupq_n_f32(0.f); - float32x4_t acc3 = vdupq_n_f32(0.f); - - auto vec_a = reinterpret_cast(ina.ptr()); - auto matrix_b = reinterpret_cast(inb.ptr()) + x; - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); -#endif /* __arm__ */ - - auto vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4);) - { - float32x2_t a0l = vld1_f32(vec_a); - - float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - - float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); - float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); - float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); - float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); -#endif /* __arm__ */ - - acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); - acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); - acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); - acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); - - acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); - acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); - acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); - acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); - - vec_a += 2; - matrix_b += 2 * in_b_stride; - - a0l = vld1_f32(vec_a); - - b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - - b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); - b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); - b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); - b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); - - acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); - acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); - acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); - acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); - - acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); - acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); - acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); - acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); - - vec_a += 2; - matrix_b += 2 * in_b_stride; - } - - for(; vec_a < vec_a_end_addr; ++vec_a) - { - const float a0 = *vec_a; - - const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - - acc0 = vmlaq_n_f32(acc0, b00, a0); - acc1 = vmlaq_n_f32(acc1, b01, a0); - acc2 = vmlaq_n_f32(acc2, b02, a0); - acc3 = vmlaq_n_f32(acc3, b03, a0); - - matrix_b += in_b_stride; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc0 = vmulq_f32(acc0, alpha_f32); - acc1 = vmulq_f32(acc1, alpha_f32); - acc2 = vmulq_f32(acc2, alpha_f32); - acc3 = vmulq_f32(acc3, alpha_f32); - } - - const auto vec_out = reinterpret_cast(out.ptr()) + x; - - vst1q_f32(vec_out + 0, acc0); - vst1q_f32(vec_out + 4, acc1); - vst1q_f32(vec_out + 8, acc2); - vst1q_f32(vec_out + 12, acc3); - } - - // Left-over loop - for(; x < window_end_x; ++x) - { - if(x > width_matrix_b) - { - return; - } - - float32x4_t vacc = vdupq_n_f32(0.f); - - auto vec_a = reinterpret_cast(ina.ptr()); - auto matrix_b = reinterpret_cast(inb.ptr()) + x; - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); -#endif /* __arm__ */ - - auto vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) - { - const float32x4_t a0l = vld1q_f32(vec_a); - - const float32x4_t b_col = - { - *(matrix_b + 0 * in_b_stride), - *(matrix_b + 1 * in_b_stride), - *(matrix_b + 2 * in_b_stride), - *(matrix_b + 3 * in_b_stride), - }; - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); -#endif /* __arm__ */ - - vacc = vmlaq_f32(vacc, b_col, a0l); - - matrix_b += 4 * in_b_stride; - } - - float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3); - - for(; vec_a < vec_a_end_addr; ++vec_a) - { - const float a0 = *vec_a; - - const float b00 = *matrix_b; - - acc += b00 * a0; - - matrix_b += in_b_stride; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc *= alpha; - } - - const auto vec_out = reinterpret_cast(out.ptr()) + x; - - *vec_out = acc; - } - }, - ina, inb, out); -} - -void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) -{ - const int out_width = static_cast(output->info()->dimension(0)); - const int out_height = static_cast(output->info()->dimension(1)); - const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()); - const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type()); - const size_t out_stride2 = out_stride1 * 2; - const size_t out_stride3 = out_stride1 * 3; - const int num_elems_matrix_b_x = input1->info()->dimension(0); - - // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix - Window win_a(window); - win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); - - Window win_b; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - if(input1->info()->num_dimensions() >= 3) - { - win_b = window; - } - // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the output matrix - // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4 - win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride)); - win_b.set(Window::DimY, Window::Dimension(0, 0, 0)); - - Iterator ina(input0, win_a); - Iterator inb(input1, win_b); - Iterator out(output, window); - - const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); - - const float32x4_t alpha_f32 = vdupq_n_f32(alpha); - - // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW - // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration - // All the values needed for computing a single 4x4 block will be read from consecutive memory positions - execute_window_loop(window, [&](const Coordinates & id) - { - auto mtx_a0 = reinterpret_cast(ina.ptr()); - auto mtx_b0 = reinterpret_cast(inb.ptr()); - auto mtx_b1 = mtx_b0 + in_b_stride; - - float32x4_t acc00 = vdupq_n_f32(0.f); - float32x4_t acc10 = vdupq_n_f32(0.f); - float32x4_t acc20 = vdupq_n_f32(0.f); - float32x4_t acc30 = vdupq_n_f32(0.f); - - float32x4_t acc01 = vdupq_n_f32(0.f); - float32x4_t acc11 = vdupq_n_f32(0.f); - float32x4_t acc21 = vdupq_n_f32(0.f); - float32x4_t acc31 = vdupq_n_f32(0.f); - -#if __arm__ - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_a0))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_b0))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_b1))); -#endif /* __arm__ */ - - auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; - for(; mtx_b0 <= (mtx_b0_end_addr - 32);) - { - float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0); - float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1); - float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2); - float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3); - - float32x4_t b00 = vld1q_f32(mtx_b0); - float32x4_t b10 = vld1q_f32(mtx_b1); - float32x4_t b01 = vld1q_f32(mtx_b0 + 4); - float32x4_t b11 = vld1q_f32(mtx_b1 + 4); - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_a0))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b0))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b1))); -#endif /* __arm__ */ - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4); - float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5); - float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6); - float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b01, a4); - acc10 = vmlaq_f32(acc10, b01, a5); - acc20 = vmlaq_f32(acc20, b01, a6); - acc30 = vmlaq_f32(acc30, b01, a7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b11, a4); - acc11 = vmlaq_f32(acc11, b11, a5); - acc21 = vmlaq_f32(acc21, b11, a6); - acc31 = vmlaq_f32(acc31, b11, a7); - - mtx_a0 += 8; - mtx_b0 += 8; - mtx_b1 += 8; - - a0 = vld1q_dup_f32(mtx_a0 + 0); - a1 = vld1q_dup_f32(mtx_a0 + 1); - a2 = vld1q_dup_f32(mtx_a0 + 2); - a3 = vld1q_dup_f32(mtx_a0 + 3); - - b00 = vld1q_f32(mtx_b0); - b10 = vld1q_f32(mtx_b1); - b01 = vld1q_f32(mtx_b0 + 4); - b11 = vld1q_f32(mtx_b1 + 4); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - a4 = vld1q_dup_f32(mtx_a0 + 4); - a5 = vld1q_dup_f32(mtx_a0 + 5); - a6 = vld1q_dup_f32(mtx_a0 + 6); - a7 = vld1q_dup_f32(mtx_a0 + 7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b01, a4); - acc10 = vmlaq_f32(acc10, b01, a5); - acc20 = vmlaq_f32(acc20, b01, a6); - acc30 = vmlaq_f32(acc30, b01, a7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b11, a4); - acc11 = vmlaq_f32(acc11, b11, a5); - acc21 = vmlaq_f32(acc21, b11, a6); - acc31 = vmlaq_f32(acc31, b11, a7); - - mtx_a0 += 8; - mtx_b0 += 8; - mtx_b1 += 8; - - a0 = vld1q_dup_f32(mtx_a0 + 0); - a1 = vld1q_dup_f32(mtx_a0 + 1); - a2 = vld1q_dup_f32(mtx_a0 + 2); - a3 = vld1q_dup_f32(mtx_a0 + 3); - b00 = vld1q_f32(mtx_b0); - b10 = vld1q_f32(mtx_b1); - b01 = vld1q_f32(mtx_b0 + 4); - b11 = vld1q_f32(mtx_b1 + 4); - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_a0))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b0))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b1))); -#endif /* __arm__ */ - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - a4 = vld1q_dup_f32(mtx_a0 + 4); - a5 = vld1q_dup_f32(mtx_a0 + 5); - a6 = vld1q_dup_f32(mtx_a0 + 6); - a7 = vld1q_dup_f32(mtx_a0 + 7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b01, a4); - acc10 = vmlaq_f32(acc10, b01, a5); - acc20 = vmlaq_f32(acc20, b01, a6); - acc30 = vmlaq_f32(acc30, b01, a7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b11, a4); - acc11 = vmlaq_f32(acc11, b11, a5); - acc21 = vmlaq_f32(acc21, b11, a6); - acc31 = vmlaq_f32(acc31, b11, a7); - - mtx_a0 += 8; - mtx_b0 += 8; - mtx_b1 += 8; - - a0 = vld1q_dup_f32(mtx_a0 + 0); - a1 = vld1q_dup_f32(mtx_a0 + 1); - a2 = vld1q_dup_f32(mtx_a0 + 2); - a3 = vld1q_dup_f32(mtx_a0 + 3); - b00 = vld1q_f32(mtx_b0); - b10 = vld1q_f32(mtx_b1); - b01 = vld1q_f32(mtx_b0 + 4); - b11 = vld1q_f32(mtx_b1 + 4); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - a4 = vld1q_dup_f32(mtx_a0 + 4); - a5 = vld1q_dup_f32(mtx_a0 + 5); - a6 = vld1q_dup_f32(mtx_a0 + 6); - a7 = vld1q_dup_f32(mtx_a0 + 7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b01, a4); - acc10 = vmlaq_f32(acc10, b01, a5); - acc20 = vmlaq_f32(acc20, b01, a6); - acc30 = vmlaq_f32(acc30, b01, a7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b11, a4); - acc11 = vmlaq_f32(acc11, b11, a5); - acc21 = vmlaq_f32(acc21, b11, a6); - acc31 = vmlaq_f32(acc31, b11, a7); - - mtx_a0 += 8; - mtx_b0 += 8; - mtx_b1 += 8; - } - - for(; mtx_b0 < mtx_b0_end_addr;) - { - float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0); - float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1); - float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2); - float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3); - float32x4_t b00 = vld1q_f32(mtx_b0); - float32x4_t b10 = vld1q_f32(mtx_b1); - -#if __arm__ - asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_a0))); - asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_b0))); - asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_b1))); -#endif /* __arm__ */ - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - mtx_a0 += 4; - mtx_b0 += 4; - mtx_b1 += 4; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc00 = vmulq_f32(acc00, alpha_f32); - acc10 = vmulq_f32(acc10, alpha_f32); - acc20 = vmulq_f32(acc20, alpha_f32); - acc30 = vmulq_f32(acc30, alpha_f32); - acc01 = vmulq_f32(acc01, alpha_f32); - acc11 = vmulq_f32(acc11, alpha_f32); - acc21 = vmulq_f32(acc21, alpha_f32); - acc31 = vmulq_f32(acc31, alpha_f32); - } - - const auto mtx_out0 = reinterpret_cast(out.ptr()); - const auto mtx_out1 = mtx_out0 + 4; - - if(id.x() < (out_width - 8)) - { - vst1q_f32(mtx_out0, acc00); - vst1q_f32(mtx_out1, acc01); - if(id.y() + 1 < out_height) - { - vst1q_f32(mtx_out0 + out_stride1, acc10); - vst1q_f32(mtx_out1 + out_stride1, acc11); - if(id.y() + 2 < out_height) - { - vst1q_f32(mtx_out0 + out_stride2, acc20); - vst1q_f32(mtx_out1 + out_stride2, acc21); - if(id.y() + 3 < out_height) - { - vst1q_f32(mtx_out0 + out_stride3, acc30); - vst1q_f32(mtx_out1 + out_stride3, acc31); - } - } - } - } - else if(id.x() < (out_width - 4)) - { - vst1q_f32(mtx_out0, acc00); - if(id.y() + 1 < out_height) - { - vst1q_f32(mtx_out0 + out_stride1, acc10); - if(id.y() + 2 < out_height) - { - vst1q_f32(mtx_out0 + out_stride2, acc20); - if(id.y() + 3 < out_height) - { - vst1q_f32(mtx_out0 + out_stride3, acc30); - } - } - } - // Left-over columns - const int columns_left = out_width - id.x() - 4; - for(auto x = 0; x < columns_left; ++x) - { - *(mtx_out1 + x) = acc01[x]; - if(id.y() + 1 < out_height) - { - *(mtx_out1 + x + out_stride1) = acc11[x]; - if(id.y() + 2 < out_height) - { - *(mtx_out1 + x + out_stride2) = acc21[x]; - if(id.y() + 3 < out_height) - { - *(mtx_out1 + x + out_stride3) = acc31[x]; - } - } - } - } - } - else - { - // Left-over columns - const int columns_left = out_width - id.x(); - for(int x = 0; x < columns_left; ++x) - { - *(mtx_out0 + x) = acc00[x]; - if(id.y() + 1 < out_height) - { - *(mtx_out0 + x + out_stride1) = acc10[x]; - if(id.y() + 2 < out_height) - { - *(mtx_out0 + x + out_stride2) = acc20[x]; - if(id.y() + 3 < out_height) - { - *(mtx_out0 + x + out_stride3) = acc30[x]; - } - } - } - } - } - }, - ina, inb, out); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) -{ - const int out_width = static_cast(output->info()->dimension(0)); - const int out_height = static_cast(output->info()->dimension(1)); - const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()); - const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type()); - const int num_elems_matrix_b_x = input1->info()->dimension(0); - - // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix - Window win_a(window); - win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); - - Window win_b; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - if(input1->info()->num_dimensions() >= 3) - { - win_b = window; - } - // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the output matrix - win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride)); - win_b.set(Window::DimY, Window::Dimension(0, 1, 0)); - - Iterator ina(input0, win_a); - Iterator inb(input1, win_b); - Iterator out(output, window); - - const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); - - const float16x8_t alpha_f16 = vdupq_n_f16(alpha); - - execute_window_loop(window, [&](const Coordinates & id) - { - const auto *mtx_a0 = reinterpret_cast(ina.ptr()); - const auto *mtx_b0 = reinterpret_cast(inb.ptr()); - auto *mtx_out = reinterpret_cast(out.ptr()); - float16x8x4_t c = - { - { - vdupq_n_f16(0.f), - vdupq_n_f16(0.f), - vdupq_n_f16(0.f), - vdupq_n_f16(0.f) - } - }; - - /* - This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values) - |a00 a01 a02 a03 | a04 a05 a06 a07| - |a10 a11 a12 a13 | a14 a15 a16 a17| - |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ... - |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ... - |a40 a41 a42 a43 | a44 a45 a46 a47| - |a50 a51 a52 a53 | a54 a55 a56 a57| - |a60 a61 a62 a63 | a64 a65 a66 a67| - |a70 a71 a72 a73 | a74 a75 a76 a77| - - After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ] - - B Matrix has been transposed as shown below - - |b00 b01 b02 b03 b04 b05 b06 b07| - |b10 b11 b12 b13 b14 b15 b16 b17| - |b20 b21 b22 b23 b24 b25 b26 b27| - |b30 b31 b32 b33 b34 b35 b36 b37| - -------------------> - - |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37| - - c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30 - c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31 - - The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size. - */ - const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; - - for(; mtx_b0 <= (mtx_b0_end_addr - 32);) - - { - const float16x8_t p00 = vld1q_f16(mtx_a0); - const float16x8_t p02 = vld1q_f16(mtx_a0 + 8); - - const float16x8_t q00 = vld1q_f16(mtx_b0); - const float16x8_t q02 = vld1q_f16(mtx_b0 + 8); - const float16x8_t q04 = vld1q_f16(mtx_b0 + 16); - const float16x8_t q06 = vld1q_f16(mtx_b0 + 24); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3))); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7))); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3))); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7))); - - mtx_a0 += 16; - mtx_b0 += 32; - } - - for(; mtx_b0 < mtx_b0_end_addr;) - - { - const float16x4_t p00 = vld1_f16(mtx_a0); - const float16x8_t q00 = vld1q_f16(mtx_b0); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3))); - - mtx_a0 += 4; - mtx_b0 += 8; - } - - if(multiply_alpha) - { - c.val[0] = vmulq_f16(c.val[0], alpha_f16); - c.val[1] = vmulq_f16(c.val[1], alpha_f16); - c.val[2] = vmulq_f16(c.val[2], alpha_f16); - c.val[3] = vmulq_f16(c.val[3], alpha_f16); - } - - if(id.x() < (out_width - 8)) - { - vst1q_f16(mtx_out, c.val[0]); - if(id.y() + 1 < out_height) - { - vst1q_f16(mtx_out + 1 * out_stride, c.val[1]); - if(id.y() + 2 < out_height) - { - vst1q_f16(mtx_out + 2 * out_stride, c.val[2]); - if(id.y() + 3 < out_height) - { - vst1q_f16(mtx_out + 3 * out_stride, c.val[3]); - } - } - } - } - else - { - // Left-over columns - const int columns_left = out_width - id.x(); - for(int x = 0; x < columns_left; ++x) - { - *(mtx_out + x) = c.val[0][x]; - if(id.y() + 1 < out_height) - { - *(mtx_out + x + 1 * out_stride) = c.val[1][x]; - if(id.y() + 2 < out_height) - { - *(mtx_out + x + 2 * out_stride) = c.val[2][x]; - if(id.y() + 3 < out_height) - { - *(mtx_out + x + 3 * out_stride) = c.val[3][x]; - } - } - } - } - } - }, - ina, inb, out); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - -inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info) -{ - ARM_COMPUTE_UNUSED(alpha); - - ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input0); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output); - - if(!is_interleaved) - { - ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1)); - - if(output->total_size() != 0) - { - ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0)); - ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output); - } - } - else - { - const int m = reshape_info.m(); - const int n = reshape_info.n(); - const int k = reshape_info.k(); - const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width(); - const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height(); - - /* Interleave */ - TensorShape tensor_shape0{ input0->tensor_shape() }; - tensor_shape0.set(0, k); - tensor_shape0.set(1, m); - - const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0); - const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0); - - if(n != 0) /* Transpose */ - { - TensorShape tensor_shape1{ input1->tensor_shape() }; - tensor_shape1.set(0, n); - tensor_shape1.set(1, k); - - const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1); - const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1); - } - - if(output->total_size() != 0) - { - if(n != 0) - { - ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast(n)); - } - ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast(m)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output); - } - } - - return Status{}; -} -} // namespace - -NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel() - : _input0(nullptr), _input1(nullptr), _output(nullptr), _alpha(1.0f) -{ -} - -void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output); - - // Output tensor auto inizialitation if not yet initialized - TensorShape tensor_shape{ input0->info()->tensor_shape() }; - tensor_shape.set(0, is_interleaved ? reshape_info.n() : input1->info()->dimension(0)); - tensor_shape.set(1, is_interleaved ? reshape_info.m() : input0->info()->dimension(1)); - - auto_init_if_empty(*output->info(), input0->info()->clone()->set_tensor_shape(tensor_shape)); - - // Perform validate step - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), alpha, is_interleaved, reshape_info)); - - _input0 = input0; - _input1 = input1; - _output = output; - _alpha = alpha; - - // Configure kernel window - Window win{}; - - // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication - if((output->info()->dimension(1) == 1)) - { - const unsigned int num_elems_processed_per_iteration_x = (input0->info()->data_type() == DataType::F32) ? 16 : 32; - - win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x)); - } - else - { - constexpr unsigned int num_elems_processed_per_iteration_x = 8; - constexpr unsigned int num_elems_processed_per_iteration_y = 4; - - win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); - } - - INEKernel::configure(win); -} - -Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, - const GEMMReshapeInfo &reshape_info) -{ - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, is_interleaved, reshape_info)); - - return Status{}; -} - -void NEGEMMMatrixMultiplyKernel::run(const Window &window, const ThreadInfo &info) -{ - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - - // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication - const bool is_output_vector = (_output->info()->dimension(1) == 1); - switch(_input0->info()->data_type()) - { - case DataType::F32: - { - is_output_vector ? vector_matrix_multiply_f32(_input0, _input1, _output, window, info, _alpha) : - matrix_matrix_multiply_f32(_input0, _input1, _output, window, _alpha); - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - is_output_vector ? vector_matrix_multiply_f16(_input0, _input1, _output, window, info, _alpha) : - matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha); - break; - } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - { - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - } -} -} // namespace arm_compute diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h deleted file mode 100644 index 4341ff00df..0000000000 --- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) 2017-2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#ifndef ARM_COMPUTE_NEGEMMMATRIXMULTIPLYKERNEL_H -#define ARM_COMPUTE_NEGEMMMATRIXMULTIPLYKERNEL_H - -#include "src/core/NEON/INEKernel.h" - -namespace arm_compute -{ -class ITensor; - -/** Kernel to multiply two input matrices "A" and "B". All elements of the output matrix/vector will be multiplied by alpha after the matrix multiplication - * - * @note If the output tensor is a matrix, the implementation assumes that the input tensors @p input0 and @p input1 are both matrices and reshaped respectively with @ref cpu::kernels::CpuGemmInterleave4x4Kernel and @ref cpu::kernels::CpuGemmTranspose1xWKernel - * @note If the output tensor is a vector and the data type is F32, the implementation assumes that the first input tensor @p input0 is a vector and the second input tensor @p input1 a matrix. The implementation also assumes that both tensors have not been reshaped - * - */ -class NEGEMMMatrixMultiplyKernel : public INEKernel -{ -public: - const char *name() const override - { - return "NEGEMMMatrixMultiplyKernel"; - } - /** Constructor */ - NEGEMMMatrixMultiplyKernel(); - /** Prevent instances of this class from being copied (As this class contains pointers) */ - NEGEMMMatrixMultiplyKernel(const NEGEMMMatrixMultiplyKernel &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - NEGEMMMatrixMultiplyKernel &operator=(const NEGEMMMatrixMultiplyKernel &) = delete; - /** Allow instances of this class to be moved */ - NEGEMMMatrixMultiplyKernel(NEGEMMMatrixMultiplyKernel &&) = default; - /** Allow instances of this class to be moved */ - NEGEMMMatrixMultiplyKernel &operator=(NEGEMMMatrixMultiplyKernel &&) = default; - /** Initialise the kernel's input and output. - * - * @note If the output tensor is a matrix, the input matrices @p input0 and @p input1 should be the output of the kernels: @ref cpu::kernels::CpuGemmInterleave4x4Kernel and @ref cpu::kernels::CpuGemmTranspose1xWKernel - * These two kernels change the layout of the original matrices to be more cache-friendly. - * - * @param[in] input0 Input tensor containing the interleaved Matrix A or the vector A. Data types supported: F16/F32 - * @param[in] input1 Input tensor containing the transposed Matrix B if the first input tensor A is not a vector. - * If the output tensor is a vector, input1 must contain the matrix B not reshaped. Data type supported: same as @p input0 - * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] alpha Weight of the matrix product - * @param[in] is_interleaved (Optional) True if input0 and input1 have been reshaped respectively using @ref cpu::kernels::CpuGemmInterleave4x4Kernel and @ref cpu::kernels::CpuGemmTranspose1xWKernel - * @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped - */ - void configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo()); - /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMMatrixMultiplyKernel - * - * @param[in] input0 Input tensor containing the interleaved Matrix A or the vector A. Data types supported: F16/F32 - * @param[in] input1 Input tensor containing the transposed Matrix B if the first input tensor A is not a vector. - * If the output tensor is a vector, input1 must contain the matrix B not reshaped. Data type supported: same as @p input0 - * @param[in] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] alpha Weight of the matrix product - * @param[in] is_interleaved (Optional) True if input0 and input1 have been reshaped respectively using @ref cpu::kernels::CpuGemmInterleave4x4Kernel and @ref cpu::kernels::CpuGemmTranspose1xWKernel - * @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how the matrix A and matrix B have been reshaped - * - * @return a status - */ - static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info); - - // Inherited methods overridden: - void run(const Window &window, const ThreadInfo &info) override; - -private: - const ITensor *_input0; - const ITensor *_input1; - ITensor *_output; - float _alpha; -}; -} // namespace arm_compute -#endif /*ARM_COMPUTE_NEGEMMMATRIXMULTIPLYKERNEL_H*/ diff --git a/src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.h b/src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.h index 216e61b5d5..c8e6fa9589 100644 --- a/src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.h +++ b/src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.h @@ -38,7 +38,7 @@ namespace kernels * @note [ MTX_OUT = MTX_0 + beta * MTX_1 ] with MTX_0 and MTX_1 of the same size * * @note This stage is used to finalize the GEMM result and it is computed if and only if beta != 0.0. In case this kernel is used for finalizing GEMM result, we have: - * - MTX_0 = A * B * alpha, where MTX_0 is the output of @ref NEGEMMMatrixMultiplyKernel + * - MTX_0 = A * B * alpha, where MTX_0 is the output of @ref CpuGemmMatrixMultiplyKernel * - MTX_1 = C */ class CpuGemmMatrixAdditionKernel : public ICpuKernel @@ -52,7 +52,7 @@ public: * @note The input and output tensor must have the same dimensions * * @param[in] src Input tensor info (Matrix C). Data types supported: F16/F32 - * @param[in, out] dst Output tensor info. If this kernel is used to finalize the GEMM result, output contains the result obtained by the kernel @ref NEGEMMMatrixMultiplyKernel. Data type supported: the same as @p src. + * @param[in, out] dst Output tensor info. If this kernel is used to finalize the GEMM result, output contains the result obtained by the kernel @ref CpuGemmMatrixMultiplyKernel. Data type supported: the same as @p src. * @param[in] beta Weight of matrix C */ void configure(const ITensorInfo *src, ITensorInfo *dst, float beta); diff --git a/src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp b/src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp new file mode 100644 index 0000000000..d86ea064de --- /dev/null +++ b/src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp @@ -0,0 +1,1174 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h" + +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Types.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/CPP/Validate.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/utils/helpers/float_ops.h" + +#include + +namespace arm_compute +{ +namespace cpu +{ +namespace kernels +{ +namespace +{ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void vector_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +{ + const auto width_matrix_b = static_cast(dst->info()->dimension(0)); + const auto in_b_stride = static_cast(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size()); + const auto num_elems_vec_a = static_cast(lhs->info()->dimension(0)); + + // The implementation computes 32 elements per iteration + const int window_start_x = 32 * info.thread_id; + const int window_step_x = 32 * info.num_threads; + const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; + ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x"); + + Window win_out(window); + win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(rhs->info()->num_dimensions() >= 3) + { + win_b = window; + } + win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); + win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Iterator ina(lhs, win_a); + Iterator inb(rhs, win_b); + Iterator out(dst, win_out); + + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float16x8_t alpha_f16 = vdupq_n_f16(alpha); + + execute_window_loop(win_out, [&](const Coordinates &) + { + int x = window_start_x; + // Here we don't check for x lower equal than (window_end_x - window_step_x) because of + // window_end_x is computed above which may cause out-of-bound writes to the dst. + for(; x < (window_end_x - window_step_x); x += window_step_x) + { + if(x > width_matrix_b) + { + return; + } + + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + + float16x8_t acc0 = vdupq_n_f16(0.f); + float16x8_t acc1 = vdupq_n_f16(0.f); + float16x8_t acc2 = vdupq_n_f16(0.f); + float16x8_t acc3 = vdupq_n_f16(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4);) + { + const float16x4_t a0l = vld1_f16(vec_a); + + float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1)); + + matrix_b += 2 * in_b_stride; + + b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3)); + + vec_a += 4; + matrix_b += 2 * in_b_stride; + } + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float16_t a0 = *vec_a; + const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0)); + acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0)); + acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0)); + acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0)); + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc0 = vmulq_f16(acc0, alpha_f16); + acc1 = vmulq_f16(acc1, alpha_f16); + acc2 = vmulq_f16(acc2, alpha_f16); + acc3 = vmulq_f16(acc3, alpha_f16); + } + + auto vec_out = reinterpret_cast(out.ptr()) + x; + + vst1q_f16(vec_out + 0, acc0); + vst1q_f16(vec_out + 8, acc1); + vst1q_f16(vec_out + 16, acc2); + vst1q_f16(vec_out + 24, acc3); + } + + for(; x < window_end_x; ++x) + { + if(x > width_matrix_b) + { + return; + } + + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + + float16x4_t vacc = vdup_n_f16(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) + { + const float16x4_t a0l = vld1_f16(vec_a); + + const float16x4_t b_col = + { + *(matrix_b + 0 * in_b_stride), + *(matrix_b + 1 * in_b_stride), + *(matrix_b + 2 * in_b_stride), + *(matrix_b + 3 * in_b_stride), + }; + + vacc = vadd_f16(vacc, vmul_f16(a0l, b_col)); + + matrix_b += 4 * in_b_stride; + } + + float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3); + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float16_t a0 = *vec_a; + const float16_t b00 = *matrix_b; + + acc += b00 * a0; + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc *= static_cast(alpha); + } + + auto vec_out = reinterpret_cast(out.ptr()) + x; + + *(vec_out) = acc; + } + }, + ina, inb, out); +} +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +{ + const auto width_matrix_b = static_cast(dst->info()->dimension(0)); + const auto in_b_stride = static_cast(rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type())); + const auto num_elems_vec_a = static_cast(lhs->info()->dimension(0)); + + // The implementation computes 16 elements per iteration + const int window_start_x = 16 * info.thread_id; + const int window_step_x = 16 * info.num_threads; + // Make sure (window_end_x - window_start_x) is a multiple of window_step_x + const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; + + Window win_out(window); + win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(rhs->info()->num_dimensions() >= 3) + { + win_b = window; + } + win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); + win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Iterator ina(lhs, win_a); + Iterator inb(rhs, win_b); + Iterator out(dst, win_out); + + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float32x4_t alpha_f32 = vdupq_n_f32(alpha); + + execute_window_loop(win_out, [&](const Coordinates &) + { + int x = window_start_x; + // Here we don't check for x lower equal than (window_end_x - window_step_x) because of + // window_end_x is computed above which may cause out-of-bound writes to the dst. + for(; x < (window_end_x - window_step_x); x += window_step_x) + { + if(x > width_matrix_b) + { + return; + } + + float32x4_t acc0 = vdupq_n_f32(0.f); + float32x4_t acc1 = vdupq_n_f32(0.f); + float32x4_t acc2 = vdupq_n_f32(0.f); + float32x4_t acc3 = vdupq_n_f32(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); +#endif /* __arm__ */ + + auto vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4);) + { + float32x2_t a0l = vld1_f32(vec_a); + + float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + + float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); + float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); + float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); + float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); +#endif /* __arm__ */ + + acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); + acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); + acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); + acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); + + acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); + acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); + acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); + acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); + + vec_a += 2; + matrix_b += 2 * in_b_stride; + + a0l = vld1_f32(vec_a); + + b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + + b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); + b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); + b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); + b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); + + acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); + acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); + acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); + acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); + + acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); + acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); + acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); + acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); + + vec_a += 2; + matrix_b += 2 * in_b_stride; + } + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float a0 = *vec_a; + + const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + + acc0 = vmlaq_n_f32(acc0, b00, a0); + acc1 = vmlaq_n_f32(acc1, b01, a0); + acc2 = vmlaq_n_f32(acc2, b02, a0); + acc3 = vmlaq_n_f32(acc3, b03, a0); + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc0 = vmulq_f32(acc0, alpha_f32); + acc1 = vmulq_f32(acc1, alpha_f32); + acc2 = vmulq_f32(acc2, alpha_f32); + acc3 = vmulq_f32(acc3, alpha_f32); + } + + const auto vec_out = reinterpret_cast(out.ptr()) + x; + + vst1q_f32(vec_out + 0, acc0); + vst1q_f32(vec_out + 4, acc1); + vst1q_f32(vec_out + 8, acc2); + vst1q_f32(vec_out + 12, acc3); + } + + // Left-over loop + for(; x < window_end_x; ++x) + { + if(x > width_matrix_b) + { + return; + } + + float32x4_t vacc = vdupq_n_f32(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); +#endif /* __arm__ */ + + auto vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) + { + const float32x4_t a0l = vld1q_f32(vec_a); + + const float32x4_t b_col = + { + *(matrix_b + 0 * in_b_stride), + *(matrix_b + 1 * in_b_stride), + *(matrix_b + 2 * in_b_stride), + *(matrix_b + 3 * in_b_stride), + }; + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); +#endif /* __arm__ */ + + vacc = vmlaq_f32(vacc, b_col, a0l); + + matrix_b += 4 * in_b_stride; + } + + float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3); + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float a0 = *vec_a; + + const float b00 = *matrix_b; + + acc += b00 * a0; + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc *= alpha; + } + + const auto vec_out = reinterpret_cast(out.ptr()) + x; + + *vec_out = acc; + } + }, + ina, inb, out); +} + +void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +{ + ARM_COMPUTE_UNUSED(info); + const int out_width = static_cast(dst->info()->dimension(0)); + const int out_height = static_cast(dst->info()->dimension(1)); + const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()); + const size_t out_stride1 = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type()); + const size_t out_stride2 = out_stride1 * 2; + const size_t out_stride3 = out_stride1 * 3; + const int num_elems_matrix_b_x = rhs->info()->dimension(0); + + // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(rhs->info()->num_dimensions() >= 3) + { + win_b = window; + } + // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the dst matrix + // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4 + win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride)); + win_b.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Iterator ina(lhs, win_a); + Iterator inb(rhs, win_b); + Iterator out(dst, window); + + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float32x4_t alpha_f32 = vdupq_n_f32(alpha); + + // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW + // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration + // All the values needed for computing a single 4x4 block will be read from consecutive memory positions + execute_window_loop(window, [&](const Coordinates & id) + { + auto mtx_a0 = reinterpret_cast(ina.ptr()); + auto mtx_b0 = reinterpret_cast(inb.ptr()); + auto mtx_b1 = mtx_b0 + in_b_stride; + + float32x4_t acc00 = vdupq_n_f32(0.f); + float32x4_t acc10 = vdupq_n_f32(0.f); + float32x4_t acc20 = vdupq_n_f32(0.f); + float32x4_t acc30 = vdupq_n_f32(0.f); + + float32x4_t acc01 = vdupq_n_f32(0.f); + float32x4_t acc11 = vdupq_n_f32(0.f); + float32x4_t acc21 = vdupq_n_f32(0.f); + float32x4_t acc31 = vdupq_n_f32(0.f); + +#if __arm__ + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_a0))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_b0))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_b1))); +#endif /* __arm__ */ + + auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; + for(; mtx_b0 <= (mtx_b0_end_addr - 32);) + { + float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0); + float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1); + float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2); + float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3); + + float32x4_t b00 = vld1q_f32(mtx_b0); + float32x4_t b10 = vld1q_f32(mtx_b1); + float32x4_t b01 = vld1q_f32(mtx_b0 + 4); + float32x4_t b11 = vld1q_f32(mtx_b1 + 4); + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_a0))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b0))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b1))); +#endif /* __arm__ */ + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4); + float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5); + float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6); + float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b01, a4); + acc10 = vmlaq_f32(acc10, b01, a5); + acc20 = vmlaq_f32(acc20, b01, a6); + acc30 = vmlaq_f32(acc30, b01, a7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b11, a4); + acc11 = vmlaq_f32(acc11, b11, a5); + acc21 = vmlaq_f32(acc21, b11, a6); + acc31 = vmlaq_f32(acc31, b11, a7); + + mtx_a0 += 8; + mtx_b0 += 8; + mtx_b1 += 8; + + a0 = vld1q_dup_f32(mtx_a0 + 0); + a1 = vld1q_dup_f32(mtx_a0 + 1); + a2 = vld1q_dup_f32(mtx_a0 + 2); + a3 = vld1q_dup_f32(mtx_a0 + 3); + + b00 = vld1q_f32(mtx_b0); + b10 = vld1q_f32(mtx_b1); + b01 = vld1q_f32(mtx_b0 + 4); + b11 = vld1q_f32(mtx_b1 + 4); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + a4 = vld1q_dup_f32(mtx_a0 + 4); + a5 = vld1q_dup_f32(mtx_a0 + 5); + a6 = vld1q_dup_f32(mtx_a0 + 6); + a7 = vld1q_dup_f32(mtx_a0 + 7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b01, a4); + acc10 = vmlaq_f32(acc10, b01, a5); + acc20 = vmlaq_f32(acc20, b01, a6); + acc30 = vmlaq_f32(acc30, b01, a7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b11, a4); + acc11 = vmlaq_f32(acc11, b11, a5); + acc21 = vmlaq_f32(acc21, b11, a6); + acc31 = vmlaq_f32(acc31, b11, a7); + + mtx_a0 += 8; + mtx_b0 += 8; + mtx_b1 += 8; + + a0 = vld1q_dup_f32(mtx_a0 + 0); + a1 = vld1q_dup_f32(mtx_a0 + 1); + a2 = vld1q_dup_f32(mtx_a0 + 2); + a3 = vld1q_dup_f32(mtx_a0 + 3); + b00 = vld1q_f32(mtx_b0); + b10 = vld1q_f32(mtx_b1); + b01 = vld1q_f32(mtx_b0 + 4); + b11 = vld1q_f32(mtx_b1 + 4); + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_a0))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b0))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b1))); +#endif /* __arm__ */ + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + a4 = vld1q_dup_f32(mtx_a0 + 4); + a5 = vld1q_dup_f32(mtx_a0 + 5); + a6 = vld1q_dup_f32(mtx_a0 + 6); + a7 = vld1q_dup_f32(mtx_a0 + 7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b01, a4); + acc10 = vmlaq_f32(acc10, b01, a5); + acc20 = vmlaq_f32(acc20, b01, a6); + acc30 = vmlaq_f32(acc30, b01, a7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b11, a4); + acc11 = vmlaq_f32(acc11, b11, a5); + acc21 = vmlaq_f32(acc21, b11, a6); + acc31 = vmlaq_f32(acc31, b11, a7); + + mtx_a0 += 8; + mtx_b0 += 8; + mtx_b1 += 8; + + a0 = vld1q_dup_f32(mtx_a0 + 0); + a1 = vld1q_dup_f32(mtx_a0 + 1); + a2 = vld1q_dup_f32(mtx_a0 + 2); + a3 = vld1q_dup_f32(mtx_a0 + 3); + b00 = vld1q_f32(mtx_b0); + b10 = vld1q_f32(mtx_b1); + b01 = vld1q_f32(mtx_b0 + 4); + b11 = vld1q_f32(mtx_b1 + 4); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + a4 = vld1q_dup_f32(mtx_a0 + 4); + a5 = vld1q_dup_f32(mtx_a0 + 5); + a6 = vld1q_dup_f32(mtx_a0 + 6); + a7 = vld1q_dup_f32(mtx_a0 + 7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b01, a4); + acc10 = vmlaq_f32(acc10, b01, a5); + acc20 = vmlaq_f32(acc20, b01, a6); + acc30 = vmlaq_f32(acc30, b01, a7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b11, a4); + acc11 = vmlaq_f32(acc11, b11, a5); + acc21 = vmlaq_f32(acc21, b11, a6); + acc31 = vmlaq_f32(acc31, b11, a7); + + mtx_a0 += 8; + mtx_b0 += 8; + mtx_b1 += 8; + } + + for(; mtx_b0 < mtx_b0_end_addr;) + { + float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0); + float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1); + float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2); + float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3); + float32x4_t b00 = vld1q_f32(mtx_b0); + float32x4_t b10 = vld1q_f32(mtx_b1); + +#if __arm__ + asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_a0))); + asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_b0))); + asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_b1))); +#endif /* __arm__ */ + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + mtx_a0 += 4; + mtx_b0 += 4; + mtx_b1 += 4; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc00 = vmulq_f32(acc00, alpha_f32); + acc10 = vmulq_f32(acc10, alpha_f32); + acc20 = vmulq_f32(acc20, alpha_f32); + acc30 = vmulq_f32(acc30, alpha_f32); + acc01 = vmulq_f32(acc01, alpha_f32); + acc11 = vmulq_f32(acc11, alpha_f32); + acc21 = vmulq_f32(acc21, alpha_f32); + acc31 = vmulq_f32(acc31, alpha_f32); + } + + const auto mtx_out0 = reinterpret_cast(out.ptr()); + const auto mtx_out1 = mtx_out0 + 4; + + if(id.x() < (out_width - 8)) + { + vst1q_f32(mtx_out0, acc00); + vst1q_f32(mtx_out1, acc01); + if(id.y() + 1 < out_height) + { + vst1q_f32(mtx_out0 + out_stride1, acc10); + vst1q_f32(mtx_out1 + out_stride1, acc11); + if(id.y() + 2 < out_height) + { + vst1q_f32(mtx_out0 + out_stride2, acc20); + vst1q_f32(mtx_out1 + out_stride2, acc21); + if(id.y() + 3 < out_height) + { + vst1q_f32(mtx_out0 + out_stride3, acc30); + vst1q_f32(mtx_out1 + out_stride3, acc31); + } + } + } + } + else if(id.x() < (out_width - 4)) + { + vst1q_f32(mtx_out0, acc00); + if(id.y() + 1 < out_height) + { + vst1q_f32(mtx_out0 + out_stride1, acc10); + if(id.y() + 2 < out_height) + { + vst1q_f32(mtx_out0 + out_stride2, acc20); + if(id.y() + 3 < out_height) + { + vst1q_f32(mtx_out0 + out_stride3, acc30); + } + } + } + // Left-over columns + const int columns_left = out_width - id.x() - 4; + for(auto x = 0; x < columns_left; ++x) + { + *(mtx_out1 + x) = acc01[x]; + if(id.y() + 1 < out_height) + { + *(mtx_out1 + x + out_stride1) = acc11[x]; + if(id.y() + 2 < out_height) + { + *(mtx_out1 + x + out_stride2) = acc21[x]; + if(id.y() + 3 < out_height) + { + *(mtx_out1 + x + out_stride3) = acc31[x]; + } + } + } + } + } + else + { + // Left-over columns + const int columns_left = out_width - id.x(); + for(int x = 0; x < columns_left; ++x) + { + *(mtx_out0 + x) = acc00[x]; + if(id.y() + 1 < out_height) + { + *(mtx_out0 + x + out_stride1) = acc10[x]; + if(id.y() + 2 < out_height) + { + *(mtx_out0 + x + out_stride2) = acc20[x]; + if(id.y() + 3 < out_height) + { + *(mtx_out0 + x + out_stride3) = acc30[x]; + } + } + } + } + } + }, + ina, inb, out); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +{ + ARM_COMPUTE_UNUSED(info); + const int out_width = static_cast(dst->info()->dimension(0)); + const int out_height = static_cast(dst->info()->dimension(1)); + const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()); + const size_t out_stride = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type()); + const int num_elems_matrix_b_x = rhs->info()->dimension(0); + + // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(rhs->info()->num_dimensions() >= 3) + { + win_b = window; + } + // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the dst matrix + win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride)); + win_b.set(Window::DimY, Window::Dimension(0, 1, 0)); + + Iterator ina(lhs, win_a); + Iterator inb(rhs, win_b); + Iterator out(dst, window); + + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float16x8_t alpha_f16 = vdupq_n_f16(alpha); + + execute_window_loop(window, [&](const Coordinates & id) + { + const auto *mtx_a0 = reinterpret_cast(ina.ptr()); + const auto *mtx_b0 = reinterpret_cast(inb.ptr()); + auto *mtx_out = reinterpret_cast(out.ptr()); + float16x8x4_t c = + { + { + vdupq_n_f16(0.f), + vdupq_n_f16(0.f), + vdupq_n_f16(0.f), + vdupq_n_f16(0.f) + } + }; + + /* + This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values) + |a00 a01 a02 a03 | a04 a05 a06 a07| + |a10 a11 a12 a13 | a14 a15 a16 a17| + |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ... + |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ... + |a40 a41 a42 a43 | a44 a45 a46 a47| + |a50 a51 a52 a53 | a54 a55 a56 a57| + |a60 a61 a62 a63 | a64 a65 a66 a67| + |a70 a71 a72 a73 | a74 a75 a76 a77| + + After this operation, the dst matrix will have the following shape: [ height * 4, width / 4 ] + + B Matrix has been transposed as shown below + + |b00 b01 b02 b03 b04 b05 b06 b07| + |b10 b11 b12 b13 b14 b15 b16 b17| + |b20 b21 b22 b23 b24 b25 b26 b27| + |b30 b31 b32 b33 b34 b35 b36 b37| + -------------------> + + |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37| + + c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30 + c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31 + + The size of the dst tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size. + */ + const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; + + for(; mtx_b0 <= (mtx_b0_end_addr - 32);) + + { + const float16x8_t p00 = vld1q_f16(mtx_a0); + const float16x8_t p02 = vld1q_f16(mtx_a0 + 8); + + const float16x8_t q00 = vld1q_f16(mtx_b0); + const float16x8_t q02 = vld1q_f16(mtx_b0 + 8); + const float16x8_t q04 = vld1q_f16(mtx_b0 + 16); + const float16x8_t q06 = vld1q_f16(mtx_b0 + 24); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3))); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7))); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3))); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7))); + + mtx_a0 += 16; + mtx_b0 += 32; + } + + for(; mtx_b0 < mtx_b0_end_addr;) + + { + const float16x4_t p00 = vld1_f16(mtx_a0); + const float16x8_t q00 = vld1q_f16(mtx_b0); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3))); + + mtx_a0 += 4; + mtx_b0 += 8; + } + + if(multiply_alpha) + { + c.val[0] = vmulq_f16(c.val[0], alpha_f16); + c.val[1] = vmulq_f16(c.val[1], alpha_f16); + c.val[2] = vmulq_f16(c.val[2], alpha_f16); + c.val[3] = vmulq_f16(c.val[3], alpha_f16); + } + + if(id.x() < (out_width - 8)) + { + vst1q_f16(mtx_out, c.val[0]); + if(id.y() + 1 < out_height) + { + vst1q_f16(mtx_out + 1 * out_stride, c.val[1]); + if(id.y() + 2 < out_height) + { + vst1q_f16(mtx_out + 2 * out_stride, c.val[2]); + if(id.y() + 3 < out_height) + { + vst1q_f16(mtx_out + 3 * out_stride, c.val[3]); + } + } + } + } + else + { + // Left-over columns + const int columns_left = out_width - id.x(); + for(int x = 0; x < columns_left; ++x) + { + *(mtx_out + x) = c.val[0][x]; + if(id.y() + 1 < out_height) + { + *(mtx_out + x + 1 * out_stride) = c.val[1][x]; + if(id.y() + 2 < out_height) + { + *(mtx_out + x + 2 * out_stride) = c.val[2][x]; + if(id.y() + 3 < out_height) + { + *(mtx_out + x + 3 * out_stride) = c.val[3][x]; + } + } + } + } + } + }, + ina, inb, out); +} +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +inline Status validate_arguments(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info) +{ + ARM_COMPUTE_UNUSED(alpha); + + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst); + + if(!is_interleaved) + { + ARM_COMPUTE_RETURN_ERROR_ON(lhs->dimension(0) != rhs->dimension(1)); + + if(dst->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(rhs->dimension(0) != dst->dimension(0)); + ARM_COMPUTE_RETURN_ERROR_ON(lhs->dimension(1) != dst->dimension(1)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); + } + } + else + { + const int m = reshape_info.m(); + const int n = reshape_info.n(); + const int k = reshape_info.k(); + const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width(); + const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height(); + + /* Interleave */ + TensorShape tensor_shape0{ lhs->tensor_shape() }; + tensor_shape0.set(0, k); + tensor_shape0.set(1, m); + + const TensorInfo tensor_info0 = lhs->clone()->set_tensor_shape(tensor_shape0); + const TensorInfo tensor_info_reshaped0 = lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lhs, &tensor_info_reshaped0); + + if(n != 0) /* Transpose */ + { + TensorShape tensor_shape1{ rhs->tensor_shape() }; + tensor_shape1.set(0, n); + tensor_shape1.set(1, k); + + const TensorInfo tensor_info1 = rhs->clone()->set_tensor_shape(tensor_shape1); + const TensorInfo tensor_info_reshaped1 = rhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(rhs, &tensor_info_reshaped1); + } + + if(dst->total_size() != 0) + { + if(n != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(dst->dimension(0) != static_cast(n)); + } + ARM_COMPUTE_RETURN_ERROR_ON(dst->dimension(1) != static_cast(m)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); + } + } + + return Status{}; +} +} // namespace + +void CpuGemmMatrixMultiplyKernel::configure(const ITensorInfo *lhs, const ITensorInfo *rhs, ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); + + // dst tensor auto inizialitation if not yet initialized + TensorShape tensor_shape{ lhs->tensor_shape() }; + tensor_shape.set(0, is_interleaved ? reshape_info.n() : rhs->dimension(0)); + tensor_shape.set(1, is_interleaved ? reshape_info.m() : lhs->dimension(1)); + + auto_init_if_empty(*dst, lhs->clone()->set_tensor_shape(tensor_shape)); + + // Perform validate step + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(lhs, rhs, dst, alpha, is_interleaved, reshape_info)); + + _alpha = alpha; + + // Configure kernel window + Window win{}; + + // Check if the dst tensor is a vector. If so,the kernel runs the vector-matrix multiplication + const bool is_dst_vector = (dst->dimension(1) == 1); + if(is_dst_vector) + { + const unsigned int num_elems_processed_per_iteration_x = (lhs->data_type() == DataType::F32) ? 16 : 32; + + win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x)); + } + else + { + constexpr unsigned int num_elems_processed_per_iteration_x = 8; + constexpr unsigned int num_elems_processed_per_iteration_y = 4; + + win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); + } + + switch(lhs->data_type()) + { + case DataType::F32: + { + _func = (is_dst_vector) ? vector_matrix_multiply_f32 : matrix_matrix_multiply_f32; + break; + } +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + { + _func = (is_dst_vector) ? vector_matrix_multiply_f16 : matrix_matrix_multiply_f16; + break; + } +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + default: + { + ARM_COMPUTE_ERROR("Data type not supported"); + break; + } + } + ICPPKernel::configure(win); +} + +Status CpuGemmMatrixMultiplyKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, + const GEMMReshapeInfo &reshape_info) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(lhs, rhs, dst, alpha, is_interleaved, reshape_info)); + + return Status{}; +} + +void CpuGemmMatrixMultiplyKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); + ARM_COMPUTE_ERROR_ON(tensors.empty()); + ARM_COMPUTE_ERROR_ON(_func == nullptr); + + const ITensor *lhs = tensors.get_const_tensor(TensorType::ACL_SRC_0); + const ITensor *rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1); + ITensor *dst = tensors.get_tensor(TensorType::ACL_DST); + + (*_func)(lhs, rhs, dst, window, info, _alpha); +} + +const char *CpuGemmMatrixMultiplyKernel::name() const +{ + return "CpuGemmMatrixMultiplyKernel"; +} +} // namespace kernels +} // namespace cpu +} // namespace arm_compute diff --git a/src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h b/src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h new file mode 100644 index 0000000000..bf13342739 --- /dev/null +++ b/src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_GEMM_MATRIX_MULTIPLY_KERNEL_H +#define ARM_COMPUTE_CPU_GEMM_MATRIX_MULTIPLY_KERNEL_H + +#include "src/core/common/Macros.h" +#include "src/core/cpu/ICpuKernel.h" + +namespace arm_compute +{ +namespace cpu +{ +namespace kernels +{ +/** Kernel to multiply two input matrices "A" and "B". All elements of the output matrix/vector will be multiplied by alpha after the matrix multiplication + * + * @note If the output tensor is a matrix, the implementation assumes that the input tensors @p lhs and @p rhs are both matrices and reshaped respectively with @ref CpuGemmInterleave4x4Kernel" and @ref CpuGemmTranspose1xWKernel + * @note If the output tensor is a vector and the data type is F32, the implementation assumes that the first input tensor @p lhs is a vector and the second input tensor @p rhs a matrix. The implementation also assumes that both tensors have not been reshaped + * + */ +class CpuGemmMatrixMultiplyKernel : public ICpuKernel +{ +public: + /** Constructor */ + CpuGemmMatrixMultiplyKernel() = default; + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmMatrixMultiplyKernel); + /** Initialise the kernel's input and output. + * + * @note If the output tensor is a matrix, the input matrices @p lhs and @p rhs should be the output of the kernels: @ref CpuGemmInterleave4x4Kernel and @ref CpuGemmTranspose1xWKernel + * These two kernels change the layout of the original matrices to be more cache-friendly. + * + * @param[in] lhs Left-handside tensor info containing the interleaved Matrix A or the vector A. Data types supported: F16/F32 + * @param[in] rhs Right-handside tensor info containing the transposed Matrix B if the first input tensor A is not a vector. + * If the output tensor is a vector, rhs must contain the matrix B not reshaped. Data type supported: same as @p lhs + * @param[out] dst Output tensor to store the result of matrix multiplication. Data type supported: same as @p lhs. + * @param[in] alpha Weight of the matrix product + * @param[in] is_interleaved (Optional) True if lhs and rhs have been reshaped respectively using @ref CpuGemmInterleave4x4Kernel and @ref CpuGemmTranspose1xWKernel + * @param[in] reshape_info (Optional) GEMM reshape info. If is_interleaved_transposed = true, this object must contain the information to understand how @p lhs and @p rhs have been reshaped + */ + void configure(const ITensorInfo *lhs, const ITensorInfo *rhs, ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo()); + /** Static function to check if given info will lead to a valid configuration of @ref CpuGemmMatrixMultiplyKernel + * + * Similar to @ref CpuGemmMatrixMultiplyKernel::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; + const char *name() const override; + +private: + /** Common signature for all the matrix multiply functions + * + * @param[in] lhs Left-handside input tensor. Data types supported: F16/F32 + * @param[in] rhs Right-handside input tensor. Data types supported: same as @p lhs + * @param[out] dst The output tensor. Data type supported: same as @p rhs + * @param[in] window Region on which to execute the kernel. + * @param[in] info Thread info metadata. + * @param[in] alpha Weight of the matrix product. + */ + using GemmFunctionPtr = void(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha); + /** Matrix multiply function to use for the particular tensor types passed to configure() */ + GemmFunctionPtr *_func{ nullptr }; + float _alpha{ 1.f }; +}; +} // namespace kernels +} // namespace cpu +} // namespace arm_compute +#endif /*ARM_COMPUTE_CPU_GEMM_MATRIX_MULTIPLY_KERNEL_H*/ diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 99a7db72b6..a52ca79504 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -34,9 +34,9 @@ #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" #include "src/core/CPP/Validate.h" -#include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" #include "src/core/cpu/kernels/CpuGemmInterleave4x4Kernel.h" #include "src/core/cpu/kernels/CpuGemmMatrixAdditionKernel.h" +#include "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h" #include "src/core/cpu/kernels/CpuGemmTranspose1xWKernel.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/MemoryHelpers.h" @@ -70,7 +70,7 @@ struct NEGEMM::Impl std::unique_ptr interleave_kernel{ nullptr }; std::unique_ptr transpose_kernel{ nullptr }; - std::unique_ptr mm_kernel{ nullptr }; + std::unique_ptr mm_kernel{ nullptr }; std::unique_ptr asm_glue{ nullptr }; std::unique_ptr ma_kernel{ nullptr }; std::unique_ptr alpha_scale_func{ nullptr }; @@ -167,13 +167,13 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe _impl->memory_group.manage(&_impl->tmp_d); } - _impl->mm_kernel = std::make_unique(); + _impl->mm_kernel = std::make_unique(); // Select between GEMV and GEMM if(_impl->run_vector_matrix_multiplication) { // Configure the matrix multiply kernel - _impl->mm_kernel->configure(a, b, _impl->gemm_output_to_use, alpha, false); + _impl->mm_kernel->configure(a->info(), b->info(), _impl->gemm_output_to_use->info(), alpha, false); } else { @@ -213,7 +213,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe _impl->transpose_kernel->configure(b->info(), _impl->tmp_b.info()); // Configure matrix multiplication kernel - _impl->mm_kernel->configure(&_impl->tmp_a, &_impl->tmp_b, _impl->gemm_output_to_use, alpha, true, GEMMReshapeInfo(m, n, k)); + _impl->mm_kernel->configure(_impl->tmp_a.info(), _impl->tmp_b.info(), _impl->gemm_output_to_use->info(), alpha, true, GEMMReshapeInfo(m, n, k)); // Allocate once the all configure methods have been called _impl->tmp_a.allocator()->allocate(); @@ -342,7 +342,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso // Validate matrix multiply auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info))); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info)); + ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info)); if(c != nullptr && gemm_info.reshape_b_only_on_first_run()) { @@ -383,6 +383,7 @@ void NEGEMM::run() } else { + ITensorPack mm_pack{ { ACL_SRC_0, _impl->a }, { ACL_SRC_1, _impl->original_b }, { ACL_DST, _impl->gemm_output_to_use } }; if(!_impl->run_vector_matrix_multiplication) { // Run interleave kernel @@ -395,9 +396,13 @@ void NEGEMM::run() ITensorPack transpose_pack{ { ACL_SRC, _impl->original_b }, { ACL_DST, &_impl->tmp_b } }; NEScheduler::get().schedule_op(_impl->transpose_kernel.get(), Window::DimY, _impl->transpose_kernel->window(), transpose_pack); } + + // Use reshaped matrices + mm_pack.add_const_tensor(ACL_SRC_0, &_impl->tmp_a); + mm_pack.add_const_tensor(ACL_SRC_1, &_impl->tmp_b); } - NEScheduler::get().schedule(_impl->mm_kernel.get(), _impl->run_vector_matrix_multiplication ? Window::DimX : Window::DimY); + NEScheduler::get().schedule_op(_impl->mm_kernel.get(), _impl->run_vector_matrix_multiplication ? Window::DimX : Window::DimY, _impl->mm_kernel->window(), mm_pack); // Run bias addition kernel if(_impl->run_bias_addition) diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp index ddd1bca5cc..36c1943f27 100644 --- a/tests/validation/NEON/GEMM.cpp +++ b/tests/validation/NEON/GEMM.cpp @@ -25,8 +25,8 @@ #include "arm_compute/runtime/NEON/functions/NEGEMM.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" -#include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" #include "src/core/cpu/kernels/CpuGemmInterleave4x4Kernel.h" +#include "src/core/cpu/kernels/CpuGemmMatrixMultiplyKernel.h" #include "src/core/cpu/kernels/CpuGemmTranspose1xWKernel.h" #include "tests/NEON/Accessor.h" #include "tests/NEON/Helper.h" @@ -113,15 +113,15 @@ bool validate_zero_padding_new(unsigned int dim0_value, unsigned int dim1_value) bool validate_gemm_zero_padding(const TensorShape shape0, const TensorShape shape1) { // Create tensors - Tensor in0 = create_tensor(shape0, DataType::F32); - Tensor in1 = create_tensor(shape1, DataType::F32); - Tensor dst; + TensorInfo in0(shape0, 1, DataType::F32); + TensorInfo in1(shape1, 1, DataType::F32); + TensorInfo dst; // Validate zero-padding - NEGEMMMatrixMultiplyKernel gemm; + cpu::kernels::CpuGemmMatrixMultiplyKernel gemm; gemm.configure(&in0, &in1, &dst, 1.0, false); - return in0.info()->padding().empty() && in1.info()->padding().empty() && dst.info()->padding().empty(); + return in0.padding().empty() && in1.padding().empty() && dst.padding().empty(); } } // namespace -- cgit v1.2.1