diff options
author | Felix Thomasmathibalan <felixjohnny.thomasmathibalan@arm.com> | 2023-09-27 17:46:17 +0100 |
---|---|---|
committer | felixjohnny.thomasmathibalan <felixjohnny.thomasmathibalan@arm.com> | 2023-09-28 12:08:05 +0000 |
commit | afd38f0c617d6f89b2b4532c6c44f116617e2b6f (patch) | |
tree | 03bc7d5a762099989b16a656fa8d397b490ed70e /src/gpu/cl/kernels/gemm/reshaped_only_rhs | |
parent | bdcb4c148ee2fdeaaddf4cf1e57bbb0de02bb894 (diff) | |
download | ComputeLibrary-afd38f0c617d6f89b2b4532c6c44f116617e2b6f.tar.gz |
Apply clang-format on repository
Code is formatted as per a revised clang format configuration
file(not part of this delivery). Version 14.0.6 is used.
Exclusion List:
- files with .cl extension
- files that are not strictly C/C++ (e.g. Android.bp, Sconscript ...)
And the following directories
- compute_kernel_writer/validation/
- tests/
- include/
- src/core/NEON/kernels/convolution/
- src/core/NEON/kernels/arm_gemm/
- src/core/NEON/kernels/arm_conv/
- data/
There will be a follow up for formatting of .cl files and the
files under tests/ and compute_kernel_writer/validation/.
Signed-off-by: Felix Thomasmathibalan <felixjohnny.thomasmathibalan@arm.com>
Change-Id: Ib7eb1fcf4e7537b9feaefcfc15098a804a3fde0a
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10391
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Diffstat (limited to 'src/gpu/cl/kernels/gemm/reshaped_only_rhs')
5 files changed, 402 insertions, 458 deletions
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp index 9c23d9c998..c4825bfbeb 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp @@ -29,7 +29,9 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" + #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" + #include <utility> namespace arm_compute @@ -47,33 +49,39 @@ ClGemmDefaultConfigReshapedRhsOnlyBifrost::ClGemmDefaultConfigReshapedRhsOnlyBif { } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k, - unsigned int b); - - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8); - - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); - - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8); - - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8); - - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ( + ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G51( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G31( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8); + + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); ConfigurationFunctionExecutorPtr func = nullptr; - switch(_target) + switch (_target) { case GPUTarget::G76: func = configs_G76.get_function(data_type); @@ -96,14 +104,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn return (this->*func)(m, n, k, b); } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { - if(n <= 2548) + if (n <= 2548) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false); } @@ -118,12 +127,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1); @@ -131,7 +141,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn else { const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1)); - if(m >= 28) + if (m >= 28) { return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1); } @@ -142,7 +152,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); @@ -154,9 +165,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn const bool is_workload_big = ((m * n * b) / 16) >= 2048; - if(m == 1) + if (m == 1) { - if(n >= 8192) + if (n >= 8192) { const unsigned int h0 = std::max(n / 4, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false); @@ -164,7 +175,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn else { const unsigned int h0 = std::max(n / 2, 1U); - if(n <= 204) + if (n <= 204) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false); } @@ -177,25 +188,29 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn else { const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1)); - if(is_workload_big) + if (is_workload_big) { - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true); } else { - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true); } } // Get lhs_info/rhs_info in case of OpenCL image const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(16)), static_cast<int>(1)); - if(is_workload_big) + if (is_workload_big) { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true); } const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32); @@ -205,7 +220,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true; - if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) + if (bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) { return std::make_pair(lhs_info_img, rhs_info_img); } @@ -215,7 +230,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; const float r_nk = static_cast<float>(n) / static_cast<float>(k); @@ -225,46 +241,49 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - if(m == 1) + if (m == 1) { - if(r_nk <= 0.4664f) + if (r_nk <= 0.4664f) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F32); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); } } else { - if(workload <= 274.4000f) + if (workload <= 274.4000f) { return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F32); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); } } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int n0 = n < 1280 ? 2 : 4; const unsigned int h0 = std::max(n / n0, 1U); @@ -276,14 +295,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { - if(n > 2048) + if (n > 2048) { const unsigned int h0 = std::max(n / 4, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true); @@ -300,7 +320,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { const float r_mn = static_cast<float>(m) / static_cast<float>(n); const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; @@ -312,57 +333,59 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - if(m == 1) + if (m == 1) { - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false); - if(r_mk <= 0.0026f) + if (r_mk <= 0.0026f) { - if(r_nk <= 0.4664f) + if (r_nk <= 0.4664f) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } } else { - if(r_mk <= 0.0148f) + if (r_mk <= 0.0148f) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } } } else { - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false); - if(workload <= 362.6000f) + if (workload <= 362.6000f) { return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); } else { - if(r_mn <= 22.6067f) + if (r_mn <= 22.6067f) { - if(workload <= 708.8000f) + if (workload <= 708.8000f) { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } else { @@ -371,27 +394,28 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(r_nk <= 0.0917f) + if (r_nk <= 0.0917f) { return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } } } } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); - if(m == 1) + if (m == 1) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); } @@ -400,15 +424,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn const float r_mn = static_cast<float>(m) / static_cast<float>(n); const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; - if(workload <= 7449.60f) + if (workload <= 7449.60f) { - if(workload <= 691.60f) + if (workload <= 691.60f) { return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false); } else { - if(workload <= 4155.20f) + if (workload <= 4155.20f) { return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); } @@ -420,21 +444,22 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(workload <= 16300.80f) + if (workload <= 16300.80f) { - if(r_mn <= 44.56f) + if (r_mn <= 44.56f) { GEMMLHSMatrixInfo lhs_info_buf; GEMMRHSMatrixInfo rhs_info_buf; GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } else { @@ -448,23 +473,25 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } } } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int n0 = n < 1280 ? 2 : 4; const unsigned int h0 = std::max(n / n0, 1U); @@ -476,14 +503,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(dot8_supported(CLKernelLibrary::get().get_device())) + if (dot8_supported(CLKernelLibrary::get().get_device())) { - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); @@ -497,7 +525,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn else { const int h0 = std::max(std::min(static_cast<int>(n / 2), static_cast<int>(128)), static_cast<int>(1)); - if(m == 1) + if (m == 1) { return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true); } @@ -508,12 +536,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); @@ -524,12 +553,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true); diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h index 321cbb5250..77c0c8d500 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.h @@ -45,21 +45,34 @@ public: ClGemmDefaultConfigReshapedRhsOnlyBifrost(GPUTarget gpu); // Inherited overridden method - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; private: - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); }; } // namespace gemm } // namespace kernels diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp index d08bf84c72..da3e2ec912 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -50,30 +50,35 @@ ClGemmDefaultConfigReshapedRhsOnlyValhall::ClGemmDefaultConfigReshapedRhsOnlyVal { } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k, - unsigned int b); + using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ( + ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, - &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16, - &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77( + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32, - &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16, - &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78( + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G710(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, - &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16, - &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G710( + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); - CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32, - &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16, - &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715( + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); ConfigurationFunctionExecutorPtr func = nullptr; - switch(_target) + switch (_target) { case GPUTarget::G78: func = configs_G78.get_function(data_type); @@ -96,29 +101,29 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn return (this->*func)(m, n, k, b); } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - if(m == 1) + if (m == 1) { const float r_mn = static_cast<float>(m) / static_cast<float>(n); const float r_mk = static_cast<float>(m) / static_cast<float>(k); - if(r_mk <= 0.0064484127797186375) + if (r_mk <= 0.0064484127797186375) { - if(r_mn <= 0.0028273810748942196) + if (r_mn <= 0.0028273810748942196) { GEMMLHSMatrixInfo lhs_info_buf; GEMMRHSMatrixInfo rhs_info_buf; GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - const unsigned int h0 = std::max(n / 4, 1U); + const unsigned int h0 = std::max(n / 4, 1U); std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1); std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F32); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); } else { @@ -127,7 +132,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(r_mk <= 0.020312500186264515) + if (r_mk <= 0.020312500186264515) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0); } @@ -143,9 +148,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; const float r_mk = static_cast<float>(m) / static_cast<float>(k); - if(workload <= 1999.2000122070312) + if (workload <= 1999.2000122070312) { - if(workload <= 747.1999816894531) + if (workload <= 747.1999816894531) { return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); } @@ -159,15 +164,14 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F32); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); } } else { - if(r_mn <= 0.03348214365541935) + if (r_mn <= 0.03348214365541935) { - if(r_mk <= 0.028125000186264515) + if (r_mk <= 0.028125000186264515) { return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); } @@ -181,8 +185,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F32); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); } } else @@ -195,168 +198,112 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F32); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); } } } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - const GeMMConfigsMatrix configs_1nkb_best = - { - { 1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0 }, - { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0 }, - { 1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0 }, - { 1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 } - }; - - const GeMMConfigsMatrix configs_mnkb_n_small_best = - { - { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 }, - { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 }, - { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1 }, - { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1 } - }; - - const GeMMConfigsMatrix configs_mnkb_n_small_fallback = - { - { 102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0 }, - { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 }, - { 16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0 }, - { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0 } - }; - - const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = - { - { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, - { 25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 }, - { 369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, - { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 } - }; - - const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = - { - { 25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, - { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0 }, - { 369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 }, - { 65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, - { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, - { 12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }, - { 29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 } - }; - - const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = - { - { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, - { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1 }, - { 49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - }; - - const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = - { - { 24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0 }, - { 49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0 }, + const GeMMConfigsMatrix configs_1nkb_best = { + {1, 8984, 640, 1, 1, 8, 8, 1, 0, 1, 1, 1, 1, 0}, {1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, + {1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 4, 8, 1, 0, 1, 0, 1, 0, 0}, + {1, 5304, 640, 1, 1, 4, 4, 1, 0, 1, 0, 1, 1, 0}, {1, 1352, 1520, 1, 1, 2, 8, 1, 0, 1, 1, 1, 1, 0}, + {1, 4096, 25088, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, {1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_n_small_best = {{102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0}, + {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1}, + {16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 1, 1}, + {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 1}}; + + const GeMMConfigsMatrix configs_mnkb_n_small_fallback = {{102400, 4, 96, 1, 2, 2, 16, 1, 4, 1, 1, 1, 1, 0}, + {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0}, + {16384, 4, 128, 1, 2, 2, 16, 1, 2, 1, 1, 1, 1, 0}, + {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 1, 1, 1, 0}}; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = { + {25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1}, + {369664, 32, 28, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1}, {65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, + {23036, 56, 736, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {90968, 40, 600, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, + {8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, + {16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {12604, 60, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, + {29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0}, + {2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = { + {25584, 88, 16, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 0, 0}, + {369664, 32, 28, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, + {23036, 56, 736, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {90968, 40, 600, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, + {8944, 32, 776, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, + {16544, 104, 160, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, {12604, 60, 160, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0}, + {29584, 32, 28, 1, 4, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 128, 1, 1, 1, 0, 0}, + {2688, 136, 1492, 1, 8, 4, 4, 1, 128, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = { + {24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0}, + {49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 1}, + {49, 1024, 1024, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, }; - const GeMMConfigsMatrix configs_mnkb_squared_best = - { - { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, - { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, - { 180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1 }, - { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, - { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 } + const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = { + {24, 488, 88, 1, 2, 4, 16, 1, 4, 1, 1, 1, 0, 0}, + {49, 1024, 512, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {49, 1024, 1024, 1, 4, 4, 8, 1, 256, 1, 1, 1, 0, 0}, }; - const GeMMConfigsMatrix configs_mnkb_squared_fallback = - { - { 72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0 }, - { 268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0 }, - { 180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0 }, - { 24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 }, - { 24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0 } - }; - - const GeMMConfigsMatrix configs_mnkb_best_batched = - { - { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1 }, - { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 } - }; - - const GeMMConfigsMatrix configs_mnkb_fallback_batched = - { - { 3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0 }, - { 112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 }, - { 2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0 } - }; + const GeMMConfigsMatrix configs_mnkb_squared_best = { + {72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0}, + {180, 420, 952, 1, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, {1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 4, 4, 1, 64, 1, 1, 1, 0, 1}, + {24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}}; + + const GeMMConfigsMatrix configs_mnkb_squared_fallback = { + {72, 92, 136, 1, 2, 2, 8, 1, 128, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 256, 1, 1, 1, 0, 0}, + {180, 420, 952, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, {1000, 152, 304, 1, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {272, 400, 2116, 1, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 4, 4, 1, 256, 1, 1, 1, 0, 0}, + {24, 88, 236, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 64, 1, 1, 1, 1, 0}}; + + const GeMMConfigsMatrix configs_mnkb_best_batched = { + {3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 1}, + {688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, + {1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_fallback_batched = { + {3136, 64, 64, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {688, 92, 68, 32, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 4, 4, 8, 1, 128, 1, 1, 1, 0, 0}, + {112, 184, 144, 28, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, + {1568, 64, 40, 36, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 64, 1, 1, 1, 0, 0}}; const GeMMConfigsMatrix *configs_best_to_use = nullptr; const GeMMConfigsMatrix *configs_fallback_to_use = nullptr; - if(b == 1) + if (b == 1) { constexpr float ratio_m_gt_n = 10.f; constexpr float ratio_n_gt_m = 0.1f; constexpr unsigned int n_small_thr = 4; const float ratio = static_cast<float>(m) / static_cast<float>(n); - if(m == 1) + if (m == 1) { // We do not need fallback in this case, as we never use cl_image for the rhs tensor configs_best_to_use = &configs_1nkb_best; configs_fallback_to_use = &configs_1nkb_best; } - else if(n <= n_small_thr && ratio > ratio_m_gt_n) + else if (n <= n_small_thr && ratio > ratio_m_gt_n) { configs_best_to_use = &configs_mnkb_n_small_best; configs_fallback_to_use = &configs_mnkb_n_small_fallback; } - else if(ratio > ratio_m_gt_n) + else if (ratio > ratio_m_gt_n) { configs_best_to_use = &configs_mnkb_m_gt_n_best; configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback; } - else if(ratio < ratio_n_gt_m) + else if (ratio < ratio_n_gt_m) { configs_best_to_use = &configs_mnkb_n_gt_m_best; configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback; @@ -381,17 +328,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b); std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b); - return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), - std::make_pair(lhs_info1, rhs_info1), - n, k, b, DataType::F16); + return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), std::make_pair(lhs_info1, rhs_info1), n, k, b, + DataType::F16); } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1); @@ -399,7 +346,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn else { const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1)); - if(m >= 28) + if (m >= 28) { return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1); } @@ -410,30 +357,31 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { const float r_mn = static_cast<float>(m) / static_cast<float>(n); const float r_mk = static_cast<float>(m) / static_cast<float>(k); const float r_nk = static_cast<float>(n) / static_cast<float>(k); const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; - if(m == 1) + if (m == 1) { - if(workload <= 278.7000f) + if (workload <= 278.7000f) { - if(workload <= 7.5000f) + if (workload <= 7.5000f) { return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); } else { - if(r_mn <= 0.0031f) + if (r_mn <= 0.0031f) { - if(workload <= 256.6000f) + if (workload <= 256.6000f) { - if(workload <= 16.7500f) + if (workload <= 16.7500f) { - if(r_nk <= 1.6671f) + if (r_nk <= 1.6671f) { return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); } @@ -454,15 +402,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(r_mk <= 0.0027f) + if (r_mk <= 0.0027f) { - if(r_mk <= 0.0014f) + if (r_mk <= 0.0014f) { return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); } else { - if(workload <= 8.9500f) + if (workload <= 8.9500f) { return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); } @@ -474,13 +422,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(workload <= 14.1500f) + if (workload <= 14.1500f) { return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0); } else { - if(r_mk <= 0.0041f) + if (r_mk <= 0.0041f) { return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0); } @@ -495,9 +443,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(workload <= 363.7000f) + if (workload <= 363.7000f) { - if(r_mk <= 0.0031f) + if (r_mk <= 0.0031f) { return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0); } @@ -514,9 +462,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(workload <= 1384.8000f) + if (workload <= 1384.8000f) { - if(workload <= 704.0000f) + if (workload <= 704.0000f) { return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0); } @@ -527,9 +475,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(workload <= 16761.6006f) + if (workload <= 16761.6006f) { - if(r_mn <= 187.1250f) + if (r_mn <= 187.1250f) { return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1); } @@ -540,7 +488,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(r_mk <= 432.4630f) + if (r_mk <= 432.4630f) { return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1); } @@ -553,42 +501,37 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f; const float r_mn = static_cast<float>(m) / static_cast<float>(n); const float r_mk = static_cast<float>(m) / static_cast<float>(k); const float r_nk = static_cast<float>(n) / static_cast<float>(k); - if(m == 1) + if (m == 1) { - const GeMMConfigsMatrix configs_mnkb_best = - { - { 1, 8984, 640, 1, 1, 4, 2, 1, 0, 1, 0, 1, 1, 0 }, - { 1, 420, 392, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 644, 5288, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 6512, 6404, 1, 1, 2, 2, 1, 0, 1, 0, 1, 1, 0 }, - { 1, 5304, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 4096, 25088, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 732, 8988, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 } - }; + const GeMMConfigsMatrix configs_mnkb_best = { + {1, 8984, 640, 1, 1, 4, 2, 1, 0, 1, 0, 1, 1, 0}, {1, 420, 392, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, + {1, 644, 5288, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 2, 2, 1, 0, 1, 0, 1, 1, 0}, + {1, 5304, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0}, {1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, + {1, 4096, 25088, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 732, 8988, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}}; return find_lhs_rhs_info(configs_mnkb_best, m, n, k, b); } else { - if(workload <= 1384.8000f) + if (workload <= 1384.8000f) { - if(r_nk <= 0.8333f) + if (r_nk <= 0.8333f) { - if(r_mk <= 0.9119f) + if (r_mk <= 0.9119f) { return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1); } else { - if(r_nk <= 0.1181f) + if (r_nk <= 0.1181f) { return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0); } @@ -600,7 +543,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(r_mk <= 1.0013f) + if (r_mk <= 1.0013f) { return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1); } @@ -612,11 +555,11 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(workload <= 11404.7998f) + if (workload <= 11404.7998f) { - if(r_mk <= 2.2884f) + if (r_mk <= 2.2884f) { - if(r_nk <= 0.9286f) + if (r_nk <= 0.9286f) { return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1); } @@ -632,9 +575,9 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } else { - if(r_nk <= 1.1926f) + if (r_nk <= 1.1926f) { - if(r_mn <= 1385.7917f) + if (r_mn <= 1385.7917f) { return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1); } @@ -652,12 +595,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { unsigned int best_m0; unsigned int best_n0; - if(is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0)) + if (is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0)) { return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true); } @@ -667,153 +611,101 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn } } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G710_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - const GeMMConfigsMatrix configs_1nkb_best = - { - { 1, 8984, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 6512, 6404, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 5304, 640, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0 }, - { 1, 4096, 25088, 1, 1, 2, 8, 1, 0, 1, 0, 1, 1, 0 }, - { 1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0 } + const GeMMConfigsMatrix configs_1nkb_best = { + {1, 8984, 640, 1, 1, 2, 2, 1, 0, 1, 0, 1, 0, 0}, {1, 420, 392, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, + {1, 644, 5288, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}, {1, 6512, 6404, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, + {1, 5304, 640, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, {1, 1352, 1520, 1, 1, 2, 4, 1, 0, 1, 0, 1, 0, 0}, + {1, 4096, 25088, 1, 1, 2, 8, 1, 0, 1, 0, 1, 1, 0}, {1, 732, 8988, 1, 1, 2, 8, 1, 0, 1, 0, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_n_small_best = {{102400, 4, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, + {102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, + {16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}, + {16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0}}; + + const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = { + {25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 16, 1, 8, 1, 1, 1, 0, 1}, + {369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, + {23036, 56, 736, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {8944, 32, 776, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {2688, 136, 1492, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {50176, 64, 300, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}, {16544, 104, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {12604, 60, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {3728, 96, 196, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0}, }; - const GeMMConfigsMatrix configs_mnkb_n_small_best = - { - { 102400, 4, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, - { 102400, 2, 96, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, - { 16384, 4, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 }, - { 16384, 2, 128, 1, 1, 2, 16, 1, 0, 1, 0, 1, 0, 0 } + const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = { + {25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}, + {369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, + {23036, 56, 736, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 0}, + {8944, 32, 776, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0}, {2688, 136, 1492, 1, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0}, + {50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0}, {16544, 104, 160, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, + {12604, 60, 160, 1, 2, 8, 8, 1, 8, 1, 1, 1, 0, 0}, {3728, 96, 196, 1, 2, 8, 8, 1, 64, 1, 1, 1, 0, 0}, + {29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0}, }; - const GeMMConfigsMatrix configs_mnkb_m_gt_n_best = - { - { 25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0 }, - { 25584, 16, 68, 1, 2, 4, 16, 1, 8, 1, 1, 1, 0, 1 }, - { 369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, - { 23036, 56, 736, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 8944, 32, 776, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 2688, 136, 1492, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 50176, 64, 300, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 }, - { 16544, 104, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 12604, 60, 160, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 3728, 96, 196, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, - { 12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 }, - }; + const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = {{24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0}, + {49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0}, + {49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}}; - const GeMMConfigsMatrix configs_mnkb_m_gt_n_fallback = - { - { 25584, 88, 16, 1, 4, 8, 4, 1, 4, 1, 1, 1, 0, 0 }, - { 25584, 16, 68, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 }, - { 369664, 32, 28, 1, 2, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 65792, 44, 24, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, - { 23036, 56, 736, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, - { 90968, 40, 600, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 0 }, - { 8944, 32, 776, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 }, - { 2688, 136, 1492, 1, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0 }, - { 50176, 64, 300, 1, 4, 8, 4, 1, 128, 1, 1, 1, 0, 0 }, - { 16544, 104, 160, 1, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, - { 12604, 60, 160, 1, 2, 8, 8, 1, 8, 1, 1, 1, 0, 0 }, - { 3728, 96, 196, 1, 2, 8, 8, 1, 64, 1, 1, 1, 0, 0 }, - { 29584, 32, 28, 1, 2, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, - { 12544, 32, 27, 1, 2, 8, 8, 1, 16, 1, 1, 1, 0, 0 }, - }; + const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = {{24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0}, + {49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0}, + {49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0}}; - const GeMMConfigsMatrix configs_mnkb_n_gt_m_best = - { - { 24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0 }, - { 49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0 }, - { 49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 } + const GeMMConfigsMatrix configs_mnkb_squared_best = { + {24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, + {72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {180, 420, 952, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1}, {1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0}, + {272, 400, 2116, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {196, 512, 512, 1, 5, 2, 8, 1, 4, 1, 1, 1, 1, 1}, }; - const GeMMConfigsMatrix configs_mnkb_n_gt_m_fallback = - { - { 24, 488, 88, 1, 2, 2, 8, 1, 8, 1, 1, 1, 1, 0 }, - { 49, 1024, 512, 1, 2, 4, 8, 1, 8, 1, 1, 1, 1, 0 }, - { 49, 1024, 1024, 1, 2, 4, 8, 1, 4, 1, 1, 1, 1, 0 } - }; - - const GeMMConfigsMatrix configs_mnkb_squared_best = - { - { 24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 }, - { 24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 }, - { 72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0 }, - { 268, 824, 5076, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 180, 420, 952, 1, 4, 4, 8, 1, 16, 1, 1, 1, 0, 1 }, - { 1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 }, - { 272, 400, 2116, 1, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 196, 512, 512, 1, 5, 2, 8, 1, 4, 1, 1, 1, 1, 1 }, + const GeMMConfigsMatrix configs_mnkb_squared_fallback = { + {24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, {24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0}, + {72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0}, {268, 824, 5076, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, + {180, 420, 952, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0}, {1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0}, + {272, 400, 2116, 1, 2, 8, 4, 1, 4, 1, 1, 1, 0, 0}, {196, 512, 512, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0}, }; - const GeMMConfigsMatrix configs_mnkb_squared_fallback = - { - { 24, 88, 236, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 }, - { 24, 88, 488, 1, 2, 2, 8, 1, 4, 1, 1, 1, 1, 0 }, - { 72, 92, 136, 1, 2, 2, 8, 1, 32, 1, 1, 1, 1, 0 }, - { 268, 824, 5076, 1, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, - { 180, 420, 952, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0 }, - { 1000, 152, 304, 1, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 }, - { 272, 400, 2116, 1, 2, 8, 4, 1, 4, 1, 1, 1, 0, 0 }, - { 196, 512, 512, 1, 5, 2, 8, 1, 8, 1, 1, 1, 1, 0 }, - }; + const GeMMConfigsMatrix configs_mnkb_best_batched = { + {3136, 64, 64, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 1}, {4096, 48, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 1}, {24, 464, 412, 24, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {112, 184, 144, 28, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, {5776, 64, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1}, + {1568, 64, 40, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}, {2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1}}; - const GeMMConfigsMatrix configs_mnkb_best_batched = - { - { 3136, 64, 64, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 1 }, - { 4096, 48, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 1 }, - { 24, 464, 412, 24, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 112, 184, 144, 28, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 5776, 64, 32, 36, 4, 4, 8, 1, 4, 1, 1, 1, 0, 1 }, - { 1568, 64, 40, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 }, - { 2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 1 } - }; - - const GeMMConfigsMatrix configs_mnkb_fallback_batched = - { - { 3136, 64, 64, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 }, - { 4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0 }, - { 688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0 }, - { 24, 464, 412, 24, 2, 8, 4, 1, 32, 1, 1, 1, 0, 0 }, - { 112, 184, 144, 28, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0 }, - { 5776, 64, 32, 36, 2, 8, 8, 1, 32, 1, 1, 1, 0, 0 }, - { 1568, 64, 40, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0 }, - { 2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0 } - }; + const GeMMConfigsMatrix configs_mnkb_fallback_batched = { + {3136, 64, 64, 36, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}, {4096, 48, 32, 36, 4, 4, 8, 1, 64, 1, 1, 1, 0, 0}, + {688, 92, 68, 32, 4, 8, 4, 1, 32, 1, 1, 1, 0, 0}, {24, 464, 412, 24, 2, 8, 4, 1, 32, 1, 1, 1, 0, 0}, + {112, 184, 144, 28, 4, 4, 8, 1, 8, 1, 1, 1, 0, 0}, {5776, 64, 32, 36, 2, 8, 8, 1, 32, 1, 1, 1, 0, 0}, + {1568, 64, 40, 36, 4, 8, 4, 1, 16, 1, 1, 1, 0, 0}, {2920, 64, 64, 24, 4, 8, 4, 1, 8, 1, 1, 1, 0, 0}}; const GeMMConfigsMatrix *configs_best_to_use = nullptr; const GeMMConfigsMatrix *configs_fallback_to_use = nullptr; - if(b == 1) + if (b == 1) { constexpr float ratio_m_gt_n = 10.f; constexpr float ratio_n_gt_m = 0.1f; constexpr unsigned int n_small_thr = 4; const float ratio = static_cast<float>(m) / static_cast<float>(n); - if(m == 1) + if (m == 1) { // We do not need fallback in this case, as we never use cl_image for the rhs tensor configs_best_to_use = &configs_1nkb_best; configs_fallback_to_use = &configs_1nkb_best; } - else if(n <= n_small_thr && ratio > ratio_m_gt_n) + else if (n <= n_small_thr && ratio > ratio_m_gt_n) { configs_best_to_use = &configs_mnkb_n_small_best; configs_fallback_to_use = &configs_mnkb_n_small_best; } - else if(ratio > ratio_m_gt_n) + else if (ratio > ratio_m_gt_n) { configs_best_to_use = &configs_mnkb_m_gt_n_best; configs_fallback_to_use = &configs_mnkb_m_gt_n_fallback; } - else if(ratio < ratio_n_gt_m) + else if (ratio < ratio_n_gt_m) { configs_best_to_use = &configs_mnkb_n_gt_m_best; configs_fallback_to_use = &configs_mnkb_n_gt_m_fallback; @@ -838,17 +730,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOn std::tie(lhs_info0, rhs_info0) = find_lhs_rhs_info(*configs_best_to_use, m, n, k, b); std::tie(lhs_info1, rhs_info1) = find_lhs_rhs_info(*configs_fallback_to_use, m, n, k, b); - return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), - std::make_pair(lhs_info1, rhs_info1), - n, k, b, DataType::F16); + return select_lhs_rhs_info(std::make_pair(lhs_info0, rhs_info0), std::make_pair(lhs_info1, rhs_info1), n, k, b, + DataType::F16); } -std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { unsigned int best_m0; unsigned int best_n0; - if(is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0)) + if (is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0)) { return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true); } diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h index f2952a3d30..a0ea337eb1 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h @@ -45,17 +45,26 @@ public: ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu); // Inherited overridden method - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) override; private: - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); - std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> + configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); }; } // namespace gemm } // namespace kernels diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h index 1503e74eb6..e07ad993ed 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmReshapedOnlyRhsKernelConfig.h @@ -50,7 +50,7 @@ public: */ static std::unique_ptr<IClGemmKernelConfig> create(GPUTarget gpu) { - switch(get_arch_from_target(gpu)) + switch (get_arch_from_target(gpu)) { case GPUTarget::MIDGARD: case GPUTarget::BIFROST: |