From 05639f6b1ee3dcdd2c7923d0cf3a5d4712bd0071 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 24 Sep 2019 12:05:06 +0100 Subject: COMPMID-2571: Add support for FP16 in CLGEMMReshaped - part 1 Change-Id: I8adb8850cc5ade49ebc1dbf63401f03d5ecad708 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1983 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- src/core/CL/cl_kernels/gemm.cl | 257 ++++++++++++++------- .../CLGEMMReshapedKernelConfigurationBifrost.cpp | 35 ++- .../kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp | 2 +- src/runtime/CL/functions/CLGEMM.cpp | 56 +++-- 4 files changed, 243 insertions(+), 107 deletions(-) (limited to 'src') diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl index 8e628e8d01..c35d160689 100644 --- a/src/core/CL/cl_kernels/gemm.cl +++ b/src/core/CL/cl_kernels/gemm.cl @@ -2041,79 +2041,37 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), #define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE) #if GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VFMA(SIZE, a, b, c) c += (a) * (b); +#define ARM_VFMA(a, b, c) c += (a) * (b); #else // GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VFMA_1(a, b, c) \ - ({ \ - c = fma((a), (b), (c)); \ - }) -#define ARM_VFMA_2(a, b, c) \ - ({ \ - (c).s0 = fma((a).s0, (b).s0, (c).s0); \ - (c).s1 = fma((a).s1, (b).s1, (c).s1); \ - }) -#define ARM_VFMA_3(a, b, c) \ - ({ \ - ARM_VFMA_2(a, b, c); \ - (c).s2 = fma((a).s2, (b).s2, (c).s2); \ - }) -#define ARM_VFMA_4(a, b, c) \ - ({ \ - ARM_VFMA_3(a, b, c); \ - (c).s3 = fma((a).s3, (b).s3, (c).s3); \ - }) -#define ARM_VFMA_8(a, b, c) \ - ({ \ - ARM_VFMA_4(a, b, c); \ - (c).s4 = fma((a).s4, (b).s4, (c).s4); \ - (c).s5 = fma((a).s5, (b).s5, (c).s5); \ - (c).s6 = fma((a).s6, (b).s6, (c).s6); \ - (c).s7 = fma((a).s7, (b).s7, (c).s7); \ - }) -#define ARM_VFMA_16(a, b, c) \ - ({ \ - ARM_VFMA_8(a, b, c); \ - (c).s8 = fma((a).s8, (b).s8, (c).s8); \ - (c).s9 = fma((a).s9, (b).s9, (c).s9); \ - (c).sA = fma((a).sA, (b).sA, (c).sA); \ - (c).sB = fma((a).sB, (b).sB, (c).sB); \ - (c).sC = fma((a).sC, (b).sC, (c).sC); \ - (c).sD = fma((a).sD, (b).sD, (c).sD); \ - (c).sE = fma((a).sE, (b).sE, (c).sE); \ - (c).sF = fma((a).sF, (b).sF, (c).sF); \ - }) - -// Factory macro for the vector FMA -#define ARM_VFMA(SIZE, a, b, c) ARM_VFMA_##SIZE((a), (b), (c)) - +#define ARM_VFMA(a, b, c) c = fma((a), (b), (c)); #endif // GPU_ARCH == GPU_ARCH_MIDGARD -#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \ +#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \ }) -#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \ +#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \ }) -#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \ +#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \ }) -#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \ +#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \ }) -#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \ - ({ \ - ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \ - ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \ +#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \ + ({ \ + ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \ + ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \ }) // Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1 @@ -2172,7 +2130,8 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), // K0: 1, 2, 3, 4, 8, 16 // This macro calls the vector-by-matrix macro K0 times // A, B and C are matrices -#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) CONCAT(ARM_MM_T_NT_M0xN0x, K0) \ +#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \ + CONCAT(ARM_MM_T_NT_M0xN0x, K0) \ (M0, N0, TYPE, A, B, C) /** This OpenCL kernel computes the matrix multiplication between 2 matrices. @@ -2272,11 +2231,9 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs), #if defined(RHS_INTERLEAVE) #define RHS_OFFSET_X (N0) #define RHS_STEP_X ((N0) * (H0)) -#define RHS_STEP_LOOP (1) #else // defined(RHS_INTERLEAVE) #define RHS_OFFSET_X (RHS_BLOCK_SIZE) #define RHS_STEP_X (N0) -#define RHS_STEP_LOOP (H0) #endif // defined(RHS_INTERLEAVE) const uint x = get_global_id(0); @@ -2306,28 +2263,160 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs), // Initialize the accumulators REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; - REPEAT_VAR_INIT_TO_CONST(K0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0; REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0); + __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr); + __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr); + for(int i = 0; i < k; i += K0) { - // Supported cases (K0, M0): - // 1,2 - 2,2 - 3,2 - 4,2 - 5,2 - 6,2 - 7,2 - 8,2 - // 1,3 - 2,3 - 3,3 - 4,3 - 5,3 - 6,3 - 7,3 - 8,3 - // 1,4 - 2,4 - 3,4 - 4,4 - 5,4 - 6,4 - 7,4 - 8,4 - // 1,8 - 2,8 - 3,8 - 4,8 - 5,8 - 6,8 - 7,8 - 8,8 - // 1,16 - 2,16 - 3,16 - 4,16 - 5,16 - 6,16 - 7,16 - 8,16 - // Load values from LHS matrix - LOAD_BLOCK(K0, M0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs); + VEC_DATA_TYPE(DATA_TYPE, M0) + a0 = VLOAD(M0)(0, lhs); + VEC_DATA_TYPE(DATA_TYPE, N0) + b0 = VLOAD(N0)(0, rhs); - // Load values from RHS matrix - LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zlhs); + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + +#if K0 > 1 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 1 + +#if K0 > 2 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 2 + +#if K0 > 3 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 3 + +#if K0 > 4 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 4 + +#if K0 > 8 + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; + + a0 = VLOAD(M0)(0, lhs); + b0 = VLOAD(N0)(0, rhs); + + ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c); + + lhs += LHS_STEP_X; + rhs += RHS_STEP_X; +#endif // K0 > 8 - // Perform the partial matrix multiplication - ARM_MM_T_NT(M0, N0, K0, DATA_TYPE, a, b, c); +#ifndef LHS_INTERLEAVE + lhs += (M0 * K0 * (V0 - 1)); +#endif // LHS_INTERLEAVE - lhs_addr += (K0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE); - rhs_addr += (K0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE); +#ifndef RHS_INTERLEAVE + rhs += (N0 * K0 * (H0 - 1)); +#endif // RHS_INTERLEAVE } __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y); diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp index b791c1cda5..0c2942a184 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp @@ -42,8 +42,7 @@ CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifro std::pair CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8); - ARM_COMPUTE_UNUSED(data_type); + ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8); using ConfigurationFunctionExecutorPtr = std::pair (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); @@ -51,6 +50,7 @@ std::pair CLGEMMReshapedKernelConfiguratio static std::map gemm_configs_G76 = { { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 }, + { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 }, { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 } }; @@ -58,6 +58,7 @@ std::pair CLGEMMReshapedKernelConfiguratio static std::map gemm_configs_G7x = { { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 }, + { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 }, { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 } }; @@ -85,6 +86,21 @@ std::pair CLGEMMReshapedKernelConfiguratio } } +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false); + } +} + std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); @@ -129,6 +145,21 @@ std::pair CLGEMMReshapedKernelConfiguratio } } +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + ARM_COMPUTE_UNUSED(k); + ARM_COMPUTE_UNUSED(b); + + if(n <= 4) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 8, 2, true, true, true, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 8, true, true, true, false); + } +} + std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp index 222a63d86a..f77ab02810 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp @@ -63,7 +63,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index e78395f1de..762b00177c 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -65,37 +65,53 @@ CLGEMM::GEMMType CLGEMM::select_gemm_type(unsigned int m, unsigned int n, unsign { GEMMType gemm_type = GEMMType::RESHAPED_V1; - if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76)) + if(gpu_target_is_in(gpu_target, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, + GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, + GPUTarget::G76, GPUTarget::G77)) { - if((m > 1) && (n < 16)) + if(data_type == DataType::F32) { - gemm_type = GEMMType::RESHAPED_V1; - } - else if((m == 1) && (data_type == DataType::F32)) - { - gemm_type = GEMMType::RESHAPED_ONLY_RHS; + if((m > 1) && (n < 16)) + { + gemm_type = GEMMType::RESHAPED_V1; + } + else if(m == 1) + { + gemm_type = GEMMType::RESHAPED_ONLY_RHS; + } + else + { + // COMPMID-852 + if((k > 256) && (m > 4) && reshape_b_only_on_first_run) + { + constexpr float alpha = 3.2f; + constexpr float fact0 = 1.51f; + constexpr float fact1 = 1.66f; + constexpr float ops = 12.0f; + const float scale = k > 1024 ? 1.07f : 1.0f; + gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE; + } + else + { + gemm_type = GEMMType::NATIVE; + } + } + + const auto workload = static_cast((m * n) / 20.0f); + + gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type; } else { - // COMPMID-852 - if((k > 256) && (m > 4) && is_data_type_float(data_type) && reshape_b_only_on_first_run) + if((m == 1) || (!reshape_b_only_on_first_run)) { - constexpr float alpha = 3.2f; - constexpr float fact0 = 1.51f; - constexpr float fact1 = 1.66f; - constexpr float ops = 12.0f; - const float scale = k > 1024 ? 1.07f : 1.0f; - gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE; + gemm_type = GEMMType::NATIVE; } else { - gemm_type = GEMMType::NATIVE; + gemm_type = GEMMType::RESHAPED_V2; } } - - const auto workload = static_cast((m * n) / 20.0f); - - gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type; } else { -- cgit v1.2.1