diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/quantized.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/quantized.cpp | 378 |
1 files changed, 348 insertions, 30 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp index 00b42cf422..6da9f4be0e 100644 --- a/src/core/NEON/kernels/arm_gemm/quantized.cpp +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2019, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,6 +24,7 @@ #ifdef __aarch64__ #include "arm_gemm.hpp" +#include "utils.hpp" #include <arm_neon.h> @@ -54,15 +55,16 @@ namespace { * column is set up in any case (and it is hoped that the compiler can elide * the needless movs in the per-layer case). */ -template<bool do_shift_correction, bool per_channel> +template<bool do_shift_correction, bool per_channel, bool do_left_shift> void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias) { - const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul); - const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift); - const int32x4_t v_minval = vdupq_n_s32(qp.minval); - const int32x4_t v_maxval = vdupq_n_s32(qp.maxval); - const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset); + const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) { + const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul); + const int32x4_t v_right_shift = vdupq_n_s32(qp.per_layer_right_shift); + const int32x4_t v_left_shift = vdupq_n_s32(qp.per_layer_left_shift); + const int32x4_t v_minval = vdupq_n_s32(qp.minval); + const int32x4_t v_maxval = vdupq_n_s32(qp.maxval); + const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset); /* To make sure we have plenty of accumulators, compute two rows at a * time. If the number of rows is odd, compute the bottom row twice to @@ -76,8 +78,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne unsigned int odds=(width % 4); const int32_t *colptr = col_bias; - const int32_t *perch_mul_ptr = qp.per_channel_muls; - const int32_t *perch_shift_ptr = qp.per_channel_shifts; + const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col; + const int32_t *perch_shift_ptr = qp.per_channel_right_shifts + start_col; + const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col; const int32_t *in_ptr = input + (row * in_stride); int8_t *out_ptr = output + (row * out_stride); @@ -111,6 +114,11 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne int32x4_t v_shf2; int32x4_t v_shf3; + int32x4_t v_shf0l; + int32x4_t v_shf1l; + int32x4_t v_shf2l; + int32x4_t v_shf3l; + if (per_channel) { v_mul0 = vld1q_s32(perch_mul_ptr); v_mul1 = vld1q_s32(perch_mul_ptr + 4); @@ -123,9 +131,18 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_shf2 = vld1q_s32(perch_shift_ptr + 8); v_shf3 = vld1q_s32(perch_shift_ptr + 12); perch_shift_ptr += 16; + + if (do_left_shift) { + v_shf0l = vld1q_s32(perch_shiftl_ptr); + v_shf1l = vld1q_s32(perch_shiftl_ptr + 4); + v_shf2l = vld1q_s32(perch_shiftl_ptr + 8); + v_shf3l = vld1q_s32(perch_shiftl_ptr + 12); + perch_shiftl_ptr += 16; + } } else { v_mul0=v_mul1=v_mul2=v_mul3=v_mul; - v_shf0=v_shf1=v_shf2=v_shf3=v_shift; + v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift; + v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift; } // Load column pointers @@ -170,7 +187,22 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_in12 = vaddq_s32(v_in12, v_col2); v_in13 = vaddq_s32(v_in13, v_col3); - // Quantize - start with multiply + // Quantize + + // If a left shift is needed it needs to happen first. + if (do_left_shift) { + v_in00 = vrshlq_s32(v_in00, v_shf0l); + v_in01 = vrshlq_s32(v_in01, v_shf1l); + v_in02 = vrshlq_s32(v_in02, v_shf2l); + v_in03 = vrshlq_s32(v_in03, v_shf3l); + + v_in10 = vrshlq_s32(v_in10, v_shf0l); + v_in11 = vrshlq_s32(v_in11, v_shf1l); + v_in12 = vrshlq_s32(v_in12, v_shf2l); + v_in13 = vrshlq_s32(v_in13, v_shf3l); + } + + // Multiply v_in00 = vqrdmulhq_s32(v_in00, v_mul0); v_in01 = vqrdmulhq_s32(v_in01, v_mul1); v_in02 = vqrdmulhq_s32(v_in02, v_mul2); @@ -269,9 +301,183 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne out_ptr1 += 16; } + // We are often quantizing one block of interleaved kernel output at a time - these are three registers + // wide. Special case that here. + if (regs==3) { + regs -= 3; + + int32x4_t v_mul0; + int32x4_t v_mul1; + int32x4_t v_mul2; + + int32x4_t v_shf0; + int32x4_t v_shf1; + int32x4_t v_shf2; + + int32x4_t v_shf0l; + int32x4_t v_shf1l; + int32x4_t v_shf2l; + + if (per_channel) { + v_mul0 = vld1q_s32(perch_mul_ptr); + v_mul1 = vld1q_s32(perch_mul_ptr + 4); + v_mul2 = vld1q_s32(perch_mul_ptr + 8); + perch_mul_ptr += 12; + + v_shf0 = vld1q_s32(perch_shift_ptr); + v_shf1 = vld1q_s32(perch_shift_ptr + 4); + v_shf2 = vld1q_s32(perch_shift_ptr + 8); + perch_shift_ptr += 12; + + if (do_left_shift) { + v_shf0l = vld1q_s32(perch_shiftl_ptr); + v_shf1l = vld1q_s32(perch_shiftl_ptr + 4); + v_shf2l = vld1q_s32(perch_shiftl_ptr + 8); + perch_shiftl_ptr += 12; + } + } else { + v_mul0=v_mul1=v_mul2=v_mul; + v_shf0=v_shf1=v_shf2=v_right_shift; + v_shf0l=v_shf1l=v_shf2l=v_left_shift; + } + + // Load column pointers + int32x4_t v_col0 = vld1q_s32(colptr); + int32x4_t v_col1 = vld1q_s32(colptr + 4); + int32x4_t v_col2 = vld1q_s32(colptr + 8); + colptr += 12; + + // Load input data (row 0); + int32x4_t v_in00 = vld1q_s32(in_ptr); + int32x4_t v_in01 = vld1q_s32(in_ptr + 4); + int32x4_t v_in02 = vld1q_s32(in_ptr + 8); + in_ptr += 12; + + // Load input data (row 1); + int32x4_t v_in10 = vld1q_s32(in_ptr1); + int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4); + int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8); + in_ptr1 += 12; + + // Add on row bias and column bias + v_in00 = vaddq_s32(v_in00, v_row_sum); + v_in01 = vaddq_s32(v_in01, v_row_sum); + v_in02 = vaddq_s32(v_in02, v_row_sum); + + v_in10 = vaddq_s32(v_in10, v_row_sum1); + v_in11 = vaddq_s32(v_in11, v_row_sum1); + v_in12 = vaddq_s32(v_in12, v_row_sum1); + + v_in00 = vaddq_s32(v_in00, v_col0); + v_in01 = vaddq_s32(v_in01, v_col1); + v_in02 = vaddq_s32(v_in02, v_col2); + + v_in10 = vaddq_s32(v_in10, v_col0); + v_in11 = vaddq_s32(v_in11, v_col1); + v_in12 = vaddq_s32(v_in12, v_col2); + + // Quantize + + // If a left shift is needed it needs to happen first. + if (do_left_shift) { + v_in00 = vrshlq_s32(v_in00, v_shf0l); + v_in01 = vrshlq_s32(v_in01, v_shf1l); + v_in02 = vrshlq_s32(v_in02, v_shf2l); + + v_in10 = vrshlq_s32(v_in10, v_shf0l); + v_in11 = vrshlq_s32(v_in11, v_shf1l); + v_in12 = vrshlq_s32(v_in12, v_shf2l); + } + + // Multiply + v_in00 = vqrdmulhq_s32(v_in00, v_mul0); + v_in01 = vqrdmulhq_s32(v_in01, v_mul1); + v_in02 = vqrdmulhq_s32(v_in02, v_mul2); + + v_in10 = vqrdmulhq_s32(v_in10, v_mul0); + v_in11 = vqrdmulhq_s32(v_in11, v_mul1); + v_in12 = vqrdmulhq_s32(v_in12, v_mul2); + + // Compute and add on corrective offset + if (do_shift_correction) { + int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0); + int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1); + int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2); + + int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0); + int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1); + int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2); + + v_temp00 = vshrq_n_s32(v_temp00, 31); + v_temp01 = vshrq_n_s32(v_temp01, 31); + v_temp02 = vshrq_n_s32(v_temp02, 31); + + v_temp10 = vshrq_n_s32(v_temp10, 31); + v_temp11 = vshrq_n_s32(v_temp11, 31); + v_temp12 = vshrq_n_s32(v_temp12, 31); + + v_in00 = vqaddq_s32(v_in00, v_temp00); + v_in01 = vqaddq_s32(v_in01, v_temp01); + v_in02 = vqaddq_s32(v_in02, v_temp02); + + v_in10 = vqaddq_s32(v_in10, v_temp10); + v_in11 = vqaddq_s32(v_in11, v_temp11); + v_in12 = vqaddq_s32(v_in12, v_temp12); + } + + v_in00 = vrshlq_s32(v_in00, v_shf0); + v_in01 = vrshlq_s32(v_in01, v_shf1); + v_in02 = vrshlq_s32(v_in02, v_shf2); + + v_in10 = vrshlq_s32(v_in10, v_shf0); + v_in11 = vrshlq_s32(v_in11, v_shf1); + v_in12 = vrshlq_s32(v_in12, v_shf2); + + v_in00 = vaddq_s32(v_in00, v_c_offset); + v_in01 = vaddq_s32(v_in01, v_c_offset); + v_in02 = vaddq_s32(v_in02, v_c_offset); + + v_in10 = vaddq_s32(v_in10, v_c_offset); + v_in11 = vaddq_s32(v_in11, v_c_offset); + v_in12 = vaddq_s32(v_in12, v_c_offset); + + v_in00 = vmaxq_s32(v_in00, v_minval); + v_in01 = vmaxq_s32(v_in01, v_minval); + v_in02 = vmaxq_s32(v_in02, v_minval); + + v_in10 = vmaxq_s32(v_in10, v_minval); + v_in11 = vmaxq_s32(v_in11, v_minval); + v_in12 = vmaxq_s32(v_in12, v_minval); + + v_in00 = vminq_s32(v_in00, v_maxval); + v_in01 = vminq_s32(v_in01, v_maxval); + v_in02 = vminq_s32(v_in02, v_maxval); + + v_in10 = vminq_s32(v_in10, v_maxval); + v_in11 = vminq_s32(v_in11, v_maxval); + v_in12 = vminq_s32(v_in12, v_maxval); + + int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01)); + int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in02)); + + int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11)); + int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in12)); + + int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01)); + int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11)); + + vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr), vreinterpretq_s64_s8(v_uz0), 0); + vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr + 8), vreinterpretq_s32_s8(v_uz0), 2); + out_ptr += 12; + vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr1), vreinterpretq_s64_s8(v_uz1), 0); + vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1 + 8), vreinterpretq_s32_s8(v_uz1), 2); + out_ptr1 += 12; + } + while (regs--) { int32x4_t v_mul0; int32x4_t v_shf0; + int32x4_t v_shf0l; if (per_channel) { v_mul0 = vld1q_s32(perch_mul_ptr); @@ -279,11 +485,16 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_shf0 = vld1q_s32(perch_shift_ptr); perch_shift_ptr += 4; + + if (do_left_shift) { + v_shf0l = vld1q_s32(perch_shiftl_ptr); + perch_shiftl_ptr += 4; + } } else { v_mul0=v_mul; - v_shf0=v_shift; + v_shf0=v_right_shift; + v_shf0l=v_left_shift; } - // Load column pointers int32x4_t v_col0 = vld1q_s32(colptr); colptr += 4; @@ -306,7 +517,14 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_in10 = vaddq_s32(v_in10, v_col0); - // Quantize - start with multiply + // Quantize - start with (optional) left shift + if (do_left_shift) { + v_in00 = vrshlq_s32(v_in00, v_shf0l); + + v_in10 = vrshlq_s32(v_in10, v_shf0l); + } + + // Then multiply v_in00 = vqrdmulhq_s32(v_in00, v_mul0); v_in10 = vqrdmulhq_s32(v_in10, v_mul0); @@ -358,10 +576,12 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne int32x4_t v_in10 = vdupq_n_s32(0); int32x4_t v_mul0 = vdupq_n_s32(0); int32x4_t v_shf0 = vdupq_n_s32(0); + int32x4_t v_shf0l = vdupq_n_s32(0); if (!per_channel) { v_mul0 = v_mul; - v_shf0 = v_shift; + v_shf0 = v_right_shift; + v_shf0l = v_left_shift; } do { @@ -371,6 +591,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne if (per_channel) { v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0); v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0); + if (do_left_shift) { + v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0); + } } if (odds == 1) { break; } @@ -380,6 +603,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne if (per_channel) { v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1); v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1); + if (do_left_shift) { + v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1); + } } if (odds == 2) { break; } @@ -389,6 +615,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne if (per_channel) { v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2); v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2); + if (do_left_shift) { + v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2); + } } } while (0); @@ -402,7 +631,14 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne v_in10 = vaddq_s32(v_in10, v_col0); - // Quantize - start with multiply + // Quantize - start with (optional) left shift + if (do_left_shift) { + v_in00 = vrshlq_s32(v_in00, v_shf0l); + + v_in10 = vrshlq_s32(v_in10, v_shf0l); + } + + // Then multiply v_in00 = vqrdmulhq_s32(v_in00, v_mul0); v_in10 = vqrdmulhq_s32(v_in10, v_mul0); @@ -461,33 +697,53 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne template<typename Tin, typename Tout> void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias) { + const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) { if (qp.per_channel_requant) { if (qp.minval >= qp.c_offset) { - requantize_block_32_int<false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, - reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + if (qp.per_channel_left_shifts) { + requantize_block_32_int<false, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col); + } else { + requantize_block_32_int<false, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col); + } } else { - requantize_block_32_int<true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, - reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + if (qp.per_channel_left_shifts) { + requantize_block_32_int<true, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col); + } else { + requantize_block_32_int<true, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col); + } } } else { if (qp.minval >= qp.c_offset) { - requantize_block_32_int<false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, - reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + if (qp.per_layer_left_shift > 0) { + requantize_block_32_int<false, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col); + } else { + requantize_block_32_int<false, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col); + } } else { - requantize_block_32_int<true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, - reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + if (qp.per_layer_left_shift > 0) { + requantize_block_32_int<true, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col); + } else { + requantize_block_32_int<true, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col); + } } } } template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias); + const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col); template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height, const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride, - const int32_t *row_bias, const int32_t *col_bias); + const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col); /* * Routine (and helpers) to compute row sums needed for offset correction. @@ -604,7 +860,6 @@ namespace { * that the terms can simply be added in the requantize code. * */ switch (rows) { - default: case 1: /* If we only have one output, just use ADDV. Multiply * the offset into all four components separately so it @@ -646,6 +901,9 @@ namespace { vst1q_s32(row_bias, t0); break; + + default: + UNREACHABLE("Impossible."); } } @@ -836,7 +1094,6 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h if (numcols==16) { switch(numrows) { - default: case 1: add_block<1>(input + row * in_stride + col, in_stride, col_bias + col); break; @@ -852,6 +1109,9 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h case 4: add_block<4>(input + row * in_stride + col, in_stride, col_bias + col); break; + + default: + UNREACHABLE("Impossible."); } } else { for (; col<width; col++) { @@ -882,6 +1142,64 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col); +void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height, + const int32_t* in_ptr, unsigned int in_stride, float *out_ptr, unsigned int out_stride, + const float* bias_ptr, bool accumulate, const Activation &act) +{ + const float32x4_t vscale = vdupq_n_f32(qp.scale); + float maxval = std::numeric_limits<float>::infinity(); + float minval = -std::numeric_limits<float>::infinity(); + + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + maxval = static_cast<float>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + minval = 0; + break; + } + + const float32x4_t vmin = vdupq_n_f32(minval); + const float32x4_t vmax = vdupq_n_f32(maxval); + + for(unsigned int row=0; row<height; row++) { + auto row_in_ptr = in_ptr + (row * in_stride); + auto row_out_ptr = out_ptr + (row * out_stride); + unsigned int col=0; + if (width >= 4) { + for(; col <= (width - 4); col+= 4) { + const int32x4_t vin = vld1q_s32(row_in_ptr + col); + float32x4_t vdeq = vmulq_f32(vcvtq_f32_s32(vin), vscale); + if(bias_ptr) { + const float32x4_t bin = vld1q_f32(bias_ptr + col); + vdeq = vaddq_f32(vdeq, bin); + } + if(accumulate) { + vdeq = vaddq_f32(vdeq, vld1q_f32(row_out_ptr + col)); + } + vdeq = vminq_f32(vmaxq_f32(vdeq, vmin), vmax); + vst1q_f32(reinterpret_cast<float *>(row_out_ptr + col), vdeq); + } + } + // left-over elements + for(; col < width; ++col) { + const int32_t val = *(row_in_ptr + col); + float res = static_cast<float>(val * qp.scale); + if(bias_ptr) { + res += static_cast<float>(*(bias_ptr + col)); + } + if(accumulate) { + res += *(row_out_ptr + col); + } + res = std::min(std::max(res, minval), maxval); + *(row_out_ptr + col) = res; + } + } +} + } // namespace arm_gemm #endif // __aarch64__ |