aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp')
-rw-r--r--src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp29
1 files changed, 18 insertions, 11 deletions
diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp
index 8b1c9a5622..48f38cac55 100644
--- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp
+++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp
@@ -44,7 +44,7 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS
// _target could be used in the future to have a dedicated heuristic for each GPU IP
ARM_COMPUTE_UNUSED(_target);
- using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionBifrost::*)(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
+ using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
// Default configurations for Bifrost architectures
static std::map<DataType, FunctionExecutorPtr> gemm_default_configs =
@@ -86,26 +86,28 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS
case GPUTarget::G71:
if(gemm_g71_configs.find(data_type) != gemm_g71_configs.end())
{
- return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant);
+ return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
}
ARM_COMPUTE_ERROR("Not supported data type");
case GPUTarget::G76:
if(gemm_g76_configs.find(data_type) != gemm_g76_configs.end())
{
- return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant);
+ return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
}
ARM_COMPUTE_ERROR("Not supported data type");
default:
if(gemm_default_configs.find(data_type) != gemm_default_configs.end())
{
- return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant);
+ return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
}
ARM_COMPUTE_ERROR("Not supported data type");
}
}
-CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant)
+CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
{
+ ARM_COMPUTE_UNUSED(b);
+
CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE_V1;
if(is_rhs_constant)
@@ -143,9 +145,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsig
return gemm_type;
}
-CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant)
+CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
{
- ARM_COMPUTE_UNUSED(n, k);
+ ARM_COMPUTE_UNUSED(n, k, b);
+
if(is_rhs_constant)
{
if(m == 1)
@@ -163,9 +166,9 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsig
}
}
-CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant)
+CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
{
- ARM_COMPUTE_UNUSED(m, n, k);
+ ARM_COMPUTE_UNUSED(m, n, k, b);
if(is_rhs_constant)
{
@@ -177,8 +180,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsign
}
}
-CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant)
+CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
{
+ ARM_COMPUTE_UNUSED(b);
+
CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE_V1;
if(is_rhs_constant)
@@ -207,8 +212,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned
return gemm_type;
}
-CLGEMMKernelType CLGEMMKernelSelectionBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant)
+CLGEMMKernelType CLGEMMKernelSelectionBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
{
+ ARM_COMPUTE_UNUSED(b);
+
if(is_rhs_constant)
{
if(m == 1)