From c0b6f76561580414f08633a804fc548ccad65659 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 2 Nov 2020 01:37:17 +0000 Subject: COMPMID-3776: Indirect GEMM Signed-off-by: Georgios Pinitas Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343 Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/quantized.cpp | 173 +++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) (limited to 'src/core/NEON/kernels/arm_gemm/quantized.cpp') diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index cac02cf28e..111d01ed3a 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -301,6 +301,179 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne out_ptr1 += 16; } + // We are often quantizing one block of interleaved kernel output at a time - these are three registers + // wide. Special case that here. + if (regs==3) { + regs -= 3; + + int32x4_t v_mul0; + int32x4_t v_mul1; + int32x4_t v_mul2; + + int32x4_t v_shf0; + int32x4_t v_shf1; + int32x4_t v_shf2; + + int32x4_t v_shf0l; + int32x4_t v_shf1l; + int32x4_t v_shf2l; + + if (per_channel) { + v_mul0 = vld1q_s32(perch_mul_ptr); + v_mul1 = vld1q_s32(perch_mul_ptr + 4); + v_mul2 = vld1q_s32(perch_mul_ptr + 8); + perch_mul_ptr += 12; + + v_shf0 = vld1q_s32(perch_shift_ptr); + v_shf1 = vld1q_s32(perch_shift_ptr + 4); + v_shf2 = vld1q_s32(perch_shift_ptr + 8); + perch_shift_ptr += 12; + + if (do_left_shift) { + v_shf0l = vld1q_s32(perch_shiftl_ptr); + v_shf1l = vld1q_s32(perch_shiftl_ptr + 4); + v_shf2l = vld1q_s32(perch_shiftl_ptr + 8); + perch_shiftl_ptr += 12; + } + } else { + v_mul0=v_mul1=v_mul2=v_mul; + v_shf0=v_shf1=v_shf2=v_right_shift; + v_shf0l=v_shf1l=v_shf2l=v_left_shift; + } + + // Load column pointers + int32x4_t v_col0 = vld1q_s32(colptr); + int32x4_t v_col1 = vld1q_s32(colptr + 4); + int32x4_t v_col2 = vld1q_s32(colptr + 8); + colptr += 12; + + // Load input data (row 0); + int32x4_t v_in00 = vld1q_s32(in_ptr); + int32x4_t v_in01 = vld1q_s32(in_ptr + 4); + int32x4_t v_in02 = vld1q_s32(in_ptr + 8); + in_ptr += 12; + + // Load input data (row 1); + int32x4_t v_in10 = vld1q_s32(in_ptr1); + int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4); + int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8); + in_ptr1 += 12; + + // Add on row bias and column bias + v_in00 = vaddq_s32(v_in00, v_row_sum); + v_in01 = vaddq_s32(v_in01, v_row_sum); + v_in02 = vaddq_s32(v_in02, v_row_sum); + + v_in10 = vaddq_s32(v_in10, v_row_sum1); + v_in11 = vaddq_s32(v_in11, v_row_sum1); + v_in12 = vaddq_s32(v_in12, v_row_sum1); + + v_in00 = vaddq_s32(v_in00, v_col0); + v_in01 = vaddq_s32(v_in01, v_col1); + v_in02 = vaddq_s32(v_in02, v_col2); + + v_in10 = vaddq_s32(v_in10, v_col0); + v_in11 = vaddq_s32(v_in11, v_col1); + v_in12 = vaddq_s32(v_in12, v_col2); + + // Quantize + + // If a left shift is needed it needs to happen first. + if (do_left_shift) { + v_in00 = vrshlq_s32(v_in00, v_shf0l); + v_in01 = vrshlq_s32(v_in01, v_shf1l); + v_in02 = vrshlq_s32(v_in02, v_shf2l); + + v_in10 = vrshlq_s32(v_in10, v_shf0l); + v_in11 = vrshlq_s32(v_in11, v_shf1l); + v_in12 = vrshlq_s32(v_in12, v_shf2l); + } + + // Multiply + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); + v_in01 = vqrdmulhq_s32(v_in01, v_mul1); + v_in02 = vqrdmulhq_s32(v_in02, v_mul2); + + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); + v_in11 = vqrdmulhq_s32(v_in11, v_mul1); + v_in12 = vqrdmulhq_s32(v_in12, v_mul2); + + // Compute and add on corrective offset + if (do_shift_correction) { + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); + int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1); + int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2); + + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); + int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1); + int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2); + + v_temp00 = vshrq_n_s32(v_temp00, 31); + v_temp01 = vshrq_n_s32(v_temp01, 31); + v_temp02 = vshrq_n_s32(v_temp02, 31); + + v_temp10 = vshrq_n_s32(v_temp10, 31); + v_temp11 = vshrq_n_s32(v_temp11, 31); + v_temp12 = vshrq_n_s32(v_temp12, 31); + + v_in00 = vqaddq_s32(v_in00, v_temp00); + v_in01 = vqaddq_s32(v_in01, v_temp01); + v_in02 = vqaddq_s32(v_in02, v_temp02); + + v_in10 = vqaddq_s32(v_in10, v_temp10); + v_in11 = vqaddq_s32(v_in11, v_temp11); + v_in12 = vqaddq_s32(v_in12, v_temp12); + } + + v_in00 = vrshlq_s32(v_in00, v_shf0); + v_in01 = vrshlq_s32(v_in01, v_shf1); + v_in02 = vrshlq_s32(v_in02, v_shf2); + + v_in10 = vrshlq_s32(v_in10, v_shf0); + v_in11 = vrshlq_s32(v_in11, v_shf1); + v_in12 = vrshlq_s32(v_in12, v_shf2); + + v_in00 = vaddq_s32(v_in00, v_c_offset); + v_in01 = vaddq_s32(v_in01, v_c_offset); + v_in02 = vaddq_s32(v_in02, v_c_offset); + + v_in10 = vaddq_s32(v_in10, v_c_offset); + v_in11 = vaddq_s32(v_in11, v_c_offset); + v_in12 = vaddq_s32(v_in12, v_c_offset); + + v_in00 = vmaxq_s32(v_in00, v_minval); + v_in01 = vmaxq_s32(v_in01, v_minval); + v_in02 = vmaxq_s32(v_in02, v_minval); + + v_in10 = vmaxq_s32(v_in10, v_minval); + v_in11 = vmaxq_s32(v_in11, v_minval); + v_in12 = vmaxq_s32(v_in12, v_minval); + + v_in00 = vminq_s32(v_in00, v_maxval); + v_in01 = vminq_s32(v_in01, v_maxval); + v_in02 = vminq_s32(v_in02, v_maxval); + + v_in10 = vminq_s32(v_in10, v_maxval); + v_in11 = vminq_s32(v_in11, v_maxval); + v_in12 = vminq_s32(v_in12, v_maxval); + + int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01)); + int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in02)); + + int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11)); + int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in12)); + + int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01)); + int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11)); + + vst1q_lane_s64(reinterpret_cast(out_ptr), vreinterpretq_s64_s8(v_uz0), 0); + vst1q_lane_s32(reinterpret_cast(out_ptr + 8), vreinterpretq_s32_s8(v_uz0), 2); + out_ptr += 12; + vst1q_lane_s64(reinterpret_cast(out_ptr1), vreinterpretq_s64_s8(v_uz1), 0); + vst1q_lane_s32(reinterpret_cast(out_ptr1 + 8), vreinterpretq_s32_s8(v_uz1), 2); + out_ptr1 += 12; + } + while (regs--) { int32x4_t v_mul0; int32x4_t v_shf0; -- cgit v1.2.1