diff options
author | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-04-03 16:27:25 +0100 |
---|---|---|
committer | Viet-Hoa Do <viet-hoa.do@arm.com> | 2023-04-14 08:57:27 +0000 |
commit | 9b0a6b49e95b221456489dd7c58681ceca5dd8cb (patch) | |
tree | 6afd87f8407fafb3de802e4ce1b4099a579b6ff8 /src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp | |
parent | 4e84f244548a18e0935502cc443336fc1b8f1454 (diff) | |
download | ComputeLibrary-9b0a6b49e95b221456489dd7c58681ceca5dd8cb.tar.gz |
Fix dynamic weights for CPU connected layer
Resolves: COMPMID-5995
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: I707b8918bebee7e70d4de5207ef555c806e7a305
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9405
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp')
-rw-r--r-- | src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp | 64 |
1 files changed, 35 insertions, 29 deletions
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp index aec9da193b..8ca128fb07 100644 --- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp +++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -65,7 +65,6 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.activation_info = info.activation_info(); asm_info.output_stage = info.gemmlowp_output_stage(); asm_info.fast_mode = info.fast_math(); - asm_info.reshape_b_only_on_first_run = info.reshape_b_only_on_first_run(); return asm_info; } @@ -120,7 +119,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(const ITensorInfo *a, const ITenso _a_offset = a->quantization_info().uniform().offset; _b_offset = b->quantization_info().uniform().offset; _run_vector_matrix_multiplication = a->dimension(1) < 2; - _reshape_b_only_on_first_run = info.reshape_b_only_on_first_run(); + _reshape_b_only_on_first_run = b->are_values_constant(); _is_prepared = false; _fused_assembly_path = false; _flip_signedness = is_data_type_quantized_per_channel(b->data_type()) && (a->data_type() == DataType::QASYMM8) && _reshape_b_only_on_first_run; @@ -167,31 +166,34 @@ void CpuGemmLowpMatrixMultiplyCore::configure(const ITensorInfo *a, const ITenso // Initialize assembly kernel meta-data const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); #ifdef __aarch64__ - switch(a->data_type()) + if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently. { - case DataType::QASYMM8: - case DataType::QASYMM8_SIGNED: - case DataType::U8: - case DataType::S8: + switch(a->data_type()) { - if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::U8: + case DataType::S8: { - auto c_info_to_use = c == nullptr ? nullptr : c; - _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info); - _fused_assembly_path = _asm_glue->is_configured(); + if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + auto c_info_to_use = c == nullptr ? nullptr : c; + _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info); + _fused_assembly_path = _asm_glue->is_configured(); + } + else + { + auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst); + _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info); + } + _assembly_path = _asm_glue->is_configured(); + break; } - else + default: { - auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst); - _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info); + ARM_COMPUTE_ERROR("Datatype not supported"); + break; } - _assembly_path = _asm_glue->is_configured(); - break; - } - default: - { - ARM_COMPUTE_ERROR("Datatype not supported"); - break; } } #endif /* __aarch64__ */ @@ -371,14 +373,18 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITens // Check if we need to run the optimized assembly kernel bool run_optimised = false; bool run_optimised_requantized = false; - if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) - { - run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)); - run_optimised_requantized = run_optimised; - } - else + + if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently. { - run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info)); + if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)); + run_optimised_requantized = run_optimised; + } + else + { + run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info)); + } } if(run_optimised) |