aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2017-06-30 12:21:00 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:15:39 +0100
commitbdb6b0bb156588dc39fd5084d4c91d05b5148610 (patch)
treebb3c3645dd9abbf20462dace7828bb7ec459dc4d /src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
parentac69aa137e360340fe9f148f019d93af6c3d8336 (diff)
downloadComputeLibrary-bdb6b0bb156588dc39fd5084d4c91d05b5148610.tar.gz
COMPMID-433 - Port NEGEMM to support 16 bit fixed point
Change-Id: I82de74d7027bbc8a00a4d6671e968785280d5f6c Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79498 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp272
1 files changed, 267 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
index bff16ec329..b81be6cee9 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
@@ -475,6 +475,135 @@ void vector_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, IT
}
template <bool multiply_alpha>
+void vector_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
+{
+ const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
+ const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
+ const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
+ const int fixed_point_position = input0->info()->fixed_point_position();
+
+ // The implementation computes 16 elements per iteration
+ const int window_start_x = 16 * window.thread_id();
+ const int window_step_x = 16 * window.num_threads();
+ // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
+ const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
+ ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
+
+ Window win_out(window);
+ win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ Window win_a(window);
+ win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Window win_b;
+ // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
+ // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ if(input1->info()->num_dimensions() >= 3)
+ {
+ win_b = window;
+ }
+ win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ Iterator ina(input0, win_a);
+ Iterator inb(input1, win_b);
+ Iterator out(output, win_out);
+
+ execute_window_loop(win_out, [&](const Coordinates & id)
+ {
+ if(id.x() > width_matrix_b)
+ {
+ return;
+ }
+
+ // Reset accumulators
+ qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc02_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc03_qs32 = vdupq_n_qs32(0);
+
+ auto vec_a = reinterpret_cast<const qint16_t *>(ina.ptr());
+ auto matrix_b = reinterpret_cast<const qint16_t *>(inb.ptr());
+
+ auto vec_a_end_addr = vec_a + num_elems_vec_a;
+ for(; vec_a <= (vec_a_end_addr - 2);)
+ {
+ const qint16x4_t a0 = vld1_dup_qs16(vec_a + 0);
+ const qint16x4_t a1 = vld1_dup_qs16(vec_a + 1);
+
+ const qint16x4_t b00 = vld1_qs16(matrix_b + 0 + 0 * in_b_stride);
+ const qint16x4_t b01 = vld1_qs16(matrix_b + 4 + 0 * in_b_stride);
+ const qint16x4_t b02 = vld1_qs16(matrix_b + 8 + 0 * in_b_stride);
+ const qint16x4_t b03 = vld1_qs16(matrix_b + 12 + 0 * in_b_stride);
+ const qint16x4_t b10 = vld1_qs16(matrix_b + 0 + 1 * in_b_stride);
+ const qint16x4_t b11 = vld1_qs16(matrix_b + 4 + 1 * in_b_stride);
+ const qint16x4_t b12 = vld1_qs16(matrix_b + 8 + 1 * in_b_stride);
+ const qint16x4_t b13 = vld1_qs16(matrix_b + 12 + 1 * in_b_stride);
+
+ // First accumulation
+ acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
+ acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
+ acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
+ acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
+
+ // Second accumulation
+ acc00_qs32 = vqmlal_qs16(acc00_qs32, b10, a1, fixed_point_position);
+ acc01_qs32 = vqmlal_qs16(acc01_qs32, b11, a1, fixed_point_position);
+ acc02_qs32 = vqmlal_qs16(acc02_qs32, b12, a1, fixed_point_position);
+ acc03_qs32 = vqmlal_qs16(acc03_qs32, b13, a1, fixed_point_position);
+
+ vec_a += 2;
+ matrix_b += 2 * in_b_stride;
+ }
+
+ for(; vec_a < vec_a_end_addr;)
+ {
+ const qint16x4_t a0 = vld1_dup_qs16(vec_a);
+
+ const qint16x4_t b00 = vld1_qs16(matrix_b + 0);
+ const qint16x4_t b01 = vld1_qs16(matrix_b + 4);
+ const qint16x4_t b02 = vld1_qs16(matrix_b + 8);
+ const qint16x4_t b03 = vld1_qs16(matrix_b + 12);
+
+ acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
+ acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
+ acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
+ acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
+
+ vec_a += 1;
+ matrix_b += in_b_stride;
+ }
+
+ // Convert back to qint16x4_t and saturate
+ qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
+ qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
+ qint16x4_t acc02_qs16 = vqmovn_qs32(acc02_qs32);
+ qint16x4_t acc03_qs16 = vqmovn_qs32(acc03_qs32);
+
+ // Multiply by the weight of the matrix product (alpha)
+ if(multiply_alpha)
+ {
+ const qint16x4_t alpha_qs16 = vdup_n_qs16(scvt_qs16_f32(alpha, fixed_point_position));
+ acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
+ acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
+ acc02_qs16 = vqmul_qs16(acc02_qs16, alpha_qs16, fixed_point_position);
+ acc03_qs16 = vqmul_qs16(acc03_qs16, alpha_qs16, fixed_point_position);
+ }
+
+ const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
+
+ // Store 16x4 output elements
+ vst1_qs16(mtx_out0 + 0, acc00_qs16);
+ vst1_qs16(mtx_out0 + 4, acc01_qs16);
+ vst1_qs16(mtx_out0 + 8, acc02_qs16);
+ vst1_qs16(mtx_out0 + 12, acc03_qs16);
+ },
+ ina, inb, out);
+}
+
+template <bool multiply_alpha>
void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
{
const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
@@ -1153,6 +1282,120 @@ void matrix_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, IT
ina, inb, out);
}
+template <bool multiply_alpha>
+void matrix_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
+{
+ const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
+ const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
+ const size_t out_stride2 = out_stride1 * 2;
+ const size_t out_stride3 = out_stride1 * 3;
+ const int num_elems_matrix_b_x = input1->info()->dimension(0);
+ const int fixed_point_position = input0->info()->fixed_point_position();
+ const qint16x4_t alpha_qs16 = vdup_n_qs16(scvt_qs16_f32(alpha, fixed_point_position));
+ ARM_COMPUTE_UNUSED(alpha_qs16);
+
+ // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
+ Window win_a(window);
+ win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
+
+ Window win_b;
+ // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
+ // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ if(input1->info()->num_dimensions() >= 3)
+ {
+ win_b = window;
+ }
+ // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
+ win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
+ win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Iterator ina(input0, win_a);
+ Iterator inb(input1, win_b);
+ Iterator out(output, window);
+
+ // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
+ // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 8x4 elements per iteration
+ // All the values needed for computing a single 8x4 block will be read from consecutive memory positions
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ auto mtx_a0 = reinterpret_cast<const qint16_t *>(ina.ptr());
+ auto mtx_b0 = reinterpret_cast<const qint16_t *>(inb.ptr());
+ auto mtx_b1 = mtx_b0 + in_b_stride;
+
+ qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc10_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc20_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc30_qs32 = vdupq_n_qs32(0);
+
+ qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc11_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc21_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc31_qs32 = vdupq_n_qs32(0);
+
+ // This for loop performs 1 accumulation
+ for(int k = 0; k <= (num_elems_matrix_b_x - 8); k += 8)
+ {
+ const qint16x4_t a0 = vld1_dup_qs16(mtx_a0 + 0);
+ const qint16x4_t a1 = vld1_dup_qs16(mtx_a0 + 1);
+ const qint16x4_t a2 = vld1_dup_qs16(mtx_a0 + 2);
+ const qint16x4_t a3 = vld1_dup_qs16(mtx_a0 + 3);
+
+ const qint16x4_t b00 = vld1_qs16(mtx_b0 + 0);
+ const qint16x4_t b01 = vld1_qs16(mtx_b0 + 4);
+
+ acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
+ acc10_qs32 = vqmlal_qs16(acc10_qs32, b00, a1, fixed_point_position);
+ acc20_qs32 = vqmlal_qs16(acc20_qs32, b00, a2, fixed_point_position);
+ acc30_qs32 = vqmlal_qs16(acc30_qs32, b00, a3, fixed_point_position);
+ acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
+ acc11_qs32 = vqmlal_qs16(acc11_qs32, b01, a1, fixed_point_position);
+ acc21_qs32 = vqmlal_qs16(acc21_qs32, b01, a2, fixed_point_position);
+ acc31_qs32 = vqmlal_qs16(acc31_qs32, b01, a3, fixed_point_position);
+
+ mtx_a0 += 4;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+ }
+
+ // Convert back to qint16x4_t and saturate
+ qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
+ qint16x4_t acc10_qs16 = vqmovn_qs32(acc10_qs32);
+ qint16x4_t acc20_qs16 = vqmovn_qs32(acc20_qs32);
+ qint16x4_t acc30_qs16 = vqmovn_qs32(acc30_qs32);
+
+ qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
+ qint16x4_t acc11_qs16 = vqmovn_qs32(acc11_qs32);
+ qint16x4_t acc21_qs16 = vqmovn_qs32(acc21_qs32);
+ qint16x4_t acc31_qs16 = vqmovn_qs32(acc31_qs32);
+
+ // Multiply by the weight of the matrix product (alpha)
+ if(multiply_alpha)
+ {
+ acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
+ acc10_qs16 = vqmul_qs16(acc10_qs16, alpha_qs16, fixed_point_position);
+ acc20_qs16 = vqmul_qs16(acc20_qs16, alpha_qs16, fixed_point_position);
+ acc30_qs16 = vqmul_qs16(acc30_qs16, alpha_qs16, fixed_point_position);
+ acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
+ acc11_qs16 = vqmul_qs16(acc11_qs16, alpha_qs16, fixed_point_position);
+ acc21_qs16 = vqmul_qs16(acc21_qs16, alpha_qs16, fixed_point_position);
+ acc31_qs16 = vqmul_qs16(acc31_qs16, alpha_qs16, fixed_point_position);
+ }
+
+ const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
+
+ // Store 8x4 output elements
+ vst1_qs16(mtx_out0 + 0, acc00_qs16);
+ vst1_qs16(mtx_out0 + 4, acc01_qs16);
+ vst1_qs16(mtx_out0 + out_stride1 + 0, acc10_qs16);
+ vst1_qs16(mtx_out0 + out_stride1 + 4, acc11_qs16);
+ vst1_qs16(mtx_out0 + out_stride2 + 0, acc20_qs16);
+ vst1_qs16(mtx_out0 + out_stride2 + 4, acc21_qs16);
+ vst1_qs16(mtx_out0 + out_stride3 + 0, acc30_qs16);
+ vst1_qs16(mtx_out0 + out_stride3 + 4, acc31_qs16);
+ },
+ ina, inb, out);
+}
} // namespace
NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
@@ -1162,10 +1405,7 @@ NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32, DataType::QS8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
@@ -1197,6 +1437,11 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
num_elems_processed_per_iteration_x = 32;
break;
}
+ case DataType::QS16:
+ {
+ num_elems_processed_per_iteration_x = 16;
+ break;
+ }
#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
@@ -1241,6 +1486,11 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
num_elems_processed_per_iteration_x = 32;
break;
}
+ case DataType::QS16:
+ {
+ num_elems_processed_per_iteration_x = 8;
+ break;
+ }
#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
@@ -1278,7 +1528,7 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window)
bool multiply_alpha = std::abs(1.0f - _alpha) > 0.00001f;
- // Check if the output tensor is a vector and the data type is F32. If so,the kernel runs the vector-matrix multiplication
+ // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
if((_output->info()->dimension(1) == 1))
{
switch(_input0->info()->data_type())
@@ -1295,6 +1545,12 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window)
vector_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
break;
}
+ case DataType::QS16:
+ {
+ multiply_alpha ? vector_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
+ vector_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
+ break;
+ }
#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
@@ -1326,6 +1582,12 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window)
matrix_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
break;
}
+ case DataType::QS16:
+ {
+ multiply_alpha ? matrix_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
+ matrix_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
+ break;
+ }
#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{