diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2020-08-28 13:52:12 +0100 |
---|---|---|
committer | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2020-08-28 16:28:34 +0000 |
commit | 026d04576d3058e34f8f7e23f7a11514a04952dc (patch) | |
tree | cc743de9348479a754b3eee0bb4f71444bb5ed1e /src/runtime/CL/gemm | |
parent | 1c76c1ddcd1294ee8149bd74ecf6f62963408286 (diff) | |
download | ComputeLibrary-026d04576d3058e34f8f7e23f7a11514a04952dc.tar.gz |
COMPMID-3770: Add batch size in the OpenCL GEMM kernel selection
Change-Id: Ia3030ea701e9ceb2ef567e0258e8f478e18b8b55
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3871
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/gemm')
-rw-r--r-- | src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp | 29 | ||||
-rw-r--r-- | src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp | 16 | ||||
-rw-r--r-- | src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp | 16 |
3 files changed, 34 insertions, 27 deletions
diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp index 8b1c9a5622..48f38cac55 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp @@ -44,7 +44,7 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS // _target could be used in the future to have a dedicated heuristic for each GPU IP ARM_COMPUTE_UNUSED(_target); - using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionBifrost::*)(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); // Default configurations for Bifrost architectures static std::map<DataType, FunctionExecutorPtr> gemm_default_configs = @@ -86,26 +86,28 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS case GPUTarget::G71: if(gemm_g71_configs.find(data_type) != gemm_g71_configs.end()) { - return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); case GPUTarget::G76: if(gemm_g76_configs.find(data_type) != gemm_g76_configs.end()) { - return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); default: if(gemm_default_configs.find(data_type) != gemm_default_configs.end()) { - return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); } } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { + ARM_COMPUTE_UNUSED(b); + CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE_V1; if(is_rhs_constant) @@ -143,9 +145,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsig return gemm_type; } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(n, k); + ARM_COMPUTE_UNUSED(n, k, b); + if(is_rhs_constant) { if(m == 1) @@ -163,9 +166,9 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsig } } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); if(is_rhs_constant) { @@ -177,8 +180,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsign } } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { + ARM_COMPUTE_UNUSED(b); + CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE_V1; if(is_rhs_constant) @@ -207,8 +212,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned return gemm_type; } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { + ARM_COMPUTE_UNUSED(b); + if(is_rhs_constant) { if(m == 1) diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp index 44700ad4f4..324c2f7dca 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp @@ -45,7 +45,7 @@ CLGEMMKernelType CLGEMMKernelSelectionMidgard::select_kernel(const CLGEMMKernelS // _target could be used in the future to have a dedicated heuristic for each GPU IP ARM_COMPUTE_UNUSED(_target); - using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionMidgard::*)(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionMidgard::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); // Configurations for Midgard architectures static std::map<DataType, FunctionExecutorPtr> gemm_configs = @@ -62,31 +62,31 @@ CLGEMMKernelType CLGEMMKernelSelectionMidgard::select_kernel(const CLGEMMKernelS if(gemm_configs.find(data_type) != gemm_configs.end()) { - return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); } -CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(n, k); + ARM_COMPUTE_UNUSED(n, k, b); // We reshape the matrices only if we do not have the vector-by-matrix case and we reshape the matrix B only once return ((m != 1) && is_rhs_constant) ? CLGEMMKernelType::RESHAPED_V1 : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(n, k); + ARM_COMPUTE_UNUSED(n, k, b); // We reshape the matrices only if we do not have the vector-by-matrix case and we reshape the matrix B only once return ((m != 1) && is_rhs_constant) ? CLGEMMKernelType::RESHAPED_V1 : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k, is_rhs_constant); + ARM_COMPUTE_UNUSED(m, n, k, b, is_rhs_constant); return CLGEMMKernelType::NATIVE; } diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp index 8b4c9e75e8..c50c7ae76b 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp @@ -44,7 +44,7 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::select_kernel(const CLGEMMKernelS // _target could be used in the future to have a dedicated heuristic for each GPU IP ARM_COMPUTE_UNUSED(_target); - using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionValhall::*)(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); // Configurations for Valhall architectures static std::map<DataType, FunctionExecutorPtr> gemm_configs = @@ -61,29 +61,29 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::select_kernel(const CLGEMMKernelS if(gemm_configs.find(data_type) != gemm_configs.end()) { - return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); if(is_rhs_constant) { |