aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp')
-rw-r--r--src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp77
1 files changed, 46 insertions, 31 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
index e50099df1f..c2e9f24ff6 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -26,10 +26,10 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "src/runtime/cpu/operators/CpuActivation.h"
-#include "src/runtime/cpu/operators/CpuPermute.h"
-#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
+#include "src/core/helpers/MemoryHelpers.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
+
+#include "support/Cast.h"
#include <set>
@@ -37,6 +37,9 @@ namespace arm_compute
{
namespace cpu
{
+using namespace arm_compute::experimental;
+using namespace arm_compute::utils::cast;
+
namespace
{
GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act)
@@ -87,12 +90,14 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect
}
} // namespace
-CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
- : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)),
+CpuGemmDirectConv2d::CpuGemmDirectConv2d()
+ : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()),
_activation_func(std::make_unique<CpuActivation>()),
_weights_permute_func(std::make_unique<CpuPermute>()),
- _permuted_weights_info(),
- _permuted_weights(std::make_unique<Tensor>())
+ _aux_mem(AuxTensorIdx::Count),
+ _perm_weights(),
+ _run_activation(false),
+ _is_prepared(false)
{
}
@@ -106,8 +111,10 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w
biases != nullptr ? biases : nullptr,
dst,
info));
- _original_weights_info = weights;
- _weights_permute_func->configure(weights, &_permuted_weights_info, PermutationVector{ 3, 0, 1, 2 });
+ _run_activation = info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info);
+ _is_prepared = false;
+
+ _weights_permute_func->configure(weights, &_perm_weights, PermutationVector{ 3, 0, 1, 2 });
// Configure assembly dispatch
cpu::AsmGemmInfo asm_info = init_assembly_metadata(info, false);
@@ -115,13 +122,27 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w
{
asm_info.output_stage = calculate_output_stage_metadata(src, weights, dst, info.act_info);
}
- _gemm_asm_func->configure(src, &_permuted_weights_info, biases, dst, asm_info);
+ _gemm_asm_func->configure(src, &_perm_weights, biases, dst, asm_info);
// Configure activation
- if(info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info))
+ if(_run_activation)
{
_activation_func->configure(dst, nullptr, info.act_info);
- _run_activation = true;
+ }
+
+ // Add auxiliary memory requirements of the assembly dispatch
+ auto asm_mem_req = _gemm_asm_func->workspace();
+ _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
+ _aux_mem[Pretranspose] = asm_mem_req[Pretranspose];
+
+ if(_aux_mem[Pretranspose].size > 0)
+ {
+ // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
+ _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, weights->total_size());
+ }
+ else
+ {
+ _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
}
}
Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info)
@@ -172,35 +193,29 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors)
}
}
-void CpuGemmDirectConv2d::allocate_permuted_weights()
-{
- // TODO: This function will be removed when memory injection is implemeted.
- ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr);
- _permuted_weights->allocator()->free();
- _permuted_weights->allocator()->init(_permuted_weights_info);
- _permuted_weights->allocator()->allocate();
-}
-
void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
- allocate_permuted_weights();
- ITensorPack permute_tensors
- {
- { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) },
- { TensorType::ACL_DST, _permuted_weights.get() },
- };
+ const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1);
+ ITensor *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux);
+ CpuAuxTensorHandler permuted_weights(_perm_weights, *weights_aux);
+ ITensorPack permute_tensors{ { ACL_SRC, weights }, { ACL_DST, permuted_weights.get() } };
_weights_permute_func->run(permute_tensors);
- tensors.get_const_tensor(TensorType::ACL_SRC_1)->mark_as_unused();
+ tensors.add_const_tensor(ACL_SRC_1, permuted_weights.get());
+ // Call prepare of assembly dispatch
+ _gemm_asm_func->prepare(tensors);
- // switch the original tensor with permuted tensor
- tensors.add_const_tensor(TensorType::ACL_SRC_1, _permuted_weights.get());
_is_prepared = true;
}
}
+experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
+{
+ return _aux_mem;
+}
} // namespace cpu
} // namespace arm_compute \ No newline at end of file