From 0c17aa25a4f7bc812707150b91930f0cf8e75294 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 27 Sep 2019 09:23:15 +0100 Subject: COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16 Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1991 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../kernels/CLGEMMMatrixMultiplyReshapedKernel.h | 4 + src/core/CL/cl_kernels/gemm.cl | 161 ++++++++++++++++---- src/core/CL/cl_kernels/gemm_helpers.h | 88 ++++++++++- src/core/CL/cl_kernels/helpers.h | 17 +++ .../CLGEMMReshapedKernelConfigurationBifrost.cpp | 20 ++- .../kernels/CLGEMMMatrixMultiplyNativeKernel.cpp | 1 + .../kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp | 15 +- .../CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 1 + tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 162 ++++++++++++++++++--- tests/validation/fixtures/GEMMFixture.h | 24 ++- tests/validation/reference/GEMM.cpp | 55 ++++++- tests/validation/reference/GEMM.h | 5 +- 12 files changed, 486 insertions(+), 67 deletions(-) diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h index e6469f0370..d3c54a76c8 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h @@ -50,6 +50,10 @@ public: /** Allow instances of this class to be moved */ CLGEMMMatrixMultiplyReshapedKernel &operator=(CLGEMMMatrixMultiplyReshapedKernel &&) = default; /** Initialise the kernel's input and output. + * + * @note The F16 computation also supports mixed precision through the gemm_info.fp_mixed_precision flag. + * Mixed precision combines different floating precisions during the computation, in particular, F32 for the accumulations and F16 for the + * multiplications. i.e. float c = (half)a * (half)b * * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F16/F32. The number of dimensions for the LHS matrix must be less or equal than 4 * @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3 diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index c35d160689..57a5af8ec2 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -1676,8 +1676,66 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), } #endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K) -#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) +#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR) && defined(M) && defined(N) +#if defined(MIXED_PRECISION) +#if K0 == 2 +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c += a.s0 * b.s0; \ + c += a.s1 * b.s1; \ + }) +#elif K0 == 3 // K0 == 3 +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c += a.s0 * b.s0; \ + c += a.s1 * b.s1; \ + c += a.s2 * b.s2; \ + }) +#elif K0 == 4 // K0 == 4 +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c += a.s0 * b.s0; \ + c += a.s1 * b.s1; \ + c += a.s2 * b.s2; \ + c += a.s3 * b.s3; \ + }) +#elif K0 == 8 // K0 == 8 +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c += a.s0 * b.s0; \ + c += a.s1 * b.s1; \ + c += a.s2 * b.s2; \ + c += a.s3 * b.s3; \ + c += a.s4 * b.s4; \ + c += a.s5 * b.s5; \ + c += a.s6 * b.s6; \ + c += a.s7 * b.s7; \ + }) +#elif K0 == 16 // K0 == 16 +#define ARM_DOT_K0(a, b, c) \ + ({ \ + c += a.s0 * b.s0; \ + c += a.s1 * b.s1; \ + c += a.s2 * b.s2; \ + c += a.s3 * b.s3; \ + c += a.s4 * b.s4; \ + c += a.s5 * b.s5; \ + c += a.s6 * b.s6; \ + c += a.s7 * b.s7; \ + c += a.s8 * b.s8; \ + c += a.s9 * b.s9; \ + c += a.sA * b.sA; \ + c += a.sB * b.sB; \ + c += a.sC * b.sC; \ + c += a.sD * b.sD; \ + c += a.sE * b.sE; \ + c += a.sF * b.sF; \ + }) +#else // K0 not supported +#error "K0 value not supported" +#endif // K0 conditions +#else // defined(MIXED_PRECISION) #if K0 == 2 #define ARM_DOT_K0(a, b, c) \ ({ \ @@ -1734,6 +1792,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), #else // K0 not supported #error "K0 value not supported" #endif // K0 conditions +#endif // defined(MIXED_PRECISION) #if N0 == 2 #define ARM_DOT_K0XN0(a, b, c) \ @@ -1796,6 +1855,9 @@ __kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs), * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed * + * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float) + * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float) + * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time. * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90). * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4). @@ -1917,7 +1979,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #endif // defined(MATRIX_B_DEPTH) // Initialize the accumulators - REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; + REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0); REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0; REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); @@ -2003,7 +2065,12 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #endif // UNIT_BIAS // c = c + bias[broadcasted] +#if defined(MIXED_PRECISION) + CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp); + ADD_BLOCK_BROADCAST(M0, c, bias_hp0); +#else // defined(MIXED_PRECISION) ADD_BLOCK_BROADCAST(M0, c, bias0); +#endif // defined(MIXED_PRECISION) #else // defined(BROADCAST_BIAS) __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id( @@ -2016,17 +2083,26 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #endif // UNIT_BIAS // c = c + bias +#if defined(MIXED_PRECISION) + CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp); + ADD_BLOCK(M0, c, bias_hp); +#else // defined(MIXED_PRECISION) ADD_BLOCK(M0, c, bias); +#endif // defined(MIXED_PRECISION) #endif // defined(BROADCAST_BIAS) #endif // defined(BETA) #if defined(ACTIVATION_TYPE) - ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL); + ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL); #endif // defined(ACTIVATION_TYPE) // Store output block +#if defined(MIXED_PRECISION) + CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); +#else // defined(MIXED_PRECISION) STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); +#endif // defined(MIXED_PRECISION) #undef LHS_BLOCK_SIZE #undef LHS_OFFSET_X @@ -2040,38 +2116,50 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE) -#if GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VFMA(a, b, c) c += (a) * (b); +#if defined(MIXED_PRECISION) + +#if(GPU_ARCH == GPU_ARCH_MIDGARD) +#define ARM_VFMA(N0, a, b, c) c += (CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))) * (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))); +#else // GPU_ARCH == GPU_ARCH_MIDGARD +#define ARM_VFMA(N0, a, b, c) c = fma((CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (c)); +#endif // GPU_ARCH == GPU_ARCH_MIDGARD + +#else // defined(MIXED_PRECISION + +#if(GPU_ARCH == GPU_ARCH_MIDGARD) +#define ARM_VFMA(N0, a, b, c) c += (a) * (b); #else // GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VFMA(a, b, c) c = fma((a), (b), (c)); +#define ARM_VFMA(N0, a, b, c) c = fma((a), (b), (c)); #endif // GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \ +#endif // defined(MIXED_PRECISION) + +#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \ }) -#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \ - ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \ +#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \ }) -#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \ +#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \ }) -#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \ +#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \ }) -#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \ - ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \ - ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \ - ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \ +#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \ + ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \ }) // Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1 @@ -2261,7 +2349,7 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs), #endif // defined(MATRIX_B_DEPTH) // Initialize the accumulators - REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; + REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0); REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0); @@ -2455,7 +2543,12 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs), #endif // UNIT_BIAS // c = c + bias[broadcasted] +#if defined(MIXED_PRECISION) + CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp); + ADD_BLOCK_BROADCAST(M0, c, bias_hp0); +#else // defined(MIXED_PRECISION) ADD_BLOCK_BROADCAST(M0, c, bias0); +#endif // defined(MIXED_PRECISION) #else // defined(BROADCAST_BIAS) __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z; @@ -2466,8 +2559,12 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs), SCALE_BLOCK(M0, DATA_TYPE, bias, BETA); #endif // UNIT_BIAS - // c = c + bias +#if defined(MIXED_PRECISION) + CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp); + ADD_BLOCK(M0, c, bias_hp); +#else // defined(MIXED_PRECISION) ADD_BLOCK(M0, c, bias); +#endif // defined(MIXED_PRECISION) #endif // defined(BROADCAST_BIAS) #endif // defined(BETA) @@ -2477,7 +2574,11 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs), #endif // defined(ACTIVATION_TYPE) // Store output block +#if defined(MIXED_PRECISION) + CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); +#else // defined(MIXED_PRECISION) STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout); +#endif // defined(MIXED_PRECISION) #undef LHS_BLOCK_SIZE #undef LHS_OFFSET_X diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h index 4715fb737f..fd8c773444 100644 --- a/src/core/CL/cl_kernels/gemm_helpers.h +++ b/src/core/CL/cl_kernels/gemm_helpers.h @@ -689,4 +689,90 @@ /** Apply activation to the variables BASENAME##0... BASENAME##(n-1) * Supported cases N=1,2,3..16, for variables BASENAME[0..N] */ -#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \ No newline at end of file +#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) + +#define CONVERT_ROW_1(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##0 = CONVERT(BASENAME_SRC##0, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_2(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_1(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##1 = CONVERT(BASENAME_SRC##1, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_3(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_2(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##2 = CONVERT(BASENAME_SRC##2, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_4(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_3(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##3 = CONVERT(BASENAME_SRC##3, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_5(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_4(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##4 = CONVERT(BASENAME_SRC##4, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_6(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_5(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##5 = CONVERT(BASENAME_SRC##5, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_7(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_6(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##6 = CONVERT(BASENAME_SRC##6, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_8(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_7(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##7 = CONVERT(BASENAME_SRC##7, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_9(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_8(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##8 = CONVERT(BASENAME_SRC##8, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_10(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_9(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##9 = CONVERT(BASENAME_SRC##9, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_11(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_10(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##A = CONVERT(BASENAME_SRC##A, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_12(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_11(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##B = CONVERT(BASENAME_SRC##B, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_13(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_12(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##C = CONVERT(BASENAME_SRC##C, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_14(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_13(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##D = CONVERT(BASENAME_SRC##D, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_15(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_14(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##E = CONVERT(BASENAME_SRC##E, VEC_DATA_TYPE(DATA_TYPE, N)); + +#define CONVERT_ROW_16(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + CONVERT_ROW_15(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ + VEC_DATA_TYPE(DATA_TYPE, N) \ + BASENAME_DST##F = CONVERT(BASENAME_SRC##F, VEC_DATA_TYPE(DATA_TYPE, N)); + +// CONVERT_ROW_m apply convert to the variables BASENAME_SRC##0... BASENAME_SRC##(n-1) +#define CONVERT_BLOCK_STR(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) CONVERT_ROW_##M(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) +/** Apply convert_ to the variables BASENAME_SRC##0... BASENAME_SRC##(m-1) + * Supported cases N=1,2,3..16, for variables BASENAME_SRC[0..N] + */ +#define CONVERT_BLOCK(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) CONVERT_BLOCK_STR(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ No newline at end of file diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h index f501077a40..6f51b87bc6 100644 --- a/src/core/CL/cl_kernels/helpers.h +++ b/src/core/CL/cl_kernels/helpers.h @@ -70,6 +70,23 @@ #define vload1(OFFSET, PTR) *(OFFSET + PTR) #define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA +// Convert built-in functions with _sat modifier are not supported in floating point so we create defines +// without _sat to overcome this issue +#define convert_float_sat convert_float +#define convert_float1_sat convert_float +#define convert_float2_sat convert_float2 +#define convert_float3_sat convert_float3 +#define convert_float4_sat convert_float4 +#define convert_float8_sat convert_float8 +#define convert_float16_sat convert_float16 +#define convert_half_sat convert_float +#define convert_half1_sat convert_half +#define convert_half2_sat convert_half2 +#define convert_half3_sat convert_half3 +#define convert_half4_sat convert_half4 +#define convert_half8_sat convert_half8 +#define convert_half16_sat convert_half16 + #define VEC_DATA_TYPE_STR(type, size) type##size #define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size) diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp index 0c2942a184..0ffbe78449 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp @@ -42,8 +42,6 @@ CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifro std::pair CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8); - using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); // Configurations for Mali-G76 @@ -65,9 +63,23 @@ std::pair CLGEMMReshapedKernelConfiguratio switch(_target) { case GPUTarget::G76: - return (this->*gemm_configs_G76[data_type])(m, n, k, b); + if (gemm_configs_G76.find(data_type) != gemm_configs_G76.end()) + { + return (this->*gemm_configs_G76[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } default: - return (this->*gemm_configs_G7x[data_type])(m, n, k, b); + if (gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end()) + { + return (this->*gemm_configs_G7x[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } } } diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp index b00faedb2f..a390e34a34 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp @@ -68,6 +68,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp index f77ab02810..9b9eb12214 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp @@ -77,6 +77,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision && (input0->data_type() == DataType::F32), "Mixed precision only supported for F16 data type"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; @@ -240,9 +241,11 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons ARM_COMPUTE_ERROR_THROW_ON(win_config.first); ICLKernel::configure_internal(win_config.second); + const bool enable_mixed_precision = gemm_info.fp_mixed_precision; + const DataType data_type = input0->info()->data_type(); + // Create build options CLBuildOptions build_opts; - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())); build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha)); build_opts.add_option_if(_input2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta)); build_opts.add_option_if(helpers::float_ops::is_one(beta), "-DUNIT_BETA"); @@ -255,6 +258,12 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons build_opts.add_option_if(rhs_info.interleave, "-DRHS_INTERLEAVE"); build_opts.add_option_if(lhs_info.transpose, "-DLHS_TRANSPOSE"); build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS"); + build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation()))); + build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a())); + build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b())); + build_opts.add_option_if(enable_mixed_precision, "-DMIXED_PRECISION"); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type)); + build_opts.add_option("-DDATA_TYPE_ACCUMULATOR=" + (enable_mixed_precision ? get_cl_type_from_data_type(DataType::F32) : get_cl_type_from_data_type(data_type))); build_opts.add_option("-DM=" + support::cpp11::to_string(gemm_info.m)); build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n)); build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0)); @@ -262,9 +271,6 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0)); build_opts.add_option("-DV0=" + support::cpp11::to_string(lhs_info.v0)); build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0)); - build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation()))); - build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a())); - build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b())); std::string kernel_name("gemm_mm_reshaped_"); kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_"; @@ -282,6 +288,7 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : ""); _config_id += lower_string(string_from_data_type(input0->info()->data_type())); _config_id += "_"; + _config_id += (enable_mixed_precision ? "mixed_precision_" : ""); _config_id += support::cpp11::to_string(output->info()->dimension(1)); _config_id += "_"; _config_id += support::cpp11::to_string(output->info()->dimension(0)); diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index fff4da6076..3d5e1486a6 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -68,6 +68,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr) && (!gemm_info.broadcast_bias), "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported"); const unsigned int m = gemm_info.m; const unsigned int n = gemm_info.n; diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index 99f5ffe191..b885bfe4af 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -60,10 +60,20 @@ using CLGEMMMatrixMultiplyReshaped = CLSynthetizeFunction using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture; +// Fixture for CLGEMMMatrixMultiplyReshaped mixed precision +template +using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture = + GEMMMatrixMultiplyReshapedValidationFixture; + // Fixture for CLGEMMMatrixMultiplyReshaped3D template using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture; +// Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision +template +using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture = + GEMMMatrixMultiplyReshaped3DValidationFixture; + namespace { // *INDENT-OFF* @@ -71,15 +81,12 @@ namespace RelativeTolerance rel_tolerance_f32(0.001f); constexpr float abs_tolerance_f32(0.0001f); +RelativeTolerance rel_tolerance_f16_mixed_precision(0.001f); +constexpr float abs_tolerance_f16_mixed_precision(0.01f); + RelativeTolerance rel_tolerance_f16(0.001f); constexpr float abs_tolerance_f16(0.01f); -/** Alpha values to test - Precommit */ -const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); - -/** Beta values to test - Precommit */ -const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} ); - /** M values to test */ const auto m_values = framework::dataset::make("M", 37); @@ -105,6 +112,12 @@ const auto act_values = framework::dataset::make("Activation", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f), }); +/** Alpha values to test - Precommit */ +const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} ); + +/** Beta values to test - Precommit */ +const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} ); + /** M0 values to test - Precommit */ const auto m0_values_precommit = framework::dataset::make("M0", { 4 }); @@ -120,6 +133,12 @@ const auto v0_values_precommit = framework::dataset::make("V0", 1, 3); /** H0 values to test - Precommit */ const auto h0_values_precommit = framework::dataset::make("H0", 1, 3); +/** Alpha values to test - Nightly */ +const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} ); + +/** Beta values to test - Nightly */ +const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} ); + /** M0 values to test - Nightly */ const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 }); @@ -167,8 +186,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fra i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -191,8 +210,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fra i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -216,8 +235,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), lhs_transpose_values), act_values)) { @@ -240,8 +259,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F32)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), lhs_transpose_values), act_values)) { @@ -266,8 +285,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fram i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -290,8 +309,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fram i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), broadcast_bias_values), lhs_transpose_values), act_values)) @@ -315,8 +334,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_precommit), + beta_values_precommit), lhs_transpose_values), act_values)) { @@ -339,8 +358,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, i_values_lhs), i_values_rhs), framework::dataset::make("DataType", DataType::F16)), - a_values), - beta_values), + a_values_nightly), + beta_values_nightly), lhs_transpose_values), act_values)) { @@ -348,6 +367,105 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); } TEST_SUITE_END() // FP16 + +TEST_SUITE(MixedPrecision) + +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_precommit), + beta_values_precommit), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_nightly), + beta_values_nightly), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_precommit), + beta_values_precommit), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} + +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values_nightly), + beta_values_nightly), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision); +} +TEST_SUITE_END() // MixedPrecision TEST_SUITE_END() // Float TEST_SUITE_END() // GEMMMatrixMultiplyReshaped TEST_SUITE_END() // CL diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index 854cc4a22b..bf919c9b09 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -667,7 +667,7 @@ protected: SimpleTensor _reference{}; }; -template +template class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture { public: @@ -734,6 +734,7 @@ protected: kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; kernel_info.activation_info = act_info; + kernel_info.fp_mixed_precision = fp_mixed_precision; // The output tensor will be auto-initialized within the function @@ -807,14 +808,21 @@ protected: } } - return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + if(fp_mixed_precision) + { + return reference::activation_layer(reference::gemm_mixed_precision(lhs, rhs, bias, alpha, beta), act_info); + } + else + { + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } } TensorType _target{}; SimpleTensor _reference{}; }; -template +template class GEMMMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture { public: @@ -879,6 +887,7 @@ protected: kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = true; kernel_info.activation_info = act_info; + kernel_info.fp_mixed_precision = fp_mixed_precision; // The output tensor will be auto-initialized within the function @@ -951,7 +960,14 @@ protected: memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); } - return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + if(fp_mixed_precision) + { + return reference::activation_layer(reference::gemm_mixed_precision(lhs, rhs, bias, alpha, beta), act_info); + } + else + { + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } } TensorType _target{}; diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp index 2feab89950..3c72b94143 100644 --- a/tests/validation/reference/GEMM.cpp +++ b/tests/validation/reference/GEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -84,8 +84,61 @@ SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const S return dst; } +template ::value, int>::type> +SimpleTensor gemm_mixed_precision(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta) +{ + // GEMM mixed-precision combines F32 accumulators with F16 multiplications + // Create reference + SimpleTensor dst{ c.shape(), c.data_type(), 1 }; + + // Compute reference + const int M = a.shape().y(); + const int N = b.shape().x(); + const int K = a.shape().x(); + const int D = a.shape().z(); // Number of matrices in a batch + const int W = a.shape()[3]; // Number of batched-gemm (Winograd case) + + const int a_stride_z = K * M; + const int a_stride_w = K * M * D; + + const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions + const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions + + const int c_stride_z = N * M; + const int c_stride_w = N * M * D; + + for(int w = 0; w < W; ++w) + { + for(int depth = 0; depth < D; ++depth) + { + const int base_addr_a = depth * a_stride_z + w * a_stride_w; + const int base_addr_b = depth * b_stride_z + w * b_stride_w; + const int base_addr_c = depth * c_stride_z + w * c_stride_w; + + for(int row = 0; row < M; ++row) + { + for(int col = 0; col < N; ++col) + { + float acc(0); + + for(int k = 0; k < K; ++k) + { + acc += static_cast(a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]); + } + + // Finalize the result: alpha * A * B + beta * C + dst[base_addr_c + col + row * N] = static_cast(alpha * acc + beta * c[base_addr_c + col + row * N]); + } + } + } + } + + return dst; +} + template SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); template SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); +template SimpleTensor gemm_mixed_precision(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h index 39007c60bc..9bcd640770 100644 --- a/tests/validation/reference/GEMM.h +++ b/tests/validation/reference/GEMM.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -38,6 +38,9 @@ namespace reference template ::value, int>::type = 0> SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); +template ::value, int>::type = 0> +SimpleTensor gemm_mixed_precision(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta); + } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1