aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp127
1 files changed, 79 insertions, 48 deletions
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
index 9bd1eae663..9a099bd1b6 100644
--- a/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021,2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -295,6 +295,7 @@ void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src,
}
// Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
+ // 4 x u/int32x4_t = 16 column accumulators
typename wrapper::traits::neon_bitvector<TAcc, wrapper::traits::BitWidth::W128>::type sum_col[4] = {
wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
@@ -308,61 +309,91 @@ void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src,
asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b + in_b_stride));
#endif /* __arm__ */
- int i = 0;
- // This for loop performs 4 accumulations
- for (; i <= (_k - 4); i += 4)
+ // If we have less than 16 columns left, we can't use the main unrolled loop
+ if ((width_matrix_b - id.x()) >= 16)
{
- const auto b0_u8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
- const auto b1_u8 = wrapper::vloadq(matrix_b + 1 * in_b_stride);
- const auto b2_u8 = wrapper::vloadq(matrix_b + 2 * in_b_stride);
- const auto b3_u8 = wrapper::vloadq(matrix_b + 3 * in_b_stride);
+ // Row index
+ int i = 0;
+ // 4 x u/int32x4_t = 16 columns unrolled across 4 rows
+ for (; i <= (_k - 4); i += 4)
+ {
+ // Load 4 rows of 16 columns of 8bit elements
+ // (| | )
+ // (| | )
+ // (| | )
+ // (| | )
+ const auto b0_u8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
+ const auto b1_u8 = wrapper::vloadq(matrix_b + 1 * in_b_stride);
+ const auto b2_u8 = wrapper::vloadq(matrix_b + 2 * in_b_stride);
+ const auto b3_u8 = wrapper::vloadq(matrix_b + 3 * in_b_stride);
#if __arm__
- asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride));
- asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride));
- asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride));
- asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride));
#endif /* __arm__ */
- // Partial accumulation in 16bit
- typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type tmp_sum[2] = {
- wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{}),
- wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{})};
-
- tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b1_u8));
- tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b0_u8));
- tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b2_u8));
- tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b3_u8));
- tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b0_u8));
- tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b1_u8));
- tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b2_u8));
- tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b3_u8));
-
- // Accumulate to 32bit
- sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(tmp_sum[0]));
- sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(tmp_sum[0]));
- sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(tmp_sum[1]));
- sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(tmp_sum[1]));
-
- matrix_b += 4 * in_b_stride;
- }
+ // Partial accumulation to 16bit (4 rows => 2 rows)
+ // (| | | )
+ // (| | | )
+ typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type tmp_sum[2] =
+ {wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{}),
+ wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{})};
+
+ tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b1_u8));
+ tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b0_u8));
+ tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b2_u8));
+ tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b3_u8));
+ tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b0_u8));
+ tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b1_u8));
+ tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b2_u8));
+ tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b3_u8));
+
+ // Accumulate to 32bit (2 rows => 1 row)
+ // (| | | | | )
+ sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(tmp_sum[0]));
+ sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(tmp_sum[0]));
+ sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(tmp_sum[1]));
+ sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(tmp_sum[1]));
+
+ matrix_b += 4 * in_b_stride;
+ }
- // This for loop perfoms the leftover accumulations
- for (; i < _k; ++i)
- {
- const auto b0_b8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
+ // This for loop accumulates the rows left over from the 4x unrolling above
+ for (; i < _k; ++i)
+ {
+ const auto b0_b8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
- // Convert S8 to S16
- const typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type b0_b16[2]{
- wrapper::vmovl(wrapper::vgetlow(b0_b8)), wrapper::vmovl(wrapper::vgethigh(b0_b8))};
+ // Convert 8bit => 16bit
+ const typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type
+ b0_b16[2]{wrapper::vmovl(wrapper::vgetlow(b0_b8)), wrapper::vmovl(wrapper::vgethigh(b0_b8))};
- // Accumulate to 32bit
- sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(b0_b16[0]));
- sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(b0_b16[0]));
- sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(b0_b16[1]));
- sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1]));
+ // Accumulate to 32bit
+ sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(b0_b16[0]));
+ sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(b0_b16[0]));
+ sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(b0_b16[1]));
+ sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1]));
- matrix_b += in_b_stride;
+ matrix_b += in_b_stride;
+ }
+ }
+ else
+ {
+ // Accumulate left over columns to sum_cols
+ for (int i = 0; i < _k; ++i) // row loop
+ {
+ auto left_over_cols = width_matrix_b - id.x();
+ auto l = left_over_cols;
+ for (auto k = 0; k < 4 && l; ++k)
+ {
+ for (auto j = 0; j < 4 && l; ++j, --l)
+ {
+ sum_col[k][j] += matrix_b[left_over_cols - l];
+ }
+ }
+ matrix_b += in_b_stride;
+ }
}
// Multiply by scalar if necessary
@@ -375,7 +406,7 @@ void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src,
}
auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr());
- if (id.x() + 16 < width_matrix_b)
+ if ((width_matrix_b - id.x()) >= 16)
{
wrapper::vstore(vector_sum_col + 0, wrapper::vreinterpret(sum_col[0]));
wrapper::vstore(vector_sum_col + 4, wrapper::vreinterpret(sum_col[1]));