From 71ac9037abce1c6c4af42c485d5395dd6fd79a5a Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 14 Nov 2019 14:31:44 +0000 Subject: COMPMID-2923 Integrate arm_gemm per channel quantization Signed-off-by: Michalis Spyrou Change-Id: I8667e75843fdd6ac75bd8272a86a348b830da28d Reviewed-on: https://review.mlplatform.org/c/2548 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/quantized.cpp | 265 ++++++++++++++++++--------- 1 file changed, 176 insertions(+), 89 deletions(-) (limited to 'src/core/NEON/kernels/arm_gemm/quantized.cpp') diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index bffb7ddcb3..00b42cf422 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -47,13 +47,19 @@ namespace { * applied to negative values being shifted right to make sure they round * properly - if negative values are never output (e.g. fused ReLU) this is * unnecessary. + * + * The 'per_channel' template parameter selects between per channel and per + * layer requantization - in the former case we need to load vectors of + * shifts and multipliers for each column. A separate vector for each + * column is set up in any case (and it is hoped that the compiler can elide + * the needless movs in the per-layer case). */ -template -void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template +void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias) { - const int32x4_t v_mul = vdupq_n_s32(qp.requant_mul); - const int32x4_t v_shift = vdupq_n_s32(qp.requant_shift); + const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul); + const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift); const int32x4_t v_minval = vdupq_n_s32(qp.minval); const int32x4_t v_maxval = vdupq_n_s32(qp.maxval); const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset); @@ -70,6 +76,8 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u unsigned int odds=(width % 4); const int32_t *colptr = col_bias; + const int32_t *perch_mul_ptr = qp.per_channel_muls; + const int32_t *perch_shift_ptr = qp.per_channel_shifts; const int32_t *in_ptr = input + (row * in_stride); int8_t *out_ptr = output + (row * out_stride); @@ -93,6 +101,33 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1); while (blocks--) { + int32x4_t v_mul0; + int32x4_t v_mul1; + int32x4_t v_mul2; + int32x4_t v_mul3; + + int32x4_t v_shf0; + int32x4_t v_shf1; + int32x4_t v_shf2; + int32x4_t v_shf3; + + if (per_channel) { + v_mul0 = vld1q_s32(perch_mul_ptr); + v_mul1 = vld1q_s32(perch_mul_ptr + 4); + v_mul2 = vld1q_s32(perch_mul_ptr + 8); + v_mul3 = vld1q_s32(perch_mul_ptr + 12); + perch_mul_ptr += 16; + + v_shf0 = vld1q_s32(perch_shift_ptr); + v_shf1 = vld1q_s32(perch_shift_ptr + 4); + v_shf2 = vld1q_s32(perch_shift_ptr + 8); + v_shf3 = vld1q_s32(perch_shift_ptr + 12); + perch_shift_ptr += 16; + } else { + v_mul0=v_mul1=v_mul2=v_mul3=v_mul; + v_shf0=v_shf1=v_shf2=v_shf3=v_shift; + } + // Load column pointers int32x4_t v_col0 = vld1q_s32(colptr); int32x4_t v_col1 = vld1q_s32(colptr + 4); @@ -136,27 +171,27 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in13 = vaddq_s32(v_in13, v_col3); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); - v_in01 = vqrdmulhq_s32(v_in01, v_mul); - v_in02 = vqrdmulhq_s32(v_in02, v_mul); - v_in03 = vqrdmulhq_s32(v_in03, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); + v_in01 = vqrdmulhq_s32(v_in01, v_mul1); + v_in02 = vqrdmulhq_s32(v_in02, v_mul2); + v_in03 = vqrdmulhq_s32(v_in03, v_mul3); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); - v_in11 = vqrdmulhq_s32(v_in11, v_mul); - v_in12 = vqrdmulhq_s32(v_in12, v_mul); - v_in13 = vqrdmulhq_s32(v_in13, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); + v_in11 = vqrdmulhq_s32(v_in11, v_mul1); + v_in12 = vqrdmulhq_s32(v_in12, v_mul2); + v_in13 = vqrdmulhq_s32(v_in13, v_mul3); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); - int32x4_t v_temp01 = vandq_s32(v_in01, v_shift); - int32x4_t v_temp02 = vandq_s32(v_in02, v_shift); - int32x4_t v_temp03 = vandq_s32(v_in03, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); + int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1); + int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2); + int32x4_t v_temp03 = vandq_s32(v_in03, v_shf3); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); - int32x4_t v_temp11 = vandq_s32(v_in11, v_shift); - int32x4_t v_temp12 = vandq_s32(v_in12, v_shift); - int32x4_t v_temp13 = vandq_s32(v_in13, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); + int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1); + int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2); + int32x4_t v_temp13 = vandq_s32(v_in13, v_shf3); v_temp00 = vshrq_n_s32(v_temp00, 31); v_temp01 = vshrq_n_s32(v_temp01, 31); @@ -179,15 +214,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in13 = vqaddq_s32(v_in13, v_temp13); } - v_in00 = vrshlq_s32(v_in00, v_shift); - v_in01 = vrshlq_s32(v_in01, v_shift); - v_in02 = vrshlq_s32(v_in02, v_shift); - v_in03 = vrshlq_s32(v_in03, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); + v_in01 = vrshlq_s32(v_in01, v_shf1); + v_in02 = vrshlq_s32(v_in02, v_shf2); + v_in03 = vrshlq_s32(v_in03, v_shf3); - v_in10 = vrshlq_s32(v_in10, v_shift); - v_in11 = vrshlq_s32(v_in11, v_shift); - v_in12 = vrshlq_s32(v_in12, v_shift); - v_in13 = vrshlq_s32(v_in13, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); + v_in11 = vrshlq_s32(v_in11, v_shf1); + v_in12 = vrshlq_s32(v_in12, v_shf2); + v_in13 = vrshlq_s32(v_in13, v_shf3); v_in00 = vaddq_s32(v_in00, v_c_offset); v_in01 = vaddq_s32(v_in01, v_c_offset); @@ -235,6 +270,20 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u } while (regs--) { + int32x4_t v_mul0; + int32x4_t v_shf0; + + if (per_channel) { + v_mul0 = vld1q_s32(perch_mul_ptr); + perch_mul_ptr += 4; + + v_shf0 = vld1q_s32(perch_shift_ptr); + perch_shift_ptr += 4; + } else { + v_mul0=v_mul; + v_shf0=v_shift; + } + // Load column pointers int32x4_t v_col0 = vld1q_s32(colptr); colptr += 4; @@ -258,15 +307,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vaddq_s32(v_in10, v_col0); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); v_temp00 = vshrq_n_s32(v_temp00, 31); @@ -277,9 +326,9 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vqaddq_s32(v_in10, v_temp10); } - v_in00 = vrshlq_s32(v_in00, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); - v_in10 = vrshlq_s32(v_in10, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); v_in00 = vaddq_s32(v_in00, v_c_offset); @@ -307,21 +356,40 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u int32x4_t v_col0 = vdupq_n_s32(0); int32x4_t v_in00 = vdupq_n_s32(0); int32x4_t v_in10 = vdupq_n_s32(0); + int32x4_t v_mul0 = vdupq_n_s32(0); + int32x4_t v_shf0 = vdupq_n_s32(0); + + if (!per_channel) { + v_mul0 = v_mul; + v_shf0 = v_shift; + } do { v_col0 = vld1q_lane_s32(colptr, v_col0, 0); v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0); v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0); + v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0); + } if (odds == 1) { break; } v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1); v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1); v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1); + v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1); + } if (odds == 2) { break; } v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2); v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2); v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2); + if (per_channel) { + v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2); + v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2); + } } while (0); // Add on row sum and bias constant @@ -335,15 +403,15 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vaddq_s32(v_in10, v_col0); // Quantize - start with multiply - v_in00 = vqrdmulhq_s32(v_in00, v_mul); + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); - v_in10 = vqrdmulhq_s32(v_in10, v_mul); + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); // Compute and add on corrective offset if (do_shift_correction) { - int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); - int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); v_temp00 = vshrq_n_s32(v_temp00, 31); @@ -354,9 +422,9 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u v_in10 = vqaddq_s32(v_in10, v_temp10); } - v_in00 = vrshlq_s32(v_in00, v_shift); + v_in00 = vrshlq_s32(v_in00, v_shf0); - v_in10 = vrshlq_s32(v_in10, v_shift); + v_in10 = vrshlq_s32(v_in10, v_shf0); v_in00 = vaddq_s32(v_in00, v_c_offset); @@ -391,23 +459,33 @@ void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, u } // anonymous namespace template -void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias) { - if (qp.minval >= qp.c_offset) { - requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias); + if (qp.per_channel_requant) { + if (qp.minval >= qp.c_offset) { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias); + } else { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias); + } } else { - requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias); + if (qp.minval >= qp.c_offset) { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias); + } else { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias); + } } } -template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias); -template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias); @@ -448,7 +526,7 @@ template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int wid */ namespace { struct row_sum_helpers { - const ARequantizeLayer32 &qp; + const Requantize32 &qp; /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */ template @@ -571,7 +649,7 @@ namespace { } } - row_sum_helpers(const ARequantizeLayer32 &qp) : qp(qp) { } + row_sum_helpers(const Requantize32 &qp) : qp(qp) { } }; template<> @@ -612,8 +690,14 @@ namespace { } template -void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, +void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *row_bias) { + /* If the 'b' offset is zero, just skip this entirely. */ + if (qp.b_offset == 0) { + memset(row_bias, 0, height * sizeof(int32_t)); + return; + } + row_sum_helpers thehelpers(qp); const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset); @@ -663,8 +747,8 @@ void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned } /* Instantiate the two versions for uint8_t and int8_t. */ -template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *); -template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *); +template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *); +template void compute_row_sums(const Requantize32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *); template inline void add_block(const T *input, unsigned int in_stride, int32_t *output); @@ -739,41 +823,44 @@ inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *outp * in cases where we are not computing the first columns of the output (i.e. * in multithreaded cases where we divide columns across threads) */ template -void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col) { - memset(reinterpret_cast(col_bias), 0, width * sizeof(int32_t)); - - for (unsigned int row=0; row(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 2: - add_block<2>(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 3: - add_block<3>(input + row * in_stride + col, in_stride, col_bias + col); - break; - - case 4: - add_block<4>(input + row * in_stride + col, in_stride, col_bias + col); - break; - } - } else { - for (; col(col_bias), 0, width * sizeof(int32_t)); + + for (unsigned int row=0; row(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 2: + add_block<2>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 3: + add_block<3>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 4: + add_block<4>(input + row * in_stride + col, in_stride, col_bias + col); + break; + } + } else { + for (; col