From 0bc80daf319ea3219ca6a6fa200118dc859ee460 Mon Sep 17 00:00:00 2001 From: morgolock Date: Mon, 10 Aug 2020 16:44:18 +0100 Subject: MLCE-229: Support for negative shifts in asm kernels Change-Id: I2c5e98aae7698963f106d7423df0e65cd00ee2a9 Signed-off-by: morgolock Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3710 Tested-by: Arm Jenkins Reviewed-by: Sheri Zhang Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/quantized.cpp | 126 +++++++++++++++++---- src/core/NEON/kernels/assembly/arm_gemm.hpp | 41 +++---- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 31 +++-- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 29 +---- 4 files changed, 152 insertions(+), 75 deletions(-) (limited to 'src') diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index e50dca7f1f..201bd9dc2c 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -55,15 +55,16 @@ namespace { * column is set up in any case (and it is hoped that the compiler can elide * the needless movs in the per-layer case). */ -template +template void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) { - const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul); - const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift); - const int32x4_t v_minval = vdupq_n_s32(qp.minval); - const int32x4_t v_maxval = vdupq_n_s32(qp.maxval); - const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset); + const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul); + const int32x4_t v_right_shift = vdupq_n_s32(qp.per_layer_right_shift); + const int32x4_t v_left_shift = vdupq_n_s32(qp.per_layer_left_shift); + const int32x4_t v_minval = vdupq_n_s32(qp.minval); + const int32x4_t v_maxval = vdupq_n_s32(qp.maxval); + const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset); /* To make sure we have plenty of accumulators, compute two rows at a * time. If the number of rows is odd, compute the bottom row twice to @@ -77,8 +78,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne unsigned int odds=(width % 4); const int32_t *colptr = col_bias; - const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col; - const int32_t *perch_shift_ptr = qp.per_channel_shifts + start_col; + const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col; + const int32_t *perch_shift_ptr = qp.per_channel_right_shifts + start_col; + const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col; const int32_t *in_ptr = input + (row * in_stride); int8_t *out_ptr = output + (row * out_stride); @@ -112,6 +114,11 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne int32x4_t v_shf2; int32x4_t v_shf3; + int32x4_t v_shf0l; + int32x4_t v_shf1l; + int32x4_t v_shf2l; + int32x4_t v_shf3l; + if (per_channel) { v_mul0 = vld1q_s32(perch_mul_ptr); v_mul1 = vld1q_s32(perch_mul_ptr + 4); @@ -124,9 +131,17 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_shf2 = vld1q_s32(perch_shift_ptr + 8); v_shf3 = vld1q_s32(perch_shift_ptr + 12); perch_shift_ptr += 16; + + if (do_left_shift) { + v_shf0l = vld1q_s32(perch_shiftl_ptr); + v_shf1l = vld1q_s32(perch_shiftl_ptr + 4); + v_shf2l = vld1q_s32(perch_shiftl_ptr + 8); + v_shf3l = vld1q_s32(perch_shiftl_ptr + 12); + } } else { v_mul0=v_mul1=v_mul2=v_mul3=v_mul; - v_shf0=v_shf1=v_shf2=v_shf3=v_shift; + v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift; + v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift; } // Load column pointers @@ -171,7 +186,22 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_in12 = vaddq_s32(v_in12, v_col2); v_in13 = vaddq_s32(v_in13, v_col3); - // Quantize - start with multiply + // Quantize + + // If a left shift is needed it needs to happen first. + if (do_left_shift) { + v_in00 = vrshlq_s32(v_in00, v_shf0l); + v_in01 = vrshlq_s32(v_in01, v_shf1l); + v_in02 = vrshlq_s32(v_in02, v_shf2l); + v_in03 = vrshlq_s32(v_in03, v_shf3l); + + v_in10 = vrshlq_s32(v_in10, v_shf0l); + v_in11 = vrshlq_s32(v_in11, v_shf1l); + v_in12 = vrshlq_s32(v_in12, v_shf2l); + v_in13 = vrshlq_s32(v_in13, v_shf3l); + } + + // Multiply v_in00 = vqrdmulhq_s32(v_in00, v_mul0); v_in01 = vqrdmulhq_s32(v_in01, v_mul1); v_in02 = vqrdmulhq_s32(v_in02, v_mul2); @@ -273,6 +303,7 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne while (regs--) { int32x4_t v_mul0; int32x4_t v_shf0; + int32x4_t v_shf0l; if (per_channel) { v_mul0 = vld1q_s32(perch_mul_ptr); @@ -280,9 +311,15 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_shf0 = vld1q_s32(perch_shift_ptr); perch_shift_ptr += 4; + + if (do_left_shift) { + v_shf0l = vld1q_s32(perch_shiftl_ptr); + perch_shiftl_ptr += 4; + } } else { v_mul0=v_mul; - v_shf0=v_shift; + v_shf0=v_right_shift; + v_shf0l=v_left_shift; } // Load column pointers int32x4_t v_col0 = vld1q_s32(colptr); @@ -306,7 +343,14 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_in10 = vaddq_s32(v_in10, v_col0); - // Quantize - start with multiply + // Quantize - start with (optional) left shift + if (do_left_shift) { + v_in00 = vrshlq_s32(v_in00, v_shf0l); + + v_in10 = vrshlq_s32(v_in10, v_shf0l); + } + + // Then multiply v_in00 = vqrdmulhq_s32(v_in00, v_mul0); v_in10 = vqrdmulhq_s32(v_in10, v_mul0); @@ -358,10 +402,12 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne int32x4_t v_in10 = vdupq_n_s32(0); int32x4_t v_mul0 = vdupq_n_s32(0); int32x4_t v_shf0 = vdupq_n_s32(0); + int32x4_t v_shf0l = vdupq_n_s32(0); if (!per_channel) { v_mul0 = v_mul; - v_shf0 = v_shift; + v_shf0 = v_right_shift; + v_shf0l = v_left_shift; } do { @@ -371,6 +417,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne if (per_channel) { v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0); v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0); + if (do_left_shift) { + v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0); + } } if (odds == 1) { break; } @@ -380,6 +429,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne if (per_channel) { v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1); v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1); + if (do_left_shift) { + v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1); + } } if (odds == 2) { break; } @@ -389,6 +441,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne if (per_channel) { v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2); v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2); + if (do_left_shift) { + v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2); + } } } while (0); @@ -402,7 +457,14 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_in10 = vaddq_s32(v_in10, v_col0); - // Quantize - start with multiply + // Quantize - start with (optional) left shift + if (do_left_shift) { + v_in00 = vrshlq_s32(v_in00, v_shf0l); + + v_in10 = vrshlq_s32(v_in10, v_shf0l); + } + + // Then multiply v_in00 = vqrdmulhq_s32(v_in00, v_mul0); v_in10 = vqrdmulhq_s32(v_in10, v_mul0); @@ -464,19 +526,39 @@ void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned in const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) { if (qp.per_channel_requant) { if (qp.minval >= qp.c_offset) { - requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + if (qp.per_channel_left_shifts) { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + } else { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + } } else { - requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + if (qp.per_channel_left_shifts) { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + } else { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + } } } else { if (qp.minval >= qp.c_offset) { - requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + if (qp.per_layer_left_shift > 0) { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + } else { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + } } else { - requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + if (qp.per_layer_left_shift > 0) { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + } else { + requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); + } } } } diff --git a/src/core/NEON/kernels/assembly/arm_gemm.hpp b/src/core/NEON/kernels/assembly/arm_gemm.hpp index 58db511547..f6421c12ab 100644 --- a/src/core/NEON/kernels/assembly/arm_gemm.hpp +++ b/src/core/NEON/kernels/assembly/arm_gemm.hpp @@ -122,38 +122,41 @@ public: struct Requantize32 { public: - const int32_t *bias = nullptr; - size_t bias_multi_stride = 0; - int32_t a_offset = 0; - int32_t b_offset = 0; - int32_t c_offset = 0; - bool per_channel_requant = false; - int32_t per_layer_shift = 0; - int32_t per_layer_mul = 0; - const int32_t *per_channel_shifts = nullptr; - const int32_t *per_channel_muls = nullptr; - int32_t minval = 0; - int32_t maxval = 0; + const int32_t *bias = nullptr; + size_t bias_multi_stride = 0; + int32_t a_offset = 0; + int32_t b_offset = 0; + int32_t c_offset = 0; + bool per_channel_requant = false; + int32_t per_layer_left_shift = 0; + int32_t per_layer_right_shift = 0; + int32_t per_layer_mul = 0; + const int32_t *per_channel_left_shifts = nullptr; + const int32_t *per_channel_right_shifts = nullptr; + const int32_t *per_channel_muls = nullptr; + int32_t minval = 0; + int32_t maxval = 0; Requantize32() = default; // Constructor for per-tensor quantization Requantize32(const int32_t *bias, size_t bias_multi_stride, int32_t a_offset, int32_t b_offset, int32_t c_offset, - int32_t requant_shift, int32_t requant_mul, - int32_t minv, int32_t maxv) - : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul), - minval(minv), maxval(maxv) + int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv) + : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max(requant_shift, int32_t(0))), + per_layer_right_shift(std::min(requant_shift, int32_t(0))), per_layer_mul(requant_mul), minval(minv), maxval(maxv) { } // Constructor for per-channel quantization Requantize32(const int32_t *bias, size_t bias_multi_stride, int32_t a_offset, int32_t b_offset, int32_t c_offset, - const int32_t *requant_shifts, const int32_t *requant_muls, + const int32_t *requant_left_shifts, + const int32_t *requant_right_shifts, + const int32_t *requant_muls, int32_t minv, int32_t maxv) - : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_shifts(requant_shifts), - per_channel_muls(requant_muls), minval(minv), maxval(maxv) + : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts), + per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv) { } }; diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 3b9dde2bf7..eeea3a45ee 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -182,8 +182,8 @@ public: * * @return A tuple with the pointers to the shift and multiplier data respectively */ - std::tuple set_requantize_data(const std::vector &shifts, - const std::vector &multipliers); + std::tuple set_requantize_data(const std::vector &shifts, + const std::vector &multipliers); // Inherited methods overridden: void run() override; @@ -235,18 +235,29 @@ private: arm_gemm::KernelDescription _kernel_info{}; /** Per channel quantization shifts */ std::vector _shifts{}; + std::vector right_shifts{}; + std::vector left_shifts{}; /** Per channel quantization multipliers */ std::vector _multipliers{}; }; template -std::tuple Fallback::set_requantize_data(const std::vector &shifts, - const std::vector &multipliers) +std::tuple Fallback::set_requantize_data(const std::vector &shifts, + const std::vector &multipliers) { - _multipliers = multipliers; - _shifts = shifts; - std::transform(_shifts.begin(), _shifts.end(), _shifts.begin(), std::negate()); - return std::make_tuple(_shifts.data(), _multipliers.data()); + _multipliers = multipliers; + _shifts = shifts; + bool need_left = false; + for(const auto s : _shifts) + { + left_shifts.push_back(std::max(-s, int32_t(0))); + right_shifts.push_back(std::min(-s, int32_t(0))); + if(s > 0 && !need_left) + { + need_left = true; + } + } + return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data()); } template @@ -498,7 +509,9 @@ void create_arm_gemm_quant(std::unique_ptr &a const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers); gemm_requant_info = arm_gemm::Requantize32(nullptr, 0, a_offset, b_offset, os_info.gemmlowp_offset, - std::get<0>(requantize_data), std::get<1>(requantize_data), + (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr, + std::get<2>(requantize_data), + std::get<3>(requantize_data), os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound); } else diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index dada6d16da..83db146a8a 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -117,18 +117,8 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, { if(is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) { - // Result shifts < 0 are not supported by asm kernels - const std::vector &shifts = info.gemmlowp_output_stage().gemmlowp_shifts; - const bool is_asm_supported = info.gemmlowp_output_stage().gemmlowp_shift >= 0 - && std::all_of(shifts.cbegin(), shifts.cend(), [](int32_t val) - { - return val >= 0; - }); - if(is_asm_supported) - { - _asm_glue.configure(a_to_use, b, c, output, gemm_info); - _fused_assembly_path = _asm_glue.is_configured(); - } + _asm_glue.configure(a_to_use, b, c, output, gemm_info); + _fused_assembly_path = _asm_glue.is_configured(); } else { @@ -339,19 +329,8 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso bool run_optimised_requantized = false; if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) { - // Result shifts < 0 are not supported by asm kernels - const std::vector &shifts = info.gemmlowp_output_stage().gemmlowp_shifts; - const bool is_asm_supported = info.gemmlowp_output_stage().gemmlowp_shift >= 0 - && std::all_of(shifts.cbegin(), shifts.cend(), [](int32_t val) - { - return val >= 0; - }); - - if(is_asm_supported) - { - run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, gemm_info)); - run_optimised_requantized = run_optimised; - } + run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, gemm_info)); + run_optimised_requantized = run_optimised; } else { -- cgit v1.2.1