aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp')
-rw-r--r--src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp35
1 files changed, 33 insertions, 2 deletions
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index b791c1cda5..0c2942a184 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -42,8 +42,7 @@ CLGEMMReshapedKernelConfigurationBifrost::CLGEMMReshapedKernelConfigurationBifro
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
{
- ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
- ARM_COMPUTE_UNUSED(data_type);
+ ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8);
using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
@@ -51,6 +50,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
{
{ DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 },
+ { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 },
{ DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }
};
@@ -58,6 +58,7 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
{
{ DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 },
+ { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 },
{ DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
};
@@ -85,6 +86,21 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(n <= 4)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);
@@ -129,6 +145,21 @@ std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfiguratio
}
}
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+ ARM_COMPUTE_UNUSED(k);
+ ARM_COMPUTE_UNUSED(b);
+
+ if(n <= 4)
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 8, 2, true, true, true, false);
+ }
+ else
+ {
+ return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 8, true, true, true, false);
+ }
+}
+
std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
{
ARM_COMPUTE_UNUSED(k);