aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/gemm.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/gemm.cl')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl257
1 files changed, 173 insertions, 84 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 8e628e8d01..c35d160689 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -2041,79 +2041,37 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
#if GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA(SIZE, a, b, c) c += (a) * (b);
+#define ARM_VFMA(a, b, c) c += (a) * (b);
#else // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA_1(a, b, c) \
- ({ \
- c = fma((a), (b), (c)); \
- })
-#define ARM_VFMA_2(a, b, c) \
- ({ \
- (c).s0 = fma((a).s0, (b).s0, (c).s0); \
- (c).s1 = fma((a).s1, (b).s1, (c).s1); \
- })
-#define ARM_VFMA_3(a, b, c) \
- ({ \
- ARM_VFMA_2(a, b, c); \
- (c).s2 = fma((a).s2, (b).s2, (c).s2); \
- })
-#define ARM_VFMA_4(a, b, c) \
- ({ \
- ARM_VFMA_3(a, b, c); \
- (c).s3 = fma((a).s3, (b).s3, (c).s3); \
- })
-#define ARM_VFMA_8(a, b, c) \
- ({ \
- ARM_VFMA_4(a, b, c); \
- (c).s4 = fma((a).s4, (b).s4, (c).s4); \
- (c).s5 = fma((a).s5, (b).s5, (c).s5); \
- (c).s6 = fma((a).s6, (b).s6, (c).s6); \
- (c).s7 = fma((a).s7, (b).s7, (c).s7); \
- })
-#define ARM_VFMA_16(a, b, c) \
- ({ \
- ARM_VFMA_8(a, b, c); \
- (c).s8 = fma((a).s8, (b).s8, (c).s8); \
- (c).s9 = fma((a).s9, (b).s9, (c).s9); \
- (c).sA = fma((a).sA, (b).sA, (c).sA); \
- (c).sB = fma((a).sB, (b).sB, (c).sB); \
- (c).sC = fma((a).sC, (b).sC, (c).sC); \
- (c).sD = fma((a).sD, (b).sD, (c).sD); \
- (c).sE = fma((a).sE, (b).sE, (c).sE); \
- (c).sF = fma((a).sF, (b).sF, (c).sF); \
- })
-
-// Factory macro for the vector FMA
-#define ARM_VFMA(SIZE, a, b, c) ARM_VFMA_##SIZE((a), (b), (c))
-
+#define ARM_VFMA(a, b, c) c = fma((a), (b), (c));
#endif // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
+#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \
})
-#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
+#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
})
-#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
+#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
})
-#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
+#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
})
-#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
+#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
})
// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
@@ -2172,7 +2130,8 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// K0: 1, 2, 3, 4, 8, 16
// This macro calls the vector-by-matrix macro K0 times
// A, B and C are matrices
-#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
+#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
+ CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
(M0, N0, TYPE, A, B, C)
/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
@@ -2272,11 +2231,9 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
#if defined(RHS_INTERLEAVE)
#define RHS_OFFSET_X (N0)
#define RHS_STEP_X ((N0) * (H0))
-#define RHS_STEP_LOOP (1)
#else // defined(RHS_INTERLEAVE)
#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
#define RHS_STEP_X (N0)
-#define RHS_STEP_LOOP (H0)
#endif // defined(RHS_INTERLEAVE)
const uint x = get_global_id(0);
@@ -2306,28 +2263,160 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
// Initialize the accumulators
REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
- REPEAT_VAR_INIT_TO_CONST(K0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
+ __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
+ __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
+
for(int i = 0; i < k; i += K0)
{
- // Supported cases (K0, M0):
- // 1,2 - 2,2 - 3,2 - 4,2 - 5,2 - 6,2 - 7,2 - 8,2
- // 1,3 - 2,3 - 3,3 - 4,3 - 5,3 - 6,3 - 7,3 - 8,3
- // 1,4 - 2,4 - 3,4 - 4,4 - 5,4 - 6,4 - 7,4 - 8,4
- // 1,8 - 2,8 - 3,8 - 4,8 - 5,8 - 6,8 - 7,8 - 8,8
- // 1,16 - 2,16 - 3,16 - 4,16 - 5,16 - 6,16 - 7,16 - 8,16
- // Load values from LHS matrix
- LOAD_BLOCK(K0, M0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+ VEC_DATA_TYPE(DATA_TYPE, M0)
+ a0 = VLOAD(M0)(0, lhs);
+ VEC_DATA_TYPE(DATA_TYPE, N0)
+ b0 = VLOAD(N0)(0, rhs);
- // Load values from RHS matrix
- LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+#if K0 > 1
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 1
+
+#if K0 > 2
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 2
+
+#if K0 > 3
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 3
+
+#if K0 > 4
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 4
+
+#if K0 > 8
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 8
- // Perform the partial matrix multiplication
- ARM_MM_T_NT(M0, N0, K0, DATA_TYPE, a, b, c);
+#ifndef LHS_INTERLEAVE
+ lhs += (M0 * K0 * (V0 - 1));
+#endif // LHS_INTERLEAVE
- lhs_addr += (K0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
- rhs_addr += (K0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
+#ifndef RHS_INTERLEAVE
+ rhs += (N0 * K0 * (H0 - 1));
+#endif // RHS_INTERLEAVE
}
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);