diff options
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp | 123 |
1 files changed, 24 insertions, 99 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp index 81376fb029..6399ebbef4 100644 --- a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp +++ b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,11 +28,10 @@ #include "arm_compute/core/Validate.h" #include "src/core/CPP/Validate.h" #include "src/core/NEON/NEFixedPoint.h" +#include "src/core/common/Registrars.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" - -#include <arm_neon.h> - +#include "src/cpu/kernels/gemm_matrix_add/list.h" namespace arm_compute { namespace cpu @@ -41,93 +40,26 @@ namespace kernels { namespace { -void matrix_addition_f32(const ITensor *src, ITensor *dst, const Window &window, float beta) +static const std::vector<CpuGemmMatrixAdditionKernel::GemmMatrixAddKernel> available_kernels = { - ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); - const float32x4_t beta_f32 = vdupq_n_f32(beta); - - constexpr int window_step_x = 16; - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - Window win = window.collapse_if_possible(window, Window::DimZ); - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator in(src, win); - Iterator out(dst, win); - - execute_window_loop(win, [&](const Coordinates &) { - const auto in_ptr = reinterpret_cast<const float *>(in.ptr()); - const auto out_ptr = reinterpret_cast<float *>(out.ptr()); - - int x = window_start_x; - for(; x < (window_end_x - window_step_x); x += window_step_x) + "neon_fp32_gemm_matrix_add", + [](const DataTypeISASelectorData & data) { - float32x4x4_t alpha_ab = vld4q_f32(out_ptr + x); - const float32x4x4_t c = vld4q_f32(in_ptr + x); - - // Multiply matrix C by its weight and accumulate - alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32); - alpha_ab.val[1] = vmlaq_f32(alpha_ab.val[1], c.val[1], beta_f32); - alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32); - alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32); - - vst4q_f32(out_ptr + x, alpha_ab); - } - - // Left-over loop - for(; x < window_end_x; ++x) - { - *(out_ptr + x) += *(in_ptr + x) * beta; - } + return (data.dt == DataType::F32); + }, + REGISTER_FP32_NEON(neon_fp32_gemm_matrix_add) }, - in, out); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void matrix_addition_f16(const ITensor *src, ITensor *dst, const Window &window, float beta) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); - const float16x8_t beta_f16 = vdupq_n_f16(beta); - - constexpr int window_step_x = 16; - const auto window_start_x = static_cast<int>(window.x().start()); - const auto window_end_x = static_cast<int>(window.x().end()); - - Window win = window.collapse_if_possible(window, Window::DimZ); - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator in(src, win); - Iterator out(dst, win); - - execute_window_loop(win, [&](const Coordinates &) { - const auto in_ptr = reinterpret_cast<const float16_t *>(in.ptr()); - const auto out_ptr = reinterpret_cast<float16_t *>(out.ptr()); - - int x = window_start_x; - for(; x < (window_end_x - window_step_x); x += window_step_x) - { - float16x8x2_t alpha_ab = vld2q_f16(out_ptr + x); - const float16x8x2_t c = vld2q_f16(in_ptr + x); - // Multiply matrix C by its weight and accumulate - alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16)); - alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16)); - - vst2q_f16(out_ptr + x, alpha_ab); - } - - // Left-over loop - for(; x < window_end_x; ++x) + "neon_fp16_gemm_matrix_add", + [](const DataTypeISASelectorData & data) { - *(out_ptr + x) += *(in_ptr + x) * static_cast<float16_t>(beta); - } + return (data.dt == DataType::F16) && data.isa.fp16; + }, + REGISTER_FP16_NEON(neon_fp16_gemm_matrix_add) }, - in, out); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +}; } // namespace void CpuGemmMatrixAdditionKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta) @@ -138,22 +70,10 @@ void CpuGemmMatrixAdditionKernel::configure(const ITensorInfo *src, ITensorInfo // Perform validation step ARM_COMPUTE_ERROR_THROW_ON(CpuGemmMatrixAdditionKernel::validate(src, dst, beta)); - _beta = beta; - switch(src->data_type()) - { - case DataType::F32: - _func = &matrix_addition_f32; - break; - case DataType::F16: -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - _func = &matrix_addition_f16; - break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - default: - ARM_COMPUTE_ERROR("Data type not supported"); - break; - } - + _beta = beta; + const auto uk = CpuGemmMatrixAdditionKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() }); + ARM_COMPUTE_ERROR_ON_NULLPTR(uk); + _func = uk->ukernel; // Configure kernel window Window win = calculate_max_window(*src, Steps()); ICPPKernel::configure(win); @@ -195,6 +115,11 @@ const char *CpuGemmMatrixAdditionKernel::name() const { return "CpuGemmMatrixAdditionKernel"; } + +const std::vector<CpuGemmMatrixAdditionKernel::GemmMatrixAddKernel> &CpuGemmMatrixAdditionKernel::get_available_kernels() +{ + return available_kernels; +} } // namespace kernels } // namespace cpu } // namespace arm_compute |