From 256ac62029d835c55b08951af0bdfa542b878956 Mon Sep 17 00:00:00 2001 From: Dana Zlotnik Date: Wed, 2 Feb 2022 15:06:11 +0200 Subject: Decouple CpuGemmMatrixMultiplyKernel and CpuGemmMatrixAdditionKernel Resolves COMPMID-4629, COMPMID-4631 Change-Id: Idceafc84735116ef63ec13a202895f954b87e32f Signed-off-by: Dana Zlotnik Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7095 Tested-by: Arm Jenkins Reviewed-by: SiCong Li Reviewed-by: Giorgio Arena Comments-Addressed: Arm Jenkins --- Android.bp | 6 + filelist.json | 8 +- src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp | 123 +-- src/cpu/kernels/CpuGemmMatrixAdditionKernel.h | 16 +- src/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp | 1024 +------------------- src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h | 18 +- .../kernels/gemm_matrix_add/generic/neon/fp16.cpp | 38 + .../kernels/gemm_matrix_add/generic/neon/fp32.cpp | 36 + .../kernels/gemm_matrix_add/generic/neon/impl.cpp | 119 +++ .../kernels/gemm_matrix_add/generic/neon/impl.h | 41 + src/cpu/kernels/gemm_matrix_add/list.h | 37 + .../kernels/gemm_matrix_mul/generic/neon/fp16.cpp | 38 + .../kernels/gemm_matrix_mul/generic/neon/fp32.cpp | 36 + .../kernels/gemm_matrix_mul/generic/neon/impl.cpp | 1016 +++++++++++++++++++ .../kernels/gemm_matrix_mul/generic/neon/impl.h | 46 + src/cpu/kernels/gemm_matrix_mul/list.h | 37 + tests/validation/NEON/GEMM.cpp | 36 +- 17 files changed, 1571 insertions(+), 1104 deletions(-) create mode 100644 src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp create mode 100644 src/cpu/kernels/gemm_matrix_add/generic/neon/fp32.cpp create mode 100644 src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp create mode 100644 src/cpu/kernels/gemm_matrix_add/generic/neon/impl.h create mode 100644 src/cpu/kernels/gemm_matrix_add/list.h create mode 100644 src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp create mode 100644 src/cpu/kernels/gemm_matrix_mul/generic/neon/fp32.cpp create mode 100644 src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp create mode 100644 src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h create mode 100644 src/cpu/kernels/gemm_matrix_mul/list.h diff --git a/Android.bp b/Android.bp index ebea73329e..3081386186 100644 --- a/Android.bp +++ b/Android.bp @@ -477,6 +477,12 @@ cc_library_static { "src/cpu/kernels/elementwise_unary/generic/sve/integer.cpp", "src/cpu/kernels/floor/neon/fp16.cpp", "src/cpu/kernels/floor/neon/fp32.cpp", + "src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp", + "src/cpu/kernels/gemm_matrix_add/generic/neon/fp32.cpp", + "src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp", + "src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp", + "src/cpu/kernels/gemm_matrix_mul/generic/neon/fp32.cpp", + "src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp", "src/cpu/kernels/genproposals/generic/neon/fp16.cpp", "src/cpu/kernels/genproposals/generic/neon/fp32.cpp", "src/cpu/kernels/genproposals/generic/neon/impl.cpp", diff --git a/filelist.json b/filelist.json index b672a832b1..88d98ae76e 100644 --- a/filelist.json +++ b/filelist.json @@ -1480,8 +1480,14 @@ "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_6x4/a55.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_6x4/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_8x4/a55.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_8x4/generic.cpp" + "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_8x4/generic.cpp", + "src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp", + "src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp" ], + "fp32":["src/cpu/kernels/gemm_matrix_mul/generic/neon/fp32.cpp", + "src/cpu/kernels/gemm_matrix_add/generic/neon/fp32.cpp"], + "fp16":["src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp", + "src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp"], "estate32": [ "src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp", diff --git a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp index 81376fb029..6399ebbef4 100644 --- a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp +++ b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,11 +28,10 @@ #include "arm_compute/core/Validate.h" #include "src/core/CPP/Validate.h" #include "src/core/NEON/NEFixedPoint.h" +#include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" - -#include - +#include "src/cpu/kernels/gemm_matrix_add/list.h" namespace arm_compute { namespace cpu @@ -41,93 +40,26 @@ namespace kernels { namespace { -void matrix_addition_f32(const ITensor *src, ITensor *dst, const Window &window, float beta) +static const std::vector available_kernels = { - ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); - const float32x4_t beta_f32 = vdupq_n_f32(beta); - - constexpr int window_step_x = 16; - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - - Window win = window.collapse_if_possible(window, Window::DimZ); - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator in(src, win); - Iterator out(dst, win); - - execute_window_loop(win, [&](const Coordinates &) { - const auto in_ptr = reinterpret_cast(in.ptr()); - const auto out_ptr = reinterpret_cast(out.ptr()); - - int x = window_start_x; - for(; x < (window_end_x - window_step_x); x += window_step_x) + "neon_fp32_gemm_matrix_add", + [](const DataTypeISASelectorData & data) { - float32x4x4_t alpha_ab = vld4q_f32(out_ptr + x); - const float32x4x4_t c = vld4q_f32(in_ptr + x); - - // Multiply matrix C by its weight and accumulate - alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32); - alpha_ab.val[1] = vmlaq_f32(alpha_ab.val[1], c.val[1], beta_f32); - alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32); - alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32); - - vst4q_f32(out_ptr + x, alpha_ab); - } - - // Left-over loop - for(; x < window_end_x; ++x) - { - *(out_ptr + x) += *(in_ptr + x) * beta; - } + return (data.dt == DataType::F32); + }, + REGISTER_FP32_NEON(neon_fp32_gemm_matrix_add) }, - in, out); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void matrix_addition_f16(const ITensor *src, ITensor *dst, const Window &window, float beta) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); - const float16x8_t beta_f16 = vdupq_n_f16(beta); - - constexpr int window_step_x = 16; - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - - Window win = window.collapse_if_possible(window, Window::DimZ); - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator in(src, win); - Iterator out(dst, win); - - execute_window_loop(win, [&](const Coordinates &) { - const auto in_ptr = reinterpret_cast(in.ptr()); - const auto out_ptr = reinterpret_cast(out.ptr()); - - int x = window_start_x; - for(; x < (window_end_x - window_step_x); x += window_step_x) - { - float16x8x2_t alpha_ab = vld2q_f16(out_ptr + x); - const float16x8x2_t c = vld2q_f16(in_ptr + x); - // Multiply matrix C by its weight and accumulate - alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16)); - alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16)); - - vst2q_f16(out_ptr + x, alpha_ab); - } - - // Left-over loop - for(; x < window_end_x; ++x) + "neon_fp16_gemm_matrix_add", + [](const DataTypeISASelectorData & data) { - *(out_ptr + x) += *(in_ptr + x) * static_cast(beta); - } + return (data.dt == DataType::F16) && data.isa.fp16; + }, + REGISTER_FP16_NEON(neon_fp16_gemm_matrix_add) }, - in, out); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +}; } // namespace void CpuGemmMatrixAdditionKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta) @@ -138,22 +70,10 @@ void CpuGemmMatrixAdditionKernel::configure(const ITensorInfo *src, ITensorInfo // Perform validation step ARM_COMPUTE_ERROR_THROW_ON(CpuGemmMatrixAdditionKernel::validate(src, dst, beta)); - _beta = beta; - switch(src->data_type()) - { - case DataType::F32: - _func = &matrix_addition_f32; - break; - case DataType::F16: -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - _func = &matrix_addition_f16; - break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - + _beta = beta; + const auto uk = CpuGemmMatrixAdditionKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() }); + ARM_COMPUTE_ERROR_ON_NULLPTR(uk); + _func = uk->ukernel; // Configure kernel window Window win = calculate_max_window(*src, Steps()); ICPPKernel::configure(win); @@ -195,6 +115,11 @@ const char *CpuGemmMatrixAdditionKernel::name() const { return "CpuGemmMatrixAdditionKernel"; } + +const std::vector &CpuGemmMatrixAdditionKernel::get_available_kernels() +{ + return available_kernels; +} } // namespace kernels } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h index 4a748218d1..cbc5b53087 100644 --- a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h +++ b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.h @@ -43,7 +43,16 @@ namespace kernels */ class CpuGemmMatrixAdditionKernel : public ICpuKernel { +private: + using GemmMatrixAddKernelPtr = std::add_pointer::type; + public: + struct GemmMatrixAddKernel + { + const char *name; + const DataTypeISASelectorPtr is_selected; + GemmMatrixAddKernelPtr ukernel; + }; CpuGemmMatrixAdditionKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmMatrixAdditionKernel); /** Initialise the kernel's input and output. @@ -69,6 +78,8 @@ public: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + static const std::vector &get_available_kernels(); + private: /** Common signature for all the matrix addition functions * @@ -77,10 +88,9 @@ private: * @param[in] window Region on which to execute the kernel. * @param[in] beta Weight of matrix C */ - using MatrixAdditionFunctionPtr = void (*)(const ITensor *src, ITensor *dst, const Window &window, float beta); /** Matrix addition function to use for the particular tensor types passed to configure() */ - MatrixAdditionFunctionPtr _func{ nullptr }; - float _beta{ 0.f }; + GemmMatrixAddKernelPtr _func{ nullptr }; + float _beta{ 0.f }; }; } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp index 93ae90436a..03b372efd4 100644 --- a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp +++ b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,11 +29,10 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "src/core/CPP/Validate.h" +#include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" -#include "src/core/utils/helpers/float_ops.h" - -#include +#include "src/cpu/kernels/gemm_matrix_mul/list.h" namespace arm_compute { @@ -43,985 +42,25 @@ namespace kernels { namespace { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void vector_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +static const std::vector available_kernels = { - const auto width_matrix_b = static_cast(dst->info()->dimension(0)); - const auto in_b_stride = static_cast(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size()); - const auto num_elems_vec_a = static_cast(lhs->info()->dimension(0)); - - // The implementation computes 32 elements per iteration - const int window_start_x = 32 * info.thread_id; - const int window_step_x = 32 * info.num_threads; - const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; - ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x"); - - Window win_out(window); - win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); - win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); - - Window win_a(window); - win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); - - Window win_b; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - if(rhs->info()->num_dimensions() >= 3) { - win_b = window; - } - win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); - win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); - - Iterator ina(lhs, win_a); - Iterator inb(rhs, win_b); - Iterator out(dst, win_out); - - const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); - - const float16x8_t alpha_f16 = vdupq_n_f16(alpha); - - execute_window_loop(win_out, [&](const Coordinates &) - { - int x = window_start_x; - // Here we don't check for x lower equal than (window_end_x - window_step_x) because of - // window_end_x is computed above which may cause out-of-bound writes to the dst. - for(; x < (window_end_x - window_step_x); x += window_step_x) - { - if(x > width_matrix_b) - { - return; - } - - auto matrix_b = reinterpret_cast(inb.ptr()) + x; - - float16x8_t acc0 = vdupq_n_f16(0.f); - float16x8_t acc1 = vdupq_n_f16(0.f); - float16x8_t acc2 = vdupq_n_f16(0.f); - float16x8_t acc3 = vdupq_n_f16(0.f); - - auto vec_a = reinterpret_cast(ina.ptr()); - const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4);) - { - const float16x4_t a0l = vld1_f16(vec_a); - - float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); - float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); - float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); - float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0)); - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1)); - - matrix_b += 2 * in_b_stride; - - b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); - b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); - b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); - b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2)); - acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3)); - acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3)); - acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3)); - acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3)); - - vec_a += 4; - matrix_b += 2 * in_b_stride; - } - - for(; vec_a < vec_a_end_addr; ++vec_a) - { - const float16_t a0 = *vec_a; - const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); - const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); - const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); - const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); - - acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0)); - acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0)); - acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0)); - acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0)); - - matrix_b += in_b_stride; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc0 = vmulq_f16(acc0, alpha_f16); - acc1 = vmulq_f16(acc1, alpha_f16); - acc2 = vmulq_f16(acc2, alpha_f16); - acc3 = vmulq_f16(acc3, alpha_f16); - } - - auto vec_out = reinterpret_cast(out.ptr()) + x; - - vst1q_f16(vec_out + 0, acc0); - vst1q_f16(vec_out + 8, acc1); - vst1q_f16(vec_out + 16, acc2); - vst1q_f16(vec_out + 24, acc3); - } - - for(; x < window_end_x; ++x) + "neon_fp32_gemm_matrix_mul", + [](const DataTypeISASelectorData & data) { - if(x > width_matrix_b) - { - return; - } - - auto matrix_b = reinterpret_cast(inb.ptr()) + x; - - float16x4_t vacc = vdup_n_f16(0.f); - - auto vec_a = reinterpret_cast(ina.ptr()); - const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) - { - const float16x4_t a0l = vld1_f16(vec_a); - - const float16x4_t b_col = - { - *(matrix_b + 0 * in_b_stride), - *(matrix_b + 1 * in_b_stride), - *(matrix_b + 2 * in_b_stride), - *(matrix_b + 3 * in_b_stride), - }; - - vacc = vadd_f16(vacc, vmul_f16(a0l, b_col)); - - matrix_b += 4 * in_b_stride; - } - - float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3); - - for(; vec_a < vec_a_end_addr; ++vec_a) - { - const float16_t a0 = *vec_a; - const float16_t b00 = *matrix_b; - - acc += b00 * a0; - - matrix_b += in_b_stride; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc *= static_cast(alpha); - } - - auto vec_out = reinterpret_cast(out.ptr()) + x; - - *(vec_out) = acc; - } + return (data.dt == DataType::F32); + }, + REGISTER_FP32_NEON(neon_fp32_gemm_matrix_mul) }, - ina, inb, out); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - -void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) -{ - const auto width_matrix_b = static_cast(dst->info()->dimension(0)); - const auto in_b_stride = static_cast(rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type())); - const auto num_elems_vec_a = static_cast(lhs->info()->dimension(0)); - - // The implementation computes 16 elements per iteration - const int window_start_x = 16 * info.thread_id; - const int window_step_x = 16 * info.num_threads; - // Make sure (window_end_x - window_start_x) is a multiple of window_step_x - const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; - - Window win_out(window); - win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); - win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); - - Window win_a(window); - win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); - - Window win_b; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - if(rhs->info()->num_dimensions() >= 3) - { - win_b = window; - } - win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); - win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); - - Iterator ina(lhs, win_a); - Iterator inb(rhs, win_b); - Iterator out(dst, win_out); - - const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); - - const float32x4_t alpha_f32 = vdupq_n_f32(alpha); - - execute_window_loop(win_out, [&](const Coordinates &) { - int x = window_start_x; - // Here we don't check for x lower equal than (window_end_x - window_step_x) because of - // window_end_x is computed above which may cause out-of-bound writes to the dst. - for(; x < (window_end_x - window_step_x); x += window_step_x) - { - if(x > width_matrix_b) - { - return; - } - - float32x4_t acc0 = vdupq_n_f32(0.f); - float32x4_t acc1 = vdupq_n_f32(0.f); - float32x4_t acc2 = vdupq_n_f32(0.f); - float32x4_t acc3 = vdupq_n_f32(0.f); - - auto vec_a = reinterpret_cast(ina.ptr()); - auto matrix_b = reinterpret_cast(inb.ptr()) + x; - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); -#endif /* __arm__ */ - - auto vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4);) - { - float32x2_t a0l = vld1_f32(vec_a); - - float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - - float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); - float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); - float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); - float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); -#endif /* __arm__ */ - - acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); - acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); - acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); - acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); - - acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); - acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); - acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); - acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); - - vec_a += 2; - matrix_b += 2 * in_b_stride; - - a0l = vld1_f32(vec_a); - - b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - - b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); - b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); - b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); - b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); - - acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); - acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); - acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); - acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); - - acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); - acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); - acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); - acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); - - vec_a += 2; - matrix_b += 2 * in_b_stride; - } - - for(; vec_a < vec_a_end_addr; ++vec_a) - { - const float a0 = *vec_a; - - const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); - const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); - const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); - const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); - - acc0 = vmlaq_n_f32(acc0, b00, a0); - acc1 = vmlaq_n_f32(acc1, b01, a0); - acc2 = vmlaq_n_f32(acc2, b02, a0); - acc3 = vmlaq_n_f32(acc3, b03, a0); - - matrix_b += in_b_stride; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc0 = vmulq_f32(acc0, alpha_f32); - acc1 = vmulq_f32(acc1, alpha_f32); - acc2 = vmulq_f32(acc2, alpha_f32); - acc3 = vmulq_f32(acc3, alpha_f32); - } - - const auto vec_out = reinterpret_cast(out.ptr()) + x; - - vst1q_f32(vec_out + 0, acc0); - vst1q_f32(vec_out + 4, acc1); - vst1q_f32(vec_out + 8, acc2); - vst1q_f32(vec_out + 12, acc3); - } - - // Left-over loop - for(; x < window_end_x; ++x) + "neon_fp16_gemm_matrix_mul", + [](const DataTypeISASelectorData & data) { - if(x > width_matrix_b) - { - return; - } - - float32x4_t vacc = vdupq_n_f32(0.f); - - auto vec_a = reinterpret_cast(ina.ptr()); - auto matrix_b = reinterpret_cast(inb.ptr()) + x; - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); -#endif /* __arm__ */ - - auto vec_a_end_addr = vec_a + num_elems_vec_a; - for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) - { - const float32x4_t a0l = vld1q_f32(vec_a); - - const float32x4_t b_col = - { - *(matrix_b + 0 * in_b_stride), - *(matrix_b + 1 * in_b_stride), - *(matrix_b + 2 * in_b_stride), - *(matrix_b + 3 * in_b_stride), - }; - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); -#endif /* __arm__ */ - - vacc = vmlaq_f32(vacc, b_col, a0l); - - matrix_b += 4 * in_b_stride; - } - - float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3); - - for(; vec_a < vec_a_end_addr; ++vec_a) - { - const float a0 = *vec_a; - - const float b00 = *matrix_b; - - acc += b00 * a0; - - matrix_b += in_b_stride; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc *= alpha; - } - - const auto vec_out = reinterpret_cast(out.ptr()) + x; - - *vec_out = acc; - } + return (data.dt == DataType::F16) && data.isa.fp16; + }, + REGISTER_FP16_NEON(neon_fp16_gemm_matrix_mul) }, - ina, inb, out); -} - -void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) -{ - ARM_COMPUTE_UNUSED(info); - const int out_width = static_cast(dst->info()->dimension(0)); - const int out_height = static_cast(dst->info()->dimension(1)); - const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()); - const size_t out_stride1 = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type()); - const size_t out_stride2 = out_stride1 * 2; - const size_t out_stride3 = out_stride1 * 3; - const int num_elems_matrix_b_x = rhs->info()->dimension(0); - - // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix - Window win_a(window); - win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); - - Window win_b; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - if(rhs->info()->num_dimensions() >= 3) - { - win_b = window; - } - // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the dst matrix - // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4 - win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride)); - win_b.set(Window::DimY, Window::Dimension(0, 0, 0)); - - Iterator ina(lhs, win_a); - Iterator inb(rhs, win_b); - Iterator out(dst, window); - - const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); - - const float32x4_t alpha_f32 = vdupq_n_f32(alpha); - - // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW - // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration - // All the values needed for computing a single 4x4 block will be read from consecutive memory positions - execute_window_loop(window, [&](const Coordinates & id) - { - auto mtx_a0 = reinterpret_cast(ina.ptr()); - auto mtx_b0 = reinterpret_cast(inb.ptr()); - auto mtx_b1 = mtx_b0 + in_b_stride; - - float32x4_t acc00 = vdupq_n_f32(0.f); - float32x4_t acc10 = vdupq_n_f32(0.f); - float32x4_t acc20 = vdupq_n_f32(0.f); - float32x4_t acc30 = vdupq_n_f32(0.f); - - float32x4_t acc01 = vdupq_n_f32(0.f); - float32x4_t acc11 = vdupq_n_f32(0.f); - float32x4_t acc21 = vdupq_n_f32(0.f); - float32x4_t acc31 = vdupq_n_f32(0.f); - -#if __arm__ - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_a0))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_b0))); - asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_b1))); -#endif /* __arm__ */ - - auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; - for(; mtx_b0 <= (mtx_b0_end_addr - 32);) - { - float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0); - float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1); - float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2); - float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3); - - float32x4_t b00 = vld1q_f32(mtx_b0); - float32x4_t b10 = vld1q_f32(mtx_b1); - float32x4_t b01 = vld1q_f32(mtx_b0 + 4); - float32x4_t b11 = vld1q_f32(mtx_b1 + 4); - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_a0))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b0))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b1))); -#endif /* __arm__ */ - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4); - float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5); - float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6); - float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b01, a4); - acc10 = vmlaq_f32(acc10, b01, a5); - acc20 = vmlaq_f32(acc20, b01, a6); - acc30 = vmlaq_f32(acc30, b01, a7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b11, a4); - acc11 = vmlaq_f32(acc11, b11, a5); - acc21 = vmlaq_f32(acc21, b11, a6); - acc31 = vmlaq_f32(acc31, b11, a7); - - mtx_a0 += 8; - mtx_b0 += 8; - mtx_b1 += 8; - - a0 = vld1q_dup_f32(mtx_a0 + 0); - a1 = vld1q_dup_f32(mtx_a0 + 1); - a2 = vld1q_dup_f32(mtx_a0 + 2); - a3 = vld1q_dup_f32(mtx_a0 + 3); - - b00 = vld1q_f32(mtx_b0); - b10 = vld1q_f32(mtx_b1); - b01 = vld1q_f32(mtx_b0 + 4); - b11 = vld1q_f32(mtx_b1 + 4); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - a4 = vld1q_dup_f32(mtx_a0 + 4); - a5 = vld1q_dup_f32(mtx_a0 + 5); - a6 = vld1q_dup_f32(mtx_a0 + 6); - a7 = vld1q_dup_f32(mtx_a0 + 7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b01, a4); - acc10 = vmlaq_f32(acc10, b01, a5); - acc20 = vmlaq_f32(acc20, b01, a6); - acc30 = vmlaq_f32(acc30, b01, a7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b11, a4); - acc11 = vmlaq_f32(acc11, b11, a5); - acc21 = vmlaq_f32(acc21, b11, a6); - acc31 = vmlaq_f32(acc31, b11, a7); - - mtx_a0 += 8; - mtx_b0 += 8; - mtx_b1 += 8; - - a0 = vld1q_dup_f32(mtx_a0 + 0); - a1 = vld1q_dup_f32(mtx_a0 + 1); - a2 = vld1q_dup_f32(mtx_a0 + 2); - a3 = vld1q_dup_f32(mtx_a0 + 3); - b00 = vld1q_f32(mtx_b0); - b10 = vld1q_f32(mtx_b1); - b01 = vld1q_f32(mtx_b0 + 4); - b11 = vld1q_f32(mtx_b1 + 4); - -#if __arm__ - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_a0))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b0))); - asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b1))); -#endif /* __arm__ */ - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - a4 = vld1q_dup_f32(mtx_a0 + 4); - a5 = vld1q_dup_f32(mtx_a0 + 5); - a6 = vld1q_dup_f32(mtx_a0 + 6); - a7 = vld1q_dup_f32(mtx_a0 + 7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b01, a4); - acc10 = vmlaq_f32(acc10, b01, a5); - acc20 = vmlaq_f32(acc20, b01, a6); - acc30 = vmlaq_f32(acc30, b01, a7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b11, a4); - acc11 = vmlaq_f32(acc11, b11, a5); - acc21 = vmlaq_f32(acc21, b11, a6); - acc31 = vmlaq_f32(acc31, b11, a7); - - mtx_a0 += 8; - mtx_b0 += 8; - mtx_b1 += 8; - - a0 = vld1q_dup_f32(mtx_a0 + 0); - a1 = vld1q_dup_f32(mtx_a0 + 1); - a2 = vld1q_dup_f32(mtx_a0 + 2); - a3 = vld1q_dup_f32(mtx_a0 + 3); - b00 = vld1q_f32(mtx_b0); - b10 = vld1q_f32(mtx_b1); - b01 = vld1q_f32(mtx_b0 + 4); - b11 = vld1q_f32(mtx_b1 + 4); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - a4 = vld1q_dup_f32(mtx_a0 + 4); - a5 = vld1q_dup_f32(mtx_a0 + 5); - a6 = vld1q_dup_f32(mtx_a0 + 6); - a7 = vld1q_dup_f32(mtx_a0 + 7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b01, a4); - acc10 = vmlaq_f32(acc10, b01, a5); - acc20 = vmlaq_f32(acc20, b01, a6); - acc30 = vmlaq_f32(acc30, b01, a7); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b11, a4); - acc11 = vmlaq_f32(acc11, b11, a5); - acc21 = vmlaq_f32(acc21, b11, a6); - acc31 = vmlaq_f32(acc31, b11, a7); - - mtx_a0 += 8; - mtx_b0 += 8; - mtx_b1 += 8; - } - - for(; mtx_b0 < mtx_b0_end_addr;) - { - float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0); - float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1); - float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2); - float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3); - float32x4_t b00 = vld1q_f32(mtx_b0); - float32x4_t b10 = vld1q_f32(mtx_b1); - -#if __arm__ - asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_a0))); - asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_b0))); - asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_b1))); -#endif /* __arm__ */ - // 4x4 block 0 - acc00 = vmlaq_f32(acc00, b00, a0); - acc10 = vmlaq_f32(acc10, b00, a1); - acc20 = vmlaq_f32(acc20, b00, a2); - acc30 = vmlaq_f32(acc30, b00, a3); - - // 4x4 block 1 - acc01 = vmlaq_f32(acc01, b10, a0); - acc11 = vmlaq_f32(acc11, b10, a1); - acc21 = vmlaq_f32(acc21, b10, a2); - acc31 = vmlaq_f32(acc31, b10, a3); - - mtx_a0 += 4; - mtx_b0 += 4; - mtx_b1 += 4; - } - - // Multiply by the weight of matrix product (alpha) - if(multiply_alpha) - { - acc00 = vmulq_f32(acc00, alpha_f32); - acc10 = vmulq_f32(acc10, alpha_f32); - acc20 = vmulq_f32(acc20, alpha_f32); - acc30 = vmulq_f32(acc30, alpha_f32); - acc01 = vmulq_f32(acc01, alpha_f32); - acc11 = vmulq_f32(acc11, alpha_f32); - acc21 = vmulq_f32(acc21, alpha_f32); - acc31 = vmulq_f32(acc31, alpha_f32); - } - - const auto mtx_out0 = reinterpret_cast(out.ptr()); - const auto mtx_out1 = mtx_out0 + 4; - - if(id.x() < (out_width - 8)) - { - vst1q_f32(mtx_out0, acc00); - vst1q_f32(mtx_out1, acc01); - if(id.y() + 1 < out_height) - { - vst1q_f32(mtx_out0 + out_stride1, acc10); - vst1q_f32(mtx_out1 + out_stride1, acc11); - if(id.y() + 2 < out_height) - { - vst1q_f32(mtx_out0 + out_stride2, acc20); - vst1q_f32(mtx_out1 + out_stride2, acc21); - if(id.y() + 3 < out_height) - { - vst1q_f32(mtx_out0 + out_stride3, acc30); - vst1q_f32(mtx_out1 + out_stride3, acc31); - } - } - } - } - else if(id.x() < (out_width - 4)) - { - vst1q_f32(mtx_out0, acc00); - if(id.y() + 1 < out_height) - { - vst1q_f32(mtx_out0 + out_stride1, acc10); - if(id.y() + 2 < out_height) - { - vst1q_f32(mtx_out0 + out_stride2, acc20); - if(id.y() + 3 < out_height) - { - vst1q_f32(mtx_out0 + out_stride3, acc30); - } - } - } - // Left-over columns - const int columns_left = out_width - id.x() - 4; - for(auto x = 0; x < columns_left; ++x) - { - *(mtx_out1 + x) = acc01[x]; - if(id.y() + 1 < out_height) - { - *(mtx_out1 + x + out_stride1) = acc11[x]; - if(id.y() + 2 < out_height) - { - *(mtx_out1 + x + out_stride2) = acc21[x]; - if(id.y() + 3 < out_height) - { - *(mtx_out1 + x + out_stride3) = acc31[x]; - } - } - } - } - } - else - { - // Left-over columns - const int columns_left = out_width - id.x(); - for(int x = 0; x < columns_left; ++x) - { - *(mtx_out0 + x) = acc00[x]; - if(id.y() + 1 < out_height) - { - *(mtx_out0 + x + out_stride1) = acc10[x]; - if(id.y() + 2 < out_height) - { - *(mtx_out0 + x + out_stride2) = acc20[x]; - if(id.y() + 3 < out_height) - { - *(mtx_out0 + x + out_stride3) = acc30[x]; - } - } - } - } - } - }, - ina, inb, out); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) -{ - ARM_COMPUTE_UNUSED(info); - const int out_width = static_cast(dst->info()->dimension(0)); - const int out_height = static_cast(dst->info()->dimension(1)); - const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()); - const size_t out_stride = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type()); - const int num_elems_matrix_b_x = rhs->info()->dimension(0); - - // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix - Window win_a(window); - win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); - win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); - - Window win_b; - // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 - // This scenario can happen when the the matrix multiplication is used to perform a convolution operation - if(rhs->info()->num_dimensions() >= 3) - { - win_b = window; - } - // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the dst matrix - win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride)); - win_b.set(Window::DimY, Window::Dimension(0, 1, 0)); - - Iterator ina(lhs, win_a); - Iterator inb(rhs, win_b); - Iterator out(dst, window); - - const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); - - const float16x8_t alpha_f16 = vdupq_n_f16(alpha); - - execute_window_loop(window, [&](const Coordinates & id) - { - const auto *mtx_a0 = reinterpret_cast(ina.ptr()); - const auto *mtx_b0 = reinterpret_cast(inb.ptr()); - auto *mtx_out = reinterpret_cast(out.ptr()); - float16x8x4_t c = - { - { - vdupq_n_f16(0.f), - vdupq_n_f16(0.f), - vdupq_n_f16(0.f), - vdupq_n_f16(0.f) - } - }; - - /* - This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values) - |a00 a01 a02 a03 | a04 a05 a06 a07| - |a10 a11 a12 a13 | a14 a15 a16 a17| - |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ... - |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ... - |a40 a41 a42 a43 | a44 a45 a46 a47| - |a50 a51 a52 a53 | a54 a55 a56 a57| - |a60 a61 a62 a63 | a64 a65 a66 a67| - |a70 a71 a72 a73 | a74 a75 a76 a77| - - After this operation, the dst matrix will have the following shape: [ height * 4, width / 4 ] - - B Matrix has been transposed as shown below - - |b00 b01 b02 b03 b04 b05 b06 b07| - |b10 b11 b12 b13 b14 b15 b16 b17| - |b20 b21 b22 b23 b24 b25 b26 b27| - |b30 b31 b32 b33 b34 b35 b36 b37| - -------------------> - - |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37| - - c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30 - c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31 - - The size of the dst tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size. - */ - const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; - - for(; mtx_b0 <= (mtx_b0_end_addr - 32);) - - { - const float16x8_t p00 = vld1q_f16(mtx_a0); - const float16x8_t p02 = vld1q_f16(mtx_a0 + 8); - - const float16x8_t q00 = vld1q_f16(mtx_b0); - const float16x8_t q02 = vld1q_f16(mtx_b0 + 8); - const float16x8_t q04 = vld1q_f16(mtx_b0 + 16); - const float16x8_t q06 = vld1q_f16(mtx_b0 + 24); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3))); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7))); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3))); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7))); - - mtx_a0 += 16; - mtx_b0 += 32; - } - - for(; mtx_b0 < mtx_b0_end_addr;) - - { - const float16x4_t p00 = vld1_f16(mtx_a0); - const float16x8_t q00 = vld1q_f16(mtx_b0); - - c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0))); - c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1))); - c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2))); - c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3))); - - mtx_a0 += 4; - mtx_b0 += 8; - } - - if(multiply_alpha) - { - c.val[0] = vmulq_f16(c.val[0], alpha_f16); - c.val[1] = vmulq_f16(c.val[1], alpha_f16); - c.val[2] = vmulq_f16(c.val[2], alpha_f16); - c.val[3] = vmulq_f16(c.val[3], alpha_f16); - } - - if(id.x() < (out_width - 8)) - { - vst1q_f16(mtx_out, c.val[0]); - if(id.y() + 1 < out_height) - { - vst1q_f16(mtx_out + 1 * out_stride, c.val[1]); - if(id.y() + 2 < out_height) - { - vst1q_f16(mtx_out + 2 * out_stride, c.val[2]); - if(id.y() + 3 < out_height) - { - vst1q_f16(mtx_out + 3 * out_stride, c.val[3]); - } - } - } - } - else - { - // Left-over columns - const int columns_left = out_width - id.x(); - for(int x = 0; x < columns_left; ++x) - { - *(mtx_out + x) = c.val[0][x]; - if(id.y() + 1 < out_height) - { - *(mtx_out + x + 1 * out_stride) = c.val[1][x]; - if(id.y() + 2 < out_height) - { - *(mtx_out + x + 2 * out_stride) = c.val[2][x]; - if(id.y() + 3 < out_height) - { - *(mtx_out + x + 3 * out_stride) = c.val[3][x]; - } - } - } - } - } - }, - ina, inb, out); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +}; inline Status validate_arguments(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info) { @@ -1083,6 +122,7 @@ inline Status validate_arguments(const ITensorInfo *lhs, const ITensorInfo *rhs, return Status{}; } + } // namespace void CpuGemmMatrixMultiplyKernel::configure(const ITensorInfo *lhs, const ITensorInfo *rhs, ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info) @@ -1120,26 +160,10 @@ void CpuGemmMatrixMultiplyKernel::configure(const ITensorInfo *lhs, const ITenso win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); } - switch(lhs->data_type()) - { - case DataType::F32: - { - _func = (is_dst_vector) ? vector_matrix_multiply_f32 : matrix_matrix_multiply_f32; - break; - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - _func = (is_dst_vector) ? vector_matrix_multiply_f16 : matrix_matrix_multiply_f16; - break; - } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - { - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - } + const auto uk = CpuGemmMatrixMultiplyKernel::get_implementation(DataTypeISASelectorData{ lhs->data_type(), CPUInfo::get().get_isa() }); + ARM_COMPUTE_ERROR_ON_NULLPTR(uk); + _func = uk->ukernel; + ICPPKernel::configure(win); } @@ -1162,13 +186,19 @@ void CpuGemmMatrixMultiplyKernel::run_op(ITensorPack &tensors, const Window &win const ITensor *rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1); ITensor *dst = tensors.get_tensor(TensorType::ACL_DST); - (*_func)(lhs, rhs, dst, window, info, _alpha); + const bool is_dst_vector = (dst->info()->dimension(1) == 1); + (*_func)(lhs, rhs, dst, window, info, _alpha, is_dst_vector); } const char *CpuGemmMatrixMultiplyKernel::name() const { return "CpuGemmMatrixMultiplyKernel"; } + +const std::vector &CpuGemmMatrixMultiplyKernel::get_available_kernels() +{ + return available_kernels; +} } // namespace kernels } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h index 9c3dc8b1a0..a7dfec87bd 100644 --- a/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h +++ b/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h @@ -41,7 +41,17 @@ namespace kernels */ class CpuGemmMatrixMultiplyKernel : public ICpuKernel { +private: + using GemmMatrixMulKernelPtr = std::add_pointer::type; + public: + struct GemmMatrixMulKernel + { + const char *name; + const DataTypeISASelectorPtr is_selected; + GemmMatrixMulKernelPtr ukernel; + }; + CpuGemmMatrixMultiplyKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmMatrixMultiplyKernel); /** Initialise the kernel's input and output. @@ -70,6 +80,8 @@ public: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + static const std::vector &get_available_kernels(); + private: /** Common signature for all the matrix multiply functions * @@ -80,10 +92,10 @@ private: * @param[in] info Thread info metadata. * @param[in] alpha Weight of the matrix product. */ - using GemmFunctionPtr = void(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha); + /** Matrix multiply function to use for the particular tensor types passed to configure() */ - GemmFunctionPtr *_func{ nullptr }; - float _alpha{ 1.f }; + GemmMatrixMulKernelPtr _func{ nullptr }; + float _alpha{ 1.f }; }; } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp b/src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp new file mode 100644 index 0000000000..3098416ada --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#include "src/cpu/kernels/gemm_matrix_add/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void neon_fp16_gemm_matrix_add(const ITensor *src, ITensor *dst, const Window &window, float beta) +{ + return matrix_addition_f16(src, dst, window, beta); +} +} // namespace cpu +} // namespace arm_compute +#endif //__ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/src/cpu/kernels/gemm_matrix_add/generic/neon/fp32.cpp b/src/cpu/kernels/gemm_matrix_add/generic/neon/fp32.cpp new file mode 100644 index 0000000000..fa3f4de11f --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_add/generic/neon/fp32.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/gemm_matrix_add/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void neon_fp32_gemm_matrix_add(const ITensor *src, ITensor *dst, const Window &window, float beta) +{ + return matrix_addition_f32(src, dst, window, beta); +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp b/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp new file mode 100644 index 0000000000..675ed1bbcb --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2016-2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/gemm_matrix_add/generic/neon/impl.h" +#include + +namespace arm_compute +{ +namespace cpu +{ +void matrix_addition_f32(const ITensor *src, ITensor *dst, const Window &window, float beta) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); + const float32x4_t beta_f32 = vdupq_n_f32(beta); + + constexpr int window_step_x = 16; + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + + Window win = window.collapse_if_possible(window, Window::DimZ); + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator in(src, win); + Iterator out(dst, win); + + execute_window_loop(win, [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast(in.ptr()); + const auto out_ptr = reinterpret_cast(out.ptr()); + + int x = window_start_x; + for(; x < (window_end_x - window_step_x); x += window_step_x) + { + float32x4x4_t alpha_ab = vld4q_f32(out_ptr + x); + const float32x4x4_t c = vld4q_f32(in_ptr + x); + + // Multiply matrix C by its weight and accumulate + alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32); + alpha_ab.val[1] = vmlaq_f32(alpha_ab.val[1], c.val[1], beta_f32); + alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32); + alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32); + + vst4q_f32(out_ptr + x, alpha_ab); + } + + // Left-over loop + for(; x < window_end_x; ++x) + { + *(out_ptr + x) += *(in_ptr + x) * beta; + } + }, + in, out); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void matrix_addition_f16(const ITensor *src, ITensor *dst, const Window &window, float beta) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); + const float16x8_t beta_f16 = vdupq_n_f16(beta); + + constexpr int window_step_x = 16; + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + + Window win = window.collapse_if_possible(window, Window::DimZ); + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator in(src, win); + Iterator out(dst, win); + + execute_window_loop(win, [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast(in.ptr()); + const auto out_ptr = reinterpret_cast(out.ptr()); + + int x = window_start_x; + for(; x < (window_end_x - window_step_x); x += window_step_x) + { + float16x8x2_t alpha_ab = vld2q_f16(out_ptr + x); + const float16x8x2_t c = vld2q_f16(in_ptr + x); + // Multiply matrix C by its weight and accumulate + alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16)); + alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16)); + + vst2q_f16(out_ptr + x, alpha_ab); + } + + // Left-over loop + for(; x < window_end_x; ++x) + { + *(out_ptr + x) += *(in_ptr + x) * static_cast(beta); + } + }, + in, out); +} +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.h b/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.h new file mode 100644 index 0000000000..ff35f28b11 --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_add/generic/neon/impl.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_KERNELS_GEMMMATRIXADD_IMPL_H +#define SRC_CORE_KERNELS_GEMMMATRIXADD_IMPL_H + +#include "arm_compute/core/Helpers.h" + +namespace arm_compute +{ +namespace cpu +{ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void matrix_addition_f16(const ITensor *src, ITensor *dst, const Window &window, float beta); +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +void matrix_addition_f32(const ITensor *src, ITensor *dst, const Window &window, float beta); + +} // namespace cpu +} // namespace arm_compute +#endif //define SRC_CORE_KERNELS_GEMMMATRIXADD_IMPL_H diff --git a/src/cpu/kernels/gemm_matrix_add/list.h b/src/cpu/kernels/gemm_matrix_add/list.h new file mode 100644 index 0000000000..415b4c8321 --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_add/list.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_NEON_KERNELS_GEMMMATRIXADD_LIST_H +#define SRC_CORE_NEON_KERNELS_GEMMMATRIXADD_LIST_H +namespace arm_compute +{ +namespace cpu +{ +#define DECLARE_GEMMMATRIXADD_KERNEL(func_name) \ + void func_name(const ITensor *src, ITensor *dst, const Window &window, float beta) +DECLARE_GEMMMATRIXADD_KERNEL(neon_fp32_gemm_matrix_add); +DECLARE_GEMMMATRIXADD_KERNEL(neon_fp16_gemm_matrix_add); +#undef DECLARE_GEMMMATRIXADD_KERNEL +} // namespace cpu +} // namespace arm_compute +#endif //SRC_CORE_NEON_KERNELS_GEMMMATRIXADD_LIST_H diff --git a/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp b/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp new file mode 100644 index 0000000000..1bd5a57fab --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#include "src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void neon_fp16_gemm_matrix_mul(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha, const bool is_dst_vector) +{ + return (is_dst_vector) ? vector_matrix_multiply_f16(lhs, rhs, dst, window, info, alpha) : matrix_matrix_multiply_f16(lhs, rhs, dst, window, info, alpha); +} +} // namespce cpu +} // namespace arm_compute +#endif //__ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp32.cpp b/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp32.cpp new file mode 100644 index 0000000000..9c1f6f3c0f --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp32.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h" + +namespace arm_compute +{ +namespace cpu +{ +void neon_fp32_gemm_matrix_mul(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha, const bool is_dst_vector) +{ + return (is_dst_vector) ? vector_matrix_multiply_f32(lhs, rhs, dst, window, info, alpha) : matrix_matrix_multiply_f32(lhs, rhs, dst, window, info, alpha); +} +} // namespce cpu +} // namespace arm_compute \ No newline at end of file diff --git a/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp new file mode 100644 index 0000000000..37c296fe5e --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp @@ -0,0 +1,1016 @@ +/* + * Copyright (c) 2017-2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h" +#include "src/core/utils/helpers/float_ops.h" + +#include + +namespace arm_compute +{ +namespace cpu +{ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void vector_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +{ + const auto width_matrix_b = static_cast(dst->info()->dimension(0)); + const auto in_b_stride = static_cast(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size()); + const auto num_elems_vec_a = static_cast(lhs->info()->dimension(0)); + + // The implementation computes 32 elements per iteration + const int window_start_x = 32 * info.thread_id; + const int window_step_x = 32 * info.num_threads; + const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; + ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x"); + + Window win_out(window); + win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(rhs->info()->num_dimensions() >= 3) + { + win_b = window; + } + win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); + win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Iterator ina(lhs, win_a); + Iterator inb(rhs, win_b); + Iterator out(dst, win_out); + + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float16x8_t alpha_f16 = vdupq_n_f16(alpha); + + execute_window_loop(win_out, [&](const Coordinates &) + { + int x = window_start_x; + // Here we don't check for x lower equal than (window_end_x - window_step_x) because of + // window_end_x is computed above which may cause out-of-bound writes to the dst. + for(; x < (window_end_x - window_step_x); x += window_step_x) + { + if(x > width_matrix_b) + { + return; + } + + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + + float16x8_t acc0 = vdupq_n_f16(0.f); + float16x8_t acc1 = vdupq_n_f16(0.f); + float16x8_t acc2 = vdupq_n_f16(0.f); + float16x8_t acc3 = vdupq_n_f16(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4);) + { + const float16x4_t a0l = vld1_f16(vec_a); + + float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1)); + + matrix_b += 2 * in_b_stride; + + b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3)); + + vec_a += 4; + matrix_b += 2 * in_b_stride; + } + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float16_t a0 = *vec_a; + const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0)); + acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0)); + acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0)); + acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0)); + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc0 = vmulq_f16(acc0, alpha_f16); + acc1 = vmulq_f16(acc1, alpha_f16); + acc2 = vmulq_f16(acc2, alpha_f16); + acc3 = vmulq_f16(acc3, alpha_f16); + } + + auto vec_out = reinterpret_cast(out.ptr()) + x; + + vst1q_f16(vec_out + 0, acc0); + vst1q_f16(vec_out + 8, acc1); + vst1q_f16(vec_out + 16, acc2); + vst1q_f16(vec_out + 24, acc3); + } + + for(; x < window_end_x; ++x) + { + if(x > width_matrix_b) + { + return; + } + + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + + float16x4_t vacc = vdup_n_f16(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) + { + const float16x4_t a0l = vld1_f16(vec_a); + + const float16x4_t b_col = + { + *(matrix_b + 0 * in_b_stride), + *(matrix_b + 1 * in_b_stride), + *(matrix_b + 2 * in_b_stride), + *(matrix_b + 3 * in_b_stride), + }; + + vacc = vadd_f16(vacc, vmul_f16(a0l, b_col)); + + matrix_b += 4 * in_b_stride; + } + + float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3); + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float16_t a0 = *vec_a; + const float16_t b00 = *matrix_b; + + acc += b00 * a0; + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc *= static_cast(alpha); + } + + auto vec_out = reinterpret_cast(out.ptr()) + x; + + *(vec_out) = acc; + } + }, + ina, inb, out); +} +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +{ + const auto width_matrix_b = static_cast(dst->info()->dimension(0)); + const auto in_b_stride = static_cast(rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type())); + const auto num_elems_vec_a = static_cast(lhs->info()->dimension(0)); + + // The implementation computes 16 elements per iteration + const int window_start_x = 16 * info.thread_id; + const int window_step_x = 16 * info.num_threads; + // Make sure (window_end_x - window_start_x) is a multiple of window_step_x + const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; + + Window win_out(window); + win_out.set(Window::DimX, Window::Dimension(0, 1, 1)); + win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(rhs->info()->num_dimensions() >= 3) + { + win_b = window; + } + win_b.set(Window::DimX, Window::Dimension(0, 1, 1)); + win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Iterator ina(lhs, win_a); + Iterator inb(rhs, win_b); + Iterator out(dst, win_out); + + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float32x4_t alpha_f32 = vdupq_n_f32(alpha); + + execute_window_loop(win_out, [&](const Coordinates &) + { + int x = window_start_x; + // Here we don't check for x lower equal than (window_end_x - window_step_x) because of + // window_end_x is computed above which may cause out-of-bound writes to the dst. + for(; x < (window_end_x - window_step_x); x += window_step_x) + { + if(x > width_matrix_b) + { + return; + } + + float32x4_t acc0 = vdupq_n_f32(0.f); + float32x4_t acc1 = vdupq_n_f32(0.f); + float32x4_t acc2 = vdupq_n_f32(0.f); + float32x4_t acc3 = vdupq_n_f32(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); +#endif /* __arm__ */ + + auto vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4);) + { + float32x2_t a0l = vld1_f32(vec_a); + + float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + + float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); + float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); + float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); + float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); +#endif /* __arm__ */ + + acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); + acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); + acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); + acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); + + acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); + acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); + acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); + acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); + + vec_a += 2; + matrix_b += 2 * in_b_stride; + + a0l = vld1_f32(vec_a); + + b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + + b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride); + b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride); + b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride); + b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride); + + acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0); + acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0); + acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0); + acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0); + + acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1); + acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1); + acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1); + acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1); + + vec_a += 2; + matrix_b += 2 * in_b_stride; + } + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float a0 = *vec_a; + + const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride); + const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride); + const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride); + const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride); + + acc0 = vmlaq_n_f32(acc0, b00, a0); + acc1 = vmlaq_n_f32(acc1, b01, a0); + acc2 = vmlaq_n_f32(acc2, b02, a0); + acc3 = vmlaq_n_f32(acc3, b03, a0); + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc0 = vmulq_f32(acc0, alpha_f32); + acc1 = vmulq_f32(acc1, alpha_f32); + acc2 = vmulq_f32(acc2, alpha_f32); + acc3 = vmulq_f32(acc3, alpha_f32); + } + + const auto vec_out = reinterpret_cast(out.ptr()) + x; + + vst1q_f32(vec_out + 0, acc0); + vst1q_f32(vec_out + 4, acc1); + vst1q_f32(vec_out + 8, acc2); + vst1q_f32(vec_out + 12, acc3); + } + + // Left-over loop + for(; x < window_end_x; ++x) + { + if(x > width_matrix_b) + { + return; + } + + float32x4_t vacc = vdupq_n_f32(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(inb.ptr()) + x; + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(matrix_b + in_b_stride))); +#endif /* __arm__ */ + + auto vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4) + { + const float32x4_t a0l = vld1q_f32(vec_a); + + const float32x4_t b_col = + { + *(matrix_b + 0 * in_b_stride), + *(matrix_b + 1 * in_b_stride), + *(matrix_b + 2 * in_b_stride), + *(matrix_b + 3 * in_b_stride), + }; + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(vec_a))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 1 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 2 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 3 * in_b_stride))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(matrix_b + 4 * in_b_stride))); +#endif /* __arm__ */ + + vacc = vmlaq_f32(vacc, b_col, a0l); + + matrix_b += 4 * in_b_stride; + } + + float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3); + + for(; vec_a < vec_a_end_addr; ++vec_a) + { + const float a0 = *vec_a; + + const float b00 = *matrix_b; + + acc += b00 * a0; + + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc *= alpha; + } + + const auto vec_out = reinterpret_cast(out.ptr()) + x; + + *vec_out = acc; + } + }, + ina, inb, out); +} + +void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +{ + ARM_COMPUTE_UNUSED(info); + const int out_width = static_cast(dst->info()->dimension(0)); + const int out_height = static_cast(dst->info()->dimension(1)); + const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()); + const size_t out_stride1 = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type()); + const size_t out_stride2 = out_stride1 * 2; + const size_t out_stride3 = out_stride1 * 3; + const int num_elems_matrix_b_x = rhs->info()->dimension(0); + + // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(rhs->info()->num_dimensions() >= 3) + { + win_b = window; + } + // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the dst matrix + // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4 + win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride)); + win_b.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Iterator ina(lhs, win_a); + Iterator inb(rhs, win_b); + Iterator out(dst, window); + + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float32x4_t alpha_f32 = vdupq_n_f32(alpha); + + // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW + // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration + // All the values needed for computing a single 4x4 block will be read from consecutive memory positions + execute_window_loop(window, [&](const Coordinates & id) + { + auto mtx_a0 = reinterpret_cast(ina.ptr()); + auto mtx_b0 = reinterpret_cast(inb.ptr()); + auto mtx_b1 = mtx_b0 + in_b_stride; + + float32x4_t acc00 = vdupq_n_f32(0.f); + float32x4_t acc10 = vdupq_n_f32(0.f); + float32x4_t acc20 = vdupq_n_f32(0.f); + float32x4_t acc30 = vdupq_n_f32(0.f); + + float32x4_t acc01 = vdupq_n_f32(0.f); + float32x4_t acc11 = vdupq_n_f32(0.f); + float32x4_t acc21 = vdupq_n_f32(0.f); + float32x4_t acc31 = vdupq_n_f32(0.f); + +#if __arm__ + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_a0))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_b0))); + asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast(mtx_b1))); +#endif /* __arm__ */ + + auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; + for(; mtx_b0 <= (mtx_b0_end_addr - 32);) + { + float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0); + float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1); + float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2); + float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3); + + float32x4_t b00 = vld1q_f32(mtx_b0); + float32x4_t b10 = vld1q_f32(mtx_b1); + float32x4_t b01 = vld1q_f32(mtx_b0 + 4); + float32x4_t b11 = vld1q_f32(mtx_b1 + 4); + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_a0))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b0))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b1))); +#endif /* __arm__ */ + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4); + float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5); + float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6); + float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b01, a4); + acc10 = vmlaq_f32(acc10, b01, a5); + acc20 = vmlaq_f32(acc20, b01, a6); + acc30 = vmlaq_f32(acc30, b01, a7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b11, a4); + acc11 = vmlaq_f32(acc11, b11, a5); + acc21 = vmlaq_f32(acc21, b11, a6); + acc31 = vmlaq_f32(acc31, b11, a7); + + mtx_a0 += 8; + mtx_b0 += 8; + mtx_b1 += 8; + + a0 = vld1q_dup_f32(mtx_a0 + 0); + a1 = vld1q_dup_f32(mtx_a0 + 1); + a2 = vld1q_dup_f32(mtx_a0 + 2); + a3 = vld1q_dup_f32(mtx_a0 + 3); + + b00 = vld1q_f32(mtx_b0); + b10 = vld1q_f32(mtx_b1); + b01 = vld1q_f32(mtx_b0 + 4); + b11 = vld1q_f32(mtx_b1 + 4); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + a4 = vld1q_dup_f32(mtx_a0 + 4); + a5 = vld1q_dup_f32(mtx_a0 + 5); + a6 = vld1q_dup_f32(mtx_a0 + 6); + a7 = vld1q_dup_f32(mtx_a0 + 7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b01, a4); + acc10 = vmlaq_f32(acc10, b01, a5); + acc20 = vmlaq_f32(acc20, b01, a6); + acc30 = vmlaq_f32(acc30, b01, a7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b11, a4); + acc11 = vmlaq_f32(acc11, b11, a5); + acc21 = vmlaq_f32(acc21, b11, a6); + acc31 = vmlaq_f32(acc31, b11, a7); + + mtx_a0 += 8; + mtx_b0 += 8; + mtx_b1 += 8; + + a0 = vld1q_dup_f32(mtx_a0 + 0); + a1 = vld1q_dup_f32(mtx_a0 + 1); + a2 = vld1q_dup_f32(mtx_a0 + 2); + a3 = vld1q_dup_f32(mtx_a0 + 3); + b00 = vld1q_f32(mtx_b0); + b10 = vld1q_f32(mtx_b1); + b01 = vld1q_f32(mtx_b0 + 4); + b11 = vld1q_f32(mtx_b1 + 4); + +#if __arm__ + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_a0))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b0))); + asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast(mtx_b1))); +#endif /* __arm__ */ + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + a4 = vld1q_dup_f32(mtx_a0 + 4); + a5 = vld1q_dup_f32(mtx_a0 + 5); + a6 = vld1q_dup_f32(mtx_a0 + 6); + a7 = vld1q_dup_f32(mtx_a0 + 7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b01, a4); + acc10 = vmlaq_f32(acc10, b01, a5); + acc20 = vmlaq_f32(acc20, b01, a6); + acc30 = vmlaq_f32(acc30, b01, a7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b11, a4); + acc11 = vmlaq_f32(acc11, b11, a5); + acc21 = vmlaq_f32(acc21, b11, a6); + acc31 = vmlaq_f32(acc31, b11, a7); + + mtx_a0 += 8; + mtx_b0 += 8; + mtx_b1 += 8; + + a0 = vld1q_dup_f32(mtx_a0 + 0); + a1 = vld1q_dup_f32(mtx_a0 + 1); + a2 = vld1q_dup_f32(mtx_a0 + 2); + a3 = vld1q_dup_f32(mtx_a0 + 3); + b00 = vld1q_f32(mtx_b0); + b10 = vld1q_f32(mtx_b1); + b01 = vld1q_f32(mtx_b0 + 4); + b11 = vld1q_f32(mtx_b1 + 4); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + a4 = vld1q_dup_f32(mtx_a0 + 4); + a5 = vld1q_dup_f32(mtx_a0 + 5); + a6 = vld1q_dup_f32(mtx_a0 + 6); + a7 = vld1q_dup_f32(mtx_a0 + 7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b01, a4); + acc10 = vmlaq_f32(acc10, b01, a5); + acc20 = vmlaq_f32(acc20, b01, a6); + acc30 = vmlaq_f32(acc30, b01, a7); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b11, a4); + acc11 = vmlaq_f32(acc11, b11, a5); + acc21 = vmlaq_f32(acc21, b11, a6); + acc31 = vmlaq_f32(acc31, b11, a7); + + mtx_a0 += 8; + mtx_b0 += 8; + mtx_b1 += 8; + } + + for(; mtx_b0 < mtx_b0_end_addr;) + { + float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0); + float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1); + float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2); + float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3); + float32x4_t b00 = vld1q_f32(mtx_b0); + float32x4_t b10 = vld1q_f32(mtx_b1); + +#if __arm__ + asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_a0))); + asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_b0))); + asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast(mtx_b1))); +#endif /* __arm__ */ + // 4x4 block 0 + acc00 = vmlaq_f32(acc00, b00, a0); + acc10 = vmlaq_f32(acc10, b00, a1); + acc20 = vmlaq_f32(acc20, b00, a2); + acc30 = vmlaq_f32(acc30, b00, a3); + + // 4x4 block 1 + acc01 = vmlaq_f32(acc01, b10, a0); + acc11 = vmlaq_f32(acc11, b10, a1); + acc21 = vmlaq_f32(acc21, b10, a2); + acc31 = vmlaq_f32(acc31, b10, a3); + + mtx_a0 += 4; + mtx_b0 += 4; + mtx_b1 += 4; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc00 = vmulq_f32(acc00, alpha_f32); + acc10 = vmulq_f32(acc10, alpha_f32); + acc20 = vmulq_f32(acc20, alpha_f32); + acc30 = vmulq_f32(acc30, alpha_f32); + acc01 = vmulq_f32(acc01, alpha_f32); + acc11 = vmulq_f32(acc11, alpha_f32); + acc21 = vmulq_f32(acc21, alpha_f32); + acc31 = vmulq_f32(acc31, alpha_f32); + } + + const auto mtx_out0 = reinterpret_cast(out.ptr()); + const auto mtx_out1 = mtx_out0 + 4; + + if(id.x() < (out_width - 8)) + { + vst1q_f32(mtx_out0, acc00); + vst1q_f32(mtx_out1, acc01); + if(id.y() + 1 < out_height) + { + vst1q_f32(mtx_out0 + out_stride1, acc10); + vst1q_f32(mtx_out1 + out_stride1, acc11); + if(id.y() + 2 < out_height) + { + vst1q_f32(mtx_out0 + out_stride2, acc20); + vst1q_f32(mtx_out1 + out_stride2, acc21); + if(id.y() + 3 < out_height) + { + vst1q_f32(mtx_out0 + out_stride3, acc30); + vst1q_f32(mtx_out1 + out_stride3, acc31); + } + } + } + } + else if(id.x() < (out_width - 4)) + { + vst1q_f32(mtx_out0, acc00); + if(id.y() + 1 < out_height) + { + vst1q_f32(mtx_out0 + out_stride1, acc10); + if(id.y() + 2 < out_height) + { + vst1q_f32(mtx_out0 + out_stride2, acc20); + if(id.y() + 3 < out_height) + { + vst1q_f32(mtx_out0 + out_stride3, acc30); + } + } + } + // Left-over columns + const int columns_left = out_width - id.x() - 4; + for(auto x = 0; x < columns_left; ++x) + { + *(mtx_out1 + x) = acc01[x]; + if(id.y() + 1 < out_height) + { + *(mtx_out1 + x + out_stride1) = acc11[x]; + if(id.y() + 2 < out_height) + { + *(mtx_out1 + x + out_stride2) = acc21[x]; + if(id.y() + 3 < out_height) + { + *(mtx_out1 + x + out_stride3) = acc31[x]; + } + } + } + } + } + else + { + // Left-over columns + const int columns_left = out_width - id.x(); + for(int x = 0; x < columns_left; ++x) + { + *(mtx_out0 + x) = acc00[x]; + if(id.y() + 1 < out_height) + { + *(mtx_out0 + x + out_stride1) = acc10[x]; + if(id.y() + 2 < out_height) + { + *(mtx_out0 + x + out_stride2) = acc20[x]; + if(id.y() + 3 < out_height) + { + *(mtx_out0 + x + out_stride3) = acc30[x]; + } + } + } + } + } + }, + ina, inb, out); +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha) +{ + ARM_COMPUTE_UNUSED(info); + const int out_width = static_cast(dst->info()->dimension(0)); + const int out_height = static_cast(dst->info()->dimension(1)); + const size_t in_b_stride = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()); + const size_t out_stride = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type()); + const int num_elems_matrix_b_x = rhs->info()->dimension(0); + + // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(rhs->info()->num_dimensions() >= 3) + { + win_b = window; + } + // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the dst matrix + win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride)); + win_b.set(Window::DimY, Window::Dimension(0, 1, 0)); + + Iterator ina(lhs, win_a); + Iterator inb(rhs, win_b); + Iterator out(dst, window); + + const bool multiply_alpha = !(helpers::float_ops::is_one(alpha)); + + const float16x8_t alpha_f16 = vdupq_n_f16(alpha); + + execute_window_loop(window, [&](const Coordinates & id) + { + const auto *mtx_a0 = reinterpret_cast(ina.ptr()); + const auto *mtx_b0 = reinterpret_cast(inb.ptr()); + auto *mtx_out = reinterpret_cast(out.ptr()); + float16x8x4_t c = + { + { + vdupq_n_f16(0.f), + vdupq_n_f16(0.f), + vdupq_n_f16(0.f), + vdupq_n_f16(0.f) + } + }; + + /* + This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values) + |a00 a01 a02 a03 | a04 a05 a06 a07| + |a10 a11 a12 a13 | a14 a15 a16 a17| + |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ... + |a30 a31 a32 a33 | a34 a35 a36 a37| | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ... + |a40 a41 a42 a43 | a44 a45 a46 a47| + |a50 a51 a52 a53 | a54 a55 a56 a57| + |a60 a61 a62 a63 | a64 a65 a66 a67| + |a70 a71 a72 a73 | a74 a75 a76 a77| + + After this operation, the dst matrix will have the following shape: [ height * 4, width / 4 ] + + B Matrix has been transposed as shown below + + |b00 b01 b02 b03 b04 b05 b06 b07| + |b10 b11 b12 b13 b14 b15 b16 b17| + |b20 b21 b22 b23 b24 b25 b26 b27| + |b30 b31 b32 b33 b34 b35 b36 b37| + -------------------> + + |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37| + + c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30 + c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31 + + The size of the dst tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size. + */ + const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; + + for(; mtx_b0 <= (mtx_b0_end_addr - 32);) + + { + const float16x8_t p00 = vld1q_f16(mtx_a0); + const float16x8_t p02 = vld1q_f16(mtx_a0 + 8); + + const float16x8_t q00 = vld1q_f16(mtx_b0); + const float16x8_t q02 = vld1q_f16(mtx_b0 + 8); + const float16x8_t q04 = vld1q_f16(mtx_b0 + 16); + const float16x8_t q06 = vld1q_f16(mtx_b0 + 24); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3))); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7))); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3))); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7))); + + mtx_a0 += 16; + mtx_b0 += 32; + } + + for(; mtx_b0 < mtx_b0_end_addr;) + + { + const float16x4_t p00 = vld1_f16(mtx_a0); + const float16x8_t q00 = vld1q_f16(mtx_b0); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3))); + + mtx_a0 += 4; + mtx_b0 += 8; + } + + if(multiply_alpha) + { + c.val[0] = vmulq_f16(c.val[0], alpha_f16); + c.val[1] = vmulq_f16(c.val[1], alpha_f16); + c.val[2] = vmulq_f16(c.val[2], alpha_f16); + c.val[3] = vmulq_f16(c.val[3], alpha_f16); + } + + if(id.x() < (out_width - 8)) + { + vst1q_f16(mtx_out, c.val[0]); + if(id.y() + 1 < out_height) + { + vst1q_f16(mtx_out + 1 * out_stride, c.val[1]); + if(id.y() + 2 < out_height) + { + vst1q_f16(mtx_out + 2 * out_stride, c.val[2]); + if(id.y() + 3 < out_height) + { + vst1q_f16(mtx_out + 3 * out_stride, c.val[3]); + } + } + } + } + else + { + // Left-over columns + const int columns_left = out_width - id.x(); + for(int x = 0; x < columns_left; ++x) + { + *(mtx_out + x) = c.val[0][x]; + if(id.y() + 1 < out_height) + { + *(mtx_out + x + 1 * out_stride) = c.val[1][x]; + if(id.y() + 2 < out_height) + { + *(mtx_out + x + 2 * out_stride) = c.val[2][x]; + if(id.y() + 3 < out_height) + { + *(mtx_out + x + 3 * out_stride) = c.val[3][x]; + } + } + } + } + } + }, + ina, inb, out); +} +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + +} // namespace cpu + +} // namespace arm_compute diff --git a/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h new file mode 100644 index 0000000000..6bf865a624 --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_KERNELS_GEMMMATRIXMUL_IMPL_H +#define SRC_CORE_KERNELS_GEMMMATRIXMUL_IMPL_H +#include "arm_compute/core/Helpers.h" +#include "src/core/CPP/Validate.h" + +namespace arm_compute +{ +namespace cpu +{ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +void vector_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha); + +void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha); + +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha); + +void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha); + +} // namespace cpu +} // namespace arm_compute +#endif //define SRC_CORE_KERNELS_GEMMMATRIXMUL_IMPL_H diff --git a/src/cpu/kernels/gemm_matrix_mul/list.h b/src/cpu/kernels/gemm_matrix_mul/list.h new file mode 100644 index 0000000000..9cdb58ae06 --- /dev/null +++ b/src/cpu/kernels/gemm_matrix_mul/list.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CORE_NEON_KERNELS_GEMMMATRIXMUL_LIST_H +#define SRC_CORE_NEON_KERNELS_GEMMMATRIXMUL_LIST_H +namespace arm_compute +{ +namespace cpu +{ +#define DECLARE_GEMMMATRIXMUL_KERNEL(func_name) \ + void func_name(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha, const bool is_dst_vector) +DECLARE_GEMMMATRIXMUL_KERNEL(neon_fp32_gemm_matrix_mul); +DECLARE_GEMMMATRIXMUL_KERNEL(neon_fp16_gemm_matrix_mul); +#undef DECLARE_GEMMMATRIXMUL_KERNEL +} // namespace cpu +} // namespace arm_compute +#endif //SRC_CORE_NEON_KERNELS_GEMMMATRIXMUL_LIST_H diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp index 079047328a..5f6e75b705 100644 --- a/tests/validation/NEON/GEMM.cpp +++ b/tests/validation/NEON/GEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -200,6 +200,40 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL) } } +TEST_SUITE(KERNEL_SELECTION) +DATA_TEST_CASE(KernelSelection_mul_and_add, framework::DatasetMode::ALL, + combine(framework::dataset::make("CpuExt", std::string("NEON")), + framework::dataset::make("DataType", { DataType::F32, + DataType::F16 + })), + cpu_ext, data_type) +{ + using namespace cpu::kernels; + + cpuinfo::CpuIsaInfo cpu_isa{}; + cpu_isa.neon = (cpu_ext == "NEON"); + cpu_isa.fp16 = (data_type == DataType::F16); + + const auto *selected_impl_mul = CpuGemmMatrixMultiplyKernel::get_implementation(DataTypeISASelectorData{ data_type, cpu_isa }, cpu::KernelSelectionType::Preferred); + + ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl_mul); + + std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_gemm_matrix_mul"; + std::string actual = selected_impl_mul->name; + + ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS); + + const auto *selected_impl_add = CpuGemmMatrixAdditionKernel::get_implementation(DataTypeISASelectorData{ data_type, cpu_isa }, cpu::KernelSelectionType::Preferred); + + ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl_add); + + expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_gemm_matrix_add"; + actual = selected_impl_add->name; + + ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS); +} +TEST_SUITE_END() // KERNEL_SELECTION + TEST_SUITE(TRANSPOSE_1XW) using CpuGemmTranspose1xW = NESynthetizeFunctionWithZeroConstantKernelBorder; DATA_TEST_CASE(ValidateZeroPadding, framework::DatasetMode::ALL, zip( -- cgit v1.2.1