From 2b84be544e4a27f7e8e80827e9c85c8f0d58b4ce Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Wed, 8 Apr 2020 10:15:51 +0100 Subject: COMPMID-3280: Make all ML primitives for CL use the new interface - Part 2 - CLFunctions have been updated Change-Id: Ie3256a6c775bc12f3126482bd8e8a46da54b267c Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3053 Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/runtime/CL/functions/CLGEMM.cpp | 49 ++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 20 deletions(-) (limited to 'src/runtime/CL/functions/CLGEMM.cpp') diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index 74d59cdad1..8466024c04 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -81,7 +81,8 @@ CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsi return gemm_kernel->select_kernel(params); } -void CLGEMM::configure_native_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info) +void CLGEMM::configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, + const GEMMInfo &gemm_info) { const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1); const unsigned int n = b->info()->dimension(0); @@ -94,13 +95,14 @@ void CLGEMM::configure_native_v1(const ICLTensor *a, const ICLTensor *b, const I GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias()); // Configure and tune matrix multiply kernel - _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); + _mm_kernel.configure(compile_context, a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); // Tune kernel statically CLScheduler::get().tune_kernel_static(_mm_kernel); } -void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info) +void CLGEMM::configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, + const GEMMInfo &gemm_info) { bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1); @@ -148,22 +150,22 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const } // Configure interleave kernel - _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, reinterpret_input_as_3d); + _reshape_lhs_kernel.configure(compile_context, a, &_tmp_a, lhs_info, reinterpret_input_as_3d); // Configure transpose kernel ICLTensor *reshaped_rhs = &_tmp_b; if(_weights_manager && _weights_manager->are_weights_managed(b)) { - _reshape_rhs_kernel_managed.configure(b, rhs_info); + _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info); reshaped_rhs = utils::cast::polymorphic_downcast(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed)); } else { - _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info); } // Configure and tune matrix multiply kernel - _mm_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); + _mm_kernel.configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); CLScheduler::get().tune_kernel_static(_mm_kernel); @@ -176,7 +178,8 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const } } -void CLGEMM::configure_reshaped(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info) +void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, + const GEMMInfo &gemm_info) { DataType data_type = a->info()->data_type(); bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); @@ -223,21 +226,21 @@ void CLGEMM::configure_reshaped(const ICLTensor *a, const ICLTensor *b, const IC // Configure lhs_info and rhs_info std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type); - _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d()); + _reshape_lhs_kernel.configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d()); ICLTensor *reshaped_rhs = &_tmp_b; if(_weights_manager && _weights_manager->are_weights_managed(b)) { - _reshape_rhs_kernel_managed.configure(b, rhs_info); + _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info); reshaped_rhs = utils::cast::polymorphic_downcast(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed)); } else { - _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info); } // Configure and tune matrix multiply kernel - _mm_reshaped_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); + _mm_reshaped_kernel.configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); // Allocate intermediate tensors _tmp_a.allocator()->allocate(); @@ -248,7 +251,8 @@ void CLGEMM::configure_reshaped(const ICLTensor *a, const ICLTensor *b, const IC } } -void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info) +void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, + const GEMMInfo &gemm_info) { DataType data_type = a->info()->data_type(); bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); @@ -293,16 +297,16 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, ICLTensor *reshaped_rhs = &_tmp_b; if(_weights_manager && _weights_manager->are_weights_managed(b)) { - _reshape_rhs_kernel_managed.configure(b, rhs_info); + _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info); reshaped_rhs = utils::cast::polymorphic_downcast(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed)); } else { - _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); + _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info); } // Configure and tune matrix multiply kernel - _mm_reshaped_only_rhs_kernel.configure(a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); + _mm_reshaped_only_rhs_kernel.configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); if(!_reshape_b_only_on_first_run && use_mm_b) { @@ -483,6 +487,11 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf } void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info) +{ + configure(CLKernelLibrary::get().get_compile_context(), a, b, c, output, alpha, beta, gemm_info); +} + +void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output); @@ -511,22 +520,22 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * { case CLGEMMKernelType::NATIVE_V1: { - configure_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info); + configure_native_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); break; } case CLGEMMKernelType::RESHAPED_V1: { - configure_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info); + configure_reshaped_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); break; } case CLGEMMKernelType::RESHAPED: { - configure_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info); + configure_reshaped_v2(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); break; } case CLGEMMKernelType::RESHAPED_ONLY_RHS: { - configure_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info); + configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); break; } default: -- cgit v1.2.1