aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/CL/CLKernelLibrary.cpp3
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl373
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp12
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp9
4 files changed, 385 insertions, 12 deletions
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 22a328bcaf..6695881d09 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -235,7 +235,8 @@ const std::map<std::string, std::string> CLKernelLibrary::_kernel_program_map =
{ "gemm_transpose1x4", "gemm.cl" },
{ "gemmlowp_matrix_a_reduction", "gemmlowp.cl" },
{ "gemmlowp_matrix_b_reduction", "gemmlowp.cl" },
- { "gemmlowp_mm", "gemmlowp.cl" },
+ { "gemmlowp_mm_bifrost", "gemmlowp.cl" },
+ { "gemmlowp_mm_midgard", "gemmlowp.cl" },
{ "gemmlowp_mm_interleaved_transposed", "gemmlowp.cl" },
{ "gemmlowp_offset_contribution", "gemmlowp.cl" },
{ "gemmlowp_output_stage_quantize_down", "gemmlowp.cl" },
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index a92881320e..d724600cdd 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -140,9 +140,9 @@ __kernel void gemmlowp_mm_interleaved_transposed(IMAGE_DECLARATION(src0),
* @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
*/
-__kernel void gemmlowp_mm(IMAGE_DECLARATION(src0),
- IMAGE_DECLARATION(src1),
- IMAGE_DECLARATION(dst))
+__kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0),
+ IMAGE_DECLARATION(src1),
+ IMAGE_DECLARATION(dst))
{
int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -167,6 +167,9 @@ __kernel void gemmlowp_mm(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
VECTOR_UINT acc3 = 0;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ VECTOR_UINT acc4 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
{
@@ -181,6 +184,9 @@ __kernel void gemmlowp_mm(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
uchar2 a3 = vload2(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uchar2 a4 = vload2(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
// Load values from matrix B
VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
VECTOR_UCHAR b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1 + src1_stride_y);
@@ -200,6 +206,10 @@ __kernel void gemmlowp_mm(IMAGE_DECLARATION(src0),
acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3.s0;
acc3 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a3.s1;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4.s0;
+ acc4 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a4.s1;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}
for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
@@ -215,6 +225,9 @@ __kernel void gemmlowp_mm(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
// Load values from matrix B
VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
@@ -229,6 +242,9 @@ __kernel void gemmlowp_mm(IMAGE_DECLARATION(src0),
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}
// Compute destination address
@@ -249,6 +265,355 @@ __kernel void gemmlowp_mm(IMAGE_DECLARATION(src0),
VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
(CONVERT(acc3, VECTOR_INT), 0, (__global int *)(offset(&dst, 0, 3)));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+ (CONVERT(acc4, VECTOR_INT), 0, (__global int *)(offset(&dst, 0, 4)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+}
+
+/** OpenCL kernel optimized for Bifrost architectures that computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
+ *
+ * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
+ *
+ * @param[in] src0_ptr Pointer to the source matrix. Supported data type: QASYMM8
+ * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in] src1_ptr Pointer to the source matrix. Supported data type: same as @p src0_ptr
+ * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
+ * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
+ * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
+ * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
+ * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ */
+__kernel void gemmlowp_mm_bifrost(IMAGE_DECLARATION(src0),
+ IMAGE_DECLARATION(src1),
+ IMAGE_DECLARATION(dst))
+{
+ int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
+
+ // Compute starting address for matrix A and Matrix B
+ int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
+
+ // Update address for the matrix A
+ src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
+
+ // Update address for the matrix B
+ src_addr.s1 += idx;
+
+ int end_row_vec_a = src_addr.s0 + COLS_A;
+
+ uint acc00 = 0;
+ uint acc01 = 0;
+ uint acc02 = 0;
+ uint acc03 = 0;
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uint acc10 = 0;
+ uint acc11 = 0;
+ uint acc12 = 0;
+ uint acc13 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uint acc20 = 0;
+ uint acc21 = 0;
+ uint acc22 = 0;
+ uint acc23 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uint acc30 = 0;
+ uint acc31 = 0;
+ uint acc32 = 0;
+ uint acc33 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uint acc40 = 0;
+ uint acc41 = 0;
+ uint acc42 = 0;
+ uint acc43 = 0;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+
+ for(; src_addr.s0 <= (end_row_vec_a - 4); src_addr += (int2)(4, 4 * src1_stride_y))
+ {
+ // Load values from matrix A
+ uchar4 a0 = vload4(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uchar4 a1 = vload4(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uchar4 a2 = vload4(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uchar4 a3 = vload4(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uchar4 a4 = vload4(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ // Load values from matrix B
+ uchar4 b0 = vload4(0, src1_ptr + src_addr.s1 + 0 * src1_stride_y);
+ uchar4 b1 = vload4(0, src1_ptr + src_addr.s1 + 1 * src1_stride_y);
+ uchar4 b2 = vload4(0, src1_ptr + src_addr.s1 + 2 * src1_stride_y);
+ uchar4 b3 = vload4(0, src1_ptr + src_addr.s1 + 3 * src1_stride_y);
+
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a0.s0;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a0.s0;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a0.s0;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a0.s0;
+
+ ushort tmp4 = (ushort)b1.s0 * (ushort)a0.s1;
+ ushort tmp5 = (ushort)b1.s1 * (ushort)a0.s1;
+ ushort tmp6 = (ushort)b1.s2 * (ushort)a0.s1;
+ ushort tmp7 = (ushort)b1.s3 * (ushort)a0.s1;
+
+ ushort tmp8 = (ushort)b2.s0 * (ushort)a0.s2;
+ ushort tmp9 = (ushort)b2.s1 * (ushort)a0.s2;
+ ushort tmpA = (ushort)b2.s2 * (ushort)a0.s2;
+ ushort tmpB = (ushort)b2.s3 * (ushort)a0.s2;
+
+ ushort tmpC = (ushort)b3.s0 * (ushort)a0.s3;
+ ushort tmpD = (ushort)b3.s1 * (ushort)a0.s3;
+ ushort tmpE = (ushort)b3.s2 * (ushort)a0.s3;
+ ushort tmpF = (ushort)b3.s3 * (ushort)a0.s3;
+
+ acc00 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
+ acc01 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
+ acc02 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
+ acc03 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+ }
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a1.s0;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a1.s0;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a1.s0;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a1.s0;
+
+ ushort tmp4 = (ushort)b1.s0 * (ushort)a1.s1;
+ ushort tmp5 = (ushort)b1.s1 * (ushort)a1.s1;
+ ushort tmp6 = (ushort)b1.s2 * (ushort)a1.s1;
+ ushort tmp7 = (ushort)b1.s3 * (ushort)a1.s1;
+
+ ushort tmp8 = (ushort)b2.s0 * (ushort)a1.s2;
+ ushort tmp9 = (ushort)b2.s1 * (ushort)a1.s2;
+ ushort tmpA = (ushort)b2.s2 * (ushort)a1.s2;
+ ushort tmpB = (ushort)b2.s3 * (ushort)a1.s2;
+
+ ushort tmpC = (ushort)b3.s0 * (ushort)a1.s3;
+ ushort tmpD = (ushort)b3.s1 * (ushort)a1.s3;
+ ushort tmpE = (ushort)b3.s2 * (ushort)a1.s3;
+ ushort tmpF = (ushort)b3.s3 * (ushort)a1.s3;
+
+ acc10 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
+ acc11 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
+ acc12 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
+ acc13 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a2.s0;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a2.s0;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a2.s0;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a2.s0;
+
+ ushort tmp4 = (ushort)b1.s0 * (ushort)a2.s1;
+ ushort tmp5 = (ushort)b1.s1 * (ushort)a2.s1;
+ ushort tmp6 = (ushort)b1.s2 * (ushort)a2.s1;
+ ushort tmp7 = (ushort)b1.s3 * (ushort)a2.s1;
+
+ ushort tmp8 = (ushort)b2.s0 * (ushort)a2.s2;
+ ushort tmp9 = (ushort)b2.s1 * (ushort)a2.s2;
+ ushort tmpA = (ushort)b2.s2 * (ushort)a2.s2;
+ ushort tmpB = (ushort)b2.s3 * (ushort)a2.s2;
+
+ ushort tmpC = (ushort)b3.s0 * (ushort)a2.s3;
+ ushort tmpD = (ushort)b3.s1 * (ushort)a2.s3;
+ ushort tmpE = (ushort)b3.s2 * (ushort)a2.s3;
+ ushort tmpF = (ushort)b3.s3 * (ushort)a2.s3;
+
+ acc20 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
+ acc21 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
+ acc22 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
+ acc23 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a3.s0;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a3.s0;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a3.s0;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a3.s0;
+
+ ushort tmp4 = (ushort)b1.s0 * (ushort)a3.s1;
+ ushort tmp5 = (ushort)b1.s1 * (ushort)a3.s1;
+ ushort tmp6 = (ushort)b1.s2 * (ushort)a3.s1;
+ ushort tmp7 = (ushort)b1.s3 * (ushort)a3.s1;
+
+ ushort tmp8 = (ushort)b2.s0 * (ushort)a3.s2;
+ ushort tmp9 = (ushort)b2.s1 * (ushort)a3.s2;
+ ushort tmpA = (ushort)b2.s2 * (ushort)a3.s2;
+ ushort tmpB = (ushort)b2.s3 * (ushort)a3.s2;
+
+ ushort tmpC = (ushort)b3.s0 * (ushort)a3.s3;
+ ushort tmpD = (ushort)b3.s1 * (ushort)a3.s3;
+ ushort tmpE = (ushort)b3.s2 * (ushort)a3.s3;
+ ushort tmpF = (ushort)b3.s3 * (ushort)a3.s3;
+
+ acc30 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
+ acc31 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
+ acc32 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
+ acc33 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a4.s0;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a4.s0;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a4.s0;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a4.s0;
+
+ ushort tmp4 = (ushort)b1.s0 * (ushort)a4.s1;
+ ushort tmp5 = (ushort)b1.s1 * (ushort)a4.s1;
+ ushort tmp6 = (ushort)b1.s2 * (ushort)a4.s1;
+ ushort tmp7 = (ushort)b1.s3 * (ushort)a4.s1;
+
+ ushort tmp8 = (ushort)b2.s0 * (ushort)a4.s2;
+ ushort tmp9 = (ushort)b2.s1 * (ushort)a4.s2;
+ ushort tmpA = (ushort)b2.s2 * (ushort)a4.s2;
+ ushort tmpB = (ushort)b2.s3 * (ushort)a4.s2;
+
+ ushort tmpC = (ushort)b3.s0 * (ushort)a4.s3;
+ ushort tmpD = (ushort)b3.s1 * (ushort)a4.s3;
+ ushort tmpE = (ushort)b3.s2 * (ushort)a4.s3;
+ ushort tmpF = (ushort)b3.s3 * (ushort)a4.s3;
+
+ acc40 += ((uint)tmp0 + (uint)tmp4 + (uint)tmp8 + (uint)tmpC);
+ acc41 += ((uint)tmp1 + (uint)tmp5 + (uint)tmp9 + (uint)tmpD);
+ acc42 += ((uint)tmp2 + (uint)tmp6 + (uint)tmpA + (uint)tmpE);
+ acc43 += ((uint)tmp3 + (uint)tmp7 + (uint)tmpB + (uint)tmpF);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ }
+
+ for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
+ {
+ // Load values from matrix A
+ uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ // Load values from matrix B
+ uchar4 b0 = vload4(0, src1_ptr + src_addr.s1);
+
+ // Accumulate
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a0;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a0;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a0;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a0;
+
+ acc00 += ((uint)tmp0);
+ acc01 += ((uint)tmp1);
+ acc02 += ((uint)tmp2);
+ acc03 += ((uint)tmp3);
+ }
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a1;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a1;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a1;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a1;
+
+ acc10 += ((uint)tmp0);
+ acc11 += ((uint)tmp1);
+ acc12 += ((uint)tmp2);
+ acc13 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a2;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a2;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a2;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a2;
+
+ acc20 += ((uint)tmp0);
+ acc21 += ((uint)tmp1);
+ acc22 += ((uint)tmp2);
+ acc23 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a3;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a3;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a3;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a3;
+
+ acc30 += ((uint)tmp0);
+ acc31 += ((uint)tmp1);
+ acc32 += ((uint)tmp2);
+ acc33 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ {
+ // Accumulate
+ ushort tmp0 = (ushort)b0.s0 * (ushort)a4;
+ ushort tmp1 = (ushort)b0.s1 * (ushort)a4;
+ ushort tmp2 = (ushort)b0.s2 * (ushort)a4;
+ ushort tmp3 = (ushort)b0.s3 * (ushort)a4;
+
+ acc40 += ((uint)tmp0);
+ acc41 += ((uint)tmp1);
+ acc42 += ((uint)tmp2);
+ acc43 += ((uint)tmp3);
+ }
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ }
+
+ // Compute destination address
+ Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+
+ // Store the result
+ vstore4((int4)(acc00, acc01, acc02, acc03), 0, (__global int *)(offset(&dst, 0, 0)));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+ vstore4((int4)(acc10, acc11, acc12, acc13), 0, (__global int *)(offset(&dst, 0, 1)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+ vstore4((int4)(acc20, acc21, acc22, acc23), 0, (__global int *)(offset(&dst, 0, 2)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+ vstore4((int4)(acc30, acc31, acc32, acc33), 0, (__global int *)(offset(&dst, 0, 3)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
+ vstore4((int4)(acc40, acc41, acc42, acc43), 0, (__global int *)(offset(&dst, 0, 4)));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}
#endif // defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
index 423592b79c..2f96724210 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#include "arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.h"
#include "arm_compute/core/AccessWindowStatic.h"
+#include "arm_compute/core/CL/CLHelpers.h"
#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/OpenCL.h"
@@ -94,8 +95,8 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITe
else
{
// Special case for 1xN, 2xN, 3xN and 4xN input0 tensor. num_elems_processed_per_iteration_x
- num_elems_processed_per_iteration_x = 16;
- num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 4);
+ num_elems_processed_per_iteration_x = 4;
+ num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 5);
// Configure window
win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
@@ -132,6 +133,9 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC
ElementsProcessed num_elements_processed{};
+ // Get target architecture
+ GPUTarget arch_target = get_arch_from_target(get_target());
+
// Configure kernel window
auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, num_elements_processed);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
@@ -150,7 +154,7 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC
build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y()));
- kernel_name = "gemmlowp_mm";
+ kernel_name = "gemmlowp_mm_" + string_from_target(arch_target);
}
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index ddcab6a256..2cd426b82d 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,7 +52,10 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
_b_offset = b->info()->quantization_info().offset;
// If the input tensor has less than 16 rows, we run a special version of GEMMLowp without reshaping the input tensors
- _is_interleaved_transposed = a->info()->dimension(1) > 16;
+ _is_interleaved_transposed = (a->info()->dimension(1)) > 16 && (CLScheduler::get().target() != GPUTarget::BIFROST);
+
+ // Set the target for the matrix multiply kernel
+ _mm_kernel.set_target(CLScheduler::get().target());
const ICLTensor *matrix_a = a;
const ICLTensor *matrix_b = b;
@@ -138,7 +141,7 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
int32_t a_offset = a->quantization_info().offset;
int32_t b_offset = b->quantization_info().offset;
- bool is_interleaved_transposed = a->dimension(1) > 16;
+ bool is_interleaved_transposed = (a->dimension(1)) > 16 && (CLScheduler::get().target() != GPUTarget::BIFROST);
if(is_interleaved_transposed)
{