From 026d04576d3058e34f8f7e23f7a11514a04952dc Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 28 Aug 2020 13:52:12 +0100 Subject: COMPMID-3770: Add batch size in the OpenCL GEMM kernel selection Change-Id: Ia3030ea701e9ceb2ef567e0258e8f478e18b8b55 Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3871 Tested-by: Arm Jenkins Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins --- arm_compute/runtime/CL/CLTypes.h | 1 + arm_compute/runtime/CL/functions/CLGEMM.h | 2 +- .../runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h | 10 ++++---- .../runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h | 6 ++--- .../runtime/CL/gemm/CLGEMMKernelSelectionValhall.h | 6 ++--- src/runtime/CL/functions/CLGEMM.cpp | 9 ++++--- .../CL/gemm/CLGEMMKernelSelectionBifrost.cpp | 29 ++++++++++++++-------- .../CL/gemm/CLGEMMKernelSelectionMidgard.cpp | 16 ++++++------ .../CL/gemm/CLGEMMKernelSelectionValhall.cpp | 16 ++++++------ 9 files changed, 53 insertions(+), 42 deletions(-) diff --git a/arm_compute/runtime/CL/CLTypes.h b/arm_compute/runtime/CL/CLTypes.h index cbc525308f..19095a5589 100644 --- a/arm_compute/runtime/CL/CLTypes.h +++ b/arm_compute/runtime/CL/CLTypes.h @@ -53,6 +53,7 @@ struct CLGEMMKernelSelectionParams unsigned int m{ 0 }; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ unsigned int n{ 0 }; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ unsigned int k{ 0 }; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ + unsigned int b{ 0 }; /**< Batch size */ bool is_rhs_constant{ false }; /**< True if the content of the rhs matrix is constant */ DataType data_type{ DataType::UNKNOWN }; /**< Data type */ }; diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h index 8e4d3906d1..6e9cf0e2ca 100644 --- a/arm_compute/runtime/CL/functions/CLGEMM.h +++ b/arm_compute/runtime/CL/functions/CLGEMM.h @@ -185,7 +185,7 @@ public: void prepare() override; private: - static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run); + static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run); void configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); void configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info); diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h index 815c2c8cef..579bbe32ad 100644 --- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h +++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h @@ -44,11 +44,11 @@ public: CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams ¶ms) override; private: - CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h index 4689f0c041..5547731821 100644 --- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h +++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h @@ -44,9 +44,9 @@ public: CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams ¶ms) override; private: - CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h index 8712be7531..782ef7474d 100644 --- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h +++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h @@ -44,9 +44,9 @@ public: CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams ¶ms) override; private: - CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); - CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 4a74630036..d56b341abf 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -66,7 +66,7 @@ CLGEMM::CLGEMM(std::shared_ptr memory_manager, IWeightsManager * { } -CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run) +CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run) { std::unique_ptr gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target()); ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); @@ -75,6 +75,7 @@ CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsi params.m = m; params.n = n; params.k = k; + params.b = b; params.is_rhs_constant = reshape_b_only_on_first_run; params.data_type = data_type; @@ -516,9 +517,10 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1); const unsigned int n = b->info()->dimension(0); const unsigned int k = a->info()->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2); // Select GEMMType - _gemm_kernel_type = select_gemm_kernel(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run); + _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); @@ -560,9 +562,10 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); const unsigned int n = b->dimension(0); const unsigned int k = a->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); // Select GEMMType - CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run()); + CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp index 8b1c9a5622..48f38cac55 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.cpp @@ -44,7 +44,7 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS // _target could be used in the future to have a dedicated heuristic for each GPU IP ARM_COMPUTE_UNUSED(_target); - using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionBifrost::*)(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); // Default configurations for Bifrost architectures static std::map gemm_default_configs = @@ -86,26 +86,28 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::select_kernel(const CLGEMMKernelS case GPUTarget::G71: if(gemm_g71_configs.find(data_type) != gemm_g71_configs.end()) { - return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); case GPUTarget::G76: if(gemm_g76_configs.find(data_type) != gemm_g76_configs.end()) { - return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); default: if(gemm_default_configs.find(data_type) != gemm_default_configs.end()) { - return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); } } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { + ARM_COMPUTE_UNUSED(b); + CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE_V1; if(is_rhs_constant) @@ -143,9 +145,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f32(unsigned int m, unsig return gemm_type; } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(n, k); + ARM_COMPUTE_UNUSED(n, k, b); + if(is_rhs_constant) { if(m == 1) @@ -163,9 +166,9 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_f16(unsigned int m, unsig } } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); if(is_rhs_constant) { @@ -177,8 +180,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::default_q8(unsigned int m, unsign } } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { + ARM_COMPUTE_UNUSED(b); + CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE_V1; if(is_rhs_constant) @@ -207,8 +212,10 @@ CLGEMMKernelType CLGEMMKernelSelectionBifrost::g76_f32(unsigned int m, unsigned return gemm_type; } -CLGEMMKernelType CLGEMMKernelSelectionBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { + ARM_COMPUTE_UNUSED(b); + if(is_rhs_constant) { if(m == 1) diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp index 44700ad4f4..324c2f7dca 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.cpp @@ -45,7 +45,7 @@ CLGEMMKernelType CLGEMMKernelSelectionMidgard::select_kernel(const CLGEMMKernelS // _target could be used in the future to have a dedicated heuristic for each GPU IP ARM_COMPUTE_UNUSED(_target); - using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionMidgard::*)(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionMidgard::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); // Configurations for Midgard architectures static std::map gemm_configs = @@ -62,31 +62,31 @@ CLGEMMKernelType CLGEMMKernelSelectionMidgard::select_kernel(const CLGEMMKernelS if(gemm_configs.find(data_type) != gemm_configs.end()) { - return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); } -CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(n, k); + ARM_COMPUTE_UNUSED(n, k, b); // We reshape the matrices only if we do not have the vector-by-matrix case and we reshape the matrix B only once return ((m != 1) && is_rhs_constant) ? CLGEMMKernelType::RESHAPED_V1 : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(n, k); + ARM_COMPUTE_UNUSED(n, k, b); // We reshape the matrices only if we do not have the vector-by-matrix case and we reshape the matrix B only once return ((m != 1) && is_rhs_constant) ? CLGEMMKernelType::RESHAPED_V1 : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionMidgard::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k, is_rhs_constant); + ARM_COMPUTE_UNUSED(m, n, k, b, is_rhs_constant); return CLGEMMKernelType::NATIVE; } diff --git a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp index 8b4c9e75e8..c50c7ae76b 100644 --- a/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMKernelSelectionValhall.cpp @@ -44,7 +44,7 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::select_kernel(const CLGEMMKernelS // _target could be used in the future to have a dedicated heuristic for each GPU IP ARM_COMPUTE_UNUSED(_target); - using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionValhall::*)(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant); + using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMKernelSelectionValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); // Configurations for Valhall architectures static std::map gemm_configs = @@ -61,29 +61,29 @@ CLGEMMKernelType CLGEMMKernelSelectionValhall::select_kernel(const CLGEMMKernelS if(gemm_configs.find(data_type) != gemm_configs.end()) { - return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.is_rhs_constant); + return (this->*gemm_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); } ARM_COMPUTE_ERROR("Not supported data type"); } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE_V1; } -CLGEMMKernelType CLGEMMKernelSelectionValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant) +CLGEMMKernelType CLGEMMKernelSelectionValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) { - ARM_COMPUTE_UNUSED(m, n, k); + ARM_COMPUTE_UNUSED(m, n, k, b); if(is_rhs_constant) { -- cgit v1.2.1