aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.cpp')
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp29
1 files changed, 16 insertions, 13 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 31c873c2ba..7460f2020c 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -839,23 +839,26 @@ void CpuGemmConv2d::run(ITensorPack &tensors)
auto weights = gemm_pack.get_const_tensor(TensorType::ACL_SRC_1);
ARM_COMPUTE_ERROR_ON_NULLPTR(weights);
// Re-interpreted weights. Only tensor shape is changed. Only memory import, no allocation
+ const bool use_reinterpreted_wei = (_run_wt && _wt_method == WeightTransformMethod::ReinterpretThenTranspose);
CpuAuxTensorHandler reinterpreted_wei(
_weights_reshaped, *weights,
/* import only if we chose the ReinterpretThenTranspose path, because otherwise the weight may have been freed */
- !(_run_wt && _wt_method == WeightTransformMethod::ReinterpretThenTranspose));
- CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
+ !use_reinterpreted_wei);
+
+ const bool use_reshaped_wei = (_run_wt && (_wt_method == WeightTransformMethod::ReshapeThenTranspose ||
+ _wt_method == WeightTransformMethod::FusedReshapeAndTranspose));
+ CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors,
+ false /* pack_inject */, !use_reshaped_wei /* bypass_alloc */,
+ !use_reshaped_wei /* bypass_import */
+ );
// Update the weights to use if it has been reshaped
- if (_run_wt)
+ if (use_reinterpreted_wei)
{
- if (_wt_method == WeightTransformMethod::ReinterpretThenTranspose)
- {
- gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get());
- }
- else if (_wt_method == WeightTransformMethod::ReshapeThenTranspose ||
- _wt_method == WeightTransformMethod::FusedReshapeAndTranspose)
- {
- gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
- }
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get());
+ }
+ else if (use_reshaped_wei)
+ {
+ gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get());
}
// Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions