aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r--src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp58
1 files changed, 45 insertions, 13 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
index b47a08a5e9..7b7b68a93b 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -53,8 +53,10 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
};
- PixelValue type_min{};
- PixelValue type_max{};
+
+ PixelValue type_min{};
+ PixelValue type_max{};
+
std::tie(type_min, type_max) = get_min_max(data_type);
int32_t min_activation = type_min.get<int32_t>();
int32_t max_activation = type_max.get<int32_t>();
@@ -87,8 +89,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect
}
} // namespace
-CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
- : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)),
+CpuGemmDirectConv2d::CpuGemmDirectConv2d()
+ : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()),
_activation_func(std::make_unique<CpuActivation>()),
_weights_permute_func(std::make_unique<CpuPermute>()),
_permuted_weights_info(),
@@ -163,6 +165,8 @@ Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *
}
void CpuGemmDirectConv2d::run(ITensorPack &tensors)
{
+ import_workspace_memory(tensors);
+
prepare(tensors);
_gemm_asm_func->run(tensors);
@@ -170,22 +174,14 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors)
{
_activation_func->run(tensors);
}
-}
-void CpuGemmDirectConv2d::allocate_permuted_weights()
-{
- // TODO: This function will be removed when memory injection is implemeted.
- ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr);
- _permuted_weights->allocator()->free();
- _permuted_weights->allocator()->init(_permuted_weights_info);
- _permuted_weights->allocator()->allocate();
+ free_imported_workspace_memory();
}
void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
- allocate_permuted_weights();
ITensorPack permute_tensors
{
{ TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) },
@@ -202,5 +198,41 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
}
}
+experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
+{
+ experimental::MemoryRequirements req = _gemm_asm_func->workspace();
+
+ auto index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_0);
+
+ if(req.size() > 0)
+ {
+ index = req.back().slot + 1;
+
+ constexpr auto max_index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_4);
+ ARM_COMPUTE_UNUSED(max_index); // in order to prevent build error with assertion is disabled.
+ ARM_COMPUTE_ERROR_ON(index > max_index);
+ }
+
+ req.emplace_back(index, _permuted_weights_info.total_size(), 0);
+
+ return req;
+}
+
+void CpuGemmDirectConv2d::import_workspace_memory(ITensorPack &tensors)
+{
+ auto imported_tensor = tensors.get_tensor(workspace().back().slot);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor);
+
+ auto imported_memory = imported_tensor->buffer();
+ _permuted_weights->allocator()->init(_permuted_weights_info);
+ _permuted_weights->allocator()->import_memory(imported_memory);
+}
+
+void CpuGemmDirectConv2d::free_imported_workspace_memory()
+{
+ _permuted_weights->allocator()->free();
+}
+
} // namespace cpu
} // namespace arm_compute \ No newline at end of file