aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-24 12:05:06 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-26 11:44:34 +0000
commit05639f6b1ee3dcdd2c7923d0cf3a5d4712bd0071 (patch)
tree8bd2f0f2b874ac7347777b40d563506348fff754
parent1a569a30a2f456ff1a3e0a665201e1c3ab92df80 (diff)
downloadComputeLibrary-05639f6b1ee3dcdd2c7923d0cf3a5d4712bd0071.tar.gz
COMPMID-2571: Add support for FP16 in CLGEMMReshaped - part 1
Change-Id: I8adb8850cc5ade49ebc1dbf63401f03d5ecad708 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1983 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h2
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h4
-rw-r--r--src/core/CL/cl_kernels/gemm.cl257
-rw-r--r--src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp35
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp2
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp56
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp182
7 files changed, 348 insertions, 190 deletions
diff --git a/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h b/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h
index a0aae190e8..3ce2776bce 100644
--- a/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h
+++ b/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h
@@ -54,6 +54,8 @@ public:
private:
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
};
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
index 2a76f44284..e6469f0370 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
@@ -51,7 +51,7 @@ public:
CLGEMMMatrixMultiplyReshapedKernel &operator=(CLGEMMMatrixMultiplyReshapedKernel &&) = default;
/** Initialise the kernel's input and output.
*
- * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F32. The number of dimensions for the LHS matrix must be less or equal than 4
+ * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F16/F32. The number of dimensions for the LHS matrix must be less or equal than 4
* @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3
* @param[in] input2 Input tensor containing the bias matrix. Data type supported: same as @p input0.
* @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
@@ -74,7 +74,7 @@ public:
const GEMMKernelInfo &gemm_info);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyReshapedKernel
*
- * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F32. The number of dimensions for the LHS matrix must be less or equal than 4
+ * @param[in] input0 Input tensor containing the LHS reshaped matrix. Data type supported: F16/F32. The number of dimensions for the LHS matrix must be less or equal than 4
* @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3
* @param[in] input2 Input tensor info containing the bias matrix. Data type supported: same as @p input0.
* @param[in] output Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 8e628e8d01..c35d160689 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -2041,79 +2041,37 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
#if GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA(SIZE, a, b, c) c += (a) * (b);
+#define ARM_VFMA(a, b, c) c += (a) * (b);
#else // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA_1(a, b, c) \
- ({ \
- c = fma((a), (b), (c)); \
- })
-#define ARM_VFMA_2(a, b, c) \
- ({ \
- (c).s0 = fma((a).s0, (b).s0, (c).s0); \
- (c).s1 = fma((a).s1, (b).s1, (c).s1); \
- })
-#define ARM_VFMA_3(a, b, c) \
- ({ \
- ARM_VFMA_2(a, b, c); \
- (c).s2 = fma((a).s2, (b).s2, (c).s2); \
- })
-#define ARM_VFMA_4(a, b, c) \
- ({ \
- ARM_VFMA_3(a, b, c); \
- (c).s3 = fma((a).s3, (b).s3, (c).s3); \
- })
-#define ARM_VFMA_8(a, b, c) \
- ({ \
- ARM_VFMA_4(a, b, c); \
- (c).s4 = fma((a).s4, (b).s4, (c).s4); \
- (c).s5 = fma((a).s5, (b).s5, (c).s5); \
- (c).s6 = fma((a).s6, (b).s6, (c).s6); \
- (c).s7 = fma((a).s7, (b).s7, (c).s7); \
- })
-#define ARM_VFMA_16(a, b, c) \
- ({ \
- ARM_VFMA_8(a, b, c); \
- (c).s8 = fma((a).s8, (b).s8, (c).s8); \
- (c).s9 = fma((a).s9, (b).s9, (c).s9); \
- (c).sA = fma((a).sA, (b).sA, (c).sA); \
- (c).sB = fma((a).sB, (b).sB, (c).sB); \
- (c).sC = fma((a).sC, (b).sC, (c).sC); \
- (c).sD = fma((a).sD, (b).sD, (c).sD); \
- (c).sE = fma((a).sE, (b).sE, (c).sE); \
- (c).sF = fma((a).sF, (b).sF, (c).sF); \
- })
-
-// Factory macro for the vector FMA
-#define ARM_VFMA(SIZE, a, b, c) ARM_VFMA_##SIZE((a), (b), (c))
-
+#define ARM_VFMA(a, b, c) c = fma((a), (b), (c));
#endif // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
+#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \
})
-#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
+#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
})
-#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
+#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
})
-#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
+#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
})
-#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
+#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
})
// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
@@ -2172,7 +2130,8 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// K0: 1, 2, 3, 4, 8, 16
// This macro calls the vector-by-matrix macro K0 times
// A, B and C are matrices
-#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
+#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
+ CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
(M0, N0, TYPE, A, B, C)
/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
@@ -2272,11 +2231,9 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
#if defined(RHS_INTERLEAVE)
#define RHS_OFFSET_X (N0)
#define RHS_STEP_X ((N0) * (H0))
-#define RHS_STEP_LOOP (1)
#else // defined(RHS_INTERLEAVE)
#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
#define RHS_STEP_X (N0)
-#define RHS_STEP_LOOP (H0)
#endif // defined(RHS_INTERLEAVE)
const uint x = get_global_id(0);
@@ -2306,28 +2263,160 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
// Initialize the accumulators
REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
- REPEAT_VAR_INIT_TO_CONST(K0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
+ __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
+ __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
+
for(int i = 0; i < k; i += K0)
{
- // Supported cases (K0, M0):
- // 1,2 - 2,2 - 3,2 - 4,2 - 5,2 - 6,2 - 7,2 - 8,2
- // 1,3 - 2,3 - 3,3 - 4,3 - 5,3 - 6,3 - 7,3 - 8,3
- // 1,4 - 2,4 - 3,4 - 4,4 - 5,4 - 6,4 - 7,4 - 8,4
- // 1,8 - 2,8 - 3,8 - 4,8 - 5,8 - 6,8 - 7,8 - 8,8
- // 1,16 - 2,16 - 3,16 - 4,16 - 5,16 - 6,16 - 7,16 - 8,16
- // Load values from LHS matrix
- LOAD_BLOCK(K0, M0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+ VEC_DATA_TYPE(DATA_TYPE, M0)
+ a0 = VLOAD(M0)(0, lhs);
+ VEC_DATA_TYPE(DATA_TYPE, N0)
+ b0 = VLOAD(N0)(0, rhs);
- // Load values from RHS matrix
- LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+#if K0 > 1
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 1
+
+#if K0 > 2
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 2
+
+#if K0 > 3
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 3
+
+#if K0 > 4
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 4
+
+#if K0 > 8
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 8
- // Perform the partial matrix multiplication
- ARM_MM_T_NT(M0, N0, K0, DATA_TYPE, a, b, c);
+#ifndef LHS_INTERLEAVE
+ lhs += (M0 * K0 * (V0 - 1));
+#endif // LHS_INTERLEAVE
- lhs_addr += (K0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
- rhs_addr += (K0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
+#ifndef RHS_INTERLEAVE
+ rhs += (N0 * K0 * (H0 - 1));
+#endif // RHS_INTERLEAVE
}
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index b791c1cda5..0c2942a184 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -42,8 +42,7 @@ CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifro
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
- ARM_COMPUTE_UNUSED(data_type);
+ ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8);
using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
@@ -51,6 +50,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
{
{ DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 },
+ { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 },
{ DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }
};
@@ -58,6 +58,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
{
{ DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 },
+ { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 },
{ DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
};
@@ -85,6 +86,21 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(n <= 4)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
@@ -129,6 +145,21 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(n <= 4)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 8, 2, true, true, true, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 8, true, true, true, false);
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index 222a63d86a..f77ab02810 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -63,7 +63,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3");
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index e78395f1de..762b00177c 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -65,37 +65,53 @@ CLGEMM::GEMMType CLGEMM::select_gemm_type(unsigned int m, unsigned int n, unsign
{
GEMMType gemm_type = GEMMType::RESHAPED_V1;
- if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
+ if(gpu_target_is_in(gpu_target, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT,
+ GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72,
+ GPUTarget::G76, GPUTarget::G77))
{
- if((m > 1) && (n < 16))
+ if(data_type == DataType::F32)
{
- gemm_type = GEMMType::RESHAPED_V1;
- }
- else if((m == 1) && (data_type == DataType::F32))
- {
- gemm_type = GEMMType::RESHAPED_ONLY_RHS;
+ if((m > 1) && (n < 16))
+ {
+ gemm_type = GEMMType::RESHAPED_V1;
+ }
+ else if(m == 1)
+ {
+ gemm_type = GEMMType::RESHAPED_ONLY_RHS;
+ }
+ else
+ {
+ // COMPMID-852
+ if((k > 256) && (m > 4) && reshape_b_only_on_first_run)
+ {
+ constexpr float alpha = 3.2f;
+ constexpr float fact0 = 1.51f;
+ constexpr float fact1 = 1.66f;
+ constexpr float ops = 12.0f;
+ const float scale = k > 1024 ? 1.07f : 1.0f;
+ gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+ }
+ else
+ {
+ gemm_type = GEMMType::NATIVE;
+ }
+ }
+
+ const auto workload = static_cast<float>((m * n) / 20.0f);
+
+ gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
}
else
{
- // COMPMID-852
- if((k > 256) && (m > 4) && is_data_type_float(data_type) && reshape_b_only_on_first_run)
+ if((m == 1) || (!reshape_b_only_on_first_run))
{
- constexpr float alpha = 3.2f;
- constexpr float fact0 = 1.51f;
- constexpr float fact1 = 1.66f;
- constexpr float ops = 12.0f;
- const float scale = k > 1024 ? 1.07f : 1.0f;
- gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+ gemm_type = GEMMType::NATIVE;
}
else
{
- gemm_type = GEMMType::NATIVE;
+ gemm_type = GEMMType::RESHAPED_V2;
}
}
-
- const auto workload = static_cast<float>((m * n) / 20.0f);
-
- gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
}
else
{
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index ba218f7cd1..99f5ffe191 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -71,6 +71,9 @@ namespace
RelativeTolerance<float> rel_tolerance_f32(0.001f);
constexpr float abs_tolerance_f32(0.0001f);
+RelativeTolerance<float> rel_tolerance_f16(0.001f);
+constexpr float abs_tolerance_f16(0.01f);
+
/** Alpha values to test - Precommit */
const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
@@ -103,7 +106,7 @@ const auto act_values = framework::dataset::make("Activation",
});
/** M0 values to test - Precommit */
-const auto m0_values_precommit = framework::dataset::make("M0", {4, 8});
+const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
/** N0 values to test - Precommit */
const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
@@ -143,80 +146,19 @@ const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {
/** LHS transposed values */
const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } );
-
-/** Configuration test */
-void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, bool lhs_transpose, DataType data_type, const ActivationLayerInfo &act_info)
-{
- const unsigned int M = m_value;
- const unsigned int N = n_value;
- const unsigned int K = k_value;
-
- GEMMLHSMatrixInfo lhs_info;
- lhs_info.m0 = m0_value;
- lhs_info.k0 = k0_value;
- lhs_info.v0 = v0_value;
- lhs_info.interleave = i_value_lhs;
- lhs_info.transpose = lhs_transpose;
-
- GEMMRHSMatrixInfo rhs_info;
- rhs_info.n0 = n0_value;
- rhs_info.k0 = k0_value;
- rhs_info.h0 = h0_value;
- rhs_info.interleave = i_value_rhs;
- rhs_info.transpose = !lhs_transpose;
-
- GEMMKernelInfo kernel_info;
- kernel_info.m = M;
- kernel_info.n = N;
- kernel_info.k = K;
- kernel_info.depth_output_gemm3d = 0;
- kernel_info.reinterpret_input_as_3d = false;
- kernel_info.broadcast_bias = broadcast_bias;
- kernel_info.activation_info = act_info;
-
- const TensorShape lhs_shape(K, M, b_value);
- const TensorShape lhs_shape_reshaped = compute_lhs_reshaped_shape(TensorInfo(lhs_shape, 1, data_type),
- lhs_info,
- false);
-
- const TensorShape rhs_shape(N, K, b_value);
- const TensorShape rhs_shape_reshaped = compute_rhs_reshaped_shape(TensorInfo(rhs_shape, 1, data_type),
- rhs_info);
-
- const TensorShape dst_shape = compute_mm_shape(TensorInfo(lhs_shape_reshaped, 1, data_type),
- TensorInfo(rhs_shape_reshaped, 1, data_type),
- kernel_info);
-
- const TensorShape bias_shape(N,
- broadcast_bias? 1 : M,
- broadcast_bias? 1 : b_value);
-
- // Create tensors
- CLTensor lhs_reshaped = create_tensor<CLTensor>(lhs_shape_reshaped, data_type);
- CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, data_type);
- CLTensor bias = create_tensor<CLTensor>(bias_shape, data_type);
- CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type);
-
- ARM_COMPUTE_EXPECT(lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
-
- // Create and configure function
- CLGEMMMatrixMultiplyReshaped gemm;
- gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, 1.0f, 1.0f, lhs_info, rhs_info, kernel_info);
-}
} // namespace
TEST_SUITE(CL)
TEST_SUITE(GEMMMatrixMultiplyReshaped)
TEST_SUITE(Float)
TEST_SUITE(FP32)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
- framework::dataset::make("batch_size", 1)),
+ b_values),
m0_values_precommit),
n0_values_precommit),
k0_values_precommit),
@@ -224,20 +166,48 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi
h0_values_precommit),
i_values_lhs),
i_values_rhs),
+ framework::dataset::make("DataType", DataType::F32)),
+ a_values),
+ beta_values),
broadcast_bias_values),
lhs_transpose_values),
- act_values),
-m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, act_value)
+ act_values))
{
- validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, DataType::F32, act_value);
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
b_values),
+ m0_values_nightly),
+ n0_values_nightly),
+ k0_values_nightly),
+ v0_values_nightly),
+ h0_values_nightly),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("DataType", DataType::F32)),
+ a_values),
+ beta_values),
+ broadcast_bias_values),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_w_values,
+ m_h_values),
+ n_values),
+ k_values),
+ b_values),
m0_values_precommit),
n0_values_precommit),
k0_values_precommit),
@@ -248,7 +218,6 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
framework::dataset::make("DataType", DataType::F32)),
a_values),
beta_values),
- broadcast_bias_values),
lhs_transpose_values),
act_values))
{
@@ -256,9 +225,10 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
+FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- m_values,
+ m_w_values,
+ m_h_values),
n_values),
k_values),
b_values),
@@ -272,15 +242,65 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
framework::dataset::make("DataType", DataType::F32)),
a_values),
beta_values),
- broadcast_bias_values),
lhs_transpose_values),
act_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
+TEST_SUITE_END() // FP32
-FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
+TEST_SUITE(FP16)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit),
+ n0_values_precommit),
+ k0_values_precommit),
+ v0_values_precommit),
+ h0_values_precommit),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values),
+ beta_values),
+ broadcast_bias_values),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_nightly),
+ n0_values_nightly),
+ k0_values_nightly),
+ v0_values_nightly),
+ h0_values_nightly),
+ i_values_lhs),
+ i_values_rhs),
+ framework::dataset::make("DataType", DataType::F16)),
+ a_values),
+ beta_values),
+ broadcast_bias_values),
+ lhs_transpose_values),
+ act_values))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_w_values,
m_h_values),
@@ -294,17 +314,17 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>,
h0_values_precommit),
i_values_lhs),
i_values_rhs),
- framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataType", DataType::F16)),
a_values),
beta_values),
lhs_transpose_values),
act_values))
{
// Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
+FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_w_values,
m_h_values),
@@ -318,16 +338,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>,
h0_values_nightly),
i_values_lhs),
i_values_rhs),
- framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataType", DataType::F16)),
a_values),
beta_values),
lhs_transpose_values),
act_values))
{
// Validate output
- validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+ validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
}
-TEST_SUITE_END() // FP32
+TEST_SUITE_END() // FP16
TEST_SUITE_END() // Float
TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
TEST_SUITE_END() // CL