From afd38f0c617d6f89b2b4532c6c44f116617e2b6f Mon Sep 17 00:00:00 2001 From: Felix Thomasmathibalan Date: Wed, 27 Sep 2023 17:46:17 +0100 Subject: Apply clang-format on repository Code is formatted as per a revised clang format configuration file(not part of this delivery). Version 14.0.6 is used. Exclusion List: - files with .cl extension - files that are not strictly C/C++ (e.g. Android.bp, Sconscript ...) And the following directories - compute_kernel_writer/validation/ - tests/ - include/ - src/core/NEON/kernels/convolution/ - src/core/NEON/kernels/arm_gemm/ - src/core/NEON/kernels/arm_conv/ - data/ There will be a follow up for formatting of .cl files and the files under tests/ and compute_kernel_writer/validation/. Signed-off-by: Felix Thomasmathibalan Change-Id: Ib7eb1fcf4e7537b9feaefcfc15098a804a3fde0a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10391 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir --- .../ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp | 242 ++++++++++++--------- 1 file changed, 136 insertions(+), 106 deletions(-) (limited to 'src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp') diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp index 9c23d9c998..c4825bfbeb 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyBifrost.cpp @@ -29,7 +29,9 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" + #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" + #include namespace arm_compute @@ -47,33 +49,39 @@ ClGemmDefaultConfigReshapedRhsOnlyBifrost::ClGemmDefaultConfigReshapedRhsOnlyBif { } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure( + unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type) { - using ConfigurationFunctionExecutorPtr = std::pair (ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k, - unsigned int b); - - CLGEMMConfigArray configs_G51(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8); - - CLGEMMConfigArray configs_G52(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); - - CLGEMMConfigArray configs_G31(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8); - - CLGEMMConfigArray configs_G76(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8); - - CLGEMMConfigArray configs_G7x(&ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, - &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + using ConfigurationFunctionExecutorPtr = std::pair ( + ClGemmDefaultConfigReshapedRhsOnlyBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + + CLGEMMConfigArray configs_G51( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8); + + CLGEMMConfigArray configs_G52( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); + + CLGEMMConfigArray configs_G31( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8); + + CLGEMMConfigArray configs_G76( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8); + + CLGEMMConfigArray configs_G7x( + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16, + &ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8); ConfigurationFunctionExecutorPtr func = nullptr; - switch(_target) + switch (_target) { case GPUTarget::G76: func = configs_G76.get_function(data_type); @@ -96,14 +104,15 @@ std::pair ClGemmDefaultConfigReshapedRhsOn return (this->*func)(m, n, k, b); } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { - if(n <= 2548) + if (n <= 2548) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, true, false); } @@ -118,12 +127,13 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G31_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1); @@ -131,7 +141,7 @@ std::pair ClGemmDefaultConfigReshapedRhsOn else { const int h0 = std::max(std::min(static_cast(n / 4), static_cast(256)), static_cast(1)); - if(m >= 28) + if (m >= 28) { return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, 0, 1, 0, 1); } @@ -142,7 +152,8 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); @@ -154,9 +165,9 @@ std::pair ClGemmDefaultConfigReshapedRhsOn const bool is_workload_big = ((m * n * b) / 16) >= 2048; - if(m == 1) + if (m == 1) { - if(n >= 8192) + if (n >= 8192) { const unsigned int h0 = std::max(n / 4, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true, false); @@ -164,7 +175,7 @@ std::pair ClGemmDefaultConfigReshapedRhsOn else { const unsigned int h0 = std::max(n / 2, 1U); - if(n <= 204) + if (n <= 204) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true, false); } @@ -177,25 +188,29 @@ std::pair ClGemmDefaultConfigReshapedRhsOn else { const int h0 = std::max(std::min(static_cast(n / 4), static_cast(16)), static_cast(1)); - if(is_workload_big) + if (is_workload_big) { - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, true); } else { - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true); } } // Get lhs_info/rhs_info in case of OpenCL image const int h0 = std::max(std::min(static_cast(n / 4), static_cast(16)), static_cast(1)); - if(is_workload_big) + if (is_workload_big) { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, false, true, false, true, true); } const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32); @@ -205,7 +220,7 @@ std::pair ClGemmDefaultConfigReshapedRhsOn // In case of vector by matrix or small workloads, we use the OpenCL buffer rather than the OpenCL image2d const bool use_cl_image2d = ((m == 1) || ((((m * n * b) / 16) < 2048) && n < 128)) ? false : true; - if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) + if (bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) { return std::make_pair(lhs_info_img, rhs_info_img); } @@ -215,7 +230,8 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; const float r_nk = static_cast(n) / static_cast(k); @@ -225,46 +241,49 @@ std::pair ClGemmDefaultConfigReshapedRhsOn GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - if(m == 1) + if (m == 1) { - if(r_nk <= 0.4664f) + if (r_nk <= 0.4664f) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F32); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); } } else { - if(workload <= 274.4000f) + if (workload <= 274.4000f) { return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F32); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F32); } } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f32( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int n0 = n < 1280 ? 2 : 4; const unsigned int h0 = std::max(n / n0, 1U); @@ -276,14 +295,15 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { - if(n > 2048) + if (n > 2048) { const unsigned int h0 = std::max(n / 4, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true); @@ -300,7 +320,8 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G52_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { const float r_mn = static_cast(m) / static_cast(n); const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; @@ -312,57 +333,59 @@ std::pair ClGemmDefaultConfigReshapedRhsOn GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - if(m == 1) + if (m == 1) { - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, false); - if(r_mk <= 0.0026f) + if (r_mk <= 0.0026f) { - if(r_nk <= 0.4664f) + if (r_nk <= 0.4664f) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } } else { - if(r_mk <= 0.0148f) + if (r_mk <= 0.0148f) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, false, true); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } } } else { - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 8, 4, 1, 2, false, false, false, false, false); - if(workload <= 362.6000f) + if (workload <= 362.6000f) { return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); } else { - if(r_mn <= 22.6067f) + if (r_mn <= 22.6067f) { - if(workload <= 708.8000f) + if (workload <= 708.8000f) { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } else { @@ -371,27 +394,28 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } else { - if(r_nk <= 0.0917f) + if (r_nk <= 0.0917f) { return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 16, false, false, false, true, false); } else { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, false, false, false, true); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } } } } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); - if(m == 1) + if (m == 1) { return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 32, false, true, false, true, false); } @@ -400,15 +424,15 @@ std::pair ClGemmDefaultConfigReshapedRhsOn const float r_mn = static_cast(m) / static_cast(n); const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; - if(workload <= 7449.60f) + if (workload <= 7449.60f) { - if(workload <= 691.60f) + if (workload <= 691.60f) { return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false); } else { - if(workload <= 4155.20f) + if (workload <= 4155.20f) { return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); } @@ -420,21 +444,22 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } else { - if(workload <= 16300.80f) + if (workload <= 16300.80f) { - if(r_mn <= 44.56f) + if (r_mn <= 44.56f) { GEMMLHSMatrixInfo lhs_info_buf; GEMMRHSMatrixInfo rhs_info_buf; GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } else { @@ -448,23 +473,25 @@ std::pair ClGemmDefaultConfigReshapedRhsOn GEMMLHSMatrixInfo lhs_info_img; GEMMRHSMatrixInfo rhs_info_img; - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); + std::tie(lhs_info_img, rhs_info_img) = + configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = + configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false); return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), - std::make_pair(lhs_info_buf, rhs_info_buf), - n, k, b, DataType::F16); + std::make_pair(lhs_info_buf, rhs_info_buf), n, k, b, DataType::F16); } } } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_f16( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int n0 = n < 1280 ? 2 : 4; const unsigned int h0 = std::max(n / n0, 1U); @@ -476,14 +503,15 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G7x_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(dot8_supported(CLKernelLibrary::get().get_device())) + if (dot8_supported(CLKernelLibrary::get().get_device())) { - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); @@ -497,7 +525,7 @@ std::pair ClGemmDefaultConfigReshapedRhsOn else { const int h0 = std::max(std::min(static_cast(n / 2), static_cast(128)), static_cast(1)); - if(m == 1) + if (m == 1) { return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true); } @@ -508,12 +536,13 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G76_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true); @@ -524,12 +553,13 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } -std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +std::pair ClGemmDefaultConfigReshapedRhsOnlyBifrost::configure_G51_u8( + unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(m == 1) + if (m == 1) { const unsigned int h0 = std::max(n / 2, 1U); return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true); -- cgit v1.2.1