aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-03-12 19:34:33 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-03-16 09:42:36 +0000
commita602f03f4c66e5ee2480f1a3fc66847968fc1076 (patch)
treea2752ca0de84f7920dd7296151d14e5edc8cacc0 /src/core/NEON
parent0ec53a0e54ae0be0ed9c4e4c14a5fd10ed5f48a8 (diff)
downloadComputeLibrary-a602f03f4c66e5ee2480f1a3fc66847968fc1076.tar.gz
COMPMID-3237: Extend GEMMLowpReduction kernels to multiply reductions by a scalar value
Change-Id: If2a242f52aea753591525d30a4cb64c1a766bf8d Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2881 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON')
-rw-r--r--src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp98
1 files changed, 73 insertions, 25 deletions
diff --git a/src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp
index 374005d897..b7e862c81f 100644
--- a/src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMLowpReductionKernel.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
@@ -37,26 +38,29 @@
#include <cstddef>
#include <cstdint>
-using namespace arm_compute;
-
namespace arm_compute
{
-class Coordinates;
-} // namespace arm_compute
-
namespace
{
Status validate_arguments_matrix_a_reduction(const ITensorInfo *input, const ITensorInfo *output)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ if(output->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(0) != input->dimension(1), "Output vector must have length equal to the number of rows of the input matrix");
+ }
return Status{};
}
std::pair<Status, Window> validate_and_configure_window_matrix_a_reduction(ITensorInfo *input, ITensorInfo *output, bool is_reshaped)
{
const unsigned int num_elems_processed_per_iteration = is_reshaped ? 4 : 1;
+ // Output auto initialization if not yet initialized
+ auto_init_if_empty(*output, TensorShape(input->dimension(1)), 1, DataType::S32);
+
Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
AccessWindowStatic input_access(input, 0, 0, ceil_to_multiple(input->dimension(0), 16), input->dimension(1));
@@ -72,9 +76,14 @@ std::pair<Status, Window> validate_and_configure_window_matrix_a_reduction(ITens
Status validate_arguments_matrix_b_reduction(const ITensorInfo *input, const ITensorInfo *output)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ if(output->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(0) != input->dimension(0), "Output vector must have length equal to the number of columns of the input matrix");
+ }
return Status{};
}
@@ -82,6 +91,9 @@ std::pair<Status, Window> validate_and_configure_window_matrix_b_reduction(ITens
{
constexpr unsigned int num_elems_processed_per_iteration = 16;
+ // Output auto initialization if not yet initialized
+ auto_init_if_empty(*output, TensorShape(input->dimension(0)), 1, DataType::S32);
+
// Configure kernel window
Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
@@ -98,20 +110,22 @@ std::pair<Status, Window> validate_and_configure_window_matrix_b_reduction(ITens
} // namespace
INEGEMMLowpReductionKernel::INEGEMMLowpReductionKernel()
- : _input(), _output(), _k(0), _is_reshaped(false)
+ : _input(), _output(), _k(0), _is_reshaped(false), _scalar(0), _mul_by_scalar(false)
{
}
-void NEGEMMLowpMatrixAReductionKernel::configure(const ITensor *mtx_a, ITensor *vector_sum_row, int32_t num_mtx_a_cols, bool is_interleaved4x4)
+void NEGEMMLowpMatrixAReductionKernel::configure(const ITensor *mtx_a, ITensor *vector_sum_row, const GEMMLowpReductionKernelInfo &info)
{
// Perform validate step
ARM_COMPUTE_ERROR_ON_NULLPTR(mtx_a, vector_sum_row);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_matrix_a_reduction(mtx_a->info(), vector_sum_row->info()));
- _input = mtx_a;
- _output = vector_sum_row;
- _k = num_mtx_a_cols;
- _is_reshaped = is_interleaved4x4;
+ _input = mtx_a;
+ _output = vector_sum_row;
+ _k = info.k;
+ _is_reshaped = info.is_reshaped;
+ _scalar = info.scalar;
+ _mul_by_scalar = info.mul_by_scalar;
// Configure kernel window
auto win_config = validate_and_configure_window_matrix_a_reduction(_input->info(), _output->info(), _is_reshaped);
@@ -119,11 +133,10 @@ void NEGEMMLowpMatrixAReductionKernel::configure(const ITensor *mtx_a, ITensor *
INEKernel::configure(win_config.second);
}
-Status NEGEMMLowpMatrixAReductionKernel::validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row, int32_t num_mtx_a_cols, bool is_interleaved4x4)
+Status NEGEMMLowpMatrixAReductionKernel::validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row, const GEMMLowpReductionKernelInfo &info)
{
- ARM_COMPUTE_UNUSED(num_mtx_a_cols);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_matrix_a_reduction(mtx_a, vector_sum_row));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_matrix_a_reduction(mtx_a->clone().get(), vector_sum_row->clone().get(), is_interleaved4x4).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_matrix_a_reduction(mtx_a->clone().get(), vector_sum_row->clone().get(), info.is_reshaped).first);
return Status{};
}
@@ -145,11 +158,12 @@ void NEGEMMLowpMatrixAReductionKernel::run_internal(const arm_compute::Window &w
Iterator in(_input, win_input);
Iterator out(_output, collapsed_window);
+ const auto vec_scalar = wrapper::vdup_n(static_cast<TAcc>(_scalar), wrapper::traits::vector_128_tag{});
+
if(_is_reshaped)
{
execute_window_loop(collapsed_window, [&](const Coordinates & id)
{
- // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
auto sum_row = wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{});
const T *matrix_a = reinterpret_cast<const T *>((in.ptr() + (id.x() / 4) * _input->info()->strides_in_bytes()[1] + id.y() * _input->info()->strides_in_bytes()[2]));
@@ -194,6 +208,12 @@ void NEGEMMLowpMatrixAReductionKernel::run_internal(const arm_compute::Window &w
sum_row = wrapper::vaddw(sum_row, a0_d16);
}
+ // Multiply by scalar if necessary
+ if(_mul_by_scalar)
+ {
+ sum_row = wrapper::vmul(sum_row, vec_scalar);
+ }
+
auto vector_sum_row = reinterpret_cast<int32_t *>(out.ptr());
wrapper::vstore(vector_sum_row, wrapper::vreinterpret(sum_row));
@@ -243,6 +263,12 @@ void NEGEMMLowpMatrixAReductionKernel::run_internal(const arm_compute::Window &w
sum_row += wrapper::vgetlane(tmp, 0);
#endif // __aarch64__
+ // Multiply by scalar if necessary
+ if(_mul_by_scalar)
+ {
+ sum_row *= _scalar;
+ }
+
*(reinterpret_cast<int *>(out.ptr())) = static_cast<int32_t>(sum_row);
},
in, out);
@@ -269,15 +295,17 @@ void NEGEMMLowpMatrixAReductionKernel::run(const Window &window, const ThreadInf
}
}
-void NEGEMMLowpMatrixBReductionKernel::configure(const ITensor *mtx_b, ITensor *vector_sum_col, int32_t num_mtx_b_rows, bool is_transposed1xW)
+void NEGEMMLowpMatrixBReductionKernel::configure(const ITensor *mtx_b, ITensor *vector_sum_col, const GEMMLowpReductionKernelInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(mtx_b, vector_sum_col);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_matrix_b_reduction(mtx_b->info(), vector_sum_col->info()));
- _input = mtx_b;
- _output = vector_sum_col;
- _k = num_mtx_b_rows;
- _is_reshaped = is_transposed1xW;
+ _input = mtx_b;
+ _output = vector_sum_col;
+ _k = info.k;
+ _is_reshaped = info.is_reshaped;
+ _scalar = info.scalar;
+ _mul_by_scalar = info.mul_by_scalar;
// Configure kernel window
auto win_config = validate_and_configure_window_matrix_b_reduction(_input->info(), _output->info());
@@ -285,10 +313,9 @@ void NEGEMMLowpMatrixBReductionKernel::configure(const ITensor *mtx_b, ITensor *
INEKernel::configure(win_config.second);
}
-Status NEGEMMLowpMatrixBReductionKernel::validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col, int32_t num_mtx_b_rows, bool is_transposed1xW)
+Status NEGEMMLowpMatrixBReductionKernel::validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col, const GEMMLowpReductionKernelInfo &info)
{
- ARM_COMPUTE_UNUSED(num_mtx_b_rows);
- ARM_COMPUTE_UNUSED(is_transposed1xW);
+ ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_matrix_b_reduction(mtx_b, vector_sum_col));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_matrix_b_reduction(mtx_b->clone().get(), vector_sum_col->clone().get()).first);
@@ -304,6 +331,8 @@ void NEGEMMLowpMatrixBReductionKernel::run_internal(const Window &window, const
Window collapsed_window = window.collapse_if_possible(IKernel::window(), Window::DimY);
+ const auto vec_scalar = wrapper::vdup_n(static_cast<TAcc>(_scalar), wrapper::traits::vector_128_tag{});
+
if(_is_reshaped)
{
Window win_input(collapsed_window);
@@ -350,6 +379,15 @@ void NEGEMMLowpMatrixBReductionKernel::run_internal(const Window &window, const
sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1]));
}
+ // Multiply by scalar if necessary
+ if(_mul_by_scalar)
+ {
+ sum_col[0] = wrapper::vmul(sum_col[0], vec_scalar);
+ sum_col[1] = wrapper::vmul(sum_col[1], vec_scalar);
+ sum_col[2] = wrapper::vmul(sum_col[2], vec_scalar);
+ sum_col[3] = wrapper::vmul(sum_col[3], vec_scalar);
+ }
+
auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr());
wrapper::vstore(vector_sum_col + 0, wrapper::vreinterpret(sum_col[0]));
@@ -465,6 +503,15 @@ void NEGEMMLowpMatrixBReductionKernel::run_internal(const Window &window, const
matrix_b += in_b_stride;
}
+ // Multiply by scalar if necessary
+ if(_mul_by_scalar)
+ {
+ sum_col[0] = wrapper::vmul(sum_col[0], vec_scalar);
+ sum_col[1] = wrapper::vmul(sum_col[1], vec_scalar);
+ sum_col[2] = wrapper::vmul(sum_col[2], vec_scalar);
+ sum_col[3] = wrapper::vmul(sum_col[3], vec_scalar);
+ }
+
auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr());
wrapper::vstore(vector_sum_col + 0, wrapper::vreinterpret(sum_col[0]));
@@ -495,3 +542,4 @@ void NEGEMMLowpMatrixBReductionKernel::run(const Window &window, const ThreadInf
ARM_COMPUTE_ERROR("Unsupported data type");
}
}
+} // namespace arm_compute