diff options
author | Sang-Hoon Park <sang-hoon.park@arm.com> | 2021-05-18 10:46:00 +0100 |
---|---|---|
committer | Pablo Marquez Tello <pablo.tello@arm.com> | 2021-05-27 16:33:44 +0000 |
commit | b3be45759bdd0749ae3a16fe470820f0d9830ea9 (patch) | |
tree | 10bb8c1c0a049a23c00781c64e993f1b197c0d05 /src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp | |
parent | bc91297c865808ed2c321febc405179f63195ff8 (diff) | |
download | ComputeLibrary-b3be45759bdd0749ae3a16fe470820f0d9830ea9.tar.gz |
Implement memory injection in CpuDirectGemmConv2d
The following operators are now stateless by implementing
memory injection.
- CpuDirectGemmConv2d
- CpuGemmAssemblyDispatch
A test case is added to test if CpuDirectGemmConv2d can
run on different group of tensors with a single configure.
Resolves: COMPMID-4506
Change-Id: I48f44ed41236ca7e18da2de07bdbacc9007a3c5e
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5718
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r-- | src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp | 58 |
1 files changed, 45 insertions, 13 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp index b47a08a5e9..7b7b68a93b 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp @@ -53,8 +53,10 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU }; - PixelValue type_min{}; - PixelValue type_max{}; + + PixelValue type_min{}; + PixelValue type_max{}; + std::tie(type_min, type_max) = get_min_max(data_type); int32_t min_activation = type_min.get<int32_t>(); int32_t max_activation = type_max.get<int32_t>(); @@ -87,8 +89,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect } } // namespace -CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager) - : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)), +CpuGemmDirectConv2d::CpuGemmDirectConv2d() + : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()), _activation_func(std::make_unique<CpuActivation>()), _weights_permute_func(std::make_unique<CpuPermute>()), _permuted_weights_info(), @@ -163,6 +165,8 @@ Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo * } void CpuGemmDirectConv2d::run(ITensorPack &tensors) { + import_workspace_memory(tensors); + prepare(tensors); _gemm_asm_func->run(tensors); @@ -170,22 +174,14 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors) { _activation_func->run(tensors); } -} -void CpuGemmDirectConv2d::allocate_permuted_weights() -{ - // TODO: This function will be removed when memory injection is implemeted. - ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr); - _permuted_weights->allocator()->free(); - _permuted_weights->allocator()->init(_permuted_weights_info); - _permuted_weights->allocator()->allocate(); + free_imported_workspace_memory(); } void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { - allocate_permuted_weights(); ITensorPack permute_tensors { { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) }, @@ -202,5 +198,41 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) } } +experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const +{ + experimental::MemoryRequirements req = _gemm_asm_func->workspace(); + + auto index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_0); + + if(req.size() > 0) + { + index = req.back().slot + 1; + + constexpr auto max_index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_4); + ARM_COMPUTE_UNUSED(max_index); // in order to prevent build error with assertion is disabled. + ARM_COMPUTE_ERROR_ON(index > max_index); + } + + req.emplace_back(index, _permuted_weights_info.total_size(), 0); + + return req; +} + +void CpuGemmDirectConv2d::import_workspace_memory(ITensorPack &tensors) +{ + auto imported_tensor = tensors.get_tensor(workspace().back().slot); + + ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor); + + auto imported_memory = imported_tensor->buffer(); + _permuted_weights->allocator()->init(_permuted_weights_info); + _permuted_weights->allocator()->import_memory(imported_memory); +} + +void CpuGemmDirectConv2d::free_imported_workspace_memory() +{ + _permuted_weights->allocator()->free(); +} + } // namespace cpu } // namespace arm_compute
\ No newline at end of file |