aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp')
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp64
1 files changed, 35 insertions, 29 deletions
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index aec9da193b..8ca128fb07 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -65,7 +65,6 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.output_stage = info.gemmlowp_output_stage();
asm_info.fast_mode = info.fast_math();
- asm_info.reshape_b_only_on_first_run = info.reshape_b_only_on_first_run();
return asm_info;
}
@@ -120,7 +119,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(const ITensorInfo *a, const ITenso
_a_offset = a->quantization_info().uniform().offset;
_b_offset = b->quantization_info().uniform().offset;
_run_vector_matrix_multiplication = a->dimension(1) < 2;
- _reshape_b_only_on_first_run = info.reshape_b_only_on_first_run();
+ _reshape_b_only_on_first_run = b->are_values_constant();
_is_prepared = false;
_fused_assembly_path = false;
_flip_signedness = is_data_type_quantized_per_channel(b->data_type()) && (a->data_type() == DataType::QASYMM8) && _reshape_b_only_on_first_run;
@@ -167,31 +166,34 @@ void CpuGemmLowpMatrixMultiplyCore::configure(const ITensorInfo *a, const ITenso
// Initialize assembly kernel meta-data
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
#ifdef __aarch64__
- switch(a->data_type())
+ if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
{
- case DataType::QASYMM8:
- case DataType::QASYMM8_SIGNED:
- case DataType::U8:
- case DataType::S8:
+ switch(a->data_type())
{
- if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ case DataType::U8:
+ case DataType::S8:
{
- auto c_info_to_use = c == nullptr ? nullptr : c;
- _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info);
- _fused_assembly_path = _asm_glue->is_configured();
+ if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+ {
+ auto c_info_to_use = c == nullptr ? nullptr : c;
+ _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info);
+ _fused_assembly_path = _asm_glue->is_configured();
+ }
+ else
+ {
+ auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst);
+ _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info);
+ }
+ _assembly_path = _asm_glue->is_configured();
+ break;
}
- else
+ default:
{
- auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst);
- _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info);
+ ARM_COMPUTE_ERROR("Datatype not supported");
+ break;
}
- _assembly_path = _asm_glue->is_configured();
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR("Datatype not supported");
- break;
}
}
#endif /* __aarch64__ */
@@ -371,14 +373,18 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITens
// Check if we need to run the optimized assembly kernel
bool run_optimised = false;
bool run_optimised_requantized = false;
- if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
- {
- run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info));
- run_optimised_requantized = run_optimised;
- }
- else
+
+ if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
{
- run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info));
+ if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+ {
+ run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info));
+ run_optimised_requantized = run_optimised;
+ }
+ else
+ {
+ run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info));
+ }
}
if(run_optimised)