aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r--src/cpu/operators/CpuGemm.cpp15
1 files changed, 8 insertions, 7 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 545d59f410..f914bceec3 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -64,7 +64,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
ARM_COMPUTE_LOG_PARAMS(a, b, c, d, alpha, beta, gemm_info);
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+ const bool is_c_bias = beta == 1 && c != nullptr;
bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run();
// Check if we need to reshape the matrix B only on the first run
@@ -72,8 +72,8 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
_run_vector_matrix_multiplication = a->dimension(1) < 2;
_run_alpha_scale = alpha != 1.f;
- _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run();
- _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run();
+ _run_bias_addition = is_c_bias;
+ _run_addition = beta != 0 && beta != 1 && c != nullptr;
_run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
if(run_optimised)
@@ -153,12 +153,13 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
- const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+ const bool is_c_bias = beta == 1 && c != nullptr;
+ const bool run_addition = c != nullptr && beta != 0 && beta != 1;
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
-
+
if (is_fixed_format_fast_math(gemm_info.weight_format()))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
@@ -177,7 +178,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d);
}
- if(c != nullptr && !is_c_bias)
+ if(run_addition)
{
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
@@ -265,7 +266,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
}
// Validate matrix addition kernel
- if(beta != 0 && c != nullptr && !is_c_bias)
+ if(run_addition)
{
ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta));
}