From c50da386dc26a6e0a1690a47e72d5fa766e7dba2 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Thu, 26 Jul 2018 15:50:09 +0100 Subject: COMPMID-1431 Use either arm_dot or arm_dot_acc for CLGEMMLowp based on what is supported Change-Id: I4c5121e0f000d5ee94a8c8c5326272806f643e35 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141520 Tested-by: Jenkins Reviewed-by: Georgios Pinitas --- src/core/CL/cl_kernels/gemmlowp.cl | 124 ++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 58 deletions(-) (limited to 'src/core/CL/cl_kernels/gemmlowp.cl') diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl index 0ee7c27350..cd8b269ae2 100644 --- a/src/core/CL/cl_kernels/gemmlowp.cl +++ b/src/core/CL/cl_kernels/gemmlowp.cl @@ -24,6 +24,14 @@ #include "helpers.h" #include "helpers_asymm.h" +#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) +#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) +#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val = arm_dot_acc((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3), val); +#else // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) +#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val += arm_dot((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3)); +#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) +#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) + #if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP) /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) * Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication @@ -493,25 +501,25 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION( uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP); // Accumulate - c00 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00); - c01 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01); - c02 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02); - c03 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03); - - c10 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10); - c11 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11); - c12 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12); - c13 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13); - - c20 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20); - c21 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21); - c22 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22); - c23 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23); - - c30 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30); - c31 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31); - c32 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32); - c33 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33); + ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s0, b1.s0, b2.s0, b3.s0, c00); + ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s1, b1.s1, b2.s1, b3.s1, c01); + ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s2, b1.s2, b2.s2, b3.s2, c02); + ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s3, b1.s3, b2.s3, b3.s3, c03); + + ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s0, b1.s0, b2.s0, b3.s0, c10); + ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s1, b1.s1, b2.s1, b3.s1, c11); + ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s2, b1.s2, b2.s2, b3.s2, c12); + ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s3, b1.s3, b2.s3, b3.s3, c13); + + ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s0, b1.s0, b2.s0, b3.s0, c20); + ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s1, b1.s1, b2.s1, b3.s1, c21); + ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s2, b1.s2, b2.s2, b3.s2, c22); + ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s3, b1.s3, b2.s3, b3.s3, c23); + + ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s0, b1.s0, b2.s0, b3.s0, c30); + ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s1, b1.s1, b2.s1, b3.s1, c31); + ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s2, b1.s2, b2.s2, b3.s2, c32); + ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s3, b1.s3, b2.s3, b3.s3, c33); // Load values from matrix A (interleaved) and matrix B (transposed) a0 = vload16(0, src_addr_a + 16); @@ -521,25 +529,25 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION( b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP); // Accumulate - c00 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00); - c01 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01); - c02 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02); - c03 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03); - - c10 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10); - c11 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11); - c12 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12); - c13 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13); - - c20 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20); - c21 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21); - c22 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22); - c23 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23); - - c30 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30); - c31 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31); - c32 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32); - c33 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33); + ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s0, b1.s0, b2.s0, b3.s0, c00); + ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s1, b1.s1, b2.s1, b3.s1, c01); + ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s2, b1.s2, b2.s2, b3.s2, c02); + ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s3, b1.s3, b2.s3, b3.s3, c03); + + ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s0, b1.s0, b2.s0, b3.s0, c10); + ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s1, b1.s1, b2.s1, b3.s1, c11); + ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s2, b1.s2, b2.s2, b3.s2, c12); + ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s3, b1.s3, b2.s3, b3.s3, c13); + + ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s0, b1.s0, b2.s0, b3.s0, c20); + ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s1, b1.s1, b2.s1, b3.s1, c21); + ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s2, b1.s2, b2.s2, b3.s2, c22); + ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s3, b1.s3, b2.s3, b3.s3, c23); + + ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s0, b1.s0, b2.s0, b3.s0, c30); + ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s1, b1.s1, b2.s1, b3.s1, c31); + ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s2, b1.s2, b2.s2, b3.s2, c32); + ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s3, b1.s3, b2.s3, b3.s3, c33); } #endif // MULT_INTERLEAVE4X4_HEIGHT == 1 @@ -1180,45 +1188,45 @@ __kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0), { // Accumulate - acc00 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0, acc00); - acc01 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0, acc01); - acc02 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0, acc02); - acc03 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0, acc03); + ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a0.s0, a0.s1, a0.s2, a0.s3, acc00); + ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a0.s0, a0.s1, a0.s2, a0.s3, acc01); + ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a0.s0, a0.s1, a0.s2, a0.s3, acc02); + ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a0.s0, a0.s1, a0.s2, a0.s3, acc03); } #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 { // Accumulate - acc10 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1, acc10); - acc11 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1, acc11); - acc12 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1, acc12); - acc13 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1, acc13); + ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a1.s0, a1.s1, a1.s2, a1.s3, acc10); + ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a1.s0, a1.s1, a1.s2, a1.s3, acc11); + ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a1.s0, a1.s1, a1.s2, a1.s3, acc12); + ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a1.s0, a1.s1, a1.s2, a1.s3, acc13); } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 { // Accumulate - acc20 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2, acc20); - acc21 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2, acc21); - acc22 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2, acc22); - acc23 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2, acc23); + ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a2.s0, a2.s1, a2.s2, a2.s3, acc20); + ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a2.s0, a2.s1, a2.s2, a2.s3, acc21); + ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a2.s0, a2.s1, a2.s2, a2.s3, acc22); + ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a2.s0, a2.s1, a2.s2, a2.s3, acc23); } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 { // Accumulate - acc30 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3, acc30); - acc31 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3, acc31); - acc32 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3, acc32); - acc33 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3, acc33); + ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a3.s0, a3.s1, a3.s2, a3.s3, acc30); + ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a3.s0, a3.s1, a3.s2, a3.s3, acc31); + ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a3.s0, a3.s1, a3.s2, a3.s3, acc32); + ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a3.s0, a3.s1, a3.s2, a3.s3, acc33); } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 { // Accumulate - acc40 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4, acc40); - acc41 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4, acc41); - acc42 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4, acc42); - acc43 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4, acc43); + ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a4.s0, a4.s1, a4.s2, a4.s3, acc40); + ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a4.s0, a4.s1, a4.s2, a4.s3, acc41); + ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a4.s0, a4.s1, a4.s2, a4.s3, acc42); + ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a4.s0, a4.s1, a4.s2, a4.s3, acc43); } #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 } -- cgit v1.2.1