From d7316eb877cc4ff8573219374335e917b19a0203 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 16 Jun 2021 11:14:41 +0100 Subject: Port NEGEMMConv2d to memory injecting interface Resolves: COMPMID-4506, COMPMID-4570 Change-Id: I6d37a06da141f1fcfcaa8525322a319cb0234791 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5824 Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp | 77 +++--- src/runtime/cpu/operators/CpuGemmDirectConv2d.h | 40 ++-- .../operators/internal/CpuGemmAssemblyDispatch.cpp | 265 ++++++--------------- .../operators/internal/CpuGemmAssemblyDispatch.h | 20 +- 4 files changed, 145 insertions(+), 257 deletions(-) (limited to 'src/runtime/cpu') diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp index e50099df1f..c2e9f24ff6 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp @@ -26,10 +26,10 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/FunctionDescriptors.h" -#include "arm_compute/runtime/NEON/NEScheduler.h" -#include "src/runtime/cpu/operators/CpuActivation.h" -#include "src/runtime/cpu/operators/CpuPermute.h" -#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h" +#include "src/core/helpers/MemoryHelpers.h" +#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h" + +#include "support/Cast.h" #include @@ -37,6 +37,9 @@ namespace arm_compute { namespace cpu { +using namespace arm_compute::experimental; +using namespace arm_compute::utils::cast; + namespace { GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act) @@ -87,12 +90,14 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect } } // namespace -CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr &memory_manager) - : _gemm_asm_func(std::make_unique(memory_manager)), +CpuGemmDirectConv2d::CpuGemmDirectConv2d() + : _gemm_asm_func(std::make_unique()), _activation_func(std::make_unique()), _weights_permute_func(std::make_unique()), - _permuted_weights_info(), - _permuted_weights(std::make_unique()) + _aux_mem(AuxTensorIdx::Count), + _perm_weights(), + _run_activation(false), + _is_prepared(false) { } @@ -106,8 +111,10 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w biases != nullptr ? biases : nullptr, dst, info)); - _original_weights_info = weights; - _weights_permute_func->configure(weights, &_permuted_weights_info, PermutationVector{ 3, 0, 1, 2 }); + _run_activation = info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info); + _is_prepared = false; + + _weights_permute_func->configure(weights, &_perm_weights, PermutationVector{ 3, 0, 1, 2 }); // Configure assembly dispatch cpu::AsmGemmInfo asm_info = init_assembly_metadata(info, false); @@ -115,13 +122,27 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w { asm_info.output_stage = calculate_output_stage_metadata(src, weights, dst, info.act_info); } - _gemm_asm_func->configure(src, &_permuted_weights_info, biases, dst, asm_info); + _gemm_asm_func->configure(src, &_perm_weights, biases, dst, asm_info); // Configure activation - if(info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info)) + if(_run_activation) { _activation_func->configure(dst, nullptr, info.act_info); - _run_activation = true; + } + + // Add auxiliary memory requirements of the assembly dispatch + auto asm_mem_req = _gemm_asm_func->workspace(); + _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace]; + _aux_mem[Pretranspose] = asm_mem_req[Pretranspose]; + + if(_aux_mem[Pretranspose].size > 0) + { + // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch + _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, weights->total_size()); + } + else + { + _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size()); } } Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info) @@ -172,35 +193,29 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors) } } -void CpuGemmDirectConv2d::allocate_permuted_weights() -{ - // TODO: This function will be removed when memory injection is implemeted. - ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr); - _permuted_weights->allocator()->free(); - _permuted_weights->allocator()->init(_permuted_weights_info); - _permuted_weights->allocator()->allocate(); -} - void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { - allocate_permuted_weights(); - ITensorPack permute_tensors - { - { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) }, - { TensorType::ACL_DST, _permuted_weights.get() }, - }; + const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1); + ITensor *weights_aux = utils::cast::polymorphic_cast(tensors.get_tensor(offset_int_vec(PermutedWeights))); + ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux); + CpuAuxTensorHandler permuted_weights(_perm_weights, *weights_aux); + ITensorPack permute_tensors{ { ACL_SRC, weights }, { ACL_DST, permuted_weights.get() } }; _weights_permute_func->run(permute_tensors); - tensors.get_const_tensor(TensorType::ACL_SRC_1)->mark_as_unused(); + tensors.add_const_tensor(ACL_SRC_1, permuted_weights.get()); + // Call prepare of assembly dispatch + _gemm_asm_func->prepare(tensors); - // switch the original tensor with permuted tensor - tensors.add_const_tensor(TensorType::ACL_SRC_1, _permuted_weights.get()); _is_prepared = true; } } +experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const +{ + return _aux_mem; +} } // namespace cpu } // namespace arm_compute \ No newline at end of file diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h index 6aa17c2349..b572f36a3a 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h @@ -24,14 +24,12 @@ #ifndef ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H #define ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H -#include "arm_compute/core/ITensorInfo.h" -#include "arm_compute/core/experimental/Types.h" -#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/core/TensorInfo.h" #include "src/core/common/Macros.h" -#include "src/core/cpu/ICpuKernel.h" #include "src/runtime/cpu/ICpuOperator.h" - -#include +#include "src/runtime/cpu/operators/CpuActivation.h" +#include "src/runtime/cpu/operators/CpuPermute.h" +#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h" namespace arm_compute { @@ -40,15 +38,11 @@ class ITensor; struct Conv2dInfo; namespace cpu { -class CpuGemmAssemblyDispatch; -class CpuActivation; -class CpuPermute; - class CpuGemmDirectConv2d : public ICpuOperator { public: /** Constructor */ - CpuGemmDirectConv2d(const std::shared_ptr &memory_manager = nullptr); + CpuGemmDirectConv2d(); ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d); /** Destructor */ ~CpuGemmDirectConv2d(); @@ -89,22 +83,24 @@ public: // Inherited methods overridden: void run(ITensorPack &tensors) override; void prepare(ITensorPack &constants) override; + experimental::MemoryRequirements workspace() const override; private: + enum AuxTensorIdx + { + AsmGemmWorkspace = 0, + Pretranspose, + PermutedWeights, + Count + }; + std::unique_ptr _gemm_asm_func; std::unique_ptr _activation_func; std::unique_ptr _weights_permute_func; - const ITensorInfo *_original_weights_info{}; - TensorInfo _permuted_weights_info; - std::unique_ptr _permuted_weights{ nullptr }; - bool _is_prepared{ false }; - bool _run_activation{ false }; - - /** Function to allocated a tensor for permuted weights - * - * @note This function will be removed when memory injection is properly implemented. - */ - void allocate_permuted_weights(); + experimental::MemoryRequirements _aux_mem; + TensorInfo _perm_weights; + bool _run_activation; + bool _is_prepared; }; } // namespace cpu } // namespace arm_compute diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 1101e05a0d..79ea1cb5a7 100644 --- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -27,15 +27,18 @@ #include "src/core/CPP/Validate.h" #include "src/core/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h" #include "src/core/cpu/kernels/assembly/arm_gemm.hpp" +#include "src/core/helpers/MemoryHelpers.h" #include "src/core/utils/AssemblyUtils.h" +#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h" #include -#include namespace arm_compute { namespace cpu { +using namespace arm_compute::experimental; + namespace { struct free_delete @@ -113,103 +116,27 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp return scheduling_hint; } -template -class FallbackTransform : public ITransformWeights -{ -public: - FallbackTransform() noexcept {}; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - FallbackTransform(const FallbackTransform &) = delete; - /** Default move constructor */ - FallbackTransform(FallbackTransform &&) = default; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - FallbackTransform &operator=(const FallbackTransform &) = delete; - /** Default move assignment operator */ - FallbackTransform &operator=(FallbackTransform &&) = default; - void run() override - { - _output.allocator()->allocate(); - ARM_COMPUTE_ERROR_ON(_output.buffer() == nullptr); - _gemm_kernel_asm->pretranspose_B_array(_output.buffer(), _in1_ptr, _ldb, _multi_stride_b); - _reshape_run = true; - } - - void release() override - { - _output.allocator()->free(); - } - - ITensor *get_weights() override - { - return &_output; - } - - uint32_t uid() override - { - uint32_t id = (_B_pretranspose_size | 0x80000000); - return id; - } - - void configure(size_t B_pretranspose_size, unsigned int alignment) - { - _output.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment) }, 1, DataType::S8), alignment); - _B_pretranspose_size = B_pretranspose_size; - } - - void set_pretranspose(ITensor *tensor) - { - if(!_reshape_run) - { - _gemm_kernel_asm->set_pretransposed_B_data(tensor->buffer()); - } - } - - void set_args(const int ldb, const TypeInput *in1_ptr, const int multi_stride_b, std::shared_ptr> gemm_kernel_asm) - { - _ldb = ldb; - _in1_ptr = in1_ptr; - _multi_stride_b = multi_stride_b; - _gemm_kernel_asm = gemm_kernel_asm; - } - -private: - Tensor _output{}; - int _ldb{}; - const TypeInput *_in1_ptr{}; - int _multi_stride_b{}; - size_t _B_pretranspose_size{}; - std::shared_ptr> _gemm_kernel_asm{ nullptr }; -}; - /** Fallback in case ACL doesn't have a function */ template class Fallback : public CpuGemmAssemblyDispatch::IFallback { public: /** Destructor */ - ~Fallback() - { - if(_pretranspose && !(is_weight_managed())) - { - delete _pretranspose; - } - } + ~Fallback() = default; /** Initialise the functions's input and output. * - * @param[in] a Input tensor containing the Matrix A. - * @param[in] b Input tensor containing the Matrix B. - * @param[in] c Input tensor containing the Matrix C. - * @param[out] d Output tensor to store the result of matrix multiplication. - * @param[in] args Matrix multiplication information. - * @param[in] gemm_info GEMM meta-data - * @param[in] memory_group Memory group to be used by the function. - * @param[in] weights_manager Weights manager to be used by the function. - * @param[in] os Output stage meta-data. + * @param[in] a Input tensor containing the Matrix A. + * @param[in] b Input tensor containing the Matrix B. + * @param[in] c Input tensor containing the Matrix C. + * @param[out] d Output tensor to store the result of matrix multiplication. + * @param[in] args Matrix multiplication information. + * @param[in] gemm_info GEMM meta-data + * @param[in] os Output stage meta-data. */ void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, - MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); + const OutputStage &os = {}); /** Set requantization shifts to be used * @@ -231,16 +158,17 @@ public: // Inherited methods overridden: void run(ITensorPack &tensors) override; void prepare(ITensorPack &tensors) override; - bool is_configured() const override; + bool is_configured() const override; + experimental::MemoryRequirements workspace() const override; private: - /** Allocate a workspace tensor. - * - * @param[in] workspace_size Size to allocate. - * @param[in] memory_group Tensor memory group. - * @param[in] alignment Workspace memory alignment. - */ - void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment); + enum AuxTensorIdx + { + AsmGemmWorkspace = 0, + Pretranspose, + Count + }; + /** Configure the indirect buffer * * @param[in] a Input tensor containing the Matrix A. @@ -256,18 +184,14 @@ private: std::shared_ptr> _gemm_kernel_asm{ nullptr }; /** Optimised Arm® Neon™ kernel */ std::unique_ptr _optimised_kernel{ nullptr }; - /** GEMM workspace */ - Tensor _workspace{}; - /** Pre-transpose tensor */ - ITensor *_pretranspose{ nullptr }; + /** Assembly GEMM workspace tensor info */ + TensorInfo _workspace_info{}; + /** Pre-transpose tensor info */ + TensorInfo _pretranspose_info{}; /** Prepared flag */ bool _is_prepared{ false }; /** GEMM meta-data */ AsmGemmInfo _gemm_info{}; - /** Weights manager */ - IWeightsManager *_weights_manager{ nullptr }; - /** Weights transform object */ - FallbackTransform _weights_transform{}; /** GEMM kernel description */ arm_gemm::KernelDescription _kernel_info{}; /** Per channel quantization shifts */ @@ -279,27 +203,9 @@ private: /** Indirect buffer */ std::unique_ptr _indirect_arg{}; std::unique_ptr _indirect_buf{}; - std::vector _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - - bool is_weight_managed() - { - // TODO (COMPMID-4539): This function should do the following: - // _weights_manager && _weights_manager->are_weights_managed(_b) - // , where _b is the second Tensor that is used to be given to the configure(). - // Currently, however, weight manager is disabled to make this class stateless. - // This should be revisited in the future. - return false; - } - - void acquire_managed_weight() - { - // TODO (COMPMID-4539): This function should do the following: - // _pretranspose = _weights_manager->acquire(_b, &_weights_transform); - // , where _b is the second Tensor that is used to be given to the configure(). - // Currently, however, weight manager is disabled to make this class stateless. - _pretranspose = nullptr; - } + std::vector _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; }; template @@ -439,12 +345,11 @@ void Fallback::configure_indirect(const ITen template void Fallback::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, - MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) + const OutputStage &os) { ARM_COMPUTE_UNUSED(c); arm_gemm::GemmConfig gemm_cfg; - _kernel_info = arm_gemm::get_gemm_method(args, os); - _weights_manager = weights_manager; + _kernel_info = arm_gemm::get_gemm_method(args, os); if(_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) { gemm_cfg.filter = _kernel_info.name; @@ -461,13 +366,10 @@ void Fallback::configure(const ITensorInfo * auto acl_gemm_wrapper = std::make_unique>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); - const size_t workspace_size = _gemm_kernel_asm->get_working_size(); - if(workspace_size > 0) - { - // Allocate workspace - const unsigned int alignment = 4096; - allocate_workspace(workspace_size, memory_group, alignment); - } + const size_t workspace_size = _gemm_kernel_asm->get_working_size(); + const unsigned int alignment = 4096; + _workspace_info = TensorInfo(TensorShape(workspace_size), 1, DataType::U8); + _aux_mem[AsmGemmWorkspace] = MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment); //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001 @@ -487,16 +389,8 @@ void Fallback::configure(const ITensorInfo * // Forcing 128-byte alignment (required by 32-bit kernels) const unsigned int alignment = 128; const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); - if(is_weight_managed()) - { - _weights_transform.configure(B_pretranspose_size, alignment); - acquire_managed_weight(); - } - else - { - _pretranspose = new Tensor(); - static_cast(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment) }, 1, DataType::S8), alignment); - } + _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8); + _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment); } // Handle indirect GEMM convolution @@ -509,10 +403,11 @@ void Fallback::configure(const ITensorInfo * template void Fallback::prepare(ITensorPack &tensors) { - auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); - auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); if(!_is_prepared) { + auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); + // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. if(c && c->info()->data_type() == DataType::S32) { @@ -526,24 +421,9 @@ void Fallback::prepare(ITensorPack &tensors) const auto in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); - if(is_weight_managed()) - { - _weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm); - _weights_manager->run(b, &_weights_transform); - - // If we didn't run the reshape function, set the pretransposed buffer - if(!_weights_transform.is_reshape_run()) - { - _weights_transform.set_pretranspose(_pretranspose); - } - } - else - { - static_cast(_pretranspose)->allocator()->allocate(); - ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr); - _gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b); - b->mark_as_unused(); - } + CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false); + ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr); + _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b); } if(_gemm_info.method == AsmConvMethod::Indirect) @@ -556,18 +436,15 @@ void Fallback::prepare(ITensorPack &tensors) } template -void Fallback::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment) +bool Fallback::is_configured() const { - ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0"); - _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment) }, 1, DataType::S8), alignment); - memory_group.manage(&_workspace); - _workspace.allocator()->allocate(); + return _optimised_kernel != nullptr; } template -bool Fallback::is_configured() const +experimental::MemoryRequirements Fallback::workspace() const { - return _optimised_kernel != nullptr; + return _aux_mem; } template @@ -609,9 +486,10 @@ void Fallback::run(ITensorPack &tensors) const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type()); // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads - if(_workspace.buffer() != nullptr) + CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false); + if(workspace.get()->buffer() != nullptr) { - _gemm_kernel_asm->set_working_space(reinterpret_cast(_workspace.buffer())); + _gemm_kernel_asm->set_working_space(reinterpret_cast(workspace.get()->buffer())); const unsigned int split_dim = scheduling_hint.split_dimension(); const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size(); unsigned int num_threads = NEScheduler::get().num_threads(); @@ -656,9 +534,9 @@ void Fallback::run(ITensorPack &tensors) } template -void create_arm_gemm(std::unique_ptr &arm_gemm, MemoryGroup &memory_group, - const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info, - IWeightsManager *weights_manager) +void create_arm_gemm(std::unique_ptr &arm_gemm, + const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, + arm_gemm::Activation activation, const AsmGemmInfo &info) { Params p = extract_parameters(a, b, d, info); const CPUInfo &ci = NEScheduler::get().cpu_info(); @@ -668,14 +546,14 @@ void create_arm_gemm(std::unique_ptr &arm_ge // Create arm_gemm fallback auto fallback = std::make_unique>(); - fallback->configure(a, b, c, d, args, info, memory_group, weights_manager); + fallback->configure(a, b, c, d, args, info); arm_gemm = std::move(fallback); } template -void create_arm_gemm_quant(std::unique_ptr &arm_gemm, MemoryGroup &memory_group, - const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info, - IWeightsManager *weights_manager) +void create_arm_gemm_quant(std::unique_ptr &arm_gemm, + const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, + arm_gemm::Activation activation, const AsmGemmInfo &info) { ARM_COMPUTE_UNUSED(activation); Params p = extract_parameters(a, b, d, info); @@ -713,14 +591,13 @@ void create_arm_gemm_quant(std::unique_ptr & } // Configure fallback - fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info); + fallback->configure(a, b, c, d, args, info, gemm_requant_info); arm_gemm = std::move(fallback); } - } //namespace -CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(std::shared_ptr memory_manager, IWeightsManager *weights_manager) - : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager) +CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() + : _arm_gemm(nullptr) { } @@ -775,40 +652,40 @@ void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo switch(a->data_type()) { case DataType::F32: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: if(d->data_type() == DataType::S32) { - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm_quant(_arm_gemm, a, b, c, d, act, info); } break; case DataType::S8: case DataType::QASYMM8_SIGNED: if(d->data_type() == DataType::S32) { - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info); } else { - create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm_quant(_arm_gemm, a, b, c, d, act, info); } break; #endif /* __aarch64__ */ #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info); break; #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: @@ -829,10 +706,14 @@ bool CpuGemmAssemblyDispatch::is_configured() const void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) { - MemoryGroupResourceScope scope_mg(_memory_group); - ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); _arm_gemm->run(tensors); } + +experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const +{ + ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); + return _arm_gemm->workspace(); +} } // namespace cpu } // namespace arm_compute diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h index ffc097c75c..355273adeb 100644 --- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -24,10 +24,6 @@ #ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H #define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H -#include "arm_compute/runtime/IMemoryManager.h" -#include "arm_compute/runtime/IWeightsManager.h" -#include "arm_compute/runtime/MemoryGroup.h" -#include "arm_compute/runtime/Tensor.h" #include "src/core/common/Macros.h" #include "src/runtime/cpu/ICpuOperator.h" @@ -62,7 +58,7 @@ class CpuGemmAssemblyDispatch : public ICpuOperator { public: /** Constructor */ - CpuGemmAssemblyDispatch(std::shared_ptr memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); + CpuGemmAssemblyDispatch(); /** Defautl destructor */ ~CpuGemmAssemblyDispatch() = default; @@ -71,10 +67,11 @@ public: class IFallback { public: - virtual void run(ITensorPack &tensors) = 0; - virtual void prepare(ITensorPack &tensors) = 0; - virtual bool is_configured() const = 0; - virtual ~IFallback() = default; + virtual void run(ITensorPack &tensors) = 0; + virtual void prepare(ITensorPack &tensors) = 0; + virtual experimental::MemoryRequirements workspace() const = 0; + virtual bool is_configured() const = 0; + virtual ~IFallback() = default; }; public: @@ -115,11 +112,10 @@ public: // Inherited methods overridden: void prepare(ITensorPack &tensors) override; void run(ITensorPack &tensors) override; + experimental::MemoryRequirements workspace() const override; private: - std::unique_ptr _arm_gemm; /**< Interface for the arm_gemm fallback */ - MemoryGroup _memory_group; /**< Function memory group */ - IWeightsManager *_weights_manager; /**< Pointer to the weights manager */ + std::unique_ptr _arm_gemm; /**< Interface for the arm_gemm fallback */ }; } // namespace cpu } // namespace arm_compute -- cgit v1.2.1