From eaca67a249b5338fb286c8e7c24253c5fc8ca7ac Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 10 Nov 2020 10:41:37 +0000 Subject: COMPMID-3959: Update Mali-G52 heuristic for CLGEMM - F32 - Add heuristic in CLGEMMKernelSelection - Add heuristic in CLGEMMReshapedRHSOnly - Add heuristic in CLGEMMReshaped Change-Id: Ibaa13398f7a5976418a0ab1b6696ace09cc480fa Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4366 Tested-by: Arm Jenkins Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins --- .../CLGEMMReshapedKernelConfigurationBifrost.cpp | 110 ++++++++++++++++ .../CLGEMMReshapedKernelConfigurationBifrost.h | 1 + ...MMReshapedOnlyRHSKernelConfigurationBifrost.cpp | 64 +++++++++ ...GEMMReshapedOnlyRHSKernelConfigurationBifrost.h | 1 + .../CL/gemm/CLGEMMKernelSelectionBifrost.cpp | 144 +++++++++++++++++++++ src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h | 1 + 6 files changed, 321 insertions(+) diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp index c1ca187a70..70992974a3 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp @@ -60,6 +60,17 @@ std::pair CLGEMMReshapedKernelConfiguratio { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 } }; + // Configurations for Mali-G52 + static std::map gemm_configs_G52 = + { + { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f32 }, + { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 }, + { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 } + }; + // Configurations for Mali-G7x static std::map gemm_configs_G7x = { @@ -153,6 +164,105 @@ std::pair CLGEMMReshapedKernelConfiguratio } } +std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if(workload <= 274.4000f) + { + if(r_nk <= 0.7461f) + { + if(r_mn <= 21.1667f) + { + return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(r_mk <= 17.3926f) + { + if(workload <= 542.4000f) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(r_nk <= 0.5463f) + { + if(workload <= 11767.6001f) + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + } +} + std::pair CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h index e3b62ced6a..6c67b70962 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h @@ -45,6 +45,7 @@ public: private: std::pair configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp index 3105db6693..188ba4d9c5 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp @@ -61,6 +61,17 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 } }; + // Configurations for Mali-G52 + static std::map gemm_configs_G52 = + { + { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32 }, + { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 }, + { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 } + }; + // Configurations for Mali-G76 static std::map gemm_configs_G76 = { @@ -94,6 +105,15 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi { ARM_COMPUTE_ERROR("Not supported data type"); } + case GPUTarget::G52: + if(gemm_configs_G52.find(data_type) != gemm_configs_G52.end()) + { + return (this->*gemm_configs_G52[data_type])(m, n, k, b); + } + else + { + ARM_COMPUTE_ERROR("Not supported data type"); + } case GPUTarget::G51: if(gemm_configs_G51.find(data_type) != gemm_configs_G51.end()) { @@ -201,6 +221,50 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi } } +std::pair CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_nk = static_cast(n) / static_cast(k); + + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + if(m == 1) + { + if(r_nk <= 0.4664f) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + if(workload <= 274.4000f) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false); + } + else + { + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } +} + std::pair CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { ARM_COMPUTE_UNUSED(k); diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h index 618dbd9923..3dfd96a822 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.h @@ -46,6 +46,7 @@ public: private: std::pair configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp index 7c6efe3f11..c77746a044 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp @@ -68,6 +68,17 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS { DataType::QSYMM8_PER_CHANNEL, &CLGEMMKernelSelectionBifrost::default_q8 } }; + // Mali-G52 configurations + static std::map gemm_g52_configs = + { + { DataType::F32, &CLGEMMKernelSelectionBifrost::g52_f32 }, + { DataType::F16, &CLGEMMKernelSelectionBifrost::default_f16 }, + { DataType::QASYMM8, &CLGEMMKernelSelectionBifrost::default_q8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMKernelSelectionBifrost::default_q8 }, + { DataType::QSYMM8, &CLGEMMKernelSelectionBifrost::default_q8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMKernelSelectionBifrost::default_q8 } + }; + // Mali-G76 configurations static std::map gemm_g76_configs = { @@ -95,6 +106,12 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); + case GPUTarget::G52: + if(gemm_g52_configs.find(data_type) != gemm_g52_configs.end()) + { + return (this->*gemm_g52_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + } + ARM_COMPUTE_ERROR("Not supported data type"); default: if(gemm_default_configs.find(data_type) != gemm_default_configs.end()) { @@ -237,6 +254,133 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned } } +CLGEMMKernelType CLGEMMKernelSelectionBifrost::g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +{ + ARM_COMPUTE_UNUSED(b); + + if (!is_rhs_constant) + { + return CLGEMMKernelType::NATIVE_V1; + } + + if (m == 1) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + const float r_mnk = static_cast(m) / (static_cast(n) * static_cast(k)); + + if(r_mn <= 1.5469f) + { + if(r_mk <= 0.8766f) + { + if(r_mk <= 0.0211f) + { + if(r_mnk <= 77.5833f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + else + { + if(r_nk <= 0.0832f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + } + else + { + if(r_mnk <= 193.0000f) + { + if(r_mn <= 0.9948f) + { + if(r_mk <= 2.5453f) + { + return CLGEMMKernelType::RESHAPED; + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + } + else + { + if(r_mn <= 17.7370f) + { + if(r_mnk <= 1391.2875f) + { + if(r_mk <= 2.9724f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + if(r_mnk <= 470.0000f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + } + else + { + if(r_nk <= 0.1381f) + { + if(r_mnk <= 9040.5000f) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + else + { + if(r_mn <= 5.6790f) + { + return CLGEMMKernelType::RESHAPED; + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } + } + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } +} + CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(b); diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h index e3cc8e4a27..fbafc531f5 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h @@ -44,6 +44,7 @@ public: CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams ¶ms) override; private: + CLGEMMKernelType g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); -- cgit v1.2.1