diff options
Diffstat (limited to 'src/runtime/CL/functions')
-rw-r--r-- | src/runtime/CL/functions/CLGEMM.cpp | 52 |
1 files changed, 39 insertions, 13 deletions
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index ccae6713a6..80c5496ede 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -60,11 +60,15 @@ CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager * _reshape_rhs_kernel_managed(), _mm_reshaped_kernel(), _mm_reshaped_only_rhs_kernel(), + _mm_reshaped_only_rhs_fallback_kernel(), _tmp_a(), _tmp_b(), _original_b(nullptr), + _lhs(nullptr), + _dst(nullptr), _reshape_b_only_on_first_run(false), _is_prepared(false), + _has_pad_y(false), _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1) { } @@ -295,16 +299,8 @@ void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target); ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get()); - unsigned int m_internal = m; - unsigned int b_internal = batch_size; - if(reinterpret_input_as_3d) - { - m_internal = a->info()->dimension(1); - b_internal = a->info()->dimension(2); - } - // Configure lhs_info and rhs_info - std::tie(lhs_info, rhs_info) = gemm_config->configure(m_internal, n, k, b_internal, data_type); + std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); ICLTensor *reshaped_rhs = &_tmp_b; if(_weights_manager && _weights_manager->are_weights_managed(b)) @@ -317,9 +313,18 @@ void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info); } - // Configure and tune matrix multiply kernel + // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true) + // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have + // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false + + // Configure matrix multiply kernel with no y padding support + kernel_info.has_pad_y = false; _mm_reshaped_only_rhs_kernel.configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); + // Configure matrix multiply kernel with y padding support + kernel_info.has_pad_y = true; + _mm_reshaped_only_rhs_fallback_kernel.configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); + if(!_reshape_b_only_on_first_run && use_mm_b) { _tmp_b.allocator()->allocate(); @@ -493,6 +498,10 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)); // Validate matrix multiply + kernel_info.has_pad_y = false; + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info)); + + kernel_info.has_pad_y = true; ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info)); return Status{}; @@ -514,6 +523,8 @@ void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); _is_prepared = gemm_info.retain_internal_weights(); _original_b = b; + _lhs = a; + _dst = output; // Get the GPU target bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); @@ -608,7 +619,6 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso void CLGEMM::run() { prepare(); - MemoryGroupResourceScope scope_mg(_memory_group); // Run matrix multiply kernel @@ -675,8 +685,14 @@ void CLGEMM::run() CLScheduler::get().enqueue(_reshape_rhs_kernel, false); } } - - CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true); + if(_has_pad_y) + { + CLScheduler::get().enqueue(_mm_reshaped_only_rhs_fallback_kernel, true); + } + else + { + CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true); + } break; } default: @@ -690,6 +706,16 @@ void CLGEMM::prepare() { if(!_is_prepared) { + // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS) + { + // Check if the lhs or dst tensors have padding + const unsigned int cross_plane_pad_lhs = _lhs->info()->padding().top + _lhs->info()->padding().bottom; + const unsigned int cross_plane_pad_dst = _dst->info()->padding().top + _dst->info()->padding().bottom; + + _has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0); + } + if(_gemm_kernel_type != CLGEMMKernelType::NATIVE_V1 && _reshape_b_only_on_first_run) { if(_weights_manager && _weights_manager->are_weights_managed(_original_b)) |