diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/quantized.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/quantized.cpp | 769 |
1 files changed, 769 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp new file mode 100644 index 0000000000..dd4eb31ea3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp @@ -0,0 +1,769 @@ +/* + * Copyright (c) 2019 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "arm_gemm.hpp" + +#include <arm_neon.h> + +namespace arm_gemm { + +namespace { + +/* Requantize a block of data, using the requantize parameters in 'qp'. + * + * row_bias and col_bias are assumed to be precomputed values which include + * any externally supplied bias, plus the row/column contibution sums, plus + * the overall constant offset (A_offset * B_offset * depth). + * + * Note that this function works equally well for uint8_t output: just set + * minval/maxval appropriately and cast the output pointer. It is caller's + * responsibility to ensure that minval/maxval are representable in the + * target type - the downcast to (u)int8_t is done by simply extracting the + * LSB. + * + * The 'do_shift_correction' template parameter turns on the correction + * applied to negative values being shifted right to make sure they round + * properly - if negative values are never output (e.g. fused ReLU) this is + * unnecessary. + */ +template<bool do_shift_correction> +void requantize_block_32_int(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, + const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, + const int32_t *row_bias, const int32_t *col_bias) { + const int32x4_t v_mul = vdupq_n_s32(qp.requant_mul); + const int32x4_t v_shift = vdupq_n_s32(qp.requant_shift); + const int32x4_t v_minval = vdupq_n_s32(qp.minval); + const int32x4_t v_maxval = vdupq_n_s32(qp.maxval); + const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset); + + /* To make sure we have plenty of accumulators, compute two rows at a + * time. If the number of rows is odd, compute the bottom row twice to + * avoid needing a duplicate codepath. */ + for (unsigned int row=0; row<height; row+=2) { + /* Prefer to do 4 vectors (16 values) at once as this collapses + * neatly to a single vector of output, failing that a vector at a + * time and then the odd ones out at the end. */ + unsigned int blocks=(width / 16); + unsigned int regs=(width % 16) / 4; + unsigned int odds=(width % 4); + + const int32_t *colptr = col_bias; + + const int32_t *in_ptr = input + (row * in_stride); + int8_t *out_ptr = output + (row * out_stride); + int32_t row_sum = row_bias[row]; + + const int32_t *in_ptr1; + int8_t *out_ptr1; + int32_t row_sum1; + + if (row == height-1) { + in_ptr1 = in_ptr; + out_ptr1 = out_ptr; + row_sum1 = row_sum; + } else { + in_ptr1 = in_ptr + in_stride; + out_ptr1 = out_ptr + out_stride; + row_sum1 = row_bias[row+1]; + } + + const int32x4_t v_row_sum = vdupq_n_s32(row_sum); + const int32x4_t v_row_sum1 = vdupq_n_s32(row_sum1); + + while (blocks--) { + // Load column pointers + int32x4_t v_col0 = vld1q_s32(colptr); + int32x4_t v_col1 = vld1q_s32(colptr + 4); + int32x4_t v_col2 = vld1q_s32(colptr + 8); + int32x4_t v_col3 = vld1q_s32(colptr + 12); + colptr += 16; + + // Load input data (row 0); + int32x4_t v_in00 = vld1q_s32(in_ptr); + int32x4_t v_in01 = vld1q_s32(in_ptr + 4); + int32x4_t v_in02 = vld1q_s32(in_ptr + 8); + int32x4_t v_in03 = vld1q_s32(in_ptr + 12); + in_ptr += 16; + + // Load input data (row 1); + int32x4_t v_in10 = vld1q_s32(in_ptr1); + int32x4_t v_in11 = vld1q_s32(in_ptr1 + 4); + int32x4_t v_in12 = vld1q_s32(in_ptr1 + 8); + int32x4_t v_in13 = vld1q_s32(in_ptr1 + 12); + in_ptr1 += 16; + + // Add on row bias and column bias + v_in00 = vaddq_s32(v_in00, v_row_sum); + v_in01 = vaddq_s32(v_in01, v_row_sum); + v_in02 = vaddq_s32(v_in02, v_row_sum); + v_in03 = vaddq_s32(v_in03, v_row_sum); + + v_in10 = vaddq_s32(v_in10, v_row_sum1); + v_in11 = vaddq_s32(v_in11, v_row_sum1); + v_in12 = vaddq_s32(v_in12, v_row_sum1); + v_in13 = vaddq_s32(v_in13, v_row_sum1); + + v_in00 = vaddq_s32(v_in00, v_col0); + v_in01 = vaddq_s32(v_in01, v_col1); + v_in02 = vaddq_s32(v_in02, v_col2); + v_in03 = vaddq_s32(v_in03, v_col3); + + v_in10 = vaddq_s32(v_in10, v_col0); + v_in11 = vaddq_s32(v_in11, v_col1); + v_in12 = vaddq_s32(v_in12, v_col2); + v_in13 = vaddq_s32(v_in13, v_col3); + + // Quantize - start with multiply + v_in00 = vqrdmulhq_s32(v_in00, v_mul); + v_in01 = vqrdmulhq_s32(v_in01, v_mul); + v_in02 = vqrdmulhq_s32(v_in02, v_mul); + v_in03 = vqrdmulhq_s32(v_in03, v_mul); + + v_in10 = vqrdmulhq_s32(v_in10, v_mul); + v_in11 = vqrdmulhq_s32(v_in11, v_mul); + v_in12 = vqrdmulhq_s32(v_in12, v_mul); + v_in13 = vqrdmulhq_s32(v_in13, v_mul); + + // Compute and add on corrective offset + if (do_shift_correction) { + int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + int32x4_t v_temp01 = vandq_s32(v_in01, v_shift); + int32x4_t v_temp02 = vandq_s32(v_in02, v_shift); + int32x4_t v_temp03 = vandq_s32(v_in03, v_shift); + + int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + int32x4_t v_temp11 = vandq_s32(v_in11, v_shift); + int32x4_t v_temp12 = vandq_s32(v_in12, v_shift); + int32x4_t v_temp13 = vandq_s32(v_in13, v_shift); + + v_temp00 = vshrq_n_s32(v_temp00, 31); + v_temp01 = vshrq_n_s32(v_temp01, 31); + v_temp02 = vshrq_n_s32(v_temp02, 31); + v_temp03 = vshrq_n_s32(v_temp03, 31); + + v_temp10 = vshrq_n_s32(v_temp10, 31); + v_temp11 = vshrq_n_s32(v_temp11, 31); + v_temp12 = vshrq_n_s32(v_temp12, 31); + v_temp13 = vshrq_n_s32(v_temp13, 31); + + v_in00 = vqaddq_s32(v_in00, v_temp00); + v_in01 = vqaddq_s32(v_in01, v_temp01); + v_in02 = vqaddq_s32(v_in02, v_temp02); + v_in03 = vqaddq_s32(v_in03, v_temp03); + + v_in10 = vqaddq_s32(v_in10, v_temp10); + v_in11 = vqaddq_s32(v_in11, v_temp11); + v_in12 = vqaddq_s32(v_in12, v_temp12); + v_in13 = vqaddq_s32(v_in13, v_temp13); + } + + v_in00 = vrshlq_s32(v_in00, v_shift); + v_in01 = vrshlq_s32(v_in01, v_shift); + v_in02 = vrshlq_s32(v_in02, v_shift); + v_in03 = vrshlq_s32(v_in03, v_shift); + + v_in10 = vrshlq_s32(v_in10, v_shift); + v_in11 = vrshlq_s32(v_in11, v_shift); + v_in12 = vrshlq_s32(v_in12, v_shift); + v_in13 = vrshlq_s32(v_in13, v_shift); + + v_in00 = vaddq_s32(v_in00, v_c_offset); + v_in01 = vaddq_s32(v_in01, v_c_offset); + v_in02 = vaddq_s32(v_in02, v_c_offset); + v_in03 = vaddq_s32(v_in03, v_c_offset); + + v_in10 = vaddq_s32(v_in10, v_c_offset); + v_in11 = vaddq_s32(v_in11, v_c_offset); + v_in12 = vaddq_s32(v_in12, v_c_offset); + v_in13 = vaddq_s32(v_in13, v_c_offset); + + v_in00 = vmaxq_s32(v_in00, v_minval); + v_in01 = vmaxq_s32(v_in01, v_minval); + v_in02 = vmaxq_s32(v_in02, v_minval); + v_in03 = vmaxq_s32(v_in03, v_minval); + + v_in10 = vmaxq_s32(v_in10, v_minval); + v_in11 = vmaxq_s32(v_in11, v_minval); + v_in12 = vmaxq_s32(v_in12, v_minval); + v_in13 = vmaxq_s32(v_in13, v_minval); + + v_in00 = vminq_s32(v_in00, v_maxval); + v_in01 = vminq_s32(v_in01, v_maxval); + v_in02 = vminq_s32(v_in02, v_maxval); + v_in03 = vminq_s32(v_in03, v_maxval); + + v_in10 = vminq_s32(v_in10, v_maxval); + v_in11 = vminq_s32(v_in11, v_maxval); + v_in12 = vminq_s32(v_in12, v_maxval); + v_in13 = vminq_s32(v_in13, v_maxval); + + int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in01)); + int16x8_t v_uz01 = vuzp1q_s16(vreinterpretq_s16_s32(v_in02), vreinterpretq_s16_s32(v_in03)); + + int16x8_t v_uz10 = vuzp1q_s16(vreinterpretq_s16_s32(v_in10), vreinterpretq_s16_s32(v_in11)); + int16x8_t v_uz11 = vuzp1q_s16(vreinterpretq_s16_s32(v_in12), vreinterpretq_s16_s32(v_in13)); + + int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz01)); + int8x16_t v_uz1 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz10), vreinterpretq_s8_s16(v_uz11)); + + vst1q_s8(out_ptr, v_uz0); + out_ptr += 16; + vst1q_s8(out_ptr1, v_uz1); + out_ptr1 += 16; + } + + while (regs--) { + // Load column pointers + int32x4_t v_col0 = vld1q_s32(colptr); + colptr += 4; + + // Load input data (row 0); + int32x4_t v_in00 = vld1q_s32(in_ptr); + in_ptr += 4; + + // Load input data (row 1); + int32x4_t v_in10 = vld1q_s32(in_ptr1); + in_ptr1 += 4; + + // Add on row sum and bias constant + v_in00 = vaddq_s32(v_in00, v_row_sum); + + v_in10 = vaddq_s32(v_in10, v_row_sum1); + + // Subtract col sum * a_offset + v_in00 = vaddq_s32(v_in00, v_col0); + + v_in10 = vaddq_s32(v_in10, v_col0); + + // Quantize - start with multiply + v_in00 = vqrdmulhq_s32(v_in00, v_mul); + + v_in10 = vqrdmulhq_s32(v_in10, v_mul); + + // Compute and add on corrective offset + if (do_shift_correction) { + int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + + int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + + v_temp00 = vshrq_n_s32(v_temp00, 31); + + v_temp10 = vshrq_n_s32(v_temp10, 31); + + v_in00 = vqaddq_s32(v_in00, v_temp00); + + v_in10 = vqaddq_s32(v_in10, v_temp10); + } + + v_in00 = vrshlq_s32(v_in00, v_shift); + + v_in10 = vrshlq_s32(v_in10, v_shift); + + v_in00 = vaddq_s32(v_in00, v_c_offset); + + v_in10 = vaddq_s32(v_in10, v_c_offset); + + v_in00 = vmaxq_s32(v_in00, v_minval); + + v_in10 = vmaxq_s32(v_in10, v_minval); + + v_in00 = vminq_s32(v_in00, v_maxval); + + v_in10 = vminq_s32(v_in10, v_maxval); + + int16x8_t v_uz00 = vuzp1q_s16(vreinterpretq_s16_s32(v_in00), vreinterpretq_s16_s32(v_in10)); + + int8x16_t v_uz0 = vuzp1q_s8(vreinterpretq_s8_s16(v_uz00), vreinterpretq_s8_s16(v_uz00)); + + vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr), vreinterpretq_s32_s8(v_uz0), 0); + out_ptr += 4; + vst1q_lane_s32(reinterpret_cast<int32_t *>(out_ptr1), vreinterpretq_s32_s8(v_uz0), 1); + out_ptr1 += 4; + } + + if (odds) { + int32x4_t v_col0 = vdupq_n_s32(0); + int32x4_t v_in00 = vdupq_n_s32(0); + int32x4_t v_in10 = vdupq_n_s32(0); + + do { + v_col0 = vld1q_lane_s32(colptr, v_col0, 0); + v_in00 = vld1q_lane_s32(in_ptr, v_in00, 0); + v_in10 = vld1q_lane_s32(in_ptr1, v_in10, 0); + if (odds == 1) { break; } + + v_col0 = vld1q_lane_s32(colptr + 1, v_col0, 1); + v_in00 = vld1q_lane_s32(in_ptr + 1, v_in00, 1); + v_in10 = vld1q_lane_s32(in_ptr1 + 1, v_in10, 1); + if (odds == 2) { break; } + + v_col0 = vld1q_lane_s32(colptr + 2, v_col0, 2); + v_in00 = vld1q_lane_s32(in_ptr + 2, v_in00, 2); + v_in10 = vld1q_lane_s32(in_ptr1 + 2, v_in10, 2); + } while (0); + + // Add on row sum and bias constant + v_in00 = vaddq_s32(v_in00, v_row_sum); + + v_in10 = vaddq_s32(v_in10, v_row_sum1); + + // Subtract col sum * a_offset + v_in00 = vaddq_s32(v_in00, v_col0); + + v_in10 = vaddq_s32(v_in10, v_col0); + + // Quantize - start with multiply + v_in00 = vqrdmulhq_s32(v_in00, v_mul); + + v_in10 = vqrdmulhq_s32(v_in10, v_mul); + + // Compute and add on corrective offset + if (do_shift_correction) { + int32x4_t v_temp00 = vandq_s32(v_in00, v_shift); + + int32x4_t v_temp10 = vandq_s32(v_in10, v_shift); + + v_temp00 = vshrq_n_s32(v_temp00, 31); + + v_temp10 = vshrq_n_s32(v_temp10, 31); + + v_in00 = vqaddq_s32(v_in00, v_temp00); + + v_in10 = vqaddq_s32(v_in10, v_temp10); + } + + v_in00 = vrshlq_s32(v_in00, v_shift); + + v_in10 = vrshlq_s32(v_in10, v_shift); + + v_in00 = vaddq_s32(v_in00, v_c_offset); + + v_in10 = vaddq_s32(v_in10, v_c_offset); + + v_in00 = vmaxq_s32(v_in00, v_minval); + + v_in10 = vmaxq_s32(v_in10, v_minval); + + v_in00 = vminq_s32(v_in00, v_maxval); + + v_in10 = vminq_s32(v_in10, v_maxval); + + do { + vst1q_lane_s8(out_ptr, vreinterpretq_s8_s32(v_in00), 0); + vst1q_lane_s8(out_ptr1, vreinterpretq_s8_s32(v_in10), 0); + + if (odds==1) { break; } + + vst1q_lane_s8(out_ptr + 1, vreinterpretq_s8_s32(v_in00), 4); + vst1q_lane_s8(out_ptr1 + 1, vreinterpretq_s8_s32(v_in10), 4); + + if (odds==2) { break; } + + vst1q_lane_s8(out_ptr + 2, vreinterpretq_s8_s32(v_in00), 8); + vst1q_lane_s8(out_ptr1 + 2, vreinterpretq_s8_s32(v_in10), 8); + } while(0); + } + } +} + +} // anonymous namespace + +template<typename Tin, typename Tout> +void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, + const Tin *input, unsigned int in_stride, Tout *output, unsigned int out_stride, + const int32_t *row_bias, const int32_t *col_bias) { + if (qp.minval >= qp.c_offset) { + requantize_block_32_int<false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + } else { + requantize_block_32_int<true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride, + reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias); + } +} + +template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, + const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride, + const int32_t *row_bias, const int32_t *col_bias); + +template void requantize_block_32(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, + const uint32_t *input, unsigned int in_stride, uint8_t *output, unsigned int out_stride, + const int32_t *row_bias, const int32_t *col_bias); + +/* + * Routine (and helpers) to compute row sums needed for offset correction. + * + * This is often needed for a lot of short rows (e.g. Syrax 5 - 6400 rows + * of length 27), therefore it's important not to sacrifice performance on + * odd length rows. + * + * To minimize performance loss in these cases, this routine will overread + * by up to 7 bytes. + * + * This is handled via "mask" and "mask mode" parameters to the inner + * routines; mask mode == 1 indicates that are between 1 and 8 bytes + * (inclusive) needed at the end; in these cases we always read 8 bytes. + * mask mode == 2 indicates that there are between 9 and 15 bytes needed at + * the end, and in this case we always read 16 bytes. In both cases the + * 'mask' vector is set up so that the read value can be masked off to clear + * the overread lanes. This is handled by 'accumulate_masked_8' and + * 'accumulate_masked_16' above. + * + * This routine is templated on the type to be accumulated, because the + * innermost instruction used needs to be of the correct signedness. + * However, beyond this point we always use signed values in both cases. + * The instructions that need to be different are therefore wrapped in + * helper functions below. + */ + +namespace { + struct row_sum_helpers { + const ARequantizeLayer32 &qp; + + /* Load a full 16 byte vector, pairwise accumulate into 'sum' with uadalp or sadalp */ + template<typename T> + inline int16x8_t accumulate_16(const T *ptr, int16x8_t sum); + + /* Load a full 16 byte vector, but mask before accumulation (see above). */ + template<typename T> + inline int16x8_t accumulate_masked_16(const T *ptr, int16x8_t sum, uint64x2_t mask); + + /* Load 8 bytes and mask before accumulation. */ + template<typename T> + inline int16x8_t accumulate_masked_8(const T *ptr, int16x8_t sum, uint64x2_t mask); + + /* This function does the actual work for up to 4 rows at a time. + * It's pulled out so we can template on the row count to generate + * the 4 different cases. 4 rows are computed at a time as this + * reduces to a single vector write. */ + template<unsigned int rows, typename T> + void compute_some_rows(unsigned int blocks, const T *input, unsigned int in_stride, int32_t *row_bias, unsigned int mask_mode, uint64x2_t mask, int32x4_t offset_mul) { + int16x8_t sums[rows]; + int32x4_t finalsums[rows]; + + for (unsigned int i=0; i<rows; i++) { + sums[i] = vdupq_n_s16(0); + } + + for (unsigned int i=0; i<blocks; i++) { + for (unsigned int r=0; r<rows; r++) { + sums[r] = accumulate_16(input + (r * in_stride) + (i * 16), sums[r]); + } + } + + /* Handle the final masked read if needed. */ + if (mask_mode > 0) { + for (unsigned int r=0; r<rows; r++) { + if (mask_mode == 1) { + sums[r] = accumulate_masked_8(input + (r * in_stride) + (blocks * 16), sums[r], mask); + } else { + sums[r] = accumulate_masked_16(input + (r * in_stride) + (blocks * 16), sums[r], mask); + } + } + } + + for (unsigned int i=0; i<rows; i++) { + finalsums[i] = vpaddlq_s16(sums[i]); + } + + int32x4_t t0, t1; + int32x2_t t2; + + /* Result writeback - need to write back one value per row + * processed. Multiply all the final totals by -b_offset so + * that the terms can simply be added in the requantize code. + * */ + switch (rows) { + case 1: + /* If we only have one output, just use ADDV. Multiply + * the offset into all four components separately so it + * can stay in the SIMD register file. */ + t0 = vmulq_s32(finalsums[0], offset_mul); + *row_bias = vaddvq_s32(t0); + break; + + case 2: + /* For two outputs, two rounds of pairwise adds will + * generate the result in a 2-vector we can store in one + * go. */ + t0 = vpaddq_s32(finalsums[0], finalsums[1]); + t0 = vpaddq_s32(t0, t0); + t2 = vmul_s32(vget_low_s32(t0), vget_low_s32(offset_mul)); + vst1_s32(row_bias, t2); + break; + + case 3: + /* Three rows - need to store the low two words plus the odd value from lane 2 */ + t0 = vpaddq_s32(finalsums[0], finalsums[1]); + t1 = vpaddq_s32(finalsums[2], finalsums[2]); + + t0 = vpaddq_s32(t0, t1); + t0 = vmulq_s32(t0, offset_mul); + + vst1_s32(row_bias, vget_low_s32(t0)); + row_bias[2] = vgetq_lane_s32(t0, 2); + break; + + case 4: + /* Four rows (most common case) - reduce to a single + * vector with pairwise adds. */ + t0 = vpaddq_s32(finalsums[0], finalsums[1]); + t1 = vpaddq_s32(finalsums[2], finalsums[3]); + + t0 = vpaddq_s32(t0, t1); + t0 = vmulq_s32(t0, offset_mul); + + vst1q_s32(row_bias, t0); + break; + default: + break; + } + } + + row_sum_helpers(const ARequantizeLayer32 &qp) : qp(qp) { } + }; + + template<> + int16x8_t row_sum_helpers::accumulate_16(const uint8_t *ptr, int16x8_t sum) { + return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), vld1q_u8(ptr))); + } + + template<> + int16x8_t row_sum_helpers::accumulate_16(const int8_t *ptr, int16x8_t sum) { + return vpadalq_s8(sum, vld1q_s8(ptr)); + } + + template<> + int16x8_t row_sum_helpers::accumulate_masked_16(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) { + int8x16_t v = vandq_s8(vld1q_s8(ptr), vreinterpretq_s8_u64(mask)); + return vpadalq_s8(sum, v); + } + + template<> + int16x8_t row_sum_helpers::accumulate_masked_16(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) { + uint8x16_t v = vandq_u8(vld1q_u8(ptr), vreinterpretq_u8_u64(mask)); + return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v)); + } + + template<> + int16x8_t row_sum_helpers::accumulate_masked_8(const int8_t *ptr, int16x8_t sum, uint64x2_t mask) { + int8x16_t v = vcombine_s8(vld1_s8(ptr), vdup_n_s8(0)); + v = vreinterpretq_s8_u64(vandq_u64(mask, vreinterpretq_u64_s8(v))); + return vpadalq_s8(sum, v); + } + + template<> + int16x8_t row_sum_helpers::accumulate_masked_8(const uint8_t *ptr, int16x8_t sum, uint64x2_t mask) { + uint8x16_t v = vcombine_u8(vld1_u8(ptr), vdup_n_u8(0)); + v = vreinterpretq_u8_u64(vandq_u64(mask, vreinterpretq_u64_u8(v))); + return vreinterpretq_s16_u16(vpadalq_u8(vreinterpretq_u16_s16(sum), v)); + } +} + +template<typename T> +void compute_row_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, + const T *input, unsigned int in_stride, int32_t *row_bias) { + row_sum_helpers thehelpers(qp); + + const int32x4_t offset_mul = vdupq_n_s32(-qp.b_offset); + + /* Work out how many full vectors of 16 bytes we will read, and how many + * odd bytes at the end */ + unsigned int blocks = (width / 16); + const unsigned int odds = width % 16; + + /* Generate a mask to use on the last iteration, if necessary. */ + uint64x2_t mask; + unsigned int mask_mode = 0; + + if (odds > 0 && odds <= 8) { + /* 1-8 odds: mask in the low lane, 0 in the top */ + uint64_t maskval = (~0ULL) >> (8 * (8-odds)); + + mask = vsetq_lane_u64(maskval, vdupq_n_u64(0), 0); + + mask_mode = 1; + } else if (odds > 8) { + /* 9-15 odds: mask in the top lane, all 1s in the bottom. */ + uint64_t maskval = (~0ULL) >> (8 * (16-odds)); + + mask = vsetq_lane_u64(maskval, vdupq_n_u64(~0ULL), 1); + + mask_mode = 2; + } + + for (unsigned int row=0; row<height; row+=4) { + switch(height-row) { + default: + case 4: + thehelpers.compute_some_rows<4>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul); + break; + case 3: + thehelpers.compute_some_rows<3>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul); + break; + case 2: + thehelpers.compute_some_rows<2>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul); + break; + case 1: + thehelpers.compute_some_rows<1>(blocks, input + (row * in_stride), in_stride, row_bias + row, mask_mode, mask, offset_mul); + break; + } + } +} + +/* Instantiate the two versions for uint8_t and int8_t. */ +template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const int8_t *, unsigned int, int32_t *); +template void compute_row_sums(const ARequantizeLayer32 &, unsigned int, unsigned int, const uint8_t *, unsigned int, int32_t *); + +template<unsigned int active_rows, typename T> +inline void add_block(const T *input, unsigned int in_stride, int32_t *output); + +template<unsigned int active_rows> +inline void add_block(const uint8_t *input, unsigned int in_stride, int32_t *output) { + uint8x16_t inputs[4]; + + for (unsigned int i=0; i<4; i++) { + if (i < active_rows) { + inputs[i] = vld1q_u8(input + i * in_stride); + } else { + inputs[i] = vdupq_n_u8(0); + } + } + + int16x8_t sums_16b[4]; + + // Two adds for the low pairs + sums_16b[0]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[0]), vget_low_u8(inputs[1]))); + sums_16b[1]=vreinterpretq_s16_u16(vaddl_u8(vget_low_u8(inputs[2]), vget_low_u8(inputs[3]))); + // Two adds for the high pairs + sums_16b[2]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[0], inputs[1])); + sums_16b[3]=vreinterpretq_s16_u16(vaddl_high_u8(inputs[2], inputs[3])); + + int32x4_t sums_32b[4]; + + sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1])); + sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]); + sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3])); + sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]); + + for (unsigned int i=0; i<4; i++) { + vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i))); + } +} + +template<unsigned int active_rows> +inline void add_block(const int8_t *input, unsigned int in_stride, int32_t *output) { + int8x16_t inputs[4]; + + for (unsigned int i=0; i<4; i++) { + if (i < active_rows) { + inputs[i] = vld1q_s8(input + i * in_stride); + } else { + inputs[i] = vdupq_n_s8(0); + } + } + + int16x8_t sums_16b[4]; + + // Two adds for the low pairs + sums_16b[0]=vaddl_s8(vget_low_s8(inputs[0]), vget_low_s8(inputs[1])); + sums_16b[1]=vaddl_s8(vget_low_s8(inputs[2]), vget_low_s8(inputs[3])); + // Two adds for the high pairs + sums_16b[2]=vaddl_high_s8(inputs[0], inputs[1]); + sums_16b[3]=vaddl_high_s8(inputs[2], inputs[3]); + + int32x4_t sums_32b[4]; + + sums_32b[0]=vaddl_s16(vget_low_s16(sums_16b[0]), vget_low_s16(sums_16b[1])); + sums_32b[1]=vaddl_high_s16(sums_16b[0], sums_16b[1]); + sums_32b[2]=vaddl_s16(vget_low_s16(sums_16b[2]), vget_low_s16(sums_16b[3])); + sums_32b[3]=vaddl_high_s16(sums_16b[2], sums_16b[3]); + + for (unsigned int i=0; i<4; i++) { + vst1q_s32(output + 4*i, vaddq_s32(sums_32b[i], vld1q_s32(output + 4*i))); + } +} + + +/* "first_col" parameter is used to offset the read into the qp.bias array, + * in cases where we are not computing the first columns of the output (i.e. + * in multithreaded cases where we divide columns across threads) */ +template<typename T> +void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const T *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col) { + memset(reinterpret_cast<void *>(col_bias), 0, width * sizeof(int32_t)); + + for (unsigned int row=0; row<height; row+=4) { + unsigned int numrows=std::min(height-row, 4u); + + for (unsigned int col=0; col<width; col+=16) { + unsigned int numcols=std::min(width-col, 16u); + + if (numcols==16) { + switch(numrows) { + case 1: + add_block<1>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 2: + add_block<2>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 3: + add_block<3>(input + row * in_stride + col, in_stride, col_bias + col); + break; + + case 4: + add_block<4>(input + row * in_stride + col, in_stride, col_bias + col); + break; + default: + break; + } + } else { + for (; col<width; col++) { + int32_t sum=0; + for (unsigned int r=0; r<numrows; r++) { + sum += input[(row + r)*in_stride + col]; + } + col_bias[col] += sum; + } + } + } + } + + for (unsigned int col=0; col<width; col++) { + int32_t result = col_bias[col]; + + result = (qp.a_offset * qp.b_offset * depth) - (result * qp.a_offset); + + if (qp.bias != nullptr) { + result += qp.bias[col + first_col]; + } + + col_bias[col] = result; + } +} + +template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const int8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col); +template void compute_col_sums(const ARequantizeLayer32 &qp, unsigned int width, unsigned int height, const uint8_t *input, unsigned int in_stride, int32_t *col_bias, unsigned int depth, unsigned int first_col); + +} // namespace arm_gemm |