diff options
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r-- | src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp | 77 |
1 files changed, 46 insertions, 31 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp index e50099df1f..c2e9f24ff6 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp @@ -26,10 +26,10 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/FunctionDescriptors.h" -#include "arm_compute/runtime/NEON/NEScheduler.h" -#include "src/runtime/cpu/operators/CpuActivation.h" -#include "src/runtime/cpu/operators/CpuPermute.h" -#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h" +#include "src/core/helpers/MemoryHelpers.h" +#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h" + +#include "support/Cast.h" #include <set> @@ -37,6 +37,9 @@ namespace arm_compute { namespace cpu { +using namespace arm_compute::experimental; +using namespace arm_compute::utils::cast; + namespace { GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act) @@ -87,12 +90,14 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect } } // namespace -CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager) - : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)), +CpuGemmDirectConv2d::CpuGemmDirectConv2d() + : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()), _activation_func(std::make_unique<CpuActivation>()), _weights_permute_func(std::make_unique<CpuPermute>()), - _permuted_weights_info(), - _permuted_weights(std::make_unique<Tensor>()) + _aux_mem(AuxTensorIdx::Count), + _perm_weights(), + _run_activation(false), + _is_prepared(false) { } @@ -106,8 +111,10 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w biases != nullptr ? biases : nullptr, dst, info)); - _original_weights_info = weights; - _weights_permute_func->configure(weights, &_permuted_weights_info, PermutationVector{ 3, 0, 1, 2 }); + _run_activation = info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info); + _is_prepared = false; + + _weights_permute_func->configure(weights, &_perm_weights, PermutationVector{ 3, 0, 1, 2 }); // Configure assembly dispatch cpu::AsmGemmInfo asm_info = init_assembly_metadata(info, false); @@ -115,13 +122,27 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w { asm_info.output_stage = calculate_output_stage_metadata(src, weights, dst, info.act_info); } - _gemm_asm_func->configure(src, &_permuted_weights_info, biases, dst, asm_info); + _gemm_asm_func->configure(src, &_perm_weights, biases, dst, asm_info); // Configure activation - if(info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info)) + if(_run_activation) { _activation_func->configure(dst, nullptr, info.act_info); - _run_activation = true; + } + + // Add auxiliary memory requirements of the assembly dispatch + auto asm_mem_req = _gemm_asm_func->workspace(); + _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace]; + _aux_mem[Pretranspose] = asm_mem_req[Pretranspose]; + + if(_aux_mem[Pretranspose].size > 0) + { + // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch + _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, weights->total_size()); + } + else + { + _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size()); } } Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info) @@ -172,35 +193,29 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors) } } -void CpuGemmDirectConv2d::allocate_permuted_weights() -{ - // TODO: This function will be removed when memory injection is implemeted. - ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr); - _permuted_weights->allocator()->free(); - _permuted_weights->allocator()->init(_permuted_weights_info); - _permuted_weights->allocator()->allocate(); -} - void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { - allocate_permuted_weights(); - ITensorPack permute_tensors - { - { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) }, - { TensorType::ACL_DST, _permuted_weights.get() }, - }; + const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1); + ITensor *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights))); + ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux); + CpuAuxTensorHandler permuted_weights(_perm_weights, *weights_aux); + ITensorPack permute_tensors{ { ACL_SRC, weights }, { ACL_DST, permuted_weights.get() } }; _weights_permute_func->run(permute_tensors); - tensors.get_const_tensor(TensorType::ACL_SRC_1)->mark_as_unused(); + tensors.add_const_tensor(ACL_SRC_1, permuted_weights.get()); + // Call prepare of assembly dispatch + _gemm_asm_func->prepare(tensors); - // switch the original tensor with permuted tensor - tensors.add_const_tensor(TensorType::ACL_SRC_1, _permuted_weights.get()); _is_prepared = true; } } +experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const +{ + return _aux_mem; +} } // namespace cpu } // namespace arm_compute
\ No newline at end of file |