diff options
Diffstat (limited to 'src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp')
-rw-r--r-- | src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp | 64 |
1 files changed, 35 insertions, 29 deletions
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp index aec9da193b..8ca128fb07 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -65,7 +65,6 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.activation_info = info.activation_info(); asm_info.output_stage = info.gemmlowp_output_stage(); asm_info.fast_mode = info.fast_math(); - asm_info.reshape_b_only_on_first_run = info.reshape_b_only_on_first_run(); return asm_info; } @@ -120,7 +119,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(const ITensorInfo *a, const ITenso _a_offset = a->quantization_info().uniform().offset; _b_offset = b->quantization_info().uniform().offset; _run_vector_matrix_multiplication = a->dimension(1) < 2; - _reshape_b_only_on_first_run = info.reshape_b_only_on_first_run(); + _reshape_b_only_on_first_run = b->are_values_constant(); _is_prepared = false; _fused_assembly_path = false; _flip_signedness = is_data_type_quantized_per_channel(b->data_type()) && (a->data_type() == DataType::QASYMM8) && _reshape_b_only_on_first_run; @@ -167,31 +166,34 @@ void CpuGemmLowpMatrixMultiplyCore::configure(const ITensorInfo *a, const ITenso // Initialize assembly kernel meta-data const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); #ifdef __aarch64__ - switch(a->data_type()) + if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently. { - case DataType::QASYMM8: - case DataType::QASYMM8_SIGNED: - case DataType::U8: - case DataType::S8: + switch(a->data_type()) { - if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::U8: + case DataType::S8: { - auto c_info_to_use = c == nullptr ? nullptr : c; - _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info); - _fused_assembly_path = _asm_glue->is_configured(); + if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + auto c_info_to_use = c == nullptr ? nullptr : c; + _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info); + _fused_assembly_path = _asm_glue->is_configured(); + } + else + { + auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst); + _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info); + } + _assembly_path = _asm_glue->is_configured(); + break; } - else + default: { - auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst); - _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info); + ARM_COMPUTE_ERROR("Datatype not supported"); + break; } - _assembly_path = _asm_glue->is_configured(); - break; - } - default: - { - ARM_COMPUTE_ERROR("Datatype not supported"); - break; } } #endif /* __aarch64__ */ @@ -371,14 +373,18 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITens // Check if we need to run the optimized assembly kernel bool run_optimised = false; bool run_optimised_requantized = false; - if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) - { - run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)); - run_optimised_requantized = run_optimised; - } - else + + if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently. { - run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info)); + if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)); + run_optimised_requantized = run_optimised; + } + else + { + run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info)); + } } if(run_optimised) |