aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemmlowp.cl
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-07-26 15:50:09 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitc50da386dc26a6e0a1690a47e72d5fa766e7dba2 (patch)
tree5ca2cc9f92b9f3a388e4f7f85098e92b093fc6d2 /src/core/CL/cl_kernels/gemmlowp.cl
parenta4658ae6cac4694fe28df5837bc4f4c154ab7204 (diff)
downloadComputeLibrary-c50da386dc26a6e0a1690a47e72d5fa766e7dba2.tar.gz
COMPMID-1431 Use either arm_dot or arm_dot_acc for CLGEMMLowp based on what is supported
Change-Id: I4c5121e0f000d5ee94a8c8c5326272806f643e35 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141520 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/gemmlowp.cl')
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl124
1 files changed, 66 insertions, 58 deletions
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index 0ee7c27350..cd8b269ae2 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -24,6 +24,14 @@
#include "helpers.h"
#include "helpers_asymm.h"
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val = arm_dot_acc((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3), val);
+#else // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#define ARM_DOT(x0, x1, x2, x3, y0, y1, y2, y3, val) val += arm_dot((uchar4)(x0, x1, x2, x3), (uchar4)(y0, y1, y2, y3));
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED)
+#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED)
+
#if defined(COLS_B) && defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(TRANSPOSE1XW_WIDTH_STEP)
/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
* Matrix A and matrix B must be reshaped respectively with @ref CLGEMMInterleave4x4Kernel and @ref CLGEMMTranspose1xWKernel before running the matrix multiplication
@@ -493,25 +501,25 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION(
uchar4 b3 = vload4(0, src_addr_b + 12 * TRANSPOSE1XW_WIDTH_STEP);
// Accumulate
- c00 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00);
- c01 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01);
- c02 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02);
- c03 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03);
-
- c10 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10);
- c11 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11);
- c12 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12);
- c13 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13);
-
- c20 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20);
- c21 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21);
- c22 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22);
- c23 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23);
-
- c30 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30);
- c31 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31);
- c32 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32);
- c33 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s0, b1.s0, b2.s0, b3.s0, c00);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s1, b1.s1, b2.s1, b3.s1, c01);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s2, b1.s2, b2.s2, b3.s2, c02);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s3, b1.s3, b2.s3, b3.s3, c03);
+
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s0, b1.s0, b2.s0, b3.s0, c10);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s1, b1.s1, b2.s1, b3.s1, c11);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s2, b1.s2, b2.s2, b3.s2, c12);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s3, b1.s3, b2.s3, b3.s3, c13);
+
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s0, b1.s0, b2.s0, b3.s0, c20);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s1, b1.s1, b2.s1, b3.s1, c21);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s2, b1.s2, b2.s2, b3.s2, c22);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s3, b1.s3, b2.s3, b3.s3, c23);
+
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s0, b1.s0, b2.s0, b3.s0, c30);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s1, b1.s1, b2.s1, b3.s1, c31);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s2, b1.s2, b2.s2, b3.s2, c32);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s3, b1.s3, b2.s3, b3.s3, c33);
// Load values from matrix A (interleaved) and matrix B (transposed)
a0 = vload16(0, src_addr_a + 16);
@@ -521,25 +529,25 @@ __kernel void gemmlowp_mm_interleaved_transposed_bifrost_dot8(IMAGE_DECLARATION(
b3 = vload4(0, src_addr_b + 28 * TRANSPOSE1XW_WIDTH_STEP);
// Accumulate
- c00 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c00);
- c01 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c01);
- c02 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c02);
- c03 = arm_dot_acc((uchar4)(a0.s0, a0.s4, a0.s8, a0.sC), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c03);
-
- c10 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c10);
- c11 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c11);
- c12 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c12);
- c13 = arm_dot_acc((uchar4)(a0.s1, a0.s5, a0.s9, a0.sD), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c13);
-
- c20 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c20);
- c21 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c21);
- c22 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c22);
- c23 = arm_dot_acc((uchar4)(a0.s2, a0.s6, a0.sA, a0.sE), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c23);
-
- c30 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), c30);
- c31 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), c31);
- c32 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), c32);
- c33 = arm_dot_acc((uchar4)(a0.s3, a0.s7, a0.sB, a0.sF), (uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), c33);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s0, b1.s0, b2.s0, b3.s0, c00);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s1, b1.s1, b2.s1, b3.s1, c01);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s2, b1.s2, b2.s2, b3.s2, c02);
+ ARM_DOT(a0.s0, a0.s4, a0.s8, a0.sC, b0.s3, b1.s3, b2.s3, b3.s3, c03);
+
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s0, b1.s0, b2.s0, b3.s0, c10);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s1, b1.s1, b2.s1, b3.s1, c11);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s2, b1.s2, b2.s2, b3.s2, c12);
+ ARM_DOT(a0.s1, a0.s5, a0.s9, a0.sD, b0.s3, b1.s3, b2.s3, b3.s3, c13);
+
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s0, b1.s0, b2.s0, b3.s0, c20);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s1, b1.s1, b2.s1, b3.s1, c21);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s2, b1.s2, b2.s2, b3.s2, c22);
+ ARM_DOT(a0.s2, a0.s6, a0.sA, a0.sE, b0.s3, b1.s3, b2.s3, b3.s3, c23);
+
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s0, b1.s0, b2.s0, b3.s0, c30);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s1, b1.s1, b2.s1, b3.s1, c31);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s2, b1.s2, b2.s2, b3.s2, c32);
+ ARM_DOT(a0.s3, a0.s7, a0.sB, a0.sF, b0.s3, b1.s3, b2.s3, b3.s3, c33);
}
#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
@@ -1180,45 +1188,45 @@ __kernel void gemmlowp_mm_bifrost_dot8(IMAGE_DECLARATION(src0),
{
// Accumulate
- acc00 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a0, acc00);
- acc01 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a0, acc01);
- acc02 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a0, acc02);
- acc03 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a0, acc03);
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a0.s0, a0.s1, a0.s2, a0.s3, acc00);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a0.s0, a0.s1, a0.s2, a0.s3, acc01);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a0.s0, a0.s1, a0.s2, a0.s3, acc02);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a0.s0, a0.s1, a0.s2, a0.s3, acc03);
}
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
{
// Accumulate
- acc10 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a1, acc10);
- acc11 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a1, acc11);
- acc12 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a1, acc12);
- acc13 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a1, acc13);
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a1.s0, a1.s1, a1.s2, a1.s3, acc10);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a1.s0, a1.s1, a1.s2, a1.s3, acc11);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a1.s0, a1.s1, a1.s2, a1.s3, acc12);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a1.s0, a1.s1, a1.s2, a1.s3, acc13);
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
{
// Accumulate
- acc20 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a2, acc20);
- acc21 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a2, acc21);
- acc22 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a2, acc22);
- acc23 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a2, acc23);
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a2.s0, a2.s1, a2.s2, a2.s3, acc20);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a2.s0, a2.s1, a2.s2, a2.s3, acc21);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a2.s0, a2.s1, a2.s2, a2.s3, acc22);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a2.s0, a2.s1, a2.s2, a2.s3, acc23);
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
{
// Accumulate
- acc30 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a3, acc30);
- acc31 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a3, acc31);
- acc32 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a3, acc32);
- acc33 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a3, acc33);
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a3.s0, a3.s1, a3.s2, a3.s3, acc30);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a3.s0, a3.s1, a3.s2, a3.s3, acc31);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a3.s0, a3.s1, a3.s2, a3.s3, acc32);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a3.s0, a3.s1, a3.s2, a3.s3, acc33);
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
{
// Accumulate
- acc40 = arm_dot_acc((uchar4)(b0.s0, b1.s0, b2.s0, b3.s0), a4, acc40);
- acc41 = arm_dot_acc((uchar4)(b0.s1, b1.s1, b2.s1, b3.s1), a4, acc41);
- acc42 = arm_dot_acc((uchar4)(b0.s2, b1.s2, b2.s2, b3.s2), a4, acc42);
- acc43 = arm_dot_acc((uchar4)(b0.s3, b1.s3, b2.s3, b3.s3), a4, acc43);
+ ARM_DOT(b0.s0, b1.s0, b2.s0, b3.s0, a4.s0, a4.s1, a4.s2, a4.s3, acc40);
+ ARM_DOT(b0.s1, b1.s1, b2.s1, b3.s1, a4.s0, a4.s1, a4.s2, a4.s3, acc41);
+ ARM_DOT(b0.s2, b1.s2, b2.s2, b3.s2, a4.s0, a4.s1, a4.s2, a4.s3, acc42);
+ ARM_DOT(b0.s3, b1.s3, b2.s3, b3.s3, a4.s0, a4.s1, a4.s2, a4.s3, acc43);
}
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}