aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMM.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2020-08-28 13:52:12 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2020-08-28 16:28:34 +0000
commit026d04576d3058e34f8f7e23f7a11514a04952dc (patch)
treecc743de9348479a754b3eee0bb4f71444bb5ed1e /src/runtime/CL/functions/CLGEMM.cpp
parent1c76c1ddcd1294ee8149bd74ecf6f62963408286 (diff)
downloadComputeLibrary-026d04576d3058e34f8f7e23f7a11514a04952dc.tar.gz
COMPMID-3770: Add batch size in the OpenCL GEMM kernel selection
Change-Id: Ia3030ea701e9ceb2ef567e0258e8f478e18b8b55 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3871 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMM.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 4a74630036..d56b341abf 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -66,7 +66,7 @@ CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *
{
}
-CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run)
+CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run)
{
std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target());
ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
@@ -75,6 +75,7 @@ CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsi
params.m = m;
params.n = n;
params.k = k;
+ params.b = b;
params.is_rhs_constant = reshape_b_only_on_first_run;
params.data_type = data_type;
@@ -516,9 +517,10 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor
const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
const unsigned int n = b->info()->dimension(0);
const unsigned int k = a->info()->dimension(0);
+ const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
// Select GEMMType
- _gemm_kernel_type = select_gemm_kernel(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run);
+ _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run);
const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
@@ -560,9 +562,10 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
const unsigned int n = b->dimension(0);
const unsigned int k = a->dimension(0);
+ const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
// Select GEMMType
- CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run());
+ CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run());
const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);