aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp')
-rw-r--r--src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp43
1 files changed, 43 insertions, 0 deletions
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
index 483bab832f..9f3fc3aae7 100644
--- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
@@ -48,6 +48,13 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
unsigned int b);
+ // Configurations for Mali-G51
+ static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G51 =
+ {
+ { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32 },
+ { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }
+ };
+
// Configurations for Mali-G76
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
{
@@ -66,6 +73,8 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
{
case GPUTarget::G76:
return (this->*gemm_configs_G76[data_type])(m, n, k, b);
+ case GPUTarget::G51:
+ return (this->*gemm_configs_G51[data_type])(m, n, k, b);
default:
return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
}
@@ -111,6 +120,23 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(m == 1)
+ {
+ const unsigned int n0 = n < 1280? 2 : 4;
+ const unsigned int h0 = std::max(n / n0, static_cast<unsigned int>(1));
+ return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
@@ -159,5 +185,22 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfi
return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, 2, false, true, false, true);
}
}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(m == 1)
+ {
+ const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+ return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
+ }
+ else
+ {
+ const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+ return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
+ }
+}
} // namespace cl_gemm
} // namespace arm_compute \ No newline at end of file