From 491f30c0fff416007d97f4a5a043923861ef7b64 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 2 Nov 2020 15:43:57 +0000 Subject: COMPMID-3939: Update GEMM heuristic Mali-G77 - Update heuristic for GEMM reshaped RHS only - Fix left-over block size in CLGEMMMatrixMultiplyReshapedOlyRHSKernel Change-Id: I34c738821ed2e4a537da4a15058eec164cb6b61f Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4305 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../CLGEMMReshapedKernelConfigurationValhall.cpp | 81 +++++++++++- ...MMReshapedOnlyRHSKernelConfigurationValhall.cpp | 129 ++++++++++++------- .../CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp | 8 +- .../CL/gemm/CLGEMMKernelSelectionValhall.cpp | 136 ++++++++++++++++++++- src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h | 1 + 5 files changed, 298 insertions(+), 57 deletions(-) diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp index 519e903a5a..3f82dcab00 100644 --- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationValhall.cpp @@ -90,13 +90,88 @@ std::pair CLGEMMReshapedKernelConfiguratio ARM_COMPUTE_UNUSED(k); ARM_COMPUTE_UNUSED(b); - if(n <= 4) + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + + if(r_mk <= 0.11824845522642136) { - return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false); + if(workload <= 880.0) + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, false, false, true, false, false); + } + else + { + if(r_nk <= 0.42521367967128754) + { + if(workload <= 1726.4000244140625) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, false, false, true, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true); + } + } + else + { + if(workload <= 1241.6000366210938) + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, false, false, true, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, false, false, true, false, false); + } + } + } } else { - return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false); + if(workload <= 11404.7998046875) + { + if(r_mk <= 1.0126488208770752) + { + if(r_mn <= 2.545312523841858) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, false, false, true, false, false); + } + } + else + { + if(workload <= 2881.199951171875) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, false, false, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true); + } + } + } + else + { + if(r_nk <= 0.5765306055545807) + { + if(r_mn <= 6.010416746139526) + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, false, true, true, false, true); + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, false, true, false, true); + } + } + else + { + return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, false, true, false, true); + } + } } } diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp index f7939d29c0..e0991674b1 100644 --- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp +++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationValhall.cpp @@ -78,66 +78,107 @@ std::pair CLGEMMReshapedOnlyRHSKernelConfi std::pair CLGEMMReshapedOnlyRHSKernelConfigurationValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) { - ARM_COMPUTE_UNUSED(k); - - GEMMLHSMatrixInfo lhs_info_buf; - GEMMRHSMatrixInfo rhs_info_buf; - GEMMLHSMatrixInfo lhs_info_img; - GEMMRHSMatrixInfo rhs_info_img; - - // Get lhs_info/rhs_info in case of OpenCL buffer if(m == 1) { - const unsigned int h0 = std::max(n / 4, 1U); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true); - } - else - { - if(m > 256) + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + + if(r_mk <= 0.0064484127797186375) { - const int v0 = std::max(std::min(static_cast(n / 4), static_cast(8)), static_cast(1)); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, v0, false, true, false, true); + if(r_mn <= 0.0028273810748942196) + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + + const unsigned int h0 = std::max(n / 4, 1U); + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, false, true, false, false, false); + } } else { - const int v0 = std::max(std::min(static_cast(n / 4), static_cast(8)), static_cast(1)); - std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, v0, false, true, false, true); + if(r_mk <= 0.020312500186264515) + { + return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, false, true, false, false, false); + } + else + { + return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, false, true, false, true, false); + } } } - - // Get lhs_info/rhs_info in case of OpenCL image - if(m == 1) - { - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 8, true, true, false, false, true); - } else { - if((m / 4) * (n / 4) > 4096) + const float r_mn = static_cast(m) / static_cast(n); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + const float r_mk = static_cast(m) / static_cast(k); + + if(workload <= 1999.2000122070312) { - const int h0 = std::max(std::min(static_cast(n / 4), static_cast(8)), static_cast(1)); - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, h0, false, true, false, false, true); + if(workload <= 747.1999816894531) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false); + + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } } else { - const int h0 = std::max(std::min(static_cast(n / 4), static_cast(8)), static_cast(1)); - std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, h0, false, true, false, false, true); - } - } - - const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32); - const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img); - const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32); + if(r_mn <= 0.03348214365541935) + { + if(r_mk <= 0.028125000186264515) + { + return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false); + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, false, false, false, true, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, false, true, false, true, false); - // In case of small workloads, we use the OpenCL buffer rather than the OpenCL image2d - const bool use_cl_image2d = ((m / lhs_info_img.m0) * (n / rhs_info_img.n0)) * b < 1024 ? false : true; + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } + else + { + GEMMLHSMatrixInfo lhs_info_buf; + GEMMRHSMatrixInfo rhs_info_buf; + GEMMLHSMatrixInfo lhs_info_img; + GEMMRHSMatrixInfo rhs_info_img; + std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, false, true); + std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, false, true, false, true, false); - if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d) - { - return std::make_pair(lhs_info_img, rhs_info_img); - } - else - { - return std::make_pair(lhs_info_buf, rhs_info_buf); + return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img), + std::make_pair(lhs_info_buf, rhs_info_buf), + n, k, b, DataType::F32); + } + } } } diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp index 68f761b9e7..d53aede3c8 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -247,14 +247,14 @@ void CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::configure(const CLCompileContext const unsigned int h_gemm_3d = _reinterpret_output_as_3d ? output->info()->dimension(1) : input0->info()->dimension(1); const unsigned int d_gemm_3d = _reinterpret_output_as_3d ? output->info()->dimension(2) : input0->info()->dimension(2); - // Calculate partial (store instead of load) M0 and partial N0 for the partial blocks at the end of a row/column if any. This is to avoid padding. - const unsigned int partial_store_m0 = internal_m % lhs_info.m0; - const unsigned int partial_store_n0 = gemm_info.n % rhs_info.n0; - // Shrink M0 to be always <= M (internal_m) to prevent out-of-bounds reads. // NOTE: This might have implications on heuristics and performance const unsigned int internal_m0 = std::min(internal_m, lhs_info.m0); + // Calculate partial (store instead of load) M0 and partial N0 for the partial blocks at the end of a row/column if any. This is to avoid padding. + const unsigned int partial_store_m0 = internal_m % internal_m0; + const unsigned int partial_store_n0 = gemm_info.n % rhs_info.n0; + // Create build options CLBuildOptions build_opts; build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())); diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp index acae0e7565..da41859b87 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp @@ -46,8 +46,8 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::select_kernel(const CLGEMMKernelS using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); - // Configurations for Valhall architectures - static std::map gemm_configs = + // Default configurations for Valhall architectures + static std::map gemm_default_configs = { { DataType::F32, &CLGEMMKernelSelectionValhall::default_f32 }, { DataType::F16, &CLGEMMKernelSelectionValhall::default_f16 }, @@ -57,14 +57,34 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::select_kernel(const CLGEMMKernelS { DataType::QSYMM8_PER_CHANNEL, &CLGEMMKernelSelectionValhall::default_q8 } }; + // Mali-G77 configurations + static std::map gemm_g77_configs = + { + { DataType::F32, &CLGEMMKernelSelectionValhall::default_f32 }, + { DataType::F16, &CLGEMMKernelSelectionValhall::g77_f16 }, + { DataType::QASYMM8, &CLGEMMKernelSelectionValhall::default_q8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMKernelSelectionValhall::default_q8 }, + { DataType::QSYMM8, &CLGEMMKernelSelectionValhall::default_q8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMKernelSelectionValhall::default_q8 } + }; + const DataType data_type = params.data_type; - if(gemm_configs.find(data_type) != gemm_configs.end()) + switch(_target) { - return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + case GPUTarget::G77: + if(gemm_g77_configs.find(data_type) != gemm_g77_configs.end()) + { + return (this->*gemm_g77_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + } + ARM_COMPUTE_ERROR("Not supported data type"); + default: + if(gemm_default_configs.find(data_type) != gemm_default_configs.end()) + { + return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + } + ARM_COMPUTE_ERROR("Not supported data type"); } - - ARM_COMPUTE_ERROR("Not supported data type"); } CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) @@ -81,6 +101,110 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f16(unsigned int m, unsig return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE_V1; } +CLGEMMKernelType CLGEMMKernelSelectionValhall::g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +{ + if (!is_rhs_constant) + { + return CLGEMMKernelType::NATIVE_V1; + } + + if (m == 1) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + + const float r_mn = static_cast(m) / static_cast(n); + const float r_mk = static_cast(m) / static_cast(k); + const float r_nk = static_cast(n) / static_cast(k); + const float workload = (static_cast(m) * static_cast(n) * static_cast(b)) / 20.0f; + + if(r_mk <= 0.6817956566810608) + { + if(workload <= 801.6000061035156) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + if(r_mn <= 0.0839829258620739) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + if(r_mk <= 0.24917218834161758) + { + return CLGEMMKernelType::RESHAPED; + } + else + { + if(workload <= 2551.75) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + if(workload <= 5061.574951171875) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + } + } + } + } + else + { + if(r_mk <= 4.849947690963745) + { + if(workload <= 17618.4501953125) + { + if(workload <= 5224.699951171875) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + if(r_nk <= 0.7933054566383362) + { + return CLGEMMKernelType::RESHAPED; + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } + } + else + { + if(workload <= 20275.2001953125) + { + return CLGEMMKernelType::RESHAPED; + } + else + { + if(r_mk <= 3.07421875) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + else + { + return CLGEMMKernelType::RESHAPED; + } + } + } + } + else + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS; + } + } +} + CLGEMMKernelType CLGEMMKernelSelectionValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { ARM_COMPUTE_UNUSED(m, n, k, b); diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h index cbea9ea548..82e46f694e 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h @@ -47,6 +47,7 @@ private: CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute -- cgit v1.2.1