aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/quantized.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/quantized.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/quantized.cpp378
1 files changed, 348 insertions, 30 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp
index 00b42cf422..6da9f4be0e 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.cpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#ifdef __aarch64__
#include "arm_gemm.hpp"
+#include "utils.hpp"
#include <arm_neon.h>
@@ -54,15 +55,16 @@ namespace {
* column is set up in any case (and it is hoped that the compiler can elide
* the needless movs in the per-layer case).
*/
-template<bool do_shift_correction, bool per_channel>
+template<bool do_shift_correction, bool per_channel, bool do_left_shift>
void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias) {
- const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
- const int32x4_t v_shift = vdupq_n_s32(qp.per_layer_shift);
- const int32x4_t v_minval = vdupq_n_s32(qp.minval);
- const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
- const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
+ const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) {
+ const int32x4_t v_mul = vdupq_n_s32(qp.per_layer_mul);
+ const int32x4_t v_right_shift = vdupq_n_s32(qp.per_layer_right_shift);
+ const int32x4_t v_left_shift = vdupq_n_s32(qp.per_layer_left_shift);
+ const int32x4_t v_minval = vdupq_n_s32(qp.minval);
+ const int32x4_t v_maxval = vdupq_n_s32(qp.maxval);
+ const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
/* To make sure we have plenty of accumulators, compute two rows at a
* time. If the number of rows is odd, compute the bottom row twice to
@@ -76,8 +78,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
unsigned int odds=(width % 4);
const int32_t *colptr = col_bias;
- const int32_t *perch_mul_ptr = qp.per_channel_muls;
- const int32_t *perch_shift_ptr = qp.per_channel_shifts;
+ const int32_t *perch_mul_ptr = qp.per_channel_muls + start_col;
+ const int32_t *perch_shift_ptr = qp.per_channel_right_shifts + start_col;
+ const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col;
const int32_t *in_ptr = input + (row * in_stride);
int8_t *out_ptr = output + (row * out_stride);
@@ -111,6 +114,11 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
int32x4_t v_shf2;
int32x4_t v_shf3;
+ int32x4_t v_shf0l;
+ int32x4_t v_shf1l;
+ int32x4_t v_shf2l;
+ int32x4_t v_shf3l;
+
if (per_channel) {
v_mul0 = vld1q_s32(perch_mul_ptr);
v_mul1 = vld1q_s32(perch_mul_ptr + 4);
@@ -123,9 +131,18 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_shf2 = vld1q_s32(perch_shift_ptr + 8);
v_shf3 = vld1q_s32(perch_shift_ptr + 12);
perch_shift_ptr += 16;
+
+ if (do_left_shift) {
+ v_shf0l = vld1q_s32(perch_shiftl_ptr);
+ v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
+ v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
+ v_shf3l = vld1q_s32(perch_shiftl_ptr + 12);
+ perch_shiftl_ptr += 16;
+ }
} else {
v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
- v_shf0=v_shf1=v_shf2=v_shf3=v_shift;
+ v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift;
+ v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift;
}
// Load column pointers
@@ -170,7 +187,22 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_in12 = vaddq_s32(v_in12, v_col2);
v_in13 = vaddq_s32(v_in13, v_col3);
- // Quantize - start with multiply
+ // Quantize
+
+ // If a left shift is needed it needs to happen first.
+ if (do_left_shift) {
+ v_in00 = vrshlq_s32(v_in00, v_shf0l);
+ v_in01 = vrshlq_s32(v_in01, v_shf1l);
+ v_in02 = vrshlq_s32(v_in02, v_shf2l);
+ v_in03 = vrshlq_s32(v_in03, v_shf3l);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0l);
+ v_in11 = vrshlq_s32(v_in11, v_shf1l);
+ v_in12 = vrshlq_s32(v_in12, v_shf2l);
+ v_in13 = vrshlq_s32(v_in13, v_shf3l);
+ }
+
+ // Multiply
v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
@@ -269,9 +301,183 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
out_ptr1 += 16;
}
+ // We are often quantizing one block of interleaved kernel output at a time - these are three registers
+ // wide. Special case that here.
+ if (regs==3) {
+ regs -= 3;
+
+ int32x4_t v_mul0;
+ int32x4_t v_mul1;
+ int32x4_t v_mul2;
+
+ int32x4_t v_shf0;
+ int32x4_t v_shf1;
+ int32x4_t v_shf2;
+
+ int32x4_t v_shf0l;
+ int32x4_t v_shf1l;
+ int32x4_t v_shf2l;
+
+ if (per_channel) {
+ v_mul0 = vld1q_s32(perch_mul_ptr);
+ v_mul1 = vld1q_s32(perch_mul_ptr + 4);
+ v_mul2 = vld1q_s32(perch_mul_ptr + 8);
+ perch_mul_ptr += 12;
+
+ v_shf0 = vld1q_s32(perch_shift_ptr);
+ v_shf1 = vld1q_s32(perch_shift_ptr + 4);
+ v_shf2 = vld1q_s32(perch_shift_ptr + 8);
+ perch_shift_ptr += 12;
+
+ if (do_left_shift) {
+ v_shf0l = vld1q_s32(perch_shiftl_ptr);
+ v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
+ v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
+ perch_shiftl_ptr += 12;
+ }
+ } else {
+ v_mul0=v_mul1=v_mul2=v_mul;
+ v_shf0=v_shf1=v_shf2=v_right_shift;
+ v_shf0l=v_shf1l=v_shf2l=v_left_shift;
+ }
+
+ // Load column pointers
+ int32x4_t v_col0 = vld1q_s32(colptr);
+ int32x4_t v_col1 = vld1q_s32(colptr + 4);
+ int32x4_t v_col2 = vld1q_s32(colptr + 8);
+ colptr += 12;
+
+ // Load input data (row 0);
+ int32x4_t v_in00 = vld1q_s32(in_ptr);
+ int32x4_t v_in01 = vld1q_s32(in_ptr + 4);
+ int32x4_t v_in02 = vld1q_s32(in_ptr + 8);
+ in_ptr += 12;
+
+ // Load input data (row 1);
+ int32x4_t v_in10 = vld1q_s32(in_ptr1);
+ int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4);
+ int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8);
+ in_ptr1 += 12;
+
+ // Add on row bias and column bias
+ v_in00 = vaddq_s32(v_in00, v_row_sum);
+ v_in01 = vaddq_s32(v_in01, v_row_sum);
+ v_in02 = vaddq_s32(v_in02, v_row_sum);
+
+ v_in10 = vaddq_s32(v_in10, v_row_sum1);
+ v_in11 = vaddq_s32(v_in11, v_row_sum1);
+ v_in12 = vaddq_s32(v_in12, v_row_sum1);
+
+ v_in00 = vaddq_s32(v_in00, v_col0);
+ v_in01 = vaddq_s32(v_in01, v_col1);
+ v_in02 = vaddq_s32(v_in02, v_col2);
+
+ v_in10 = vaddq_s32(v_in10, v_col0);
+ v_in11 = vaddq_s32(v_in11, v_col1);
+ v_in12 = vaddq_s32(v_in12, v_col2);
+
+ // Quantize
+
+ // If a left shift is needed it needs to happen first.
+ if (do_left_shift) {
+ v_in00 = vrshlq_s32(v_in00, v_shf0l);
+ v_in01 = vrshlq_s32(v_in01, v_shf1l);
+ v_in02 = vrshlq_s32(v_in02, v_shf2l);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0l);
+ v_in11 = vrshlq_s32(v_in11, v_shf1l);
+ v_in12 = vrshlq_s32(v_in12, v_shf2l);
+ }
+
+ // Multiply
+ v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
+ v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
+ v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
+
+ v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
+ v_in11 = vqrdmulhq_s32(v_in11, v_mul1);
+ v_in12 = vqrdmulhq_s32(v_in12, v_mul2);
+
+ // Compute and add on corrective offset
+ if (do_shift_correction) {
+ int32x4_t v_temp00 = vandq_s32(v_in00, v_shf0);
+ int32x4_t v_temp01 = vandq_s32(v_in01, v_shf1);
+ int32x4_t v_temp02 = vandq_s32(v_in02, v_shf2);
+
+ int32x4_t v_temp10 = vandq_s32(v_in10, v_shf0);
+ int32x4_t v_temp11 = vandq_s32(v_in11, v_shf1);
+ int32x4_t v_temp12 = vandq_s32(v_in12, v_shf2);
+
+ v_temp00 = vshrq_n_s32(v_temp00, 31);
+ v_temp01 = vshrq_n_s32(v_temp01, 31);
+ v_temp02 = vshrq_n_s32(v_temp02, 31);
+
+ v_temp10 = vshrq_n_s32(v_temp10, 31);
+ v_temp11 = vshrq_n_s32(v_temp11, 31);
+ v_temp12 = vshrq_n_s32(v_temp12, 31);
+
+ v_in00 = vqaddq_s32(v_in00, v_temp00);
+ v_in01 = vqaddq_s32(v_in01, v_temp01);
+ v_in02 = vqaddq_s32(v_in02, v_temp02);
+
+ v_in10 = vqaddq_s32(v_in10, v_temp10);
+ v_in11 = vqaddq_s32(v_in11, v_temp11);
+ v_in12 = vqaddq_s32(v_in12, v_temp12);
+ }
+
+ v_in00 = vrshlq_s32(v_in00, v_shf0);
+ v_in01 = vrshlq_s32(v_in01, v_shf1);
+ v_in02 = vrshlq_s32(v_in02, v_shf2);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0);
+ v_in11 = vrshlq_s32(v_in11, v_shf1);
+ v_in12 = vrshlq_s32(v_in12, v_shf2);
+
+ v_in00 = vaddq_s32(v_in00, v_c_offset);
+ v_in01 = vaddq_s32(v_in01, v_c_offset);
+ v_in02 = vaddq_s32(v_in02, v_c_offset);
+
+ v_in10 = vaddq_s32(v_in10, v_c_offset);
+ v_in11 = vaddq_s32(v_in11, v_c_offset);
+ v_in12 = vaddq_s32(v_in12, v_c_offset);
+
+ v_in00 = vmaxq_s32(v_in00, v_minval);
+ v_in01 = vmaxq_s32(v_in01, v_minval);
+ v_in02 = vmaxq_s32(v_in02, v_minval);
+
+ v_in10 = vmaxq_s32(v_in10, v_minval);
+ v_in11 = vmaxq_s32(v_in11, v_minval);
+ v_in12 = vmaxq_s32(v_in12, v_minval);
+
+ v_in00 = vminq_s32(v_in00, v_maxval);
+ v_in01 = vminq_s32(v_in01, v_maxval);
+ v_in02 = vminq_s32(v_in02, v_maxval);
+
+ v_in10 = vminq_s32(v_in10, v_maxval);
+ v_in11 = vminq_s32(v_in11, v_maxval);
+ v_in12 = vminq_s32(v_in12, v_maxval);
+
+ int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01));
+ int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in02));
+
+ int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11));
+ int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in12));
+
+ int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01));
+ int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11));
+
+ vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr), vreinterpretq_s64_s8(v_uz0), 0);
+ vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr + 8), vreinterpretq_s32_s8(v_uz0), 2);
+ out_ptr += 12;
+ vst1q_lane_s64(reinterpret_cast<int64_t *>(out_ptr1), vreinterpretq_s64_s8(v_uz1), 0);
+ vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1 + 8), vreinterpretq_s32_s8(v_uz1), 2);
+ out_ptr1 += 12;
+ }
+
while (regs--) {
int32x4_t v_mul0;
int32x4_t v_shf0;
+ int32x4_t v_shf0l;
if (per_channel) {
v_mul0 = vld1q_s32(perch_mul_ptr);
@@ -279,11 +485,16 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_shf0 = vld1q_s32(perch_shift_ptr);
perch_shift_ptr += 4;
+
+ if (do_left_shift) {
+ v_shf0l = vld1q_s32(perch_shiftl_ptr);
+ perch_shiftl_ptr += 4;
+ }
} else {
v_mul0=v_mul;
- v_shf0=v_shift;
+ v_shf0=v_right_shift;
+ v_shf0l=v_left_shift;
}
-
// Load column pointers
int32x4_t v_col0 = vld1q_s32(colptr);
colptr += 4;
@@ -306,7 +517,14 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_in10 = vaddq_s32(v_in10, v_col0);
- // Quantize - start with multiply
+ // Quantize - start with (optional) left shift
+ if (do_left_shift) {
+ v_in00 = vrshlq_s32(v_in00, v_shf0l);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0l);
+ }
+
+ // Then multiply
v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
@@ -358,10 +576,12 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
int32x4_t v_in10 = vdupq_n_s32(0);
int32x4_t v_mul0 = vdupq_n_s32(0);
int32x4_t v_shf0 = vdupq_n_s32(0);
+ int32x4_t v_shf0l = vdupq_n_s32(0);
if (!per_channel) {
v_mul0 = v_mul;
- v_shf0 = v_shift;
+ v_shf0 = v_right_shift;
+ v_shf0l = v_left_shift;
}
do {
@@ -371,6 +591,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
if (per_channel) {
v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
+ if (do_left_shift) {
+ v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0);
+ }
}
if (odds == 1) { break; }
@@ -380,6 +603,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
if (per_channel) {
v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
+ if (do_left_shift) {
+ v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1);
+ }
}
if (odds == 2) { break; }
@@ -389,6 +615,9 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
if (per_channel) {
v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
+ if (do_left_shift) {
+ v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2);
+ }
}
} while (0);
@@ -402,7 +631,14 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
v_in10 = vaddq_s32(v_in10, v_col0);
- // Quantize - start with multiply
+ // Quantize - start with (optional) left shift
+ if (do_left_shift) {
+ v_in00 = vrshlq_s32(v_in00, v_shf0l);
+
+ v_in10 = vrshlq_s32(v_in10, v_shf0l);
+ }
+
+ // Then multiply
v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
@@ -461,33 +697,53 @@ void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigne
template<typename Tin, typename Tout>
void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias) {
+ const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) {
if (qp.per_channel_requant) {
if (qp.minval >= qp.c_offset) {
- requantize_block_32_int<false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ if (qp.per_channel_left_shifts) {
+ requantize_block_32_int<false, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ } else {
+ requantize_block_32_int<false, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ }
} else {
- requantize_block_32_int<true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ if (qp.per_channel_left_shifts) {
+ requantize_block_32_int<true, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ } else {
+ requantize_block_32_int<true, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ }
}
} else {
if (qp.minval >= qp.c_offset) {
- requantize_block_32_int<false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ if (qp.per_layer_left_shift > 0) {
+ requantize_block_32_int<false, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ } else {
+ requantize_block_32_int<false, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ }
} else {
- requantize_block_32_int<true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
- reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias);
+ if (qp.per_layer_left_shift > 0) {
+ requantize_block_32_int<true, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ } else {
+ requantize_block_32_int<true, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+ reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+ }
}
}
}
template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias);
+ const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
template void requantize_block_32(const Requantize32 &qp, unsigned int width, unsigned int height,
const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride,
- const int32_t *row_bias, const int32_t *col_bias);
+ const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col);
/*
* Routine (and helpers) to compute row sums needed for offset correction.
@@ -604,7 +860,6 @@ namespace {
* that the terms can simply be added in the requantize code.
* */
switch (rows) {
- default:
case 1:
/* If we only have one output, just use ADDV. Multiply
* the offset into all four components separately so it
@@ -646,6 +901,9 @@ namespace {
vst1q_s32(row_bias, t0);
break;
+
+ default:
+ UNREACHABLE("Impossible.");
}
}
@@ -836,7 +1094,6 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h
if (numcols==16) {
switch(numrows) {
- default:
case 1:
add_block<1>(input + row * in_stride + col, in_stride, col_bias + col);
break;
@@ -852,6 +1109,9 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h
case 4:
add_block<4>(input + row * in_stride + col, in_stride, col_bias + col);
break;
+
+ default:
+ UNREACHABLE("Impossible.");
}
} else {
for (; col<width; col++) {
@@ -882,6 +1142,64 @@ void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int h
template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
template void compute_col_sums(const Requantize32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int multi, unsigned int first_col);
+void dequantize_block_32(const DequantizeFloat &qp, unsigned int width, unsigned int height,
+ const int32_t* in_ptr, unsigned int in_stride, float *out_ptr, unsigned int out_stride,
+ const float* bias_ptr, bool accumulate, const Activation &act)
+{
+ const float32x4_t vscale = vdupq_n_f32(qp.scale);
+ float maxval = std::numeric_limits<float>::infinity();
+ float minval = -std::numeric_limits<float>::infinity();
+
+ switch(act.type) {
+ default:
+ case Activation::Type::None:
+ break;
+ case Activation::Type::BoundedReLU:
+ maxval = static_cast<float>(act.param1);
+ /* fall through */
+ case Activation::Type::ReLU:
+ minval = 0;
+ break;
+ }
+
+ const float32x4_t vmin = vdupq_n_f32(minval);
+ const float32x4_t vmax = vdupq_n_f32(maxval);
+
+ for(unsigned int row=0; row<height; row++) {
+ auto row_in_ptr = in_ptr + (row * in_stride);
+ auto row_out_ptr = out_ptr + (row * out_stride);
+ unsigned int col=0;
+ if (width >= 4) {
+ for(; col <= (width - 4); col+= 4) {
+ const int32x4_t vin = vld1q_s32(row_in_ptr + col);
+ float32x4_t vdeq = vmulq_f32(vcvtq_f32_s32(vin), vscale);
+ if(bias_ptr) {
+ const float32x4_t bin = vld1q_f32(bias_ptr + col);
+ vdeq = vaddq_f32(vdeq, bin);
+ }
+ if(accumulate) {
+ vdeq = vaddq_f32(vdeq, vld1q_f32(row_out_ptr + col));
+ }
+ vdeq = vminq_f32(vmaxq_f32(vdeq, vmin), vmax);
+ vst1q_f32(reinterpret_cast<float *>(row_out_ptr + col), vdeq);
+ }
+ }
+ // left-over elements
+ for(; col < width; ++col) {
+ const int32_t val = *(row_in_ptr + col);
+ float res = static_cast<float>(val * qp.scale);
+ if(bias_ptr) {
+ res += static_cast<float>(*(bias_ptr + col));
+ }
+ if(accumulate) {
+ res += *(row_out_ptr + col);
+ }
+ res = std::min(std::max(res, minval), maxval);
+ *(row_out_ptr + col) = res;
+ }
+ }
+}
+
} // namespace arm_gemm
#endif // __aarch64__