diff options
Diffstat (limited to 'src/runtime/CL/functions/CLGEMM.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 4a74630036..d56b341abf 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -66,7 +66,7 @@ CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager * { } -CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run) +CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run) { std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target()); ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get()); @@ -75,6 +75,7 @@ CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsi params.m = m; params.n = n; params.k = k; + params.b = b; params.is_rhs_constant = reshape_b_only_on_first_run; params.data_type = data_type; @@ -516,9 +517,10 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1); const unsigned int n = b->info()->dimension(0); const unsigned int k = a->info()->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2); // Select GEMMType - _gemm_kernel_type = select_gemm_kernel(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run); + _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); @@ -560,9 +562,10 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); const unsigned int n = b->dimension(0); const unsigned int k = a->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); // Select GEMMType - CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run()); + CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); |