aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp123
1 files changed, 24 insertions, 99 deletions
diff --git a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp
index 81376fb029..6399ebbef4 100644
--- a/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp
+++ b/src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,11 +28,10 @@
#include "arm_compute/core/Validate.h"
#include "src/core/CPP/Validate.h"
#include "src/core/NEON/NEFixedPoint.h"
+#include "src/core/common/Registrars.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
-
-#include <arm_neon.h>
-
+#include "src/cpu/kernels/gemm_matrix_add/list.h"
namespace arm_compute
{
namespace cpu
@@ -41,93 +40,26 @@ namespace kernels
{
namespace
{
-void matrix_addition_f32(const ITensor *src, ITensor *dst, const Window &window, float beta)
+static const std::vector<CpuGemmMatrixAdditionKernel::GemmMatrixAddKernel> available_kernels =
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
- const float32x4_t beta_f32 = vdupq_n_f32(beta);
-
- constexpr int window_step_x = 16;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- Window win = window.collapse_if_possible(window, Window::DimZ);
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator out(dst, win);
-
- execute_window_loop(win, [&](const Coordinates &)
{
- const auto in_ptr = reinterpret_cast<const float *>(in.ptr());
- const auto out_ptr = reinterpret_cast<float *>(out.ptr());
-
- int x = window_start_x;
- for(; x < (window_end_x - window_step_x); x += window_step_x)
+ "neon_fp32_gemm_matrix_add",
+ [](const DataTypeISASelectorData & data)
{
- float32x4x4_t alpha_ab = vld4q_f32(out_ptr + x);
- const float32x4x4_t c = vld4q_f32(in_ptr + x);
-
- // Multiply matrix C by its weight and accumulate
- alpha_ab.val[0] = vmlaq_f32(alpha_ab.val[0], c.val[0], beta_f32);
- alpha_ab.val[1] = vmlaq_f32(alpha_ab.val[1], c.val[1], beta_f32);
- alpha_ab.val[2] = vmlaq_f32(alpha_ab.val[2], c.val[2], beta_f32);
- alpha_ab.val[3] = vmlaq_f32(alpha_ab.val[3], c.val[3], beta_f32);
-
- vst4q_f32(out_ptr + x, alpha_ab);
- }
-
- // Left-over loop
- for(; x < window_end_x; ++x)
- {
- *(out_ptr + x) += *(in_ptr + x) * beta;
- }
+ return (data.dt == DataType::F32);
+ },
+ REGISTER_FP32_NEON(neon_fp32_gemm_matrix_add)
},
- in, out);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void matrix_addition_f16(const ITensor *src, ITensor *dst, const Window &window, float beta)
-{
- ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
- const float16x8_t beta_f16 = vdupq_n_f16(beta);
-
- constexpr int window_step_x = 16;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- Window win = window.collapse_if_possible(window, Window::DimZ);
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator in(src, win);
- Iterator out(dst, win);
-
- execute_window_loop(win, [&](const Coordinates &)
{
- const auto in_ptr = reinterpret_cast<const float16_t *>(in.ptr());
- const auto out_ptr = reinterpret_cast<float16_t *>(out.ptr());
-
- int x = window_start_x;
- for(; x < (window_end_x - window_step_x); x += window_step_x)
- {
- float16x8x2_t alpha_ab = vld2q_f16(out_ptr + x);
- const float16x8x2_t c = vld2q_f16(in_ptr + x);
- // Multiply matrix C by its weight and accumulate
- alpha_ab.val[0] = vaddq_f16(alpha_ab.val[0], vmulq_f16(c.val[0], beta_f16));
- alpha_ab.val[1] = vaddq_f16(alpha_ab.val[1], vmulq_f16(c.val[1], beta_f16));
-
- vst2q_f16(out_ptr + x, alpha_ab);
- }
-
- // Left-over loop
- for(; x < window_end_x; ++x)
+ "neon_fp16_gemm_matrix_add",
+ [](const DataTypeISASelectorData & data)
{
- *(out_ptr + x) += *(in_ptr + x) * static_cast<float16_t>(beta);
- }
+ return (data.dt == DataType::F16) && data.isa.fp16;
+ },
+ REGISTER_FP16_NEON(neon_fp16_gemm_matrix_add)
},
- in, out);
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+};
} // namespace
void CpuGemmMatrixAdditionKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta)
@@ -138,22 +70,10 @@ void CpuGemmMatrixAdditionKernel::configure(const ITensorInfo *src, ITensorInfo
// Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(CpuGemmMatrixAdditionKernel::validate(src, dst, beta));
- _beta = beta;
- switch(src->data_type())
- {
- case DataType::F32:
- _func = &matrix_addition_f32;
- break;
- case DataType::F16:
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- _func = &matrix_addition_f16;
- break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
- }
-
+ _beta = beta;
+ const auto uk = CpuGemmMatrixAdditionKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() });
+ ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
+ _func = uk->ukernel;
// Configure kernel window
Window win = calculate_max_window(*src, Steps());
ICPPKernel::configure(win);
@@ -195,6 +115,11 @@ const char *CpuGemmMatrixAdditionKernel::name() const
{
return "CpuGemmMatrixAdditionKernel";
}
+
+const std::vector<CpuGemmMatrixAdditionKernel::GemmMatrixAddKernel> &CpuGemmMatrixAdditionKernel::get_available_kernels()
+{
+ return available_kernels;
+}
} // namespace kernels
} // namespace cpu
} // namespace arm_compute