aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.cpp
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-03-13 16:20:04 +0000
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-03-21 10:33:53 +0000
commita3e57c20a0b7a174f0c357676a4da40a248d04db (patch)
treed92b2316a00db6ce07dd2af476791281fcc98de6 /src/cpu/operators/CpuGemm.cpp
parent8918b23073851417e8be6e5e53c6380dbdedf201 (diff)
downloadComputeLibrary-a3e57c20a0b7a174f0c357676a4da40a248d04db.tar.gz
Add dynamic weights for CPU fully connected layer
Resolves: COMPMID-5917 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I073067b490f2a1b96b81a037ea431c9a2e5c7503 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9322 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r--src/cpu/operators/CpuGemm.cpp15
1 files changed, 8 insertions, 7 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 545d59f410..f914bceec3 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -64,7 +64,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
ARM_COMPUTE_LOG_PARAMS(a, b, c, d, alpha, beta, gemm_info);
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+ const bool is_c_bias = beta == 1 && c != nullptr;
bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run();
// Check if we need to reshape the matrix B only on the first run
@@ -72,8 +72,8 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
_run_vector_matrix_multiplication = a->dimension(1) < 2;
_run_alpha_scale = alpha != 1.f;
- _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run();
- _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run();
+ _run_bias_addition = is_c_bias;
+ _run_addition = beta != 0 && beta != 1 && c != nullptr;
_run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
if(run_optimised)
@@ -153,12 +153,13 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
- const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+ const bool is_c_bias = beta == 1 && c != nullptr;
+ const bool run_addition = c != nullptr && beta != 0 && beta != 1;
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
-
+
if (is_fixed_format_fast_math(gemm_info.weight_format()))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
@@ -177,7 +178,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d);
}
- if(c != nullptr && !is_c_bias)
+ if(run_addition)
{
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
@@ -265,7 +266,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
}
// Validate matrix addition kernel
- if(beta != 0 && c != nullptr && !is_c_bias)
+ if(run_addition)
{
ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta));
}