aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp316
1 files changed, 169 insertions, 147 deletions
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
index 534076b97c..9bd1eae663 100644
--- a/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
@@ -26,9 +26,10 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/TensorInfo.h"
-#include "src/core/NEON/wrapper/wrapper.h"
+
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
+#include "src/core/NEON/wrapper/wrapper.h"
namespace arm_compute
{
@@ -38,37 +39,49 @@ namespace kernels
{
namespace
{
-Status validate_arguments_matrix_a_reduction(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info)
+Status validate_arguments_matrix_a_reduction(const ITensorInfo *src,
+ const ITensorInfo *dst,
+ const GEMMLowpReductionKernelInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
ARM_COMPUTE_ERROR_ON_MSG(info.is_reshaped == true, "Not supported");
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
+ DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL);
- if(dst->total_size() > 0)
+ if (dst->total_size() > 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->dimension(0) != src->dimension(1), "Output vector must have length equal to the number of rows of the input matrix");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ dst->dimension(0) != src->dimension(1),
+ "Output vector must have length equal to the number of rows of the input matrix");
}
return Status{};
}
-Status validate_arguments_matrix_b_reduction(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info)
+Status validate_arguments_matrix_b_reduction(const ITensorInfo *src,
+ const ITensorInfo *dst,
+ const GEMMLowpReductionKernelInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
ARM_COMPUTE_ERROR_ON_MSG(info.is_reshaped == true, "Not supported");
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
+ DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL);
- if(dst->total_size() > 0)
+ if (dst->total_size() > 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->dimension(0) != src->dimension(0), "Output vector must have length equal to the number of columns of the input matrix");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ dst->dimension(0) != src->dimension(0),
+ "Output vector must have length equal to the number of columns of the input matrix");
}
return Status{};
}
} // namespace
-void CpuGemmLowpMatrixAReductionKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info)
+void CpuGemmLowpMatrixAReductionKernel::configure(const ITensorInfo *src,
+ ITensorInfo *dst,
+ const GEMMLowpReductionKernelInfo &info)
{
// Perform validate step
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
@@ -77,7 +90,7 @@ void CpuGemmLowpMatrixAReductionKernel::configure(const ITensorInfo *src, ITenso
_scalar = info.scalar;
_mul_by_scalar = info.mul_by_scalar;
- switch(src->data_type())
+ switch (src->data_type())
{
case DataType::QASYMM8:
_func = &CpuGemmLowpMatrixAReductionKernel::run_internal<uint8_t>;
@@ -98,14 +111,18 @@ void CpuGemmLowpMatrixAReductionKernel::configure(const ITensorInfo *src, ITenso
ICpuKernel::configure(win);
}
-Status CpuGemmLowpMatrixAReductionKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info)
+Status CpuGemmLowpMatrixAReductionKernel::validate(const ITensorInfo *src,
+ const ITensorInfo *dst,
+ const GEMMLowpReductionKernelInfo &info)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_matrix_a_reduction(src, dst, info));
return Status{};
}
template <typename T>
-void CpuGemmLowpMatrixAReductionKernel::run_internal(const ITensor *src, ITensor *dst, const arm_compute::Window &window)
+void CpuGemmLowpMatrixAReductionKernel::run_internal(const ITensor *src,
+ ITensor *dst,
+ const arm_compute::Window &window)
{
// Intermediate and final accumulator types
using TIAcc = wrapper::traits::promote_t<T>;
@@ -121,55 +138,58 @@ void CpuGemmLowpMatrixAReductionKernel::run_internal(const ITensor *src, ITensor
Iterator in(src, win_input);
Iterator out(dst, collapsed_window);
- execute_window_loop(collapsed_window, [&](const Coordinates & id)
- {
- auto vsum_row = wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{});
- TAcc sum_row = 0;
+ execute_window_loop(
+ collapsed_window,
+ [&](const Coordinates &id)
+ {
+ auto vsum_row = wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{});
+ TAcc sum_row = 0;
- const T *matrix_a = reinterpret_cast<const T *>((in.ptr() + id.x() * src->info()->strides_in_bytes()[1] + id.y() * src->info()->strides_in_bytes()[2]));
+ const T *matrix_a = reinterpret_cast<const T *>(
+ (in.ptr() + id.x() * src->info()->strides_in_bytes()[1] + id.y() * src->info()->strides_in_bytes()[2]));
#if __arm__
- asm volatile("PLD [%0, #128*4]" ::"r"(matrix_a));
+ asm volatile("PLD [%0, #128*4]" ::"r"(matrix_a));
#endif /* __arm__ */
- int i = 0;
- // This for loop performs 16 accumulations
- for(; i <= (_k - 16); i += 16)
- {
- const auto a0_d8 = wrapper::vloadq(matrix_a + i);
+ int i = 0;
+ // This for loop performs 16 accumulations
+ for (; i <= (_k - 16); i += 16)
+ {
+ const auto a0_d8 = wrapper::vloadq(matrix_a + i);
- // Partial accumulations in U16
- const auto tmp_sum0 = wrapper::vaddl(wrapper::vgetlow(a0_d8), wrapper::vgethigh(a0_d8));
+ // Partial accumulations in U16
+ const auto tmp_sum0 = wrapper::vaddl(wrapper::vgetlow(a0_d8), wrapper::vgethigh(a0_d8));
- // Accumulate to U32
- vsum_row = wrapper::vadd(vsum_row, wrapper::vpaddl(tmp_sum0));
- }
+ // Accumulate to U32
+ vsum_row = wrapper::vadd(vsum_row, wrapper::vpaddl(tmp_sum0));
+ }
- // This for loop performs the leftover accumulations
- for(; i < _k; ++i)
- {
- sum_row += static_cast<TAcc>(matrix_a[i]);
- }
+ // This for loop performs the leftover accumulations
+ for (; i < _k; ++i)
+ {
+ sum_row += static_cast<TAcc>(matrix_a[i]);
+ }
#if defined(__aarch64__)
- // Reduction operation available on 64 bit architectures only
- sum_row += wrapper::vaddv(vsum_row);
+ // Reduction operation available on 64 bit architectures only
+ sum_row += wrapper::vaddv(vsum_row);
#else // __aarch64__
- auto tmp = wrapper::vpadd(wrapper::vgethigh(vsum_row), wrapper::vgetlow(vsum_row));
- tmp = wrapper::vpadd(tmp, tmp);
+ auto tmp = wrapper::vpadd(wrapper::vgethigh(vsum_row), wrapper::vgetlow(vsum_row));
+ tmp = wrapper::vpadd(tmp, tmp);
- sum_row += wrapper::vgetlane(tmp, 0);
+ sum_row += wrapper::vgetlane(tmp, 0);
#endif // __aarch64__
- // Multiply by scalar if necessary
- if(_mul_by_scalar)
- {
- sum_row *= _scalar;
- }
+ // Multiply by scalar if necessary
+ if (_mul_by_scalar)
+ {
+ sum_row *= _scalar;
+ }
- *(reinterpret_cast<int *>(out.ptr())) = static_cast<int32_t>(sum_row);
- },
- in, out);
+ *(reinterpret_cast<int *>(out.ptr())) = static_cast<int32_t>(sum_row);
+ },
+ in, out);
}
void CpuGemmLowpMatrixAReductionKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
@@ -189,7 +209,9 @@ const char *CpuGemmLowpMatrixAReductionKernel::name() const
return "CpuGemmLowpMatrixAReductionKernel";
}
-void CpuGemmLowpMatrixBReductionKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info)
+void CpuGemmLowpMatrixBReductionKernel::configure(const ITensorInfo *src,
+ ITensorInfo *dst,
+ const GEMMLowpReductionKernelInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_matrix_b_reduction(src, dst, info));
@@ -201,7 +223,7 @@ void CpuGemmLowpMatrixBReductionKernel::configure(const ITensorInfo *src, ITenso
// Configure kernel window
constexpr unsigned int num_elems_processed_per_iteration = 16;
- switch(src->data_type())
+ switch (src->data_type())
{
case DataType::QASYMM8:
_func = &CpuGemmLowpMatrixBReductionKernel::run_internal<uint8_t>;
@@ -223,14 +245,19 @@ void CpuGemmLowpMatrixBReductionKernel::configure(const ITensorInfo *src, ITenso
ICpuKernel::configure(win);
}
-Status CpuGemmLowpMatrixBReductionKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info)
+Status CpuGemmLowpMatrixBReductionKernel::validate(const ITensorInfo *src,
+ const ITensorInfo *dst,
+ const GEMMLowpReductionKernelInfo &info)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_matrix_b_reduction(src, dst, info));
return Status{};
}
template <typename T>
-void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src, ITensor *dst, const Window &window, const ThreadInfo &info)
+void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src,
+ ITensor *dst,
+ const Window &window,
+ const ThreadInfo &info)
{
// Intermediate and final accumulator types
using TIAcc = wrapper::traits::promote_t<T>;
@@ -258,121 +285,116 @@ void CpuGemmLowpMatrixBReductionKernel::run_internal(const ITensor *src, ITensor
Iterator inb(src, win_in);
Iterator out(dst, win_out);
- execute_window_loop(win_out, [&](const Coordinates & id)
- {
- if(id.x() > width_matrix_b)
+ execute_window_loop(
+ win_out,
+ [&](const Coordinates &id)
{
- return;
- }
+ if (id.x() > width_matrix_b)
+ {
+ return;
+ }
- // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
- typename wrapper::traits::neon_bitvector<TAcc, wrapper::traits::BitWidth::W128>::type sum_col[4] =
- {
- wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
- wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
- wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
- wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{})
- };
+ // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
+ typename wrapper::traits::neon_bitvector<TAcc, wrapper::traits::BitWidth::W128>::type sum_col[4] = {
+ wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
+ wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
+ wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
+ wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{})};
- const auto *matrix_b = reinterpret_cast<const T *>(inb.ptr() + id.y() * src->info()->strides_in_bytes()[2]);
+ const auto *matrix_b = reinterpret_cast<const T *>(inb.ptr() + id.y() * src->info()->strides_in_bytes()[2]);
#if __arm__
- asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b));
- asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b + in_b_stride));
+ asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b));
+ asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b + in_b_stride));
#endif /* __arm__ */
- int i = 0;
- // This for loop performs 4 accumulations
- for(; i <= (_k - 4); i += 4)
- {
- const auto b0_u8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
- const auto b1_u8 = wrapper::vloadq(matrix_b + 1 * in_b_stride);
- const auto b2_u8 = wrapper::vloadq(matrix_b + 2 * in_b_stride);
- const auto b3_u8 = wrapper::vloadq(matrix_b + 3 * in_b_stride);
+ int i = 0;
+ // This for loop performs 4 accumulations
+ for (; i <= (_k - 4); i += 4)
+ {
+ const auto b0_u8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
+ const auto b1_u8 = wrapper::vloadq(matrix_b + 1 * in_b_stride);
+ const auto b2_u8 = wrapper::vloadq(matrix_b + 2 * in_b_stride);
+ const auto b3_u8 = wrapper::vloadq(matrix_b + 3 * in_b_stride);
#if __arm__
- asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride));
- asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride));
- asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride));
- asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride));
+ asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride));
#endif /* __arm__ */
- // Partial accumulation in 16bit
- typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type tmp_sum[2] =
- {
- wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{}),
- wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{})
- };
-
- tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b1_u8));
- tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b0_u8));
- tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b2_u8));
- tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b3_u8));
- tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b0_u8));
- tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b1_u8));
- tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b2_u8));
- tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b3_u8));
-
- // Accumulate to 32bit
- sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(tmp_sum[0]));
- sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(tmp_sum[0]));
- sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(tmp_sum[1]));
- sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(tmp_sum[1]));
-
- matrix_b += 4 * in_b_stride;
- }
-
- // This for loop perfoms the leftover accumulations
- for(; i < _k; ++i)
- {
- const auto b0_b8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
+ // Partial accumulation in 16bit
+ typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type tmp_sum[2] = {
+ wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{}),
+ wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{})};
+
+ tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b1_u8));
+ tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b0_u8));
+ tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b2_u8));
+ tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b3_u8));
+ tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b0_u8));
+ tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b1_u8));
+ tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b2_u8));
+ tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b3_u8));
+
+ // Accumulate to 32bit
+ sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(tmp_sum[0]));
+ sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(tmp_sum[0]));
+ sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(tmp_sum[1]));
+ sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(tmp_sum[1]));
+
+ matrix_b += 4 * in_b_stride;
+ }
- // Convert S8 to S16
- const typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type b0_b16[2]
+ // This for loop perfoms the leftover accumulations
+ for (; i < _k; ++i)
{
- wrapper::vmovl(wrapper::vgetlow(b0_b8)),
- wrapper::vmovl(wrapper::vgethigh(b0_b8))
- };
+ const auto b0_b8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
- // Accumulate to 32bit
- sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(b0_b16[0]));
- sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(b0_b16[0]));
- sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(b0_b16[1]));
- sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1]));
+ // Convert S8 to S16
+ const typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type b0_b16[2]{
+ wrapper::vmovl(wrapper::vgetlow(b0_b8)), wrapper::vmovl(wrapper::vgethigh(b0_b8))};
- matrix_b += in_b_stride;
- }
+ // Accumulate to 32bit
+ sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(b0_b16[0]));
+ sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(b0_b16[0]));
+ sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(b0_b16[1]));
+ sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1]));
- // Multiply by scalar if necessary
- if(_mul_by_scalar)
- {
- sum_col[0] = wrapper::vmul(sum_col[0], vec_scalar);
- sum_col[1] = wrapper::vmul(sum_col[1], vec_scalar);
- sum_col[2] = wrapper::vmul(sum_col[2], vec_scalar);
- sum_col[3] = wrapper::vmul(sum_col[3], vec_scalar);
- }
-
- auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr());
- if(id.x() + 16 < width_matrix_b)
- {
- wrapper::vstore(vector_sum_col + 0, wrapper::vreinterpret(sum_col[0]));
- wrapper::vstore(vector_sum_col + 4, wrapper::vreinterpret(sum_col[1]));
- wrapper::vstore(vector_sum_col + 8, wrapper::vreinterpret(sum_col[2]));
- wrapper::vstore(vector_sum_col + 12, wrapper::vreinterpret(sum_col[3]));
- }
- else
- {
- auto left_over = width_matrix_b - id.x();
- for(auto k = 0; k < 4 && left_over; ++k)
+ matrix_b += in_b_stride;
+ }
+
+ // Multiply by scalar if necessary
+ if (_mul_by_scalar)
+ {
+ sum_col[0] = wrapper::vmul(sum_col[0], vec_scalar);
+ sum_col[1] = wrapper::vmul(sum_col[1], vec_scalar);
+ sum_col[2] = wrapper::vmul(sum_col[2], vec_scalar);
+ sum_col[3] = wrapper::vmul(sum_col[3], vec_scalar);
+ }
+
+ auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr());
+ if (id.x() + 16 < width_matrix_b)
+ {
+ wrapper::vstore(vector_sum_col + 0, wrapper::vreinterpret(sum_col[0]));
+ wrapper::vstore(vector_sum_col + 4, wrapper::vreinterpret(sum_col[1]));
+ wrapper::vstore(vector_sum_col + 8, wrapper::vreinterpret(sum_col[2]));
+ wrapper::vstore(vector_sum_col + 12, wrapper::vreinterpret(sum_col[3]));
+ }
+ else
{
- for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ auto left_over = width_matrix_b - id.x();
+ for (auto k = 0; k < 4 && left_over; ++k)
{
- *(vector_sum_col + k * 4 + j) = sum_col[k][j];
+ for (auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(vector_sum_col + k * 4 + j) = sum_col[k][j];
+ }
}
}
- }
- },
- inb, out);
+ },
+ inb, out);
}
void CpuGemmLowpMatrixBReductionKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
@@ -393,4 +415,4 @@ const char *CpuGemmLowpMatrixBReductionKernel::name() const
}
} // namespace kernels
} // namespace cpu
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute