diff options
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r-- | src/cpu/operators/CpuGemm.cpp | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index 545d59f410..f914bceec3 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -64,7 +64,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso ARM_COMPUTE_LOG_PARAMS(a, b, c, d, alpha, beta, gemm_info); const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); - const bool is_c_bias = gemm_info.reshape_b_only_on_first_run(); + const bool is_c_bias = beta == 1 && c != nullptr; bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run(); // Check if we need to reshape the matrix B only on the first run @@ -72,8 +72,8 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); _run_vector_matrix_multiplication = a->dimension(1) < 2; _run_alpha_scale = alpha != 1.f; - _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run(); - _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run(); + _run_bias_addition = is_c_bias; + _run_addition = beta != 0 && beta != 1 && c != nullptr; _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info()))); if(run_optimised) @@ -153,12 +153,13 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); - const bool is_c_bias = gemm_info.reshape_b_only_on_first_run(); + const bool is_c_bias = beta == 1 && c != nullptr; + const bool run_addition = c != nullptr && beta != 0 && beta != 1; ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32); - + if (is_fixed_format_fast_math(gemm_info.weight_format())) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); @@ -177,7 +178,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d); } - if(c != nullptr && !is_c_bias) + if(run_addition) { ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0); ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d()); @@ -265,7 +266,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens } // Validate matrix addition kernel - if(beta != 0 && c != nullptr && !is_c_bias) + if(run_addition) { ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta)); } |