From 026d04576d3058e34f8f7e23f7a11514a04952dc Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 28 Aug 2020 13:52:12 +0100 Subject: COMPMID-3770: Add batch size in the OpenCL GEMM kernel selection Change-Id: Ia3030ea701e9ceb2ef567e0258e8f478e18b8b55 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3871 Tested-by: Arm Jenkins Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins --- src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp') diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp index 8b4c9e75e8..c50c7ae76b 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp @@ -44,7 +44,7 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::select_kernel(const CLGEMMKernelS // _target could be used in the future to have a dedicated heuristic for each GPU IP ARM_COMPUTE_UNUSED(_target); - using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionValhall::*)(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); // Configurations for Valhall architectures static std::map gemm_configs = @@ -61,29 +61,29 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::select_kernel(const CLGEMMKernelS if(gemm_configs.find(data_type) != gemm_configs.end()) { - return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); if(is_rhs_constant) { -- cgit v1.2.1