diff options
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r-- | src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp | 58 |
1 files changed, 13 insertions, 45 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp index 7b7b68a93b..e50099df1f 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp @@ -53,13 +53,11 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU }; - PixelValue type_min{}; PixelValue type_max{}; - std::tie(type_min, type_max) = get_min_max(data_type); - int32_t min_activation = type_min.get<int32_t>(); - int32_t max_activation = type_max.get<int32_t>(); + int32_t min_activation = type_min.get<int32_t>(); + int32_t max_activation = type_max.get<int32_t>(); if(supported_acts.count(act.activation()) != 0) { std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo); @@ -89,8 +87,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect } } // namespace -CpuGemmDirectConv2d::CpuGemmDirectConv2d() - : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()), +CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager) + : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)), _activation_func(std::make_unique<CpuActivation>()), _weights_permute_func(std::make_unique<CpuPermute>()), _permuted_weights_info(), @@ -165,8 +163,6 @@ Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo * } void CpuGemmDirectConv2d::run(ITensorPack &tensors) { - import_workspace_memory(tensors); - prepare(tensors); _gemm_asm_func->run(tensors); @@ -174,14 +170,22 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors) { _activation_func->run(tensors); } +} - free_imported_workspace_memory(); +void CpuGemmDirectConv2d::allocate_permuted_weights() +{ + // TODO: This function will be removed when memory injection is implemeted. + ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr); + _permuted_weights->allocator()->free(); + _permuted_weights->allocator()->init(_permuted_weights_info); + _permuted_weights->allocator()->allocate(); } void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { + allocate_permuted_weights(); ITensorPack permute_tensors { { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) }, @@ -198,41 +202,5 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) } } -experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const -{ - experimental::MemoryRequirements req = _gemm_asm_func->workspace(); - - auto index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_0); - - if(req.size() > 0) - { - index = req.back().slot + 1; - - constexpr auto max_index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_4); - ARM_COMPUTE_UNUSED(max_index); // in order to prevent build error with assertion is disabled. - ARM_COMPUTE_ERROR_ON(index > max_index); - } - - req.emplace_back(index, _permuted_weights_info.total_size(), 0); - - return req; -} - -void CpuGemmDirectConv2d::import_workspace_memory(ITensorPack &tensors) -{ - auto imported_tensor = tensors.get_tensor(workspace().back().slot); - - ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor); - - auto imported_memory = imported_tensor->buffer(); - _permuted_weights->allocator()->init(_permuted_weights_info); - _permuted_weights->allocator()->import_memory(imported_memory); -} - -void CpuGemmDirectConv2d::free_imported_workspace_memory() -{ - _permuted_weights->allocator()->free(); -} - } // namespace cpu } // namespace arm_compute
\ No newline at end of file |