aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2017-06-30 12:21:00 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:15:39 +0100
commitbdb6b0bb156588dc39fd5084d4c91d05b5148610 (patch)
treebb3c3645dd9abbf20462dace7828bb7ec459dc4d
parentac69aa137e360340fe9f148f019d93af6c3d8336 (diff)
downloadComputeLibrary-bdb6b0bb156588dc39fd5084d4c91d05b5148610.tar.gz
COMPMID-433 - Port NEGEMM to support 16 bit fixed point
Change-Id: I82de74d7027bbc8a00a4d6671e968785280d5f6c Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79498 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h4
-rw-r--r--arm_compute/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h4
-rw-r--r--arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h2
-rw-r--r--arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h2
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMM.h2
-rw-r--r--src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp3
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp32
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp272
-rw-r--r--src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp3
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp36
-rw-r--r--tests/TensorLibrary.h1
-rw-r--r--tests/TypePrinter.h3
-rw-r--r--tests/validation/NEON/GEMM.cpp12
13 files changed, 324 insertions, 52 deletions
diff --git a/arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h b/arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h
index b9884ffb57..84b82d0ffc 100644
--- a/arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h
+++ b/arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h
@@ -56,7 +56,7 @@ public:
NEGEMMInterleave4x4Kernel();
/** Initialise the kernel's input and output.
*
- * @param[in] input Input tensor. Data types supported: U8/S8/QS8/U16/S16/F16/U32/S32/F32
+ * @param[in] input Input tensor. Data types supported: U8/S8/QS8/QS16/U16/S16/F16/U32/S32/F32
* @param[out] output Output tensor which stores the interleaved matrix. Data type supported: same as @p input.
*/
void configure(const ITensor *input, ITensor *output);
@@ -67,7 +67,7 @@ public:
private:
/** Common signature for all the transpose functions
*
- * @param[in] input An input tensor. Data types supported: U8/S8/QS8/U16/S16/F16/U32/S32/F32
+ * @param[in] input An input tensor. Data types supported: U8/S8/QS8/U16/S16/QS16/F16/U32/S32/F32
* @param[out] output The output tensor. Data type supported: same as @p input
* @param[in] window Region on which to execute the kernel.
*/
diff --git a/arm_compute/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h b/arm_compute/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h
index 1ab52fa2f2..5cdcc95ee9 100644
--- a/arm_compute/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h
+++ b/arm_compute/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h
@@ -55,7 +55,7 @@ public:
*
* @note The input and output tensor must have the same dimensions
*
- * @param[in] input Input tensor (Matrix C). Data types supported: QS8/F16/F32
+ * @param[in] input Input tensor (Matrix C). Data types supported: QS8/QS16/F16/F32
* @param[in, out] output Output tensor. If this kernel is used to finalize the GEMM result, output contains the result obtained by the kernel @ref NEGEMMMatrixMultiplyKernel. Data type supported: the same as @p input.
* @param[in] beta Weight of matrix C
*/
@@ -67,7 +67,7 @@ public:
private:
/** Common signature for all the matrix addition functions
*
- * @param[in] input An input tensor. Data types supported: QS8/F16/F32
+ * @param[in] input An input tensor. Data types supported: QS8/QS16/F16/F32
* @param[out] output The output tensor. Data type supported: same as @p input
* @param[in] window Region on which to execute the kernel.
* @param[in] beta Weight of matrix C
diff --git a/arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h b/arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h
index a684945828..e82fc6f5d7 100644
--- a/arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h
+++ b/arm_compute/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h
@@ -54,7 +54,7 @@ public:
* @note If the output tensor is a matrix, the input matrices @p input0 and @p input1 should be the output of the kernels: @ref NEGEMMInterleave4x4Kernel and @ref NEGEMMTranspose1xWKernel
* These two kernels change the layout of the original matrices to be more cache-friendly.
*
- * @param[in] input0 Input tensor containing the interleaved Matrix A or the vector A. Data types supported: F16/F32
+ * @param[in] input0 Input tensor containing the interleaved Matrix A or the vector A. Data types supported: QS8/QS16/F16/F32
* @param[in] input1 Input tensor containing the transposed Matrix B if the first input tensor A is not a vector.
* If the output tensor is a vector, input1 must contain the matrix B not reshaped. Data type supported: same as @p input0
* @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
diff --git a/arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h b/arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h
index 5d8a3697cb..22c07e5c9a 100644
--- a/arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h
+++ b/arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h
@@ -70,7 +70,7 @@ class NEGEMMTranspose1xWKernel : public INESimpleKernel
public:
/** Initialise the kernel's input and output.
*
- * @param[in] input Input tensor. Data types supported: U8/S8/QS8/U16/S16/F16/U32/S32/F32
+ * @param[in] input Input tensor. Data types supported: U8/S8/QS8/U16/S16/QS16/F16/U32/S32/F32
* @param[out] output Output tensor. Data type supported: same as @p input.
*/
void configure(const ITensor *input, ITensor *output);
diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h
index a40aa910a5..3c8d7cf9b7 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMM.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMM.h
@@ -52,7 +52,7 @@ public:
* @note GEMM: General Matrix Multiply - [alpha * A * B + beta * C].
* @note GEMM: The tensors a, b, c, d must have the same data type. You should not mix data types when calling this function.
*
- * @param[in] a First input tensor (Matrix A or Vector A). Data type supported: QS8/F16/F32
+ * @param[in] a First input tensor (Matrix A or Vector A). Data type supported: QS8/QS16/F16/F32
* @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a
* @param[in] c Third input tensor (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a
* @param[out] d Output tensor. Data type supported: same as @p a
diff --git a/src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp b/src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp
index 4505dcb363..40ece9faab 100644
--- a/src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.cpp
@@ -132,7 +132,8 @@ NEGEMMInterleave4x4Kernel::NEGEMMInterleave4x4Kernel()
void NEGEMMInterleave4x4Kernel::configure(const ITensor *input, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::U8, DataType::S8, DataType::U16, DataType::S16, DataType::U32, DataType::S32, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::U8, DataType::S8, DataType::U16, DataType::S16, DataType::U32, DataType::S32, DataType::F16,
+ DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
TensorShape output_shape = input->info()->tensor_shape();
diff --git a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
index 57d2807b8a..91fbe6f962 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.cpp
@@ -114,6 +114,31 @@ void matrix_addition_qs8(const ITensor *input, ITensor *output, const Window &wi
},
in, out);
}
+
+void matrix_addition_qs16(const ITensor *input, ITensor *output, const Window &window, float beta)
+{
+ const int fixed_point_position = input->info()->fixed_point_position();
+ const qint16x8_t beta_qs16 = vdupq_n_qs16(scvt_qs16_f32(beta, fixed_point_position));
+
+ Iterator in(input, window);
+ Iterator out(output, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ const auto in_ptr = reinterpret_cast<const qint16_t *>(in.ptr());
+ const auto out_ptr = reinterpret_cast<qint16_t *>(out.ptr());
+
+ qint16x8x2_t alpha_ab = vld2q_s16(out_ptr);
+ const qint16x8x2_t c = vld2q_s16(in_ptr);
+
+ // Multiply matrix C by its weight and accumulate
+ alpha_ab.val[0] = vqmlaq_qs16(alpha_ab.val[0], c.val[0], beta_qs16, fixed_point_position);
+ alpha_ab.val[1] = vqmlaq_qs16(alpha_ab.val[1], c.val[1], beta_qs16, fixed_point_position);
+
+ vst2q_s16(out_ptr, alpha_ab);
+ },
+ in, out);
+}
} // namespace
NEGEMMMatrixAdditionKernel::NEGEMMMatrixAdditionKernel()
@@ -123,8 +148,8 @@ NEGEMMMatrixAdditionKernel::NEGEMMMatrixAdditionKernel()
void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output, float beta)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != output->info()->dimension(0));
@@ -138,6 +163,9 @@ void NEGEMMMatrixAdditionKernel::configure(const ITensor *input, ITensor *output
case DataType::QS8:
_func = &matrix_addition_qs8;
break;
+ case DataType::QS16:
+ _func = &matrix_addition_qs16;
+ break;
case DataType::F16:
#ifdef ARM_COMPUTE_ENABLE_FP16
_func = &matrix_addition_f16;
diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
index bff16ec329..b81be6cee9 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
@@ -475,6 +475,135 @@ void vector_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, IT
}
template <bool multiply_alpha>
+void vector_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
+{
+ const auto width_matrix_b = static_cast<int>(output->info()->dimension(0));
+ const auto in_b_stride = static_cast<int>(input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type()));
+ const auto num_elems_vec_a = static_cast<int>(input0->info()->dimension(0));
+ const int fixed_point_position = input0->info()->fixed_point_position();
+
+ // The implementation computes 16 elements per iteration
+ const int window_start_x = 16 * window.thread_id();
+ const int window_step_x = 16 * window.num_threads();
+ // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
+ const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
+ ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
+
+ Window win_out(window);
+ win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ Window win_a(window);
+ win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Window win_b;
+ // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
+ // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ if(input1->info()->num_dimensions() >= 3)
+ {
+ win_b = window;
+ }
+ win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
+ win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
+
+ Iterator ina(input0, win_a);
+ Iterator inb(input1, win_b);
+ Iterator out(output, win_out);
+
+ execute_window_loop(win_out, [&](const Coordinates & id)
+ {
+ if(id.x() > width_matrix_b)
+ {
+ return;
+ }
+
+ // Reset accumulators
+ qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc02_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc03_qs32 = vdupq_n_qs32(0);
+
+ auto vec_a = reinterpret_cast<const qint16_t *>(ina.ptr());
+ auto matrix_b = reinterpret_cast<const qint16_t *>(inb.ptr());
+
+ auto vec_a_end_addr = vec_a + num_elems_vec_a;
+ for(; vec_a <= (vec_a_end_addr - 2);)
+ {
+ const qint16x4_t a0 = vld1_dup_qs16(vec_a + 0);
+ const qint16x4_t a1 = vld1_dup_qs16(vec_a + 1);
+
+ const qint16x4_t b00 = vld1_qs16(matrix_b + 0 + 0 * in_b_stride);
+ const qint16x4_t b01 = vld1_qs16(matrix_b + 4 + 0 * in_b_stride);
+ const qint16x4_t b02 = vld1_qs16(matrix_b + 8 + 0 * in_b_stride);
+ const qint16x4_t b03 = vld1_qs16(matrix_b + 12 + 0 * in_b_stride);
+ const qint16x4_t b10 = vld1_qs16(matrix_b + 0 + 1 * in_b_stride);
+ const qint16x4_t b11 = vld1_qs16(matrix_b + 4 + 1 * in_b_stride);
+ const qint16x4_t b12 = vld1_qs16(matrix_b + 8 + 1 * in_b_stride);
+ const qint16x4_t b13 = vld1_qs16(matrix_b + 12 + 1 * in_b_stride);
+
+ // First accumulation
+ acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
+ acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
+ acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
+ acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
+
+ // Second accumulation
+ acc00_qs32 = vqmlal_qs16(acc00_qs32, b10, a1, fixed_point_position);
+ acc01_qs32 = vqmlal_qs16(acc01_qs32, b11, a1, fixed_point_position);
+ acc02_qs32 = vqmlal_qs16(acc02_qs32, b12, a1, fixed_point_position);
+ acc03_qs32 = vqmlal_qs16(acc03_qs32, b13, a1, fixed_point_position);
+
+ vec_a += 2;
+ matrix_b += 2 * in_b_stride;
+ }
+
+ for(; vec_a < vec_a_end_addr;)
+ {
+ const qint16x4_t a0 = vld1_dup_qs16(vec_a);
+
+ const qint16x4_t b00 = vld1_qs16(matrix_b + 0);
+ const qint16x4_t b01 = vld1_qs16(matrix_b + 4);
+ const qint16x4_t b02 = vld1_qs16(matrix_b + 8);
+ const qint16x4_t b03 = vld1_qs16(matrix_b + 12);
+
+ acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
+ acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
+ acc02_qs32 = vqmlal_qs16(acc02_qs32, b02, a0, fixed_point_position);
+ acc03_qs32 = vqmlal_qs16(acc03_qs32, b03, a0, fixed_point_position);
+
+ vec_a += 1;
+ matrix_b += in_b_stride;
+ }
+
+ // Convert back to qint16x4_t and saturate
+ qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
+ qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
+ qint16x4_t acc02_qs16 = vqmovn_qs32(acc02_qs32);
+ qint16x4_t acc03_qs16 = vqmovn_qs32(acc03_qs32);
+
+ // Multiply by the weight of the matrix product (alpha)
+ if(multiply_alpha)
+ {
+ const qint16x4_t alpha_qs16 = vdup_n_qs16(scvt_qs16_f32(alpha, fixed_point_position));
+ acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
+ acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
+ acc02_qs16 = vqmul_qs16(acc02_qs16, alpha_qs16, fixed_point_position);
+ acc03_qs16 = vqmul_qs16(acc03_qs16, alpha_qs16, fixed_point_position);
+ }
+
+ const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
+
+ // Store 16x4 output elements
+ vst1_qs16(mtx_out0 + 0, acc00_qs16);
+ vst1_qs16(mtx_out0 + 4, acc01_qs16);
+ vst1_qs16(mtx_out0 + 8, acc02_qs16);
+ vst1_qs16(mtx_out0 + 12, acc03_qs16);
+ },
+ ina, inb, out);
+}
+
+template <bool multiply_alpha>
void matrix_matrix_multiply_f32(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
{
const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
@@ -1153,6 +1282,120 @@ void matrix_matrix_multiply_qs8(const ITensor *input0, const ITensor *input1, IT
ina, inb, out);
}
+template <bool multiply_alpha>
+void matrix_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, ITensor *output, const Window &window, float alpha)
+{
+ const size_t in_b_stride = input1->info()->strides_in_bytes()[1] / data_size_from_type(input1->info()->data_type());
+ const size_t out_stride1 = output->info()->strides_in_bytes()[1] / data_size_from_type(output->info()->data_type());
+ const size_t out_stride2 = out_stride1 * 2;
+ const size_t out_stride3 = out_stride1 * 3;
+ const int num_elems_matrix_b_x = input1->info()->dimension(0);
+ const int fixed_point_position = input0->info()->fixed_point_position();
+ const qint16x4_t alpha_qs16 = vdup_n_qs16(scvt_qs16_f32(alpha, fixed_point_position));
+ ARM_COMPUTE_UNUSED(alpha_qs16);
+
+ // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
+ Window win_a(window);
+ win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
+ win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
+
+ Window win_b;
+ // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
+ // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
+ if(input1->info()->num_dimensions() >= 3)
+ {
+ win_b = window;
+ }
+ // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the cols of the output matrix
+ win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
+ win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Iterator ina(input0, win_a);
+ Iterator inb(input1, win_b);
+ Iterator out(output, window);
+
+ // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
+ // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 8x4 elements per iteration
+ // All the values needed for computing a single 8x4 block will be read from consecutive memory positions
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ auto mtx_a0 = reinterpret_cast<const qint16_t *>(ina.ptr());
+ auto mtx_b0 = reinterpret_cast<const qint16_t *>(inb.ptr());
+ auto mtx_b1 = mtx_b0 + in_b_stride;
+
+ qint32x4_t acc00_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc10_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc20_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc30_qs32 = vdupq_n_qs32(0);
+
+ qint32x4_t acc01_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc11_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc21_qs32 = vdupq_n_qs32(0);
+ qint32x4_t acc31_qs32 = vdupq_n_qs32(0);
+
+ // This for loop performs 1 accumulation
+ for(int k = 0; k <= (num_elems_matrix_b_x - 8); k += 8)
+ {
+ const qint16x4_t a0 = vld1_dup_qs16(mtx_a0 + 0);
+ const qint16x4_t a1 = vld1_dup_qs16(mtx_a0 + 1);
+ const qint16x4_t a2 = vld1_dup_qs16(mtx_a0 + 2);
+ const qint16x4_t a3 = vld1_dup_qs16(mtx_a0 + 3);
+
+ const qint16x4_t b00 = vld1_qs16(mtx_b0 + 0);
+ const qint16x4_t b01 = vld1_qs16(mtx_b0 + 4);
+
+ acc00_qs32 = vqmlal_qs16(acc00_qs32, b00, a0, fixed_point_position);
+ acc10_qs32 = vqmlal_qs16(acc10_qs32, b00, a1, fixed_point_position);
+ acc20_qs32 = vqmlal_qs16(acc20_qs32, b00, a2, fixed_point_position);
+ acc30_qs32 = vqmlal_qs16(acc30_qs32, b00, a3, fixed_point_position);
+ acc01_qs32 = vqmlal_qs16(acc01_qs32, b01, a0, fixed_point_position);
+ acc11_qs32 = vqmlal_qs16(acc11_qs32, b01, a1, fixed_point_position);
+ acc21_qs32 = vqmlal_qs16(acc21_qs32, b01, a2, fixed_point_position);
+ acc31_qs32 = vqmlal_qs16(acc31_qs32, b01, a3, fixed_point_position);
+
+ mtx_a0 += 4;
+ mtx_b0 += 8;
+ mtx_b1 += 8;
+ }
+
+ // Convert back to qint16x4_t and saturate
+ qint16x4_t acc00_qs16 = vqmovn_qs32(acc00_qs32);
+ qint16x4_t acc10_qs16 = vqmovn_qs32(acc10_qs32);
+ qint16x4_t acc20_qs16 = vqmovn_qs32(acc20_qs32);
+ qint16x4_t acc30_qs16 = vqmovn_qs32(acc30_qs32);
+
+ qint16x4_t acc01_qs16 = vqmovn_qs32(acc01_qs32);
+ qint16x4_t acc11_qs16 = vqmovn_qs32(acc11_qs32);
+ qint16x4_t acc21_qs16 = vqmovn_qs32(acc21_qs32);
+ qint16x4_t acc31_qs16 = vqmovn_qs32(acc31_qs32);
+
+ // Multiply by the weight of the matrix product (alpha)
+ if(multiply_alpha)
+ {
+ acc00_qs16 = vqmul_qs16(acc00_qs16, alpha_qs16, fixed_point_position);
+ acc10_qs16 = vqmul_qs16(acc10_qs16, alpha_qs16, fixed_point_position);
+ acc20_qs16 = vqmul_qs16(acc20_qs16, alpha_qs16, fixed_point_position);
+ acc30_qs16 = vqmul_qs16(acc30_qs16, alpha_qs16, fixed_point_position);
+ acc01_qs16 = vqmul_qs16(acc01_qs16, alpha_qs16, fixed_point_position);
+ acc11_qs16 = vqmul_qs16(acc11_qs16, alpha_qs16, fixed_point_position);
+ acc21_qs16 = vqmul_qs16(acc21_qs16, alpha_qs16, fixed_point_position);
+ acc31_qs16 = vqmul_qs16(acc31_qs16, alpha_qs16, fixed_point_position);
+ }
+
+ const auto mtx_out0 = reinterpret_cast<qint16_t *>(out.ptr());
+
+ // Store 8x4 output elements
+ vst1_qs16(mtx_out0 + 0, acc00_qs16);
+ vst1_qs16(mtx_out0 + 4, acc01_qs16);
+ vst1_qs16(mtx_out0 + out_stride1 + 0, acc10_qs16);
+ vst1_qs16(mtx_out0 + out_stride1 + 4, acc11_qs16);
+ vst1_qs16(mtx_out0 + out_stride2 + 0, acc20_qs16);
+ vst1_qs16(mtx_out0 + out_stride2 + 4, acc21_qs16);
+ vst1_qs16(mtx_out0 + out_stride3 + 0, acc30_qs16);
+ vst1_qs16(mtx_out0 + out_stride3 + 4, acc31_qs16);
+ },
+ ina, inb, out);
+}
} // namespace
NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
@@ -1162,10 +1405,7 @@ NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::F16, DataType::F32, DataType::QS8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
@@ -1197,6 +1437,11 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
num_elems_processed_per_iteration_x = 32;
break;
}
+ case DataType::QS16:
+ {
+ num_elems_processed_per_iteration_x = 16;
+ break;
+ }
#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
@@ -1241,6 +1486,11 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
num_elems_processed_per_iteration_x = 32;
break;
}
+ case DataType::QS16:
+ {
+ num_elems_processed_per_iteration_x = 8;
+ break;
+ }
#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
@@ -1278,7 +1528,7 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window)
bool multiply_alpha = std::abs(1.0f - _alpha) > 0.00001f;
- // Check if the output tensor is a vector and the data type is F32. If so,the kernel runs the vector-matrix multiplication
+ // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
if((_output->info()->dimension(1) == 1))
{
switch(_input0->info()->data_type())
@@ -1295,6 +1545,12 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window)
vector_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
break;
}
+ case DataType::QS16:
+ {
+ multiply_alpha ? vector_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
+ vector_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
+ break;
+ }
#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
@@ -1326,6 +1582,12 @@ void NEGEMMMatrixMultiplyKernel::run(const Window &window)
matrix_matrix_multiply_qs8<false>(_input0, _input1, _output, window, _alpha);
break;
}
+ case DataType::QS16:
+ {
+ multiply_alpha ? matrix_matrix_multiply_qs16<true>(_input0, _input1, _output, window, _alpha) :
+ matrix_matrix_multiply_qs16<false>(_input0, _input1, _output, window, _alpha);
+ break;
+ }
#ifdef ARM_COMPUTE_ENABLE_FP16
case DataType::F16:
{
diff --git a/src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp b/src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp
index f6cf2d1f8d..881ef122a1 100644
--- a/src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMTranspose1xWKernel.cpp
@@ -43,7 +43,8 @@ using namespace arm_compute;
void NEGEMMTranspose1xWKernel::configure(const ITensor *input, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::U8, DataType::S8, DataType::U16, DataType::S16, DataType::U32, DataType::S32, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::U8, DataType::S8, DataType::U16, DataType::S16, DataType::U32, DataType::S32, DataType::F16,
+ DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
TensorShape output_shape{ input->info()->tensor_shape() };
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 73c5f548c9..dfcb3954ea 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -43,11 +43,13 @@ NEGEMM::NEGEMM()
void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16, DataType::QS8);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16, DataType::QS8, DataType::QS16);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b, d);
+ ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(0) != b->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
if(c != nullptr)
{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(c, 1, DataType::F32, DataType::F16, DataType::QS8, DataType::QS16);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, c);
ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(1) != c->info()->dimension(1), "The C matrix must have the same number of rows as the matrix A");
ARM_COMPUTE_ERROR_ON_MSG(b->info()->dimension(0) != c->info()->dimension(0), "The C matrix must have the same number of columns as the matrix B");
@@ -55,8 +57,6 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
ARM_COMPUTE_ERROR_ON_MSG(c->info()->dimension(1) != d->info()->dimension(1), "The C matrix must have the same number of columns as the output matrix");
}
- ARM_COMPUTE_ERROR_ON_MSG(a->info()->dimension(0) != b->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
-
// Check if the first input tensor is a vector. If so, all the kernels for reshaping the tensors can be skipped
if((a->info()->dimension(1) == 1))
{
@@ -75,33 +75,9 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
shape_tmp_a.set(0, a->info()->dimension(0) * 4);
shape_tmp_a.set(1, std::ceil(a->info()->dimension(1) / 4.0f));
- switch(a->info()->data_type())
- {
- case DataType::F32:
- {
- shape_tmp_b.set(0, b->info()->dimension(1) * 4);
- shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / 4.0f));
- break;
- }
- case DataType::F16:
-#ifdef ARM_COMPUTE_ENABLE_FP16
- {
- shape_tmp_b.set(0, b->info()->dimension(1) * 8);
- shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / 8.0f));
- break;
- }
-#endif /* ARM_COMPUTE_ENABLE_FP16 */
- case DataType::QS8:
- {
- shape_tmp_b.set(0, b->info()->dimension(1) * 16);
- shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / 16.0f));
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR_ON("Data type not supported");
- }
- }
+ const unsigned int transpose_w = 16 / data_size_from_type(b->info()->data_type());
+ shape_tmp_b.set(0, b->info()->dimension(1) * transpose_w);
+ shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / static_cast<float>(transpose_w)));
TensorInfo info_a(shape_tmp_a, 1, a->info()->data_type(), a->info()->fixed_point_position());
TensorInfo info_b(shape_tmp_b, 1, b->info()->data_type(), a->info()->fixed_point_position());
diff --git a/tests/TensorLibrary.h b/tests/TensorLibrary.h
index 6c079b6872..5b2c5b6fd5 100644
--- a/tests/TensorLibrary.h
+++ b/tests/TensorLibrary.h
@@ -469,6 +469,7 @@ void TensorLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t
break;
}
case DataType::S16:
+ case DataType::QS16:
{
std::uniform_int_distribution<int16_t> distribution_s16(std::numeric_limits<int16_t>::lowest(), std::numeric_limits<int16_t>::max());
fill(tensor, distribution_s16, seed_offset);
diff --git a/tests/TypePrinter.h b/tests/TypePrinter.h
index 4fb3b64d42..ff9863e1fb 100644
--- a/tests/TypePrinter.h
+++ b/tests/TypePrinter.h
@@ -311,6 +311,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const DataType &data_type)
case DataType::S16:
os << "S16";
break;
+ case DataType::QS16:
+ os << "QS16";
+ break;
case DataType::U32:
os << "U32";
break;
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index 4174de0cc8..0b608902a3 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -50,7 +50,7 @@ using namespace arm_compute::test::validation;
namespace
{
const float tolerance_f32 = 1e-03f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
-const float tolerance_qs8 = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for DataType::QS8 */
+const float tolerance_q = 1.0f; /**< Tolerance value for comparing reference's output against implementation's output for fixed point data types */
Tensor compute_gemm(const TensorShape &src_shape1, const TensorShape &src_shape2, const TensorShape &src_shape3,
const TensorShape &out_shape, float alpha, float beta, DataType dt, int fixed_point_position = 0)
@@ -104,7 +104,7 @@ BOOST_AUTO_TEST_SUITE(GEMM)
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly"))
BOOST_DATA_TEST_CASE(Configuration,
- SmallGEMMDataset() * boost::unit_test::data::make({ DataType::F32, DataType::QS8 }),
+ SmallGEMMDataset() * boost::unit_test::data::make({ DataType::F32, DataType::QS8, DataType::QS16 }),
gemm_set, dt)
{
// Set fixed point position data type allowed
@@ -187,7 +187,7 @@ BOOST_AUTO_TEST_SUITE_END()
BOOST_AUTO_TEST_SUITE(Quantized)
BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
-BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make(DataType::QS8) * boost::unit_test::data::xrange(1, 7),
+BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::make({ DataType::QS8, DataType::QS16 }) * boost::unit_test::data::xrange(1, 7),
gemm_set, dt, fixed_point_position)
{
// Compute reference
@@ -197,11 +197,11 @@ BOOST_DATA_TEST_CASE(SmallGEMM, SmallGEMMDataset() * boost::unit_test::data::mak
Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
// Validate output
- validate(NEAccessor(dst), ref_dst, tolerance_qs8);
+ validate(NEAccessor(dst), ref_dst, tolerance_q);
}
BOOST_TEST_DECORATOR(*boost::unit_test::label("nightly"))
-BOOST_DATA_TEST_CASE(LargeGEMM, LargeGEMMDataset() * boost::unit_test::data::make(DataType::QS8) * boost::unit_test::data::xrange(1, 7),
+BOOST_DATA_TEST_CASE(LargeGEMM, LargeGEMMDataset() * boost::unit_test::data::make({ DataType::QS8, DataType::QS16 }) * boost::unit_test::data::xrange(1, 7),
gemm_set, dt, fixed_point_position)
{
// Compute reference
@@ -211,7 +211,7 @@ BOOST_DATA_TEST_CASE(LargeGEMM, LargeGEMMDataset() * boost::unit_test::data::mak
Tensor dst = compute_gemm(gemm_set.shape_a, gemm_set.shape_b, gemm_set.shape_c, gemm_set.shape_d, gemm_set.alpha, gemm_set.beta, dt, fixed_point_position);
// Validate output
- validate(NEAccessor(dst), ref_dst, tolerance_qs8);
+ validate(NEAccessor(dst), ref_dst, tolerance_q);
}
BOOST_AUTO_TEST_SUITE_END()