From f16eed979ecaa234b308c8eb145c5f9512673a54 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Wed, 19 Apr 2023 16:29:26 +0100 Subject: =?UTF-8?q?Change=20fp16=20GeMM=20heuristic=20for=20Arm=C2=AE=20Ma?= =?UTF-8?q?li=E2=84=A2-G77?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace existing heuristic with look-up tables - Expected performance improvement is between 5-15% on various models Signed-off-by: Gian Marco Iodice Change-Id: Ie26ddf66895ede131aa06fde7b200ef94d2dd467 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9472 Tested-by: Arm Jenkins Reviewed-by: Viet-Hoa Do Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp | 64 +++++++- src/gpu/cl/kernels/gemm/ClGemmHelpers.h | 16 +- .../ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp | 181 ++++++++++++++++----- src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp | 98 +---------- 4 files changed, 225 insertions(+), 134 deletions(-) diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp index 9e7029a7ae..b97ffedfe5 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 Arm Limited. + * Copyright (c) 2019-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ #include "arm_compute/core/CL/OpenCL.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include #include namespace arm_compute @@ -42,8 +43,18 @@ std::pair configure_lhs_rhs_info(unsigned bool lhs_interleave, bool rhs_interleave, bool lhs_transpose, bool rhs_transpose, bool export_to_cl_image) { ARM_COMPUTE_ERROR_ON(m0 == 0 || n0 == 0); + ARM_COMPUTE_ERROR_ON(v0 == 0); v0 = std::max(std::min(static_cast(m / m0), static_cast(v0)), static_cast(1)); - h0 = std::max(std::min(static_cast(n / n0), static_cast(h0)), static_cast(1)); + + if(h0 == 0) + { + // When h0 is 0, we should take the maximum H0 possible + h0 = std::max(n / n0, 1U); + } + else + { + h0 = std::max(std::min(static_cast(n / n0), static_cast(h0)), static_cast(1)); + } const GEMMLHSMatrixInfo lhs_info(m0, k0, v0, lhs_transpose, lhs_interleave); const GEMMRHSMatrixInfo rhs_info(n0, k0, h0, rhs_transpose, rhs_interleave, export_to_cl_image); @@ -55,6 +66,8 @@ std::pair select_lhs_rhs_info(std::pair info_buf, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { + ARM_COMPUTE_ERROR_ON_MSG(info_buf.second.export_to_cl_image == true, "The fallback GeMM configuration cannot have export_to_cl_image = true"); + const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, data_type); const TensorShape shape = misc::shape_calculator::compute_rhs_reshaped_shape(tensor_rhs_info, info_img.second); const TensorInfo tensor_reshaped_info(shape, 1, data_type); @@ -127,6 +140,53 @@ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const return ((k % mmul_k0) == 0) && (gws_y > 4); } + +std::pair find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + float min_acc = std::numeric_limits::max(); + size_t min_idx = 0; + + ARM_COMPUTE_ERROR_ON(configs.size() == 0); + const size_t num_rows = configs.size(); + const size_t num_cols = configs[0].size(); + + ARM_COMPUTE_ERROR_ON_MSG(num_cols != 14U, "The entry should have 14 integer values representing: M, N, K, B, M0, N0. K0, V0, H0, INT_LHS, INT_RHS, TRA_LHS, TRA_RHS, IMG_RHS"); + ARM_COMPUTE_UNUSED(num_cols); + + // Find nearest GeMM shape + for(size_t y = 0; y < num_rows; ++y) + { + float mc0 = configs[y][0]; + float nc0 = configs[y][1]; + float kc0 = configs[y][2]; + float bc0 = configs[y][3]; + float acc = 0; + acc += (m - mc0) * (m - mc0); + acc += (n - nc0) * (n - nc0); + acc += (k - kc0) * (n - kc0); + acc += (b - bc0) * (n - bc0); + acc = std::sqrt(acc); + if(acc < min_acc) + { + min_acc = acc; + min_idx = y; + } + } + + // Get the configuration from the nearest GeMM shape + const int m0 = configs[min_idx][4]; + const int n0 = configs[min_idx][5]; + const int k0 = configs[min_idx][6]; + const int v0 = configs[min_idx][7]; + const int h0 = configs[min_idx][8]; + const int i_lhs = configs[min_idx][9]; + const int i_rhs = configs[min_idx][10]; + const int t_lhs = configs[min_idx][11]; + const int t_rhs = configs[min_idx][12]; + const int im_rhs = configs[min_idx][13]; + + return configure_lhs_rhs_info(m, n, m0, n0, k0, v0, h0, i_lhs, i_rhs, t_lhs, t_rhs, im_rhs); +} } // namespace gemm } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h index bf1e8fce82..6689b10e69 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022 Arm Limited. + * Copyright (c) 2019-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,6 +35,8 @@ namespace kernels { namespace gemm { +using GeMMConfigsMatrix = std::vector>; + /** Configure @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo * * @param[in] m Number of rows (M) in the LHS matrix not reshaped @@ -103,6 +105,18 @@ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, */ bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b, const DataType data_type, unsigned int &best_m0, unsigned int &best_n0); + +/** Find the preferred configurations for the LHS and RHS tensor using the GeMMConfigsMatrix provided by the user + * + * @param[in] configs List of best configurations for a limited number of GeMM shapes + * @param[in] m Number of rows of the LHS matrix + * @param[in] n Number of columns of the RHS matrix + * @param[in] k Number of columns of the LHS matrix, rows of the RHS matrix + * @param[in] b Batch size + * + * @return @ref GEMMLHSMatrixInfo and @ref GEMMRHSMatrixInfo + */ +std::pair find_lhs_rhs_info(const GeMMConfigsMatrix &configs, unsigned int m, unsigned int n, unsigned int k, unsigned int b); } // namespace gemm } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp index e4e35cb8ce..5d666c03a5 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -196,52 +196,159 @@ std::pair ClGemmDefaultConfigReshapedRhsOn std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - ARM_COMPUTE_UNUSED(k); - ARM_COMPUTE_UNUSED(b); + const GeMMConfigsMatrix configs_1nkb_best = + { + { 1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0 }, + { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0 }, + { 1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0 }, + { 1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, + { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_n_small_best = + { + { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 }, + { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 }, + { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 }, + { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1 } + }; - if(m == 1) + const GeMMConfigsMatrix configs_mnkb_n_small_fallback = { - const unsigned int h0 = std::max(n / 2, 1U); - if(n <= 836.0) - { - return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, 0, 1, 0, 1, 0); - } - else + { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 }, + { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 }, + { 16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0 }, + { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_best = + { + { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, + { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, + { 25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 }, + { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, + { 369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, + { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, + { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 }, + { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 } + }; + + const GeMMConfigsMatrix configs_mnkb_fallback = + { + { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, + { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, + { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, + { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0 }, + { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, + { 369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, + { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, + { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }, + { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, + { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 }, + { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_best_batched = + { + { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, + { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 } + }; + + const GeMMConfigsMatrix configs_mnkb_fallback_batched = + { + { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, + { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, + { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 } + }; + + const GeMMConfigsMatrix *configs_best_to_use = nullptr; + const GeMMConfigsMatrix *configs_fallback_to_use = nullptr; + + if(b == 1) + { + constexpr float ratio_m_gt_n = 10.f; + constexpr unsigned int n_small_thr = 4; + const float ratio = static_cast(m) / static_cast(n); + + if(m == 1) { - return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, 0, 1, 0, 1, 0); + // We do not need fallback in this case, as we never use cl_image for the rhs tensor + configs_best_to_use = &configs_1nkb_best; + configs_fallback_to_use = &configs_1nkb_best; } - } - else if(m < 128) - { - const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); - if(k >= 512) + else if(n <= n_small_thr && ratio > ratio_m_gt_n) { - return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0); + configs_best_to_use = &configs_mnkb_n_small_best; + configs_fallback_to_use = &configs_mnkb_n_small_fallback; } else { - return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0); + configs_best_to_use = &configs_mnkb_best; + configs_fallback_to_use = &configs_mnkb_fallback; } } else { - const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); - if(n >= 64) - { - return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, 0, 1, 0, 0); - } - else - { - if(k >= 512) - { - return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0); - } - else - { - return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0); - } - } + configs_best_to_use = &configs_mnkb_best_batched; + configs_fallback_to_use = &configs_mnkb_fallback_batched; } + + GEMMLHSMatrixInfo lhs_info0; + GEMMRHSMatrixInfo rhs_info0; + GEMMLHSMatrixInfo lhs_info1; + GEMMRHSMatrixInfo rhs_info1; + + std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b); + std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b); + + return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), + std::make_pair(lhs_info1, rhs_info1), + n, k, b, DataType::F16); } std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) @@ -414,9 +521,9 @@ std::pair ClGemmDefaultConfigReshapedRhsOn std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; - const float r_mn = static_cast(m) / static_cast(n); - const float r_mk = static_cast(m) / static_cast(k); - const float r_nk = static_cast(n) / static_cast(k); + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); if(m == 1) { diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp index 4c7daf916e..29d3177424 100644 --- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -138,105 +138,15 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::default_f16(unsigned int m, unsigned CLGEMMKernelType CLGEMMDefaultTypeValhall::g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { + ARM_COMPUTE_UNUSED(m, n, k, b); + if(!is_rhs_constant) { return CLGEMMKernelType::NATIVE; } - - if(m == 1) - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - - const float r_mn = static_cast(m) / static_cast(n); - const float r_mk = static_cast(m) / static_cast(k); - const float r_nk = static_cast(n) / static_cast(k); - const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; - - if(r_mk <= 0.6817956566810608) - { - if(workload <= 801.6000061035156) - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - else - { - if(r_mn <= 0.0839829258620739) - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - else - { - if(r_mk <= 0.24917218834161758) - { - return CLGEMMKernelType::RESHAPED; - } - else - { - if(workload <= 2551.75) - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - else - { - if(workload <= 5061.574951171875) - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - else - { - return CLGEMMKernelType::RESHAPED; - } - } - } - } - } - } else { - if(r_mk <= 4.849947690963745) - { - if(workload <= 17618.4501953125) - { - if(workload <= 5224.699951171875) - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - else - { - if(r_nk <= 0.7933054566383362) - { - return CLGEMMKernelType::RESHAPED; - } - else - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - } - } - else - { - if(workload <= 20275.2001953125) - { - return CLGEMMKernelType::RESHAPED; - } - else - { - if(r_mk <= 3.07421875) - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } - else - { - return CLGEMMKernelType::RESHAPED; - } - } - } - } - else - { - return CLGEMMKernelType::RESHAPED_ONLY_RHS; - } + return CLGEMMKernelType::RESHAPED_ONLY_RHS; } } -- cgit v1.2.1