aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r--src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp58
1 files changed, 13 insertions, 45 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
index 7b7b68a93b..e50099df1f 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -53,13 +53,11 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
};
-
PixelValue type_min{};
PixelValue type_max{};
-
std::tie(type_min, type_max) = get_min_max(data_type);
- int32_t min_activation = type_min.get<int32_t>();
- int32_t max_activation = type_max.get<int32_t>();
+ int32_t min_activation = type_min.get<int32_t>();
+ int32_t max_activation = type_max.get<int32_t>();
if(supported_acts.count(act.activation()) != 0)
{
std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo);
@@ -89,8 +87,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect
}
} // namespace
-CpuGemmDirectConv2d::CpuGemmDirectConv2d()
- : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()),
+CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
+ : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)),
_activation_func(std::make_unique<CpuActivation>()),
_weights_permute_func(std::make_unique<CpuPermute>()),
_permuted_weights_info(),
@@ -165,8 +163,6 @@ Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *
}
void CpuGemmDirectConv2d::run(ITensorPack &tensors)
{
- import_workspace_memory(tensors);
-
prepare(tensors);
_gemm_asm_func->run(tensors);
@@ -174,14 +170,22 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors)
{
_activation_func->run(tensors);
}
+}
- free_imported_workspace_memory();
+void CpuGemmDirectConv2d::allocate_permuted_weights()
+{
+ // TODO: This function will be removed when memory injection is implemeted.
+ ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr);
+ _permuted_weights->allocator()->free();
+ _permuted_weights->allocator()->init(_permuted_weights_info);
+ _permuted_weights->allocator()->allocate();
}
void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
+ allocate_permuted_weights();
ITensorPack permute_tensors
{
{ TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) },
@@ -198,41 +202,5 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
}
}
-experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
-{
- experimental::MemoryRequirements req = _gemm_asm_func->workspace();
-
- auto index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_0);
-
- if(req.size() > 0)
- {
- index = req.back().slot + 1;
-
- constexpr auto max_index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_4);
- ARM_COMPUTE_UNUSED(max_index); // in order to prevent build error with assertion is disabled.
- ARM_COMPUTE_ERROR_ON(index > max_index);
- }
-
- req.emplace_back(index, _permuted_weights_info.total_size(), 0);
-
- return req;
-}
-
-void CpuGemmDirectConv2d::import_workspace_memory(ITensorPack &tensors)
-{
- auto imported_tensor = tensors.get_tensor(workspace().back().slot);
-
- ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor);
-
- auto imported_memory = imported_tensor->buffer();
- _permuted_weights->allocator()->init(_permuted_weights_info);
- _permuted_weights->allocator()->import_memory(imported_memory);
-}
-
-void CpuGemmDirectConv2d::free_imported_workspace_memory()
-{
- _permuted_weights->allocator()->free();
-}
-
} // namespace cpu
} // namespace arm_compute \ No newline at end of file