aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-27 09:23:15 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-30 08:28:43 +0000
commit0c17aa25a4f7bc812707150b91930f0cf8e75294 (patch)
tree29088e00bd7ba443dc122ad3436b0a4ef369a102 /src
parent40958adf8bad8fd9fefe591ee55a381f7bbb6fea (diff)
downloadComputeLibrary-0c17aa25a4f7bc812707150b91930f0cf8e75294.tar.gz
COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16
Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1991 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl161
-rw-r--r--src/core/CL/cl_kernels/gemm_helpers.h88
-rw-r--r--src/core/CL/cl_kernels/helpers.h17
-rw-r--r--src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp20
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp1
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp15
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp1
7 files changed, 264 insertions, 39 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index c35d160689..57a5af8ec2 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -1676,8 +1676,66 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
}
#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
-#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
+#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR) && defined(M) && defined(N)
+#if defined(MIXED_PRECISION)
+#if K0 == 2
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c += a.s0 * b.s0; \
+ c += a.s1 * b.s1; \
+ })
+#elif K0 == 3 // K0 == 3
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c += a.s0 * b.s0; \
+ c += a.s1 * b.s1; \
+ c += a.s2 * b.s2; \
+ })
+#elif K0 == 4 // K0 == 4
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c += a.s0 * b.s0; \
+ c += a.s1 * b.s1; \
+ c += a.s2 * b.s2; \
+ c += a.s3 * b.s3; \
+ })
+#elif K0 == 8 // K0 == 8
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c += a.s0 * b.s0; \
+ c += a.s1 * b.s1; \
+ c += a.s2 * b.s2; \
+ c += a.s3 * b.s3; \
+ c += a.s4 * b.s4; \
+ c += a.s5 * b.s5; \
+ c += a.s6 * b.s6; \
+ c += a.s7 * b.s7; \
+ })
+#elif K0 == 16 // K0 == 16
+#define ARM_DOT_K0(a, b, c) \
+ ({ \
+ c += a.s0 * b.s0; \
+ c += a.s1 * b.s1; \
+ c += a.s2 * b.s2; \
+ c += a.s3 * b.s3; \
+ c += a.s4 * b.s4; \
+ c += a.s5 * b.s5; \
+ c += a.s6 * b.s6; \
+ c += a.s7 * b.s7; \
+ c += a.s8 * b.s8; \
+ c += a.s9 * b.s9; \
+ c += a.sA * b.sA; \
+ c += a.sB * b.sB; \
+ c += a.sC * b.sC; \
+ c += a.sD * b.sD; \
+ c += a.sE * b.sE; \
+ c += a.sF * b.sF; \
+ })
+#else // K0 not supported
+#error "K0 value not supported"
+#endif // K0 conditions
+#else // defined(MIXED_PRECISION)
#if K0 == 2
#define ARM_DOT_K0(a, b, c) \
({ \
@@ -1734,6 +1792,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
#else // K0 not supported
#error "K0 value not supported"
#endif // K0 conditions
+#endif // defined(MIXED_PRECISION)
#if N0 == 2
#define ARM_DOT_K0XN0(a, b, c) \
@@ -1796,6 +1855,9 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
* The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
* The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
*
+ * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
+ * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float)
+ * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float
* @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
* @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
* @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
@@ -1917,7 +1979,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#endif // defined(MATRIX_B_DEPTH)
// Initialize the accumulators
- REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
@@ -2003,7 +2065,12 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#endif // UNIT_BIAS
// c = c + bias[broadcasted]
+#if defined(MIXED_PRECISION)
+ CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
+ ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
+#else // defined(MIXED_PRECISION)
ADD_BLOCK_BROADCAST(M0, c, bias0);
+#endif // defined(MIXED_PRECISION)
#else // defined(BROADCAST_BIAS)
__global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
@@ -2016,17 +2083,26 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#endif // UNIT_BIAS
// c = c + bias
+#if defined(MIXED_PRECISION)
+ CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
+ ADD_BLOCK(M0, c, bias_hp);
+#else // defined(MIXED_PRECISION)
ADD_BLOCK(M0, c, bias);
+#endif // defined(MIXED_PRECISION)
#endif // defined(BROADCAST_BIAS)
#endif // defined(BETA)
#if defined(ACTIVATION_TYPE)
- ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
+ ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
#endif // defined(ACTIVATION_TYPE)
// Store output block
+#if defined(MIXED_PRECISION)
+ CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
+#else // defined(MIXED_PRECISION)
STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
+#endif // defined(MIXED_PRECISION)
#undef LHS_BLOCK_SIZE
#undef LHS_OFFSET_X
@@ -2040,38 +2116,50 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
-#if GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA(a, b, c) c += (a) * (b);
+#if defined(MIXED_PRECISION)
+
+#if(GPU_ARCH == GPU_ARCH_MIDGARD)
+#define ARM_VFMA(N0, a, b, c) c += (CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))) * (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0)));
+#else // GPU_ARCH == GPU_ARCH_MIDGARD
+#define ARM_VFMA(N0, a, b, c) c = fma((CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (c));
+#endif // GPU_ARCH == GPU_ARCH_MIDGARD
+
+#else // defined(MIXED_PRECISION
+
+#if(GPU_ARCH == GPU_ARCH_MIDGARD)
+#define ARM_VFMA(N0, a, b, c) c += (a) * (b);
#else // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA(a, b, c) c = fma((a), (b), (c));
+#define ARM_VFMA(N0, a, b, c) c = fma((a), (b), (c));
#endif // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \
+#endif // defined(MIXED_PRECISION)
+
+#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
})
-#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
- ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
+#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
})
-#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
+#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
})
-#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
+#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
})
-#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
- ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
- ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
- ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
+#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
+ ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
})
// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
@@ -2261,7 +2349,7 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
#endif // defined(MATRIX_B_DEPTH)
// Initialize the accumulators
- REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
@@ -2455,7 +2543,12 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
#endif // UNIT_BIAS
// c = c + bias[broadcasted]
+#if defined(MIXED_PRECISION)
+ CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
+ ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
+#else // defined(MIXED_PRECISION)
ADD_BLOCK_BROADCAST(M0, c, bias0);
+#endif // defined(MIXED_PRECISION)
#else // defined(BROADCAST_BIAS)
__global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
@@ -2466,8 +2559,12 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
#endif // UNIT_BIAS
- // c = c + bias
+#if defined(MIXED_PRECISION)
+ CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
+ ADD_BLOCK(M0, c, bias_hp);
+#else // defined(MIXED_PRECISION)
ADD_BLOCK(M0, c, bias);
+#endif // defined(MIXED_PRECISION)
#endif // defined(BROADCAST_BIAS)
#endif // defined(BETA)
@@ -2477,7 +2574,11 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
#endif // defined(ACTIVATION_TYPE)
// Store output block
+#if defined(MIXED_PRECISION)
+ CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
+#else // defined(MIXED_PRECISION)
STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
+#endif // defined(MIXED_PRECISION)
#undef LHS_BLOCK_SIZE
#undef LHS_OFFSET_X
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 4715fb737f..fd8c773444 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -689,4 +689,90 @@
/** Apply activation to the variables BASENAME##0... BASENAME##(n-1)
* Supported cases N=1,2,3..16, for variables BASENAME[0..N]
*/
-#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \ No newline at end of file
+#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)
+
+#define CONVERT_ROW_1(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##0 = CONVERT(BASENAME_SRC##0, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_2(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_1(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##1 = CONVERT(BASENAME_SRC##1, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_3(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_2(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##2 = CONVERT(BASENAME_SRC##2, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_4(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_3(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##3 = CONVERT(BASENAME_SRC##3, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_5(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_4(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##4 = CONVERT(BASENAME_SRC##4, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_6(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_5(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##5 = CONVERT(BASENAME_SRC##5, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_7(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_6(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##6 = CONVERT(BASENAME_SRC##6, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_8(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_7(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##7 = CONVERT(BASENAME_SRC##7, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_9(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_8(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##8 = CONVERT(BASENAME_SRC##8, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_10(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_9(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##9 = CONVERT(BASENAME_SRC##9, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_11(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_10(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##A = CONVERT(BASENAME_SRC##A, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_12(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_11(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##B = CONVERT(BASENAME_SRC##B, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_13(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_12(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##C = CONVERT(BASENAME_SRC##C, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_14(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_13(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##D = CONVERT(BASENAME_SRC##D, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_15(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_14(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##E = CONVERT(BASENAME_SRC##E, VEC_DATA_TYPE(DATA_TYPE, N));
+
+#define CONVERT_ROW_16(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ CONVERT_ROW_15(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
+ VEC_DATA_TYPE(DATA_TYPE, N) \
+ BASENAME_DST##F = CONVERT(BASENAME_SRC##F, VEC_DATA_TYPE(DATA_TYPE, N));
+
+// CONVERT_ROW_m apply convert to the variables BASENAME_SRC##0... BASENAME_SRC##(n-1)
+#define CONVERT_BLOCK_STR(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) CONVERT_ROW_##M(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)
+/** Apply convert_<data_type> to the variables BASENAME_SRC##0... BASENAME_SRC##(m-1)
+ * Supported cases N=1,2,3..16, for variables BASENAME_SRC[0..N]
+ */
+#define CONVERT_BLOCK(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) CONVERT_BLOCK_STR(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ No newline at end of file
diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h
index f501077a40..6f51b87bc6 100644
--- a/src/core/CL/cl_kernels/helpers.h
+++ b/src/core/CL/cl_kernels/helpers.h
@@ -70,6 +70,23 @@
#define vload1(OFFSET, PTR) *(OFFSET + PTR)
#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
+// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
+// without _sat to overcome this issue
+#define convert_float_sat convert_float
+#define convert_float1_sat convert_float
+#define convert_float2_sat convert_float2
+#define convert_float3_sat convert_float3
+#define convert_float4_sat convert_float4
+#define convert_float8_sat convert_float8
+#define convert_float16_sat convert_float16
+#define convert_half_sat convert_float
+#define convert_half1_sat convert_half
+#define convert_half2_sat convert_half2
+#define convert_half3_sat convert_half3
+#define convert_half4_sat convert_half4
+#define convert_half8_sat convert_half8
+#define convert_half16_sat convert_half16
+
#define VEC_DATA_TYPE_STR(type, size) type##size
#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index 0c2942a184..0ffbe78449 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -42,8 +42,6 @@ CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifro
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8);
-
using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
// Configurations for Mali-G76
@@ -65,9 +63,23 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
switch(_target)
{
case GPUTarget::G76:
- return (this->*gemm_configs_G76[data_type])(m, n, k, b);
+ if (gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
+ {
+ return (this->*gemm_configs_G76[data_type])(m, n, k, b);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported data type");
+ }
default:
- return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
+ if (gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
+ {
+ return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported data type");
+ }
}
}
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
index b00faedb2f..a390e34a34 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
@@ -68,6 +68,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr)
&& (!gemm_info.broadcast_bias),
"Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported");
const unsigned int m = gemm_info.m;
const unsigned int n = gemm_info.n;
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index f77ab02810..9b9eb12214 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -77,6 +77,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr)
&& (!gemm_info.broadcast_bias),
"Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision && (input0->data_type() == DataType::F32), "Mixed precision only supported for F16 data type");
const unsigned int m = gemm_info.m;
const unsigned int n = gemm_info.n;
@@ -240,9 +241,11 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure_internal(win_config.second);
+ const bool enable_mixed_precision = gemm_info.fp_mixed_precision;
+ const DataType data_type = input0->info()->data_type();
+
// Create build options
CLBuildOptions build_opts;
- build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha));
build_opts.add_option_if(_input2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta));
build_opts.add_option_if(helpers::float_ops::is_one(beta), "-DUNIT_BETA");
@@ -255,6 +258,12 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons
build_opts.add_option_if(rhs_info.interleave, "-DRHS_INTERLEAVE");
build_opts.add_option_if(lhs_info.transpose, "-DLHS_TRANSPOSE");
build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS");
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
+ build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
+ build_opts.add_option_if(enable_mixed_precision, "-DMIXED_PRECISION");
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
+ build_opts.add_option("-DDATA_TYPE_ACCUMULATOR=" + (enable_mixed_precision ? get_cl_type_from_data_type(DataType::F32) : get_cl_type_from_data_type(data_type)));
build_opts.add_option("-DM=" + support::cpp11::to_string(gemm_info.m));
build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n));
build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0));
@@ -262,9 +271,6 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons
build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0));
build_opts.add_option("-DV0=" + support::cpp11::to_string(lhs_info.v0));
build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
- build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
std::string kernel_name("gemm_mm_reshaped_");
kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_";
@@ -282,6 +288,7 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons
_config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : "");
_config_id += lower_string(string_from_data_type(input0->info()->data_type()));
_config_id += "_";
+ _config_id += (enable_mixed_precision ? "mixed_precision_" : "");
_config_id += support::cpp11::to_string(output->info()->dimension(1));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(0));
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index fff4da6076..3d5e1486a6 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -68,6 +68,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr)
&& (!gemm_info.broadcast_bias),
"Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported");
const unsigned int m = gemm_info.m;
const unsigned int n = gemm_info.n;