From af56d520b486079e5f773a530582a0a710f7f376 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 1 Jul 2020 12:35:30 +0100 Subject: COMPMID-3324: Fix per-channel quantization on N blocking Direct the column to start from in the quantized code Signed-off-by: Georgios Pinitas Change-Id: I8231e0b541c6b1b76becf349a1d6ddf973ade9e2 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3488 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- .../NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp | 2 +- src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp | 2 +- src/core/NEON/kernels/arm_gemm/quantized.cpp | 20 ++++++++++---------- src/core/NEON/kernels/arm_gemm/quantized.hpp | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp index d9b1a71ea8..2b936d0b8f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_quantized.hpp @@ -228,7 +228,7 @@ public: requantize_block_32(_qp, (nmax - n0), (m_end - m_start), result_buffer, (nmax - n0), this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc, - local_row_sums, col_bias + (multi * _Nsize) + n0); + local_row_sums, col_bias + (multi * _Nsize) + n0, n0); } } while (p.next_dim0()); } diff --git a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp index 18f030fec0..995716575a 100644 --- a/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantize_wrapper.hpp @@ -115,7 +115,7 @@ private: _args._Nsize, this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (first_row * this->_ldc), this->_ldc, _row_sums + (multi * _args._nbatches * _args._Msize) + (batch * _args._Msize) + first_row, - _col_sums + (multi * _args._Nsize)); + _col_sums + (multi * _args._Nsize), 0); } } } diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index 00b42cf422..53e5527a8d 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -57,7 +57,7 @@ namespace { template void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias) { + const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) { const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul); const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift); const int32x4_t v_minval = vdupq_n_s32(qp.minval); @@ -76,8 +76,8 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne unsigned int odds=(width % 4); const int32_t *colptr = col_bias; - const int32_t *perch_mul_ptr = qp.per_channel_muls; - const int32_t *perch_shift_ptr = qp.per_channel_shifts; + const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col; + const int32_t *perch_shift_ptr = qp.per_channel_shifts + start_col; const int32_t *in_ptr = input + (row * in_stride); int8_t *out_ptr = output + (row * out_stride); @@ -461,33 +461,33 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias) { + const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) { if (qp.per_channel_requant) { if (qp.minval >= qp.c_offset) { requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias); + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); } else { requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias); + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); } } else { if (qp.minval >= qp.c_offset) { requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias); + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); } else { requantize_block_32_int(qp, width, height, reinterpret_cast(input), in_stride, - reinterpret_cast(output), out_stride, row_bias, col_bias); + reinterpret_cast(output), out_stride, row_bias, col_bias, start_col); } } } template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias); + const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col); template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias); + const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col); /* * Routine (and helpers) to compute row sums needed for offset correction. diff --git a/src/core/NEON/kernels/arm_gemm/quantized.hpp b/src/core/NEON/kernels/arm_gemm/quantized.hpp index a91a888ad9..b0e0c3b580 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.hpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.hpp @@ -28,7 +28,7 @@ namespace arm_gemm { template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias); + const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col); template void compute_row_sums(const Requantize32 &qp, unsigned int width, unsigned int height, -- cgit v1.2.1