diff options
-rw-r--r-- | src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp | 127 |
1 files changed, 79 insertions, 48 deletions
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp index 9bd1eae663..9a099bd1b6 100644 --- a/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp +++ b/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021,2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -295,6 +295,7 @@ void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src, } // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation + // 4 x u/int32x4_t = 16 column accumulators typename wrapper::traits::neon_bitvector<TAcc, wrapper::traits::BitWidth::W128>::type sum_col[4] = { wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}), wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}), @@ -308,61 +309,91 @@ void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src, asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b + in_b_stride)); #endif /* __arm__ */ - int i = 0; - // This for loop performs 4 accumulations - for (; i <= (_k - 4); i += 4) + // If we have less than 16 columns left, we can't use the main unrolled loop + if ((width_matrix_b - id.x()) >= 16) { - const auto b0_u8 = wrapper::vloadq(matrix_b + 0 * in_b_stride); - const auto b1_u8 = wrapper::vloadq(matrix_b + 1 * in_b_stride); - const auto b2_u8 = wrapper::vloadq(matrix_b + 2 * in_b_stride); - const auto b3_u8 = wrapper::vloadq(matrix_b + 3 * in_b_stride); + // Row index + int i = 0; + // 4 x u/int32x4_t = 16 columns unrolled across 4 rows + for (; i <= (_k - 4); i += 4) + { + // Load 4 rows of 16 columns of 8bit elements + // (| | ) + // (| | ) + // (| | ) + // (| | ) + const auto b0_u8 = wrapper::vloadq(matrix_b + 0 * in_b_stride); + const auto b1_u8 = wrapper::vloadq(matrix_b + 1 * in_b_stride); + const auto b2_u8 = wrapper::vloadq(matrix_b + 2 * in_b_stride); + const auto b3_u8 = wrapper::vloadq(matrix_b + 3 * in_b_stride); #if __arm__ - asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride)); - asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride)); - asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride)); - asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride)); + asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride)); + asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride)); + asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride)); + asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride)); #endif /* __arm__ */ - // Partial accumulation in 16bit - typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type tmp_sum[2] = { - wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{}), - wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{})}; - - tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b1_u8)); - tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b0_u8)); - tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b2_u8)); - tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b3_u8)); - tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b0_u8)); - tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b1_u8)); - tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b2_u8)); - tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b3_u8)); - - // Accumulate to 32bit - sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(tmp_sum[0])); - sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(tmp_sum[0])); - sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(tmp_sum[1])); - sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(tmp_sum[1])); - - matrix_b += 4 * in_b_stride; - } + // Partial accumulation to 16bit (4 rows => 2 rows) + // (| | | ) + // (| | | ) + typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type tmp_sum[2] = + {wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{}), + wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{})}; + + tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b1_u8)); + tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b0_u8)); + tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b2_u8)); + tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b3_u8)); + tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b0_u8)); + tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b1_u8)); + tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b2_u8)); + tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b3_u8)); + + // Accumulate to 32bit (2 rows => 1 row) + // (| | | | | ) + sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(tmp_sum[0])); + sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(tmp_sum[0])); + sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(tmp_sum[1])); + sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(tmp_sum[1])); + + matrix_b += 4 * in_b_stride; + } - // This for loop perfoms the leftover accumulations - for (; i < _k; ++i) - { - const auto b0_b8 = wrapper::vloadq(matrix_b + 0 * in_b_stride); + // This for loop accumulates the rows left over from the 4x unrolling above + for (; i < _k; ++i) + { + const auto b0_b8 = wrapper::vloadq(matrix_b + 0 * in_b_stride); - // Convert S8 to S16 - const typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type b0_b16[2]{ - wrapper::vmovl(wrapper::vgetlow(b0_b8)), wrapper::vmovl(wrapper::vgethigh(b0_b8))}; + // Convert 8bit => 16bit + const typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type + b0_b16[2]{wrapper::vmovl(wrapper::vgetlow(b0_b8)), wrapper::vmovl(wrapper::vgethigh(b0_b8))}; - // Accumulate to 32bit - sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(b0_b16[0])); - sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(b0_b16[0])); - sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(b0_b16[1])); - sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1])); + // Accumulate to 32bit + sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(b0_b16[0])); + sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(b0_b16[0])); + sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(b0_b16[1])); + sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1])); - matrix_b += in_b_stride; + matrix_b += in_b_stride; + } + } + else + { + // Accumulate left over columns to sum_cols + for (int i = 0; i < _k; ++i) // row loop + { + auto left_over_cols = width_matrix_b - id.x(); + auto l = left_over_cols; + for (auto k = 0; k < 4 && l; ++k) + { + for (auto j = 0; j < 4 && l; ++j, --l) + { + sum_col[k][j] += matrix_b[left_over_cols - l]; + } + } + matrix_b += in_b_stride; + } } // Multiply by scalar if necessary @@ -375,7 +406,7 @@ void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src, } auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr()); - if (id.x() + 16 < width_matrix_b) + if ((width_matrix_b - id.x()) >= 16) { wrapper::vstore(vector_sum_col + 0, wrapper::vreinterpret(sum_col[0])); wrapper::vstore(vector_sum_col + 4, wrapper::vreinterpret(sum_col[1])); |