aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemm.cpp')
-rw-r--r--src/cpu/operators/CpuGemm.cpp16
1 files changed, 10 insertions, 6 deletions
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index f914bceec3..b9d18c4cb6 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -65,11 +65,13 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
const bool is_c_bias = beta == 1 && c != nullptr;
- bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run();
+ bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) &&
+ (c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
+ !(!b->are_values_constant() && b->tensor_shape().z() > 1); // Disable batch matmul as optimized GeMM handles batching differently.
// Check if we need to reshape the matrix B only on the first run
_is_prepared = false;
- _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+ _reshape_b_only_on_first_run = b->are_values_constant();
_run_vector_matrix_multiplication = a->dimension(1) < 2;
_run_alpha_scale = alpha != 1.f;
_run_bias_addition = is_c_bias;
@@ -211,7 +213,9 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
// Check if we need to run the optimized assembly kernel
cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info));
+ const bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info)) &&
+ (c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
+ !(!b->are_values_constant() && b->tensor_shape().z() > 1); // Disable batch matmul as optimized GeMM handles batching differently.
if(!run_optimised)
{
@@ -221,7 +225,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
// Check if the first input tensor is a vector.
const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
// Check if we need to reshape the matrix A and matrix B
- const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
+ const bool run_interleave_transpose = !run_vector_matrix_multiplication && !b->are_values_constant();
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CpuGemmMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to GEMMReshapeInfo
@@ -259,7 +263,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
- if(c != nullptr && gemm_info.reshape_b_only_on_first_run())
+ if(is_c_bias)
{
ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuAdd::validate(&tmp_output_info, c, d, ConvertPolicy::SATURATE));
}
@@ -294,7 +298,7 @@ void CpuGemm::run(ITensorPack &tensors)
{
// Pass c to asm dispatch only if it's the bias tensor
ITensorPack asm_pack = tensors;
- asm_pack.add_const_tensor(ACL_SRC_2, (_reshape_b_only_on_first_run) ? c : nullptr);
+ asm_pack.add_const_tensor(ACL_SRC_2, _run_bias_addition ? c : nullptr);
_asm_glue->run(asm_pack);
if(_run_alpha_scale)
{