From bdb6b0bb156588dc39fd5084d4c91d05b5148610 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 30 Jun 2017 12:21:00 +0100 Subject: COMPMID-433 - Port NEGEMM to support 16 bit fixed point Change-Id: I82de74d7027bbc8a00a4d6671e968785280d5f6c Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79498 Reviewed-by: Georgios Pinitas Tested-by: Kaizen Reviewed-by: Moritz Pflanzer Reviewed-by: Anthony Barbier --- .../NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp | 272 ++++++++++++++++++++- 1 file changed, 267 insertions(+), 5 deletions(-) (limited to 'src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp') diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp index bff16ec329..b81be6cee9 100644 --- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp +++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp @@ -474,6 +474,135 @@ void vector_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, IT ina, inb, out); } +template +void vector_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) +{ + const auto width_matrix_b = static_cast(output->info()->dimension(0)); + const auto in_b_stride = static_cast(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type())); + const auto num_elems_vec_a = static_cast(input0->info()->dimension(0)); + const int fixed_point_position = input0->info()->fixed_point_position(); + + // The implementation computes 16 elements per iteration + const int window_start_x = 16 * window.thread_id(); + const int window_step_x = 16 * window.num_threads(); + // Make sure (window_end_x - window_start_x) is a multiple of window_step_x + const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x; + ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x"); + + Window win_out(window); + win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + win_out.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(input1->info()->num_dimensions() >= 3) + { + win_b = window; + } + win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x)); + win_b.set(Window::DimY, Window::Dimension(0, 1, 1)); + + Iterator ina(input0, win_a); + Iterator inb(input1, win_b); + Iterator out(output, win_out); + + execute_window_loop(win_out, [&](const Coordinates & id) + { + if(id.x() > width_matrix_b) + { + return; + } + + // Reset accumulators + qint32x4_t acc00_qs32 = vdupq_n_qs32(0); + qint32x4_t acc01_qs32 = vdupq_n_qs32(0); + qint32x4_t acc02_qs32 = vdupq_n_qs32(0); + qint32x4_t acc03_qs32 = vdupq_n_qs32(0); + + auto vec_a = reinterpret_cast(ina.ptr()); + auto matrix_b = reinterpret_cast(inb.ptr()); + + auto vec_a_end_addr = vec_a + num_elems_vec_a; + for(; vec_a <= (vec_a_end_addr - 2);) + { + const qint16x4_t a0 = vld1_dup_qs16(vec_a + 0); + const qint16x4_t a1 = vld1_dup_qs16(vec_a + 1); + + const qint16x4_t b00 = vld1_qs16(matrix_b + 0 + 0 * in_b_stride); + const qint16x4_t b01 = vld1_qs16(matrix_b + 4 + 0 * in_b_stride); + const qint16x4_t b02 = vld1_qs16(matrix_b + 8 + 0 * in_b_stride); + const qint16x4_t b03 = vld1_qs16(matrix_b + 12 + 0 * in_b_stride); + const qint16x4_t b10 = vld1_qs16(matrix_b + 0 + 1 * in_b_stride); + const qint16x4_t b11 = vld1_qs16(matrix_b + 4 + 1 * in_b_stride); + const qint16x4_t b12 = vld1_qs16(matrix_b + 8 + 1 * in_b_stride); + const qint16x4_t b13 = vld1_qs16(matrix_b + 12 + 1 * in_b_stride); + + // First accumulation + acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position); + acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position); + acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position); + acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position); + + // Second accumulation + acc00_qs32 = vqmlal_qs16(acc00_qs32, b10, a1, fixed_point_position); + acc01_qs32 = vqmlal_qs16(acc01_qs32, b11, a1, fixed_point_position); + acc02_qs32 = vqmlal_qs16(acc02_qs32, b12, a1, fixed_point_position); + acc03_qs32 = vqmlal_qs16(acc03_qs32, b13, a1, fixed_point_position); + + vec_a += 2; + matrix_b += 2 * in_b_stride; + } + + for(; vec_a < vec_a_end_addr;) + { + const qint16x4_t a0 = vld1_dup_qs16(vec_a); + + const qint16x4_t b00 = vld1_qs16(matrix_b + 0); + const qint16x4_t b01 = vld1_qs16(matrix_b + 4); + const qint16x4_t b02 = vld1_qs16(matrix_b + 8); + const qint16x4_t b03 = vld1_qs16(matrix_b + 12); + + acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position); + acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position); + acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position); + acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position); + + vec_a += 1; + matrix_b += in_b_stride; + } + + // Convert back to qint16x4_t and saturate + qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32); + qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32); + qint16x4_t acc02_qs16 = vqmovn_qs32(acc02_qs32); + qint16x4_t acc03_qs16 = vqmovn_qs32(acc03_qs32); + + // Multiply by the weight of the matrix product (alpha) + if(multiply_alpha) + { + const qint16x4_t alpha_qs16 = vdup_n_qs16(scvt_qs16_f32(alpha, fixed_point_position)); + acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position); + acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position); + acc02_qs16 = vqmul_qs16(acc02_qs16, alpha_qs16, fixed_point_position); + acc03_qs16 = vqmul_qs16(acc03_qs16, alpha_qs16, fixed_point_position); + } + + const auto mtx_out0 = reinterpret_cast(out.ptr()); + + // Store 16x4 output elements + vst1_qs16(mtx_out0 + 0, acc00_qs16); + vst1_qs16(mtx_out0 + 4, acc01_qs16); + vst1_qs16(mtx_out0 + 8, acc02_qs16); + vst1_qs16(mtx_out0 + 12, acc03_qs16); + }, + ina, inb, out); +} + template void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) { @@ -1153,6 +1282,120 @@ void matrix_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, IT ina, inb, out); } +template +void matrix_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha) +{ + const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()); + const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type()); + const size_t out_stride2 = out_stride1 * 2; + const size_t out_stride3 = out_stride1 * 3; + const int num_elems_matrix_b_x = input1->info()->dimension(0); + const int fixed_point_position = input0->info()->fixed_point_position(); + const qint16x4_t alpha_qs16 = vdup_n_qs16(scvt_qs16_f32(alpha, fixed_point_position)); + ARM_COMPUTE_UNUSED(alpha_qs16); + + // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix + Window win_a(window); + win_a.set(Window::DimX, Window::Dimension(0, 0, 0)); + win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1)); + + Window win_b; + // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2 + // This scenario can happen when the the matrix multiplication is used to perform a convolution operation + if(input1->info()->num_dimensions() >= 3) + { + win_b = window; + } + // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix + win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride)); + win_b.set(Window::DimY, Window::Dimension(0, 0, 0)); + + Iterator ina(input0, win_a); + Iterator inb(input1, win_b); + Iterator out(output, window); + + // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW + // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 8x4 elements per iteration + // All the values needed for computing a single 8x4 block will be read from consecutive memory positions + execute_window_loop(window, [&](const Coordinates & id) + { + auto mtx_a0 = reinterpret_cast(ina.ptr()); + auto mtx_b0 = reinterpret_cast(inb.ptr()); + auto mtx_b1 = mtx_b0 + in_b_stride; + + qint32x4_t acc00_qs32 = vdupq_n_qs32(0); + qint32x4_t acc10_qs32 = vdupq_n_qs32(0); + qint32x4_t acc20_qs32 = vdupq_n_qs32(0); + qint32x4_t acc30_qs32 = vdupq_n_qs32(0); + + qint32x4_t acc01_qs32 = vdupq_n_qs32(0); + qint32x4_t acc11_qs32 = vdupq_n_qs32(0); + qint32x4_t acc21_qs32 = vdupq_n_qs32(0); + qint32x4_t acc31_qs32 = vdupq_n_qs32(0); + + // This for loop performs 1 accumulation + for(int k = 0; k <= (num_elems_matrix_b_x - 8); k += 8) + { + const qint16x4_t a0 = vld1_dup_qs16(mtx_a0 + 0); + const qint16x4_t a1 = vld1_dup_qs16(mtx_a0 + 1); + const qint16x4_t a2 = vld1_dup_qs16(mtx_a0 + 2); + const qint16x4_t a3 = vld1_dup_qs16(mtx_a0 + 3); + + const qint16x4_t b00 = vld1_qs16(mtx_b0 + 0); + const qint16x4_t b01 = vld1_qs16(mtx_b0 + 4); + + acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position); + acc10_qs32 = vqmlal_qs16(acc10_qs32, b00, a1, fixed_point_position); + acc20_qs32 = vqmlal_qs16(acc20_qs32, b00, a2, fixed_point_position); + acc30_qs32 = vqmlal_qs16(acc30_qs32, b00, a3, fixed_point_position); + acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position); + acc11_qs32 = vqmlal_qs16(acc11_qs32, b01, a1, fixed_point_position); + acc21_qs32 = vqmlal_qs16(acc21_qs32, b01, a2, fixed_point_position); + acc31_qs32 = vqmlal_qs16(acc31_qs32, b01, a3, fixed_point_position); + + mtx_a0 += 4; + mtx_b0 += 8; + mtx_b1 += 8; + } + + // Convert back to qint16x4_t and saturate + qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32); + qint16x4_t acc10_qs16 = vqmovn_qs32(acc10_qs32); + qint16x4_t acc20_qs16 = vqmovn_qs32(acc20_qs32); + qint16x4_t acc30_qs16 = vqmovn_qs32(acc30_qs32); + + qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32); + qint16x4_t acc11_qs16 = vqmovn_qs32(acc11_qs32); + qint16x4_t acc21_qs16 = vqmovn_qs32(acc21_qs32); + qint16x4_t acc31_qs16 = vqmovn_qs32(acc31_qs32); + + // Multiply by the weight of the matrix product (alpha) + if(multiply_alpha) + { + acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position); + acc10_qs16 = vqmul_qs16(acc10_qs16, alpha_qs16, fixed_point_position); + acc20_qs16 = vqmul_qs16(acc20_qs16, alpha_qs16, fixed_point_position); + acc30_qs16 = vqmul_qs16(acc30_qs16, alpha_qs16, fixed_point_position); + acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position); + acc11_qs16 = vqmul_qs16(acc11_qs16, alpha_qs16, fixed_point_position); + acc21_qs16 = vqmul_qs16(acc21_qs16, alpha_qs16, fixed_point_position); + acc31_qs16 = vqmul_qs16(acc31_qs16, alpha_qs16, fixed_point_position); + } + + const auto mtx_out0 = reinterpret_cast(out.ptr()); + + // Store 8x4 output elements + vst1_qs16(mtx_out0 + 0, acc00_qs16); + vst1_qs16(mtx_out0 + 4, acc01_qs16); + vst1_qs16(mtx_out0 + out_stride1 + 0, acc10_qs16); + vst1_qs16(mtx_out0 + out_stride1 + 4, acc11_qs16); + vst1_qs16(mtx_out0 + out_stride2 + 0, acc20_qs16); + vst1_qs16(mtx_out0 + out_stride2 + 4, acc21_qs16); + vst1_qs16(mtx_out0 + out_stride3 + 0, acc30_qs16); + vst1_qs16(mtx_out0 + out_stride3 + 4, acc31_qs16); + }, + ina, inb, out); +} } // namespace NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel() @@ -1162,10 +1405,7 @@ NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel() void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32, DataType::QS8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output); ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output); @@ -1197,6 +1437,11 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor num_elems_processed_per_iteration_x = 32; break; } + case DataType::QS16: + { + num_elems_processed_per_iteration_x = 16; + break; + } #ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: { @@ -1241,6 +1486,11 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor num_elems_processed_per_iteration_x = 32; break; } + case DataType::QS16: + { + num_elems_processed_per_iteration_x = 8; + break; + } #ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: { @@ -1278,7 +1528,7 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window) bool multiply_alpha = std::abs(1.0f - _alpha) > 0.00001f; - // Check if the output tensor is a vector and the data type is F32. If so,the kernel runs the vector-matrix multiplication + // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication if((_output->info()->dimension(1) == 1)) { switch(_input0->info()->data_type()) @@ -1295,6 +1545,12 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window) vector_matrix_multiply_qs8(_input0, _input1, _output, window, _alpha); break; } + case DataType::QS16: + { + multiply_alpha ? vector_matrix_multiply_qs16(_input0, _input1, _output, window, _alpha) : + vector_matrix_multiply_qs16(_input0, _input1, _output, window, _alpha); + break; + } #ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: { @@ -1326,6 +1582,12 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window) matrix_matrix_multiply_qs8(_input0, _input1, _output, window, _alpha); break; } + case DataType::QS16: + { + multiply_alpha ? matrix_matrix_multiply_qs16(_input0, _input1, _output, window, _alpha) : + matrix_matrix_multiply_qs16(_input0, _input1, _output, window, _alpha); + break; + } #ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: { -- cgit v1.2.1