diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMConv2d.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMConv2d.cpp | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMConv2d.cpp b/src/runtime/NEON/functions/NEGEMMConv2d.cpp index 860b6bb4e1..b8349d98db 100644 --- a/src/runtime/NEON/functions/NEGEMMConv2d.cpp +++ b/src/runtime/NEON/functions/NEGEMMConv2d.cpp @@ -22,9 +22,11 @@ * SOFTWARE. */ #include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h" + #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/NEON/NEScheduler.h" +#include "src/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" #include <set> @@ -81,9 +83,13 @@ AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect) } // namespace NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr<IMemoryManager> &memory_manager) - : _gemm_asm_func(memory_manager), _activation_func(), _weights_permute_func(), _original_weights(nullptr), _permuted_weights(), _is_prepared(false), _run_activation(false) + : _gemm_asm_func(std::make_unique<NEGEMMAssemblyDispatch>(memory_manager)), _activation_func(), _weights_permute_func(), _original_weights(nullptr), _permuted_weights(), _is_prepared(false), + _run_activation(false) { } + +NEGEMMConv2d::~NEGEMMConv2d() = default; + void NEGEMMConv2d::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const Conv2dInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); @@ -101,10 +107,10 @@ void NEGEMMConv2d::configure(ITensor *input, const ITensor *weights, const ITens { asm_info.output_stage = calculate_output_stage_metadata(input->info(), weights->info(), output->info(), info.act_info); } - _gemm_asm_func.configure(input, &_permuted_weights, biases, output, asm_info); + _gemm_asm_func->configure(input, &_permuted_weights, biases, output, asm_info); // Configure activation - if(info.act_info.enabled() && !_gemm_asm_func.is_activation_supported(info.act_info)) + if(info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info)) { _activation_func.configure(output, nullptr, info.act_info); _run_activation = true; @@ -150,7 +156,7 @@ void NEGEMMConv2d::run() { prepare(); - _gemm_asm_func.run(); + _gemm_asm_func->run(); if(_run_activation) { _activation_func.run(); |