From 05639f6b1ee3dcdd2c7923d0cf3a5d4712bd0071 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 24 Sep 2019 12:05:06 +0100 Subject: COMPMID-2571: Add support for FP16 in CLGEMMReshaped - part 1 Change-Id: I8adb8850cc5ade49ebc1dbf63401f03d5ecad708 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1983 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../CLGEMMReshapedKernelConfigurationBifrost.h | 2 + .../kernels/CLGEMMMatrixMultiplyReshapedKernel.h | 4 +- src/core/CL/cl_kernels/gemm.cl | 257 ++++++++++++++------- .../CLGEMMReshapedKernelConfigurationBifrost.cpp | 35 ++- .../kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp | 2 +- src/runtime/CL/functions/CLGEMM.cpp | 56 +++-- tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp | 182 ++++++++------- 7 files changed, 348 insertions(+), 190 deletions(-) diff --git a/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h b/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h index a0aae190e8..3ce2776bce 100644 --- a/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h +++ b/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h @@ -54,6 +54,8 @@ public: private: std::pair configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); }; diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h index 2a76f44284..e6469f0370 100644 --- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h @@ -51,7 +51,7 @@ public: CLGEMMMatrixMultiplyReshapedKernel &operator=(CLGEMMMatrixMultiplyReshapedKernel &&) = default; /** Initialise the kernel's input and output. * - * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F32. The number of dimensions for the LHS matrix must be less or equal than 4 + * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F16/F32. The number of dimensions for the LHS matrix must be less or equal than 4 * @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3 * @param[in] input2 Input tensor containing the bias matrix. Data type supported: same as @p input0. * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0 @@ -74,7 +74,7 @@ public: const GEMMKernelInfo &gemm_info); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyReshapedKernel * - * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F32. The number of dimensions for the LHS matrix must be less or equal than 4 + * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F16/F32. The number of dimensions for the LHS matrix must be less or equal than 4 * @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3 * @param[in] input2 Input tensor info containing the bias matrix. Data type supported: same as @p input0. * @param[in] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0 diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 8e628e8d01..c35d160689 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -2041,79 +2041,37 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE) #if GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VFMA(SIZE, a, b, c) c += (a) * (b); +#define ARM_VFMA(a, b, c) c += (a) * (b); #else // GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VFMA_1(a, b, c) \ - ({ \ - c = fma((a), (b), (c)); \ - }) -#define ARM_VFMA_2(a, b, c) \ - ({ \ - (c).s0 = fma((a).s0, (b).s0, (c).s0); \ - (c).s1 = fma((a).s1, (b).s1, (c).s1); \ - }) -#define ARM_VFMA_3(a, b, c) \ - ({ \ - ARM_VFMA_2(a, b, c); \ - (c).s2 = fma((a).s2, (b).s2, (c).s2); \ - }) -#define ARM_VFMA_4(a, b, c) \ - ({ \ - ARM_VFMA_3(a, b, c); \ - (c).s3 = fma((a).s3, (b).s3, (c).s3); \ - }) -#define ARM_VFMA_8(a, b, c) \ - ({ \ - ARM_VFMA_4(a, b, c); \ - (c).s4 = fma((a).s4, (b).s4, (c).s4); \ - (c).s5 = fma((a).s5, (b).s5, (c).s5); \ - (c).s6 = fma((a).s6, (b).s6, (c).s6); \ - (c).s7 = fma((a).s7, (b).s7, (c).s7); \ - }) -#define ARM_VFMA_16(a, b, c) \ - ({ \ - ARM_VFMA_8(a, b, c); \ - (c).s8 = fma((a).s8, (b).s8, (c).s8); \ - (c).s9 = fma((a).s9, (b).s9, (c).s9); \ - (c).sA = fma((a).sA, (b).sA, (c).sA); \ - (c).sB = fma((a).sB, (b).sB, (c).sB); \ - (c).sC = fma((a).sC, (b).sC, (c).sC); \ - (c).sD = fma((a).sD, (b).sD, (c).sD); \ - (c).sE = fma((a).sE, (b).sE, (c).sE); \ - (c).sF = fma((a).sF, (b).sF, (c).sF); \ - }) - -// Factory macro for the vector FMA -#define ARM_VFMA(SIZE, a, b, c) ARM_VFMA_##SIZE((a), (b), (c)) - +#define ARM_VFMA(a, b, c) c = fma((a), (b), (c)); #endif // GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \ +#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \ }) -#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \ +#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \ }) -#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \ +#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \ }) -#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \ +#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \ }) -#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \ +#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \ }) // Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1 @@ -2172,7 +2130,8 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), // K0: 1, 2, 3, 4, 8, 16 // This macro calls the vector-by-matrix macro K0 times // A, B and C are matrices -#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) CONCAT(ARM_MM_T_NT_M0xN0x, K0) \ +#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \ + CONCAT(ARM_MM_T_NT_M0xN0x, K0) \ (M0, N0, TYPE, A, B, C) /** This OpenCL kernel computes the matrix multiplication between 2 matrices. @@ -2272,11 +2231,9 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs), #if defined(RHS_INTERLEAVE) #define RHS_OFFSET_X (N0) #define RHS_STEP_X ((N0) * (H0)) -#define RHS_STEP_LOOP (1) #else // defined(RHS_INTERLEAVE) #define RHS_OFFSET_X (RHS_BLOCK_SIZE) #define RHS_STEP_X (N0) -#define RHS_STEP_LOOP (H0) #endif // defined(RHS_INTERLEAVE) const uint x = get_global_id(0); @@ -2306,28 +2263,160 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs), // Initialize the accumulators REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; - REPEAT_VAR_INIT_TO_CONST(K0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0; REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0); + __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr); + __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr); + for(int i = 0; i < k; i += K0) { - // Supported cases (K0, M0): - // 1,2 - 2,2 - 3,2 - 4,2 - 5,2 - 6,2 - 7,2 - 8,2 - // 1,3 - 2,3 - 3,3 - 4,3 - 5,3 - 6,3 - 7,3 - 8,3 - // 1,4 - 2,4 - 3,4 - 4,4 - 5,4 - 6,4 - 7,4 - 8,4 - // 1,8 - 2,8 - 3,8 - 4,8 - 5,8 - 6,8 - 7,8 - 8,8 - // 1,16 - 2,16 - 3,16 - 4,16 - 5,16 - 6,16 - 7,16 - 8,16 - // Load values from LHS matrix - LOAD_BLOCK(K0, M0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs); + VEC_DATA_TYPE(DATA_TYPE, M0) + a0 = VLOAD(M0)(0, lhs); + VEC_DATA_TYPE(DATA_TYPE, N0) + b0 = VLOAD(N0)(0, rhs); - // Load values from RHS matrix - LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zlhs); + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + +#if K0 > 1 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 1 + +#if K0 > 2 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 2 + +#if K0 > 3 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 3 + +#if K0 > 4 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 4 + +#if K0 > 8 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 8 - // Perform the partial matrix multiplication - ARM_MM_T_NT(M0, N0, K0, DATA_TYPE, a, b, c); +#ifndef LHS_INTERLEAVE + lhs += (M0 * K0 * (V0 - 1)); +#endif // LHS_INTERLEAVE - lhs_addr += (K0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE); - rhs_addr += (K0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE); +#ifndef RHS_INTERLEAVE + rhs += (N0 * K0 * (H0 - 1)); +#endif // RHS_INTERLEAVE } __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y); diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp index b791c1cda5..0c2942a184 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp @@ -42,8 +42,7 @@ CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifro std::pair CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8); - ARM_COMPUTE_UNUSED(data_type); + ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8); using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); @@ -51,6 +50,7 @@ std::pair CLGEMMReshapedKernelConfiguratio static std::map gemm_configs_G76 = { { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 }, + { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 }, { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 } }; @@ -58,6 +58,7 @@ std::pair CLGEMMReshapedKernelConfiguratio static std::map gemm_configs_G7x = { { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 }, + { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 }, { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 } }; @@ -85,6 +86,21 @@ std::pair CLGEMMReshapedKernelConfiguratio } } +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false); + } +} + std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); @@ -129,6 +145,21 @@ std::pair CLGEMMReshapedKernelConfiguratio } } +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 8, 2, true, true, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 8, true, true, true, false); + } +} + std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp index 222a63d86a..f77ab02810 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp @@ -63,7 +63,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index e78395f1de..762b00177c 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -65,37 +65,53 @@ CLGEMM::GEMMType CLGEMM::select_gemm_type(unsigned int m, unsigned int n, unsign { GEMMType gemm_type = GEMMType::RESHAPED_V1; - if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76)) + if(gpu_target_is_in(gpu_target, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, + GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, + GPUTarget::G76, GPUTarget::G77)) { - if((m > 1) && (n < 16)) + if(data_type == DataType::F32) { - gemm_type = GEMMType::RESHAPED_V1; - } - else if((m == 1) && (data_type == DataType::F32)) - { - gemm_type = GEMMType::RESHAPED_ONLY_RHS; + if((m > 1) && (n < 16)) + { + gemm_type = GEMMType::RESHAPED_V1; + } + else if(m == 1) + { + gemm_type = GEMMType::RESHAPED_ONLY_RHS; + } + else + { + // COMPMID-852 + if((k > 256) && (m > 4) && reshape_b_only_on_first_run) + { + constexpr float alpha = 3.2f; + constexpr float fact0 = 1.51f; + constexpr float fact1 = 1.66f; + constexpr float ops = 12.0f; + const float scale = k > 1024 ? 1.07f : 1.0f; + gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE; + } + else + { + gemm_type = GEMMType::NATIVE; + } + } + + const auto workload = static_cast((m * n) / 20.0f); + + gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type; } else { - // COMPMID-852 - if((k > 256) && (m > 4) && is_data_type_float(data_type) && reshape_b_only_on_first_run) + if((m == 1) || (!reshape_b_only_on_first_run)) { - constexpr float alpha = 3.2f; - constexpr float fact0 = 1.51f; - constexpr float fact1 = 1.66f; - constexpr float ops = 12.0f; - const float scale = k > 1024 ? 1.07f : 1.0f; - gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE; + gemm_type = GEMMType::NATIVE; } else { - gemm_type = GEMMType::NATIVE; + gemm_type = GEMMType::RESHAPED_V2; } } - - const auto workload = static_cast((m * n) / 20.0f); - - gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type; } else { diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp index ba218f7cd1..99f5ffe191 100644 --- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp +++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp @@ -71,6 +71,9 @@ namespace RelativeTolerance rel_tolerance_f32(0.001f); constexpr float abs_tolerance_f32(0.0001f); +RelativeTolerance rel_tolerance_f16(0.001f); +constexpr float abs_tolerance_f16(0.01f); + /** Alpha values to test - Precommit */ const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} ); @@ -103,7 +106,7 @@ const auto act_values = framework::dataset::make("Activation", }); /** M0 values to test - Precommit */ -const auto m0_values_precommit = framework::dataset::make("M0", {4, 8}); +const auto m0_values_precommit = framework::dataset::make("M0", { 4 }); /** N0 values to test - Precommit */ const auto n0_values_precommit = framework::dataset::make("N0", { 4 }); @@ -143,80 +146,19 @@ const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { /** LHS transposed values */ const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } ); - -/** Configuration test */ -void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, bool lhs_transpose, DataType data_type, const ActivationLayerInfo &act_info) -{ - const unsigned int M = m_value; - const unsigned int N = n_value; - const unsigned int K = k_value; - - GEMMLHSMatrixInfo lhs_info; - lhs_info.m0 = m0_value; - lhs_info.k0 = k0_value; - lhs_info.v0 = v0_value; - lhs_info.interleave = i_value_lhs; - lhs_info.transpose = lhs_transpose; - - GEMMRHSMatrixInfo rhs_info; - rhs_info.n0 = n0_value; - rhs_info.k0 = k0_value; - rhs_info.h0 = h0_value; - rhs_info.interleave = i_value_rhs; - rhs_info.transpose = !lhs_transpose; - - GEMMKernelInfo kernel_info; - kernel_info.m = M; - kernel_info.n = N; - kernel_info.k = K; - kernel_info.depth_output_gemm3d = 0; - kernel_info.reinterpret_input_as_3d = false; - kernel_info.broadcast_bias = broadcast_bias; - kernel_info.activation_info = act_info; - - const TensorShape lhs_shape(K, M, b_value); - const TensorShape lhs_shape_reshaped = compute_lhs_reshaped_shape(TensorInfo(lhs_shape, 1, data_type), - lhs_info, - false); - - const TensorShape rhs_shape(N, K, b_value); - const TensorShape rhs_shape_reshaped = compute_rhs_reshaped_shape(TensorInfo(rhs_shape, 1, data_type), - rhs_info); - - const TensorShape dst_shape = compute_mm_shape(TensorInfo(lhs_shape_reshaped, 1, data_type), - TensorInfo(rhs_shape_reshaped, 1, data_type), - kernel_info); - - const TensorShape bias_shape(N, - broadcast_bias? 1 : M, - broadcast_bias? 1 : b_value); - - // Create tensors - CLTensor lhs_reshaped = create_tensor(lhs_shape_reshaped, data_type); - CLTensor rhs_reshaped = create_tensor(rhs_shape_reshaped, data_type); - CLTensor bias = create_tensor(bias_shape, data_type); - CLTensor dst = create_tensor(dst_shape, data_type); - - ARM_COMPUTE_EXPECT(lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); - - // Create and configure function - CLGEMMMatrixMultiplyReshaped gemm; - gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, 1.0f, 1.0f, lhs_info, rhs_info, kernel_info); -} } // namespace TEST_SUITE(CL) TEST_SUITE(GEMMMatrixMultiplyReshaped) TEST_SUITE(Float) TEST_SUITE(FP32) -DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), - framework::dataset::make("batch_size", 1)), + b_values), m0_values_precommit), n0_values_precommit), k0_values_precommit), @@ -224,20 +166,48 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi h0_values_precommit), i_values_lhs), i_values_rhs), + framework::dataset::make("DataType", DataType::F32)), + a_values), + beta_values), broadcast_bias_values), lhs_transpose_values), - act_values), -m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, act_value) + act_values)) { - validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, DataType::F32, act_value); + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::ALL, +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_values, n_values), k_values), b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F32)), + a_values), + beta_values), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_w_values, + m_h_values), + n_values), + k_values), + b_values), m0_values_precommit), n0_values_precommit), k0_values_precommit), @@ -248,7 +218,6 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fra framework::dataset::make("DataType", DataType::F32)), a_values), beta_values), - broadcast_bias_values), lhs_transpose_values), act_values)) { @@ -256,9 +225,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, fra validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::NIGHTLY, +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( - m_values, + m_w_values, + m_h_values), n_values), k_values), b_values), @@ -272,15 +242,65 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, fra framework::dataset::make("DataType", DataType::F32)), a_values), beta_values), - broadcast_bias_values), lhs_transpose_values), act_values)) { // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); } +TEST_SUITE_END() // FP32 -FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::ALL, +TEST_SUITE(FP16) + +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + v0_values_precommit), + h0_values_precommit), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values), + beta_values), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture, framework::DatasetMode::NIGHTLY, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_nightly), + n0_values_nightly), + k0_values_nightly), + v0_values_nightly), + h0_values_nightly), + i_values_lhs), + i_values_rhs), + framework::dataset::make("DataType", DataType::F16)), + a_values), + beta_values), + broadcast_bias_values), + lhs_transpose_values), + act_values)) +{ + // Validate output + validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); +} + +FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), @@ -294,17 +314,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture, h0_values_precommit), i_values_lhs), i_values_rhs), - framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataType", DataType::F16)), a_values), beta_values), lhs_transpose_values), act_values)) { // Validate output - validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); + validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); } -FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::NIGHTLY, +FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( m_w_values, m_h_values), @@ -318,16 +338,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture, h0_values_nightly), i_values_lhs), i_values_rhs), - framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataType", DataType::F16)), a_values), beta_values), lhs_transpose_values), act_values)) { // Validate output - validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); + validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); } -TEST_SUITE_END() // FP32 +TEST_SUITE_END() // FP16 TEST_SUITE_END() // Float TEST_SUITE_END() // GEMMMatrixMultiplyReshaped TEST_SUITE_END() // CL -- cgit v1.2.1