From f3622becf1f0d6bf5147ebb7d6d0f14d5252860a Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 29 Jul 2019 14:27:16 +0100 Subject: COMPMID-1979: Fuse Activation Function in CLGEMM - part 4 Fused activation function in CLGEMM Change-Id: I644fdf09349325c0b3a2cd5fef2a3ea2c974149d Signed-off-by: Gian Marco Iodice Reviewed-on: https://review.mlplatform.org/c/1640 Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- src/runtime/CL/functions/CLGEMM.cpp | 91 ++++++++++++------------------------- 1 file changed, 28 insertions(+), 63 deletions(-) (limited to 'src/runtime/CL/functions/CLGEMM.cpp') diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index c0ccd0f451..e78395f1de 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -48,7 +48,6 @@ using namespace arm_compute::cl_gemm; CLGEMM::CLGEMM(std::shared_ptr memory_manager) : _memory_group(std::move(memory_manager)), _mm_kernel(), - _ma_kernel(), _reshape_lhs_kernel(), _reshape_rhs_kernel(), _mm_reshaped_kernel(), @@ -56,7 +55,6 @@ CLGEMM::CLGEMM(std::shared_ptr memory_manager) _tmp_a(), _tmp_b(), _original_b(nullptr), - _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _gemm_type(GEMMType::NATIVE) @@ -118,10 +116,10 @@ void CLGEMM::configure_native(const ICLTensor *a, const ICLTensor *b, const ICLT // Set the target for the kernels _mm_kernel.set_target(gpu_target); - GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d()); + GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias()); // Configure and tune matrix multiply kernel - _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision()); + _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); // Tune kernel statically CLScheduler::get().tune_kernel_static(_mm_kernel); @@ -162,7 +160,7 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const lhs_info.interleave = true; lhs_info.transpose = true; - GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false); + GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias()); _memory_group.manage(&_tmp_a); if(!_reshape_b_only_on_first_run) @@ -177,7 +175,7 @@ void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info); // Configure and tune matrix multiply kernel - _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision()); + _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info()); CLScheduler::get().tune_kernel_static(_mm_kernel); @@ -200,13 +198,15 @@ void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const GPUTarget gpu_target = CLScheduler::get().target(); bool broadcast_bias = gemm_info.broadcast_bias(); - GEMMKernelInfo kernel_info; + + GEMMKernelInfo kernel_info; kernel_info.m = m; kernel_info.n = n; kernel_info.k = k; kernel_info.depth_output_gemm3d = depth_output_gemm3d; kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); // Set the target for the kernels _reshape_lhs_kernel.set_target(gpu_target); @@ -255,13 +255,15 @@ void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const GPUTarget gpu_target = CLScheduler::get().target(); bool broadcast_bias = gemm_info.broadcast_bias(); - GEMMKernelInfo kernel_info; + + GEMMKernelInfo kernel_info; kernel_info.m = m; kernel_info.n = n; kernel_info.k = k; kernel_info.depth_output_gemm3d = depth_output_gemm3d; kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); // Set the target for the kernels _mm_kernel.set_target(gpu_target); @@ -305,21 +307,12 @@ Status CLGEMM::validate_native(const ITensorInfo *a, const ITensorInfo *b, const const unsigned int n = b->dimension(0); const unsigned int k = a->dimension(0); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); - const bool add_c = (beta != 0.f && c != nullptr); - const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f; - const bool fuse_add = is_beta_one && (c != nullptr && c->num_dimensions() == 1); - const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d); + const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias()); // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, (add_c && fuse_add) ? c : nullptr, output, alpha, beta, - false, reshape_info, gpu_target, gemm_info.fp_mixed_precision())); - - if(add_c && !fuse_add) - { - // Validate matrix addition kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta)); - } + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta, + false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info())); return Status{}; } @@ -340,9 +333,6 @@ Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, int mult_transpose1xW_width = 1; int mult_interleave4x4_height = 1; const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); - const bool add_c = (beta != 0.f && c != nullptr); - const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f; - const bool fuse_add = is_beta_one && (c != nullptr && c->num_dimensions() == 1); if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST) { @@ -364,7 +354,7 @@ Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, lhs_info.interleave = true; lhs_info.transpose = true; - const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false); + const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias()); // Validate interleave kernel auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d()))); @@ -375,14 +365,8 @@ Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)); // Validate matrix multiply - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, (add_c && fuse_add) ? c : nullptr, output, alpha, beta, - true, reshape_info, gpu_target, gemm_info.fp_mixed_precision())); - - if(add_c && !fuse_add) - { - // Validate matrix addition kernel - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta)); - } + ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, + true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info())); return Status{}; } @@ -405,13 +389,15 @@ Status CLGEMM::validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const bool broadcast_bias = gemm_info.broadcast_bias(); - GEMMKernelInfo kernel_info; + + GEMMKernelInfo kernel_info; kernel_info.m = m; kernel_info.n = n; kernel_info.k = k; kernel_info.depth_output_gemm3d = depth_output_gemm3d; kernel_info.reinterpret_input_as_3d = false; kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; @@ -452,13 +438,15 @@ Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); const bool broadcast_bias = gemm_info.broadcast_bias(); - GEMMKernelInfo kernel_info; + + GEMMKernelInfo kernel_info; kernel_info.m = m; kernel_info.n = n; kernel_info.k = k; kernel_info.depth_output_gemm3d = depth_output_gemm3d; kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); GEMMLHSMatrixInfo lhs_info; GEMMRHSMatrixInfo rhs_info; @@ -501,9 +489,7 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * // Select GEMMType _gemm_type = select_gemm_type(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target); - const bool is_fuse_add_c_supported = (_gemm_type == GEMMType::RESHAPED_V2) || (_gemm_type == GEMMType::RESHAPED_ONLY_RHS); - const bool add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); - const bool fuse_add_c = add_c && is_fuse_add_c_supported; + const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); const ICLTensor *c_to_use = fuse_add_c ? c : nullptr; @@ -534,13 +520,6 @@ void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor * ARM_COMPUTE_ERROR("GEMMType not supported"); } } - - // Configure matrix addition kernel - if(add_c && !fuse_add_c) - { - _ma_kernel.configure(c, output, beta); - _run_addition = true; - } } Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) @@ -555,9 +534,7 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso // Select GEMMType GEMMType gemm_type = select_gemm_type(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run(), gpu_target); - const bool is_fuse_add_c_supported = (gemm_type == GEMMType::RESHAPED_V2) || (gemm_type == GEMMType::RESHAPED_ONLY_RHS); - const bool add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); - const bool fuse_add_c = add_c && is_fuse_add_c_supported; + const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr; @@ -589,12 +566,6 @@ Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso } } - // Validate matrix addition kernel - if(add_c && !fuse_add_c) - { - ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta)); - } - return Status{}; } @@ -609,7 +580,7 @@ void CLGEMM::run() { case GEMMType::NATIVE: { - CLScheduler::get().enqueue(_mm_kernel, !_run_addition); + CLScheduler::get().enqueue(_mm_kernel, true); break; } case GEMMType::RESHAPED_V1: @@ -623,7 +594,7 @@ void CLGEMM::run() CLScheduler::get().enqueue(_reshape_rhs_kernel, false); } - CLScheduler::get().enqueue(_mm_kernel, !_run_addition); + CLScheduler::get().enqueue(_mm_kernel, true); break; } case GEMMType::RESHAPED_V2: @@ -637,7 +608,7 @@ void CLGEMM::run() CLScheduler::get().enqueue(_reshape_rhs_kernel, false); } - CLScheduler::get().enqueue(_mm_reshaped_kernel, !_run_addition); + CLScheduler::get().enqueue(_mm_reshaped_kernel, true); break; } case GEMMType::RESHAPED_ONLY_RHS: @@ -648,7 +619,7 @@ void CLGEMM::run() CLScheduler::get().enqueue(_reshape_rhs_kernel, false); } - CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, !_run_addition); + CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true); break; } default: @@ -656,12 +627,6 @@ void CLGEMM::run() ARM_COMPUTE_ERROR("GEMMType not supported"); } } - - // Run matrix addition kernel - if(_run_addition) - { - CLScheduler::get().enqueue(_ma_kernel); - } } void CLGEMM::prepare() -- cgit v1.2.1