aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp')
-rw-r--r--src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp64
1 files changed, 64 insertions, 0 deletions
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
index 3105db6693..188ba4d9c5 100644
--- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
@@ -61,6 +61,17 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
{ DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }
};
+ // Configurations for Mali-G52
+ static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G52 =
+ {
+ { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32 },
+ { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 },
+ { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
+ { DataType::QSYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
+ { DataType::QASYMM8_SIGNED, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 },
+ { DataType::QSYMM8_PER_CHANNEL, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }
+ };
+
// Configurations for Mali-G76
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
{
@@ -94,6 +105,15 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
{
ARM_COMPUTE_ERROR("Not supported data type");
}
+ case GPUTarget::G52:
+ if(gemm_configs_G52.find(data_type) != gemm_configs_G52.end())
+ {
+ return (this->*gemm_configs_G52[data_type])(m, n, k, b);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Not supported data type");
+ }
case GPUTarget::G51:
if(gemm_configs_G51.find(data_type) != gemm_configs_G51.end())
{
@@ -201,6 +221,50 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+ const float r_nk = static_cast<float>(n) / static_cast<float>(k);
+
+ GEMMLHSMatrixInfo lhs_info_buf;
+ GEMMRHSMatrixInfo rhs_info_buf;
+ GEMMLHSMatrixInfo lhs_info_img;
+ GEMMRHSMatrixInfo rhs_info_img;
+
+ if(m == 1)
+ {
+ if(r_nk <= 0.4664f)
+ {
+ return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
+ }
+ else
+ {
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, false, true, false, true, false);
+
+ return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+ std::make_pair(lhs_info_buf, rhs_info_buf),
+ n, k, b, DataType::F32);
+ }
+ }
+ else
+ {
+ if(workload <= 274.4000f)
+ {
+ return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 16, false, false, false, true, false);
+ }
+ else
+ {
+ std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, true);
+ std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, false, false, true, false);
+
+ return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+ std::make_pair(lhs_info_buf, rhs_info_buf),
+ n, k, b, DataType::F32);
+ }
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);