diff options
Diffstat (limited to 'src/cpu/operators/CpuGemmConv2d.cpp')
-rw-r--r-- | src/cpu/operators/CpuGemmConv2d.cpp | 29 |
1 files changed, 16 insertions, 13 deletions
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp index 31c873c2ba..7460f2020c 100644 --- a/src/cpu/operators/CpuGemmConv2d.cpp +++ b/src/cpu/operators/CpuGemmConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023 Arm Limited. + * Copyright (c) 2021-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -839,23 +839,26 @@ void CpuGemmConv2d::run(ITensorPack &tensors) auto weights = gemm_pack.get_const_tensor(TensorType::ACL_SRC_1); ARM_COMPUTE_ERROR_ON_NULLPTR(weights); // Re-interpreted weights. Only tensor shape is changed. Only memory import, no allocation + const bool use_reinterpreted_wei = (_run_wt && _wt_method == WeightTransformMethod::ReinterpretThenTranspose); CpuAuxTensorHandler reinterpreted_wei( _weights_reshaped, *weights, /* import only if we chose the ReinterpretThenTranspose path, because otherwise the weight may have been freed */ - !(_run_wt && _wt_method == WeightTransformMethod::ReinterpretThenTranspose)); - CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors); + !use_reinterpreted_wei); + + const bool use_reshaped_wei = (_run_wt && (_wt_method == WeightTransformMethod::ReshapeThenTranspose || + _wt_method == WeightTransformMethod::FusedReshapeAndTranspose)); + CpuAuxTensorHandler reshaped_wei(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors, + false /* pack_inject */, !use_reshaped_wei /* bypass_alloc */, + !use_reshaped_wei /* bypass_import */ + ); // Update the weights to use if it has been reshaped - if (_run_wt) + if (use_reinterpreted_wei) { - if (_wt_method == WeightTransformMethod::ReinterpretThenTranspose) - { - gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get()); - } - else if (_wt_method == WeightTransformMethod::ReshapeThenTranspose || - _wt_method == WeightTransformMethod::FusedReshapeAndTranspose) - { - gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); - } + gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reinterpreted_wei.get()); + } + else if (use_reshaped_wei) + { + gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); } // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions |