aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/quantized.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/quantized.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.cpp173
1 files changed, 173 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp
index cac02cf28e..111d01ed3a 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.cpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp
@@ -301,6 +301,179 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
out_ptr1 += 16;
}
+ // We are often quantizing one block of interleaved kernel output at a time - these are three registers
+ // wide. Special case that here.
+ if (regs==3) {
+ regs -= 3;
+
+ int32x4_t v_mul0;
+ int32x4_t v_mul1;
+ int32x4_t v_mul2;
+
+ int32x4_t v_shf0;
+ int32x4_t v_shf1;
+ int32x4_t v_shf2;
+
+ int32x4_t v_shf0l;
+ int32x4_t v_shf1l;
+ int32x4_t v_shf2l;
+
+ if (per_channel) {
+ v_mul0 = vld1q_s32(perch_mul_ptr);
+ v_mul1 = vld1q_s32(perch_mul_ptr + 4);
+ v_mul2 = vld1q_s32(perch_mul_ptr + 8);
+ perch_mul_ptr += 12;
+
+ v_shf0 = vld1q_s32(perch_shift_ptr);
+ v_shf1 = vld1q_s32(perch_shift_ptr + 4);
+ v_shf2 = vld1q_s32(perch_shift_ptr + 8);
+ perch_shift_ptr += 12;
+
+ if (do_left_shift) {
+ v_shf0l = vld1q_s32(perch_shiftl_ptr);
+ v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
+ v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
+ perch_shiftl_ptr += 12;
+ }
+ } else {
+ v_mul0=v_mul1=v_mul2=v_mul;
+ v_shf0=v_shf1=v_shf2=v_right_shift;
+ v_shf0l=v_shf1l=v_shf2l=v_left_shift;
+ }
+
+ // Load column pointers
+ int32x4_t v_col0 = vld1q_s32(colptr);
+ int32x4_t v_col1 = vld1q_s32(colptr + 4);
+ int32x4_t v_col2 = vld1q_s32(colptr + 8);
+ colptr += 12;
+
+ // Load input data (row 0);
+ int32x4_t v_in00 = vld1q_s32(in_ptr);
+ int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
+ int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
+ in_ptr += 12;
+
+ // Load input data (row 1);
+ int32x4_t v_in10 = vld1q_s32(in_ptr1);
+ int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
+ int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
+ in_ptr1 += 12;
+
+ // Add on row bias and column bias
+ v_in00 = vaddq_s32(v_in00, v_row_sum);
+ v_in01 = vaddq_s32(v_in01, v_row_sum);
+ v_in02 = vaddq_s32(v_in02, v_row_sum);
+
+ v_in10 = vaddq_s32(v_in10, v_row_sum1);
+ v_in11 = vaddq_s32(v_in11, v_row_sum1);
+ v_in12 = vaddq_s32(v_in12, v_row_sum1);
+
+ v_in00 = vaddq_s32(v_in00, v_col0);
+ v_in01 = vaddq_s32(v_in01, v_col1);
+ v_in02 = vaddq_s32(v_in02, v_col2);
+
+ v_in10 = vaddq_s32(v_in10, v_col0);
+ v_in11 = vaddq_s32(v_in11, v_col1);
+ v_in12 = vaddq_s32(v_in12, v_col2);
+
+ // Quantize
+
+ // If a left shift is needed it needs to happen first.
+ if (do_left_shift) {
+ v_in00 = vrshlq_s32(v_in00, v_shf0l);
+ v_in01 = vrshlq_s32(v_in01, v_shf1l);
+ v_in02 = vrshlq_s32(v_in02, v_shf2l);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0l);
+ v_in11 = vrshlq_s32(v_in11, v_shf1l);
+ v_in12 = vrshlq_s32(v_in12, v_shf2l);
+ }
+
+ // Multiply
+ v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
+ v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
+ v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
+
+ v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
+ v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
+ v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
+
+ // Compute and add on corrective offset
+ if (do_shift_correction) {
+ int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
+ int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
+ int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
+
+ int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
+ int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
+ int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
+
+ v_temp00 = vshrq_n_s32(v_temp00, 31);
+ v_temp01 = vshrq_n_s32(v_temp01, 31);
+ v_temp02 = vshrq_n_s32(v_temp02, 31);
+
+ v_temp10 = vshrq_n_s32(v_temp10, 31);
+ v_temp11 = vshrq_n_s32(v_temp11, 31);
+ v_temp12 = vshrq_n_s32(v_temp12, 31);
+
+ v_in00 = vqaddq_s32(v_in00, v_temp00);
+ v_in01 = vqaddq_s32(v_in01, v_temp01);
+ v_in02 = vqaddq_s32(v_in02, v_temp02);
+
+ v_in10 = vqaddq_s32(v_in10, v_temp10);
+ v_in11 = vqaddq_s32(v_in11, v_temp11);
+ v_in12 = vqaddq_s32(v_in12, v_temp12);
+ }
+
+ v_in00 = vrshlq_s32(v_in00, v_shf0);
+ v_in01 = vrshlq_s32(v_in01, v_shf1);
+ v_in02 = vrshlq_s32(v_in02, v_shf2);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0);
+ v_in11 = vrshlq_s32(v_in11, v_shf1);
+ v_in12 = vrshlq_s32(v_in12, v_shf2);
+
+ v_in00 = vaddq_s32(v_in00, v_c_offset);
+ v_in01 = vaddq_s32(v_in01, v_c_offset);
+ v_in02 = vaddq_s32(v_in02, v_c_offset);
+
+ v_in10 = vaddq_s32(v_in10, v_c_offset);
+ v_in11 = vaddq_s32(v_in11, v_c_offset);
+ v_in12 = vaddq_s32(v_in12, v_c_offset);
+
+ v_in00 = vmaxq_s32(v_in00, v_minval);
+ v_in01 = vmaxq_s32(v_in01, v_minval);
+ v_in02 = vmaxq_s32(v_in02, v_minval);
+
+ v_in10 = vmaxq_s32(v_in10, v_minval);
+ v_in11 = vmaxq_s32(v_in11, v_minval);
+ v_in12 = vmaxq_s32(v_in12, v_minval);
+
+ v_in00 = vminq_s32(v_in00, v_maxval);
+ v_in01 = vminq_s32(v_in01, v_maxval);
+ v_in02 = vminq_s32(v_in02, v_maxval);
+
+ v_in10 = vminq_s32(v_in10, v_maxval);
+ v_in11 = vminq_s32(v_in11, v_maxval);
+ v_in12 = vminq_s32(v_in12, v_maxval);
+
+ int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
+ int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in02));
+
+ int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
+ int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in12));
+
+ int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
+ int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
+
+ vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr), vreinterpretq_s64_s8(v_uz0), 0);
+ vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr + 8), vreinterpretq_s32_s8(v_uz0), 2);
+ out_ptr += 12;
+ vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr1), vreinterpretq_s64_s8(v_uz1), 0);
+ vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1 + 8), vreinterpretq_s32_s8(v_uz1), 2);
+ out_ptr1 += 12;
+ }
+
while (regs--) {
int32x4_t v_mul0;
int32x4_t v_shf0;