aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-04-03 16:27:25 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-04-14 08:57:27 +0000
commit9b0a6b49e95b221456489dd7c58681ceca5dd8cb (patch)
tree6afd87f8407fafb3de802e4ce1b4099a579b6ff8
parent4e84f244548a18e0935502cc443336fc1b8f1454 (diff)
downloadComputeLibrary-9b0a6b49e95b221456489dd7c58681ceca5dd8cb.tar.gz
Fix dynamic weights for CPU connected layer
Resolves: COMPMID-5995 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I707b8918bebee7e70d4de5207ef555c806e7a305 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9405 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/cpu/operators/CpuFullyConnected.cpp11
-rw-r--r--src/cpu/operators/CpuGemm.cpp16
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp4
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp64
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp20
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp21
-rw-r--r--tests/validation/NEON/FullyConnectedLayer.cpp4
7 files changed, 91 insertions, 49 deletions
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index af630154cf..70584a64f8 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -136,7 +136,7 @@ Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITe
}
else
{
- GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+ GEMMInfo gemm_info;
gemm_info.set_weight_format(weight_format);
gemm_info.set_fixed_format(weight_format != WeightFormat::UNSPECIFIED);
gemm_info.set_fast_math(enable_fast_math);
@@ -190,7 +190,7 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo *
const Status status = get_gemmlowp_output_stage_info(&src_info, &weights_info, dst, act, gemmlowp_output_stage_info);
ARM_COMPUTE_ERROR_ON(status.error_code() != ErrorCode::OK);
- GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */);
+ GEMMInfo gemm_info;
gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info);
gemm_info.set_activation_info(act);
gemm_info.set_fast_math(_enable_fast_math);
@@ -200,7 +200,7 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo *
else
{
// Configure matrix multiply kernel
- GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */);
+ GEMMInfo gemm_info;
gemm_info.set_activation_info(act);
gemm_info.set_fast_math(_enable_fast_math);
gemm_info.set_fixed_format(_fixed_format);
@@ -284,6 +284,8 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
// Reshape the weights
_transpose_weights = std::make_unique<kernels::CpuTransposeKernel>();
_transpose_weights->configure(weights, &_reshaped_weights);
+ _reshaped_weights.set_are_values_constant(weights->are_values_constant());
+
weights_to_use = &_reshaped_weights;
_trans_weights_idx = AuxTensorIdx::TransposedWeights;
}
@@ -297,6 +299,7 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
&_converted_weights,
src->tensor_shape(),
fc_info.weights_trained_layout);
+ _converted_weights.set_are_values_constant(weights_to_use->are_values_constant());
weights_to_use = &_converted_weights;
_needs_weights_conversion = true;
@@ -364,7 +367,7 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights,
const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info, WeightsInfo weights_info)
{
- GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+ GEMMInfo gemm_info;
gemm_info.set_activation_info(fc_info.activation_info);
gemm_info.set_fast_math(fc_info.enable_fast_math);
gemm_info.set_fixed_format(weights_info.weight_format() != WeightFormat::UNSPECIFIED);
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index f914bceec3..b9d18c4cb6 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -65,11 +65,13 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
const bool is_c_bias = beta == 1 && c != nullptr;
- bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run();
+ bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) &&
+ (c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
+ !(!b->are_values_constant() && b->tensor_shape().z() > 1); // Disable batch matmul as optimized GeMM handles batching differently.
// Check if we need to reshape the matrix B only on the first run
_is_prepared = false;
- _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+ _reshape_b_only_on_first_run = b->are_values_constant();
_run_vector_matrix_multiplication = a->dimension(1) < 2;
_run_alpha_scale = alpha != 1.f;
_run_bias_addition = is_c_bias;
@@ -211,7 +213,9 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
// Check if we need to run the optimized assembly kernel
cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info));
+ const bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info)) &&
+ (c == nullptr || beta == 0.f || beta == 1.f) && // Optimized GeMM doesn't support beta coefficient.
+ !(!b->are_values_constant() && b->tensor_shape().z() > 1); // Disable batch matmul as optimized GeMM handles batching differently.
if(!run_optimised)
{
@@ -221,7 +225,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
// Check if the first input tensor is a vector.
const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
// Check if we need to reshape the matrix A and matrix B
- const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
+ const bool run_interleave_transpose = !run_vector_matrix_multiplication && !b->are_values_constant();
// Arguments used by GEMMReshapeInfo
// If we pass the matrix A and matrix B reshaped to CpuGemmMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to GEMMReshapeInfo
@@ -259,7 +263,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
- if(c != nullptr && gemm_info.reshape_b_only_on_first_run())
+ if(is_c_bias)
{
ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuAdd::validate(&tmp_output_info, c, d, ConvertPolicy::SATURATE));
}
@@ -294,7 +298,7 @@ void CpuGemm::run(ITensorPack &tensors)
{
// Pass c to asm dispatch only if it's the bias tensor
ITensorPack asm_pack = tensors;
- asm_pack.add_const_tensor(ACL_SRC_2, (_reshape_b_only_on_first_run) ? c : nullptr);
+ asm_pack.add_const_tensor(ACL_SRC_2, _run_bias_addition ? c : nullptr);
_asm_glue->run(asm_pack);
if(_run_alpha_scale)
{
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 9bf6ed1e85..ebf2ebcc1b 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -169,7 +169,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig
{
// Configure matrix multiply function
_mm_gemm = std::make_unique<CpuGemm>();
- _mm_gemm->configure(src, weights, biases, dst, 1.0f, 0.0f, gemm_info);
+ _mm_gemm->configure(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
auto mm_mem_req = _mm_gemm->workspace();
for(unsigned int cont = 0; cont < mm_mem_req.size(); ++cont)
{
@@ -235,7 +235,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei
else
{
// Perform validation step on Matrix multiply function
- return CpuGemm::validate(src, weights, nullptr, dst, 1.0f, 0.0f, gemm_info);
+ return CpuGemm::validate(src, weights, biases, dst, 1.0f, 1.0f, gemm_info);
}
}
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index aec9da193b..8ca128fb07 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -65,7 +65,6 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
asm_info.activation_info = info.activation_info();
asm_info.output_stage = info.gemmlowp_output_stage();
asm_info.fast_mode = info.fast_math();
- asm_info.reshape_b_only_on_first_run = info.reshape_b_only_on_first_run();
return asm_info;
}
@@ -120,7 +119,7 @@ void CpuGemmLowpMatrixMultiplyCore::configure(const ITensorInfo *a, const ITenso
_a_offset = a->quantization_info().uniform().offset;
_b_offset = b->quantization_info().uniform().offset;
_run_vector_matrix_multiplication = a->dimension(1) < 2;
- _reshape_b_only_on_first_run = info.reshape_b_only_on_first_run();
+ _reshape_b_only_on_first_run = b->are_values_constant();
_is_prepared = false;
_fused_assembly_path = false;
_flip_signedness = is_data_type_quantized_per_channel(b->data_type()) && (a->data_type() == DataType::QASYMM8) && _reshape_b_only_on_first_run;
@@ -167,31 +166,34 @@ void CpuGemmLowpMatrixMultiplyCore::configure(const ITensorInfo *a, const ITenso
// Initialize assembly kernel meta-data
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
#ifdef __aarch64__
- switch(a->data_type())
+ if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
{
- case DataType::QASYMM8:
- case DataType::QASYMM8_SIGNED:
- case DataType::U8:
- case DataType::S8:
+ switch(a->data_type())
{
- if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ case DataType::U8:
+ case DataType::S8:
{
- auto c_info_to_use = c == nullptr ? nullptr : c;
- _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info);
- _fused_assembly_path = _asm_glue->is_configured();
+ if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+ {
+ auto c_info_to_use = c == nullptr ? nullptr : c;
+ _asm_glue->configure(a_to_use, b, c_info_to_use, dst, asm_info);
+ _fused_assembly_path = _asm_glue->is_configured();
+ }
+ else
+ {
+ auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst);
+ _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info);
+ }
+ _assembly_path = _asm_glue->is_configured();
+ break;
}
- else
+ default:
{
- auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : dst);
- _asm_glue->configure(a_to_use, b, nullptr, output_to_use, asm_info);
+ ARM_COMPUTE_ERROR("Datatype not supported");
+ break;
}
- _assembly_path = _asm_glue->is_configured();
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR("Datatype not supported");
- break;
}
}
#endif /* __aarch64__ */
@@ -371,14 +373,18 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITens
// Check if we need to run the optimized assembly kernel
bool run_optimised = false;
bool run_optimised_requantized = false;
- if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
- {
- run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info));
- run_optimised_requantized = run_optimised;
- }
- else
+
+ if(!(!b->are_values_constant() && b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
{
- run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info));
+ if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
+ {
+ run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info));
+ run_optimised_requantized = run_optimised;
+ }
+ else
+ {
+ run_optimised = bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info));
+ }
}
if(run_optimised)
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 0266c48f86..e51f2f9eb6 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -71,7 +71,14 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
_impl->original_b = b;
_impl->op = std::make_unique<cpu::CpuGemm>();
- _impl->op->configure(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info);
+ // Make the B matrix dynamic values.
+ auto b_info_to_use = b->info()->clone();
+ if(!gemm_info.reshape_b_only_on_first_run())
+ {
+ b_info_to_use->set_are_values_constant(false);
+ }
+
+ _impl->op->configure(a->info(), b_info_to_use.get(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info);
_impl->aux_mem_req = _impl->op->workspace();
_impl->run_pack = { { ACL_SRC_0, a }, { ACL_SRC_1, b }, { ACL_SRC_2, c }, { ACL_DST, d } };
@@ -81,7 +88,14 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
{
- return cpu::CpuGemm::validate(a, b, c, output, alpha, beta, gemm_info);
+ // Make the B matrix dynamic values.
+ auto b_to_use = b->clone();
+ if(!gemm_info.reshape_b_only_on_first_run())
+ {
+ b_to_use->set_are_values_constant(false);
+ }
+
+ return cpu::CpuGemm::validate(a, b_to_use.get(), c, output, alpha, beta, gemm_info);
}
Status NEGEMM::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output,
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index 6c179f8387..453d3cedef 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -61,9 +61,17 @@ NEGEMMLowpMatrixMultiplyCore::~NEGEMMLowpMatrixMultiplyCore() = default;
void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *output, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
+
+ // Make the B matrix dynamic values.
+ auto b_info_to_use = b->info()->clone();
+ if(!gemm_info.reshape_b_only_on_first_run())
+ {
+ b_info_to_use->set_are_values_constant(false);
+ }
+
_impl->b = b;
_impl->op = std::make_unique<cpu::CpuGemmLowpMatrixMultiplyCore>();
- _impl->op->configure(a->info(), b->info(), (c != nullptr ? c->info() : nullptr), output->info(), gemm_info);
+ _impl->op->configure(a->info(), b_info_to_use.get(), (c != nullptr ? c->info() : nullptr), output->info(), gemm_info);
_impl->run_pack =
{
{ TensorType::ACL_SRC_0, a },
@@ -82,7 +90,14 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info)
{
- return cpu::CpuGemmLowpMatrixMultiplyCore::validate(a, b, c, output, gemm_info);
+ // Make the B matrix dynamic values.
+ auto b_info_to_use = b->clone();
+ if(!gemm_info.reshape_b_only_on_first_run())
+ {
+ b_info_to_use->set_are_values_constant(false);
+ }
+
+ return cpu::CpuGemmLowpMatrixMultiplyCore::validate(a, b_info_to_use.get(), c, output, gemm_info);
}
void NEGEMMLowpMatrixMultiplyCore::run()
diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp
index 882ad04cd2..04889a9dba 100644
--- a/tests/validation/NEON/FullyConnectedLayer.cpp
+++ b/tests/validation/NEON/FullyConnectedLayer.cpp
@@ -438,7 +438,7 @@ FIXTURE_DATA_TEST_CASE(RunDynamicBias, NEFullyConnectedLayerDynamicBiasFixture<u
FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
framework::dataset::make("DataType", DataType::QASYMM8)),
framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))),
- framework::dataset::make("WeightsReshaped", { false, true })))
+ framework::dataset::make("WeightsReshaped", { false })))
{
}
TEST_SUITE_END()
@@ -480,7 +480,7 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, NEFullyConnectedLayerQuantizedFixture<
FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))),
- framework::dataset::make("WeightsReshaped", { false, true })))
+ framework::dataset::make("WeightsReshaped", { false })))
{
}
TEST_SUITE_END() // QASYMM8_SIGNED