aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-24 12:05:06 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-26 11:44:34 +0000
commit05639f6b1ee3dcdd2c7923d0cf3a5d4712bd0071 (patch)
tree8bd2f0f2b874ac7347777b40d563506348fff754 /src
parent1a569a30a2f456ff1a3e0a665201e1c3ab92df80 (diff)
downloadComputeLibrary-05639f6b1ee3dcdd2c7923d0cf3a5d4712bd0071.tar.gz
COMPMID-2571: Add support for FP16 in CLGEMMReshaped - part 1
Change-Id: I8adb8850cc5ade49ebc1dbf63401f03d5ecad708 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1983 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/cl_kernels/gemm.cl257
-rw-r--r--src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp35
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp2
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp56
4 files changed, 243 insertions, 107 deletions
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 8e628e8d01..c35d160689 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -2041,79 +2041,37 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
#if GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA(SIZE, a, b, c) c += (a) * (b);
+#define ARM_VFMA(a, b, c) c += (a) * (b);
#else // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA_1(a, b, c) \
- ({ \
- c = fma((a), (b), (c)); \
- })
-#define ARM_VFMA_2(a, b, c) \
- ({ \
- (c).s0 = fma((a).s0, (b).s0, (c).s0); \
- (c).s1 = fma((a).s1, (b).s1, (c).s1); \
- })
-#define ARM_VFMA_3(a, b, c) \
- ({ \
- ARM_VFMA_2(a, b, c); \
- (c).s2 = fma((a).s2, (b).s2, (c).s2); \
- })
-#define ARM_VFMA_4(a, b, c) \
- ({ \
- ARM_VFMA_3(a, b, c); \
- (c).s3 = fma((a).s3, (b).s3, (c).s3); \
- })
-#define ARM_VFMA_8(a, b, c) \
- ({ \
- ARM_VFMA_4(a, b, c); \
- (c).s4 = fma((a).s4, (b).s4, (c).s4); \
- (c).s5 = fma((a).s5, (b).s5, (c).s5); \
- (c).s6 = fma((a).s6, (b).s6, (c).s6); \
- (c).s7 = fma((a).s7, (b).s7, (c).s7); \
- })
-#define ARM_VFMA_16(a, b, c) \
- ({ \
- ARM_VFMA_8(a, b, c); \
- (c).s8 = fma((a).s8, (b).s8, (c).s8); \
- (c).s9 = fma((a).s9, (b).s9, (c).s9); \
- (c).sA = fma((a).sA, (b).sA, (c).sA); \
- (c).sB = fma((a).sB, (b).sB, (c).sB); \
- (c).sC = fma((a).sC, (b).sC, (c).sC); \
- (c).sD = fma((a).sD, (b).sD, (c).sD); \
- (c).sE = fma((a).sE, (b).sE, (c).sE); \
- (c).sF = fma((a).sF, (b).sF, (c).sF); \
- })
-
-// Factory macro for the vector FMA
-#define ARM_VFMA(SIZE, a, b, c) ARM_VFMA_##SIZE((a), (b), (c))
-
+#define ARM_VFMA(a, b, c) c = fma((a), (b), (c));
#endif // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
+#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \
})
-#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
+#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
})
-#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
+#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
})
-#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
+#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
})
-#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
- ({ \
- ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
- ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
+#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
+ ({ \
+ ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
+ ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
})
// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
@@ -2172,7 +2130,8 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
// K0: 1, 2, 3, 4, 8, 16
// This macro calls the vector-by-matrix macro K0 times
// A, B and C are matrices
-#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
+#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
+ CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
(M0, N0, TYPE, A, B, C)
/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
@@ -2272,11 +2231,9 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
#if defined(RHS_INTERLEAVE)
#define RHS_OFFSET_X (N0)
#define RHS_STEP_X ((N0) * (H0))
-#define RHS_STEP_LOOP (1)
#else // defined(RHS_INTERLEAVE)
#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
#define RHS_STEP_X (N0)
-#define RHS_STEP_LOOP (H0)
#endif // defined(RHS_INTERLEAVE)
const uint x = get_global_id(0);
@@ -2306,28 +2263,160 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
// Initialize the accumulators
REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
- REPEAT_VAR_INIT_TO_CONST(K0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
+ __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
+ __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
+
for(int i = 0; i < k; i += K0)
{
- // Supported cases (K0, M0):
- // 1,2 - 2,2 - 3,2 - 4,2 - 5,2 - 6,2 - 7,2 - 8,2
- // 1,3 - 2,3 - 3,3 - 4,3 - 5,3 - 6,3 - 7,3 - 8,3
- // 1,4 - 2,4 - 3,4 - 4,4 - 5,4 - 6,4 - 7,4 - 8,4
- // 1,8 - 2,8 - 3,8 - 4,8 - 5,8 - 6,8 - 7,8 - 8,8
- // 1,16 - 2,16 - 3,16 - 4,16 - 5,16 - 6,16 - 7,16 - 8,16
- // Load values from LHS matrix
- LOAD_BLOCK(K0, M0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+ VEC_DATA_TYPE(DATA_TYPE, M0)
+ a0 = VLOAD(M0)(0, lhs);
+ VEC_DATA_TYPE(DATA_TYPE, N0)
+ b0 = VLOAD(N0)(0, rhs);
- // Load values from RHS matrix
- LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+#if K0 > 1
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 1
+
+#if K0 > 2
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 2
+
+#if K0 > 3
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 3
+
+#if K0 > 4
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 4
+
+#if K0 > 8
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+
+ a0 = VLOAD(M0)(0, lhs);
+ b0 = VLOAD(N0)(0, rhs);
+
+ ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+ lhs += LHS_STEP_X;
+ rhs += RHS_STEP_X;
+#endif // K0 > 8
- // Perform the partial matrix multiplication
- ARM_MM_T_NT(M0, N0, K0, DATA_TYPE, a, b, c);
+#ifndef LHS_INTERLEAVE
+ lhs += (M0 * K0 * (V0 - 1));
+#endif // LHS_INTERLEAVE
- lhs_addr += (K0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
- rhs_addr += (K0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
+#ifndef RHS_INTERLEAVE
+ rhs += (N0 * K0 * (H0 - 1));
+#endif // RHS_INTERLEAVE
}
__global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index b791c1cda5..0c2942a184 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -42,8 +42,7 @@ CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifro
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
- ARM_COMPUTE_UNUSED(data_type);
+ ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8);
using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
@@ -51,6 +50,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
{
{ DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 },
+ { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 },
{ DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }
};
@@ -58,6 +58,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
{
{ DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 },
+ { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 },
{ DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
};
@@ -85,6 +86,21 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(n <= 4)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
@@ -129,6 +145,21 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(n <= 4)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 8, 2, true, true, true, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 8, true, true, true, false);
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index 222a63d86a..f77ab02810 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -63,7 +63,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3");
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index e78395f1de..762b00177c 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -65,37 +65,53 @@ CLGEMM::GEMMType CLGEMM::select_gemm_type(unsigned int m, unsigned int n, unsign
{
GEMMType gemm_type = GEMMType::RESHAPED_V1;
- if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
+ if(gpu_target_is_in(gpu_target, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT,
+ GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72,
+ GPUTarget::G76, GPUTarget::G77))
{
- if((m > 1) && (n < 16))
+ if(data_type == DataType::F32)
{
- gemm_type = GEMMType::RESHAPED_V1;
- }
- else if((m == 1) && (data_type == DataType::F32))
- {
- gemm_type = GEMMType::RESHAPED_ONLY_RHS;
+ if((m > 1) && (n < 16))
+ {
+ gemm_type = GEMMType::RESHAPED_V1;
+ }
+ else if(m == 1)
+ {
+ gemm_type = GEMMType::RESHAPED_ONLY_RHS;
+ }
+ else
+ {
+ // COMPMID-852
+ if((k > 256) && (m > 4) && reshape_b_only_on_first_run)
+ {
+ constexpr float alpha = 3.2f;
+ constexpr float fact0 = 1.51f;
+ constexpr float fact1 = 1.66f;
+ constexpr float ops = 12.0f;
+ const float scale = k > 1024 ? 1.07f : 1.0f;
+ gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+ }
+ else
+ {
+ gemm_type = GEMMType::NATIVE;
+ }
+ }
+
+ const auto workload = static_cast<float>((m * n) / 20.0f);
+
+ gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
}
else
{
- // COMPMID-852
- if((k > 256) && (m > 4) && is_data_type_float(data_type) && reshape_b_only_on_first_run)
+ if((m == 1) || (!reshape_b_only_on_first_run))
{
- constexpr float alpha = 3.2f;
- constexpr float fact0 = 1.51f;
- constexpr float fact1 = 1.66f;
- constexpr float ops = 12.0f;
- const float scale = k > 1024 ? 1.07f : 1.0f;
- gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+ gemm_type = GEMMType::NATIVE;
}
else
{
- gemm_type = GEMMType::NATIVE;
+ gemm_type = GEMMType::RESHAPED_V2;
}
}
-
- const auto workload = static_cast<float>((m * n) / 20.0f);
-
- gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
}
else
{