From a3e57c20a0b7a174f0c357676a4da40a248d04db Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Mon, 13 Mar 2023 16:20:04 +0000 Subject: Add dynamic weights for CPU fully connected layer Resolves: COMPMID-5917 Signed-off-by: Viet-Hoa Do Change-Id: I073067b490f2a1b96b81a037ea431c9a2e5c7503 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9322 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/cpu/operators/CpuGemm.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'src/cpu/operators/CpuGemm.cpp') diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index 545d59f410..f914bceec3 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -64,7 +64,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso ARM_COMPUTE_LOG_PARAMS(a, b, c, d, alpha, beta, gemm_info); const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); - const bool is_c_bias = gemm_info.reshape_b_only_on_first_run(); + const bool is_c_bias = beta == 1 && c != nullptr; bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run(); // Check if we need to reshape the matrix B only on the first run @@ -72,8 +72,8 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); _run_vector_matrix_multiplication = a->dimension(1) < 2; _run_alpha_scale = alpha != 1.f; - _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run(); - _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run(); + _run_bias_addition = is_c_bias; + _run_addition = beta != 0 && beta != 1 && c != nullptr; _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info()))); if(run_optimised) @@ -153,12 +153,13 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); - const bool is_c_bias = gemm_info.reshape_b_only_on_first_run(); + const bool is_c_bias = beta == 1 && c != nullptr; + const bool run_addition = c != nullptr && beta != 0 && beta != 1; ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32); - + if (is_fixed_format_fast_math(gemm_info.weight_format())) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); @@ -177,7 +178,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d); } - if(c != nullptr && !is_c_bias) + if(run_addition) { ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0); ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d()); @@ -265,7 +266,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens } // Validate matrix addition kernel - if(beta != 0 && c != nullptr && !is_c_bias) + if(run_addition) { ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta)); } -- cgit v1.2.1