From 221f38176b0d4dbc212441779d9bbac3cc0eecfa Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Wed, 28 Jun 2017 17:27:56 +0100 Subject: COMPMID-421: Fixed FP16 support in Neon GEMM. Fixed GEMM FP16 problem with matrices that are not multiple of 32. Added a new test suite NEON/GEMM/Float16/SmallGEMM. Implemented FP16 function to multiply vector by a matrix. Change-Id: Ie6c692885a48d0206bd6fe748332fa83bc286d67 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79118 Tested-by: Kaizen Reviewed-by: Moritz Pflanzer --- .../NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp | 197 +++++++++++++++++++-- 1 file changed, 186 insertions(+), 11 deletions(-) (limited to 'src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp') diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp index 1db025723c..101c5c8132 100644 --- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp @@ -49,6 +49,147 @@ class Coordinates; namespace { +template +void vector_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) +{ +#ifdef ARM_COMPUTE_ENABLE_FP16 + const auto width_matrix_b = static_cast(output->info()->dimension(0)); + const auto in_b_stride = static_cast(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type())); + const auto num_elems_vec_a = static_cast(input0->info()->dimension(0)); + + // The implementation computes 32 elements per iteration + const int window_start_x = 32 * window.thread_id(); + const int window_step_x = 32 * window.num_threads(); + const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; + ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x"); + + Window win_out(window); + win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(input1->info()->num_dimensions() >= 3) + { + win_b = window; + } + win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Iterator ina(input0, win_a); + Iterator inb(input1, win_b); + Iterator out(output, win_out); + + const float16x8_t alpha_f16 = vdupq_n_f16(alpha); + ARM_COMPUTE_UNUSED(alpha_f16); + + execute_window_loop(win_out, [&](const Coordinates & id) + { + if(id.x() > width_matrix_b) + { + return; + } + + float16x8_t acc0 = vdupq_n_f16(0.f); + float16x8_t acc1 = vdupq_n_f16(0.f); + float16x8_t acc2 = vdupq_n_f16(0.f); + float16x8_t acc3 = vdupq_n_f16(0.f); + + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(inb.ptr()); + + const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 4);) + { + const float16x4_t a0l = vld1_f16(vec_a); + + float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1)); + + matrix_b += 2 * in_b_stride; + + b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride); + b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride); + b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride); + b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2)); + acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3)); + acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3)); + acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3)); + acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3)); + + vec_a += 4; + matrix_b += 2 * in_b_stride; + } + + for(; vec_a < vec_a_end_addr;) + { + const float16_t a0 = *vec_a; + const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride); + const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride); + const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride); + const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride); + + acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0)); + acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0)); + acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0)); + acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0)); + + vec_a += 1; + matrix_b += in_b_stride; + } + + // Multiply by the weight of matrix product (alpha) + if(multiply_alpha) + { + acc0 = vmulq_f16(acc0, alpha_f16); + acc1 = vmulq_f16(acc1, alpha_f16); + acc2 = vmulq_f16(acc2, alpha_f16); + acc3 = vmulq_f16(acc3, alpha_f16); + } + + const auto vec_out = reinterpret_cast(out.ptr()); + + vst1q_f16(vec_out + 0, acc0); + vst1q_f16(vec_out + 8, acc1); + vst1q_f16(vec_out + 16, acc2); + vst1q_f16(vec_out + 24, acc3); + + }, + ina, inb, out); +#else /* ARM_COMPUTE_ENABLE_FP16 */ + ARM_COMPUTE_ERROR("Not implemented"); +#endif /* ARM_COMPUTE_ENABLE_FP16 */ +} + template void vector_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) { @@ -639,9 +780,9 @@ template void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) { #ifdef ARM_COMPUTE_ENABLE_FP16 - - const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()); - const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type()); + const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()); + const size_t out_stride = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type()); + const int num_elems_matrix_b_x = input1->info()->dimension(0); // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix Window win_a(window); @@ -663,9 +804,6 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT Iterator inb(input1, win_b); Iterator out(output, window); - // Number of iterations of inner loop. Since 8 is the number of accumulations per loop, num_it = (width_mtx_b / 4) / 8 - const size_t num_it = ((input1->info()->dimension(0)) >> 2) >> 3; - const float16x8_t alpha_f16 = vdupq_n_f16(alpha); execute_window_loop(window, [&](const Coordinates & id) @@ -711,10 +849,14 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT The size of the output tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size. */ - for(size_t k = num_it; k > 0; mtx_a0 += 16, mtx_b0 += 32, --k) + const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x; + + for(; mtx_b0 <= (mtx_b0_end_addr - 32);) + { const float16x8_t p00 = vld1q_f16(mtx_a0); const float16x8_t p02 = vld1q_f16(mtx_a0 + 8); + const float16x8_t q00 = vld1q_f16(mtx_b0); const float16x8_t q02 = vld1q_f16(mtx_b0 + 8); const float16x8_t q04 = vld1q_f16(mtx_b0 + 16); @@ -739,6 +881,24 @@ void matrix_matrix_multiply_f16(const ITensor *input0, const ITensor *input1, IT c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5))); c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6))); c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7))); + + mtx_a0 += 16; + mtx_b0 += 32; + } + + for(; mtx_b0 < mtx_b0_end_addr;) + + { + const float16x4_t p00 = vld1_f16(mtx_a0); + const float16x8_t q00 = vld1q_f16(mtx_b0); + + c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0))); + c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1))); + c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2))); + c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3))); + + mtx_a0 += 4; + mtx_b0 += 8; } if(multiply_alpha) @@ -1037,6 +1197,13 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor num_elems_processed_per_iteration_x = 32; break; } +#ifdef ARM_COMPUTE_ENABLE_FP16 + case DataType::F16: + { + num_elems_processed_per_iteration_x = 32; + break; + } +#endif /* ARM_COMPUTE_ENABLE_FP16 */ default: { ARM_COMPUTE_ERROR("Data type not supported"); @@ -1074,13 +1241,13 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor num_elems_processed_per_iteration_x = 32; break; } +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: { -#ifdef ARM_COMPUTE_ENABLE_FP16 num_elems_processed_per_iteration_x = 8; break; -#endif } +#endif /* ARM_COMPUTE_ENABLE_FP16 */ default: { ARM_COMPUTE_ERROR("Data type not supported"); @@ -1128,6 +1295,14 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window) vector_matrix_multiply_qs8(_input0, _input1, _output, window, _alpha); break; } +#ifdef ARM_COMPUTE_ENABLE_FP16 + case DataType::F16: + { + multiply_alpha ? vector_matrix_multiply_f16(_input0, _input1, _output, window, _alpha) : + vector_matrix_multiply_f16(_input0, _input1, _output, window, _alpha); + break; + } +#endif /* ARM_COMPUTE_ENABLE_FP16 */ default: { ARM_COMPUTE_ERROR("Data type not supported"); @@ -1151,14 +1326,14 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window) matrix_matrix_multiply_qs8(_input0, _input1, _output, window, _alpha); break; } +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: { -#ifdef ARM_COMPUTE_ENABLE_FP16 multiply_alpha ? matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha) : matrix_matrix_multiply_f16(_input0, _input1, _output, window, _alpha); break; -#endif } +#endif /* ARM_COMPUTE_ENABLE_FP16 */ default: { ARM_COMPUTE_ERROR("Data type not supported"); -- cgit v1.2.1