From 1a569a30a2f456ff1a3e0a665201e1c3ab92df80 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Tue, 10 Sep 2019 17:20:34 +0100 Subject: COMPMID-2161 [NEON] Create IWeightManager class Change-Id: I1a9a46da2f98e896b825099151b56d1d8271dd31 Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/1915 Comments-Addressed: Arm Jenkins Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins --- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 163 +++++++++++++++++---- 1 file changed, 131 insertions(+), 32 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 2a4498b0a9..956ded55d2 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -38,7 +38,8 @@ namespace std::unique_ptr create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr memory_manager) + std::shared_ptr memory_manager, + IWeightsManager *weights_manager) { // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() @@ -50,7 +51,7 @@ std::unique_ptr create_function_all_types(const arm_gemm::KernelDescr { return nullptr; } - auto function = support::cpp14::make_unique(memory_manager); + auto function = support::cpp14::make_unique(memory_manager, weights_manager); function->configure(a, b, d, alpha, beta, gemm_info); return std::move(function); } @@ -73,25 +74,95 @@ std::unique_ptr create_function_all_types(const arm_gemm::KernelDescr } } +template +class FallbackTransform : public ITransformWeights +{ +public: + void run() override + { + _output.allocator()->allocate(); + ARM_COMPUTE_ERROR_ON(_output.buffer() == nullptr); + _gemm_kernel_asm->pretranspose_B_array(_output.buffer(), _in1_ptr, _ldb, _multi_stride_b); + _reshape_run = true; + } + + void release() override + { + _output.allocator()->free(); + } + + ITensor *get_weights() override + { + return &_output; + } + + uint32_t uid() override + { + uint32_t id = (_B_pretranspose_size | 0x80000000); + return id; + } + + void configure(size_t B_pretranspose_size, unsigned int alignment) + { + _output.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment); + _B_pretranspose_size = B_pretranspose_size; + } + + void set_pretranspose(ITensor *tensor) + { + if(!_reshape_run) + { + _gemm_kernel_asm->set_pretransposed_B_data(tensor->buffer()); + } + } + + void set_args(const int ldb, const TypeInput *in1_ptr, const int multi_stride_b, std::shared_ptr> gemm_kernel_asm) + { + _ldb = ldb; + _in1_ptr = in1_ptr; + _multi_stride_b = multi_stride_b; + _gemm_kernel_asm = gemm_kernel_asm; + } + +private: + Tensor _output{}; + int _ldb{}; + const TypeInput *_in1_ptr{}; + int _multi_stride_b{}; + size_t _B_pretranspose_size{}; + std::shared_ptr> _gemm_kernel_asm{ nullptr }; +}; + /** Fallback in case ACL doesn't have a function */ template class Fallback : public NEGEMMAssemblyDispatch::IFallback { public: + /** Destructor */ + ~Fallback() + { + // Release memory if we have allocated the memory ourselves + if(_pretranspose && !(_weights_manager && _weights_manager->are_weights_managed(_b))) + { + delete _pretranspose; + } + } + /** Initialise the functions's input and output. * - * @param[in] a Input tensor containing the Matrix A. - * @param[in] b Input tensor containing the Matrix B. - * @param[in] c Input tensor containing the Matrix C. - * @param[out] d Output tensor to store the result of matrix multiplication. - * @param[in] args Matrix multiplication information. - * @param[in] gemm_info GEMM meta-data - * @param[in] memory_group Memory group to be used by the function. - * @param[in] os Output stage meta-data. + * @param[in] a Input tensor containing the Matrix A. + * @param[in] b Input tensor containing the Matrix B. + * @param[in] c Input tensor containing the Matrix C. + * @param[out] d Output tensor to store the result of matrix multiplication. + * @param[in] args Matrix multiplication information. + * @param[in] gemm_info GEMM meta-data + * @param[in] memory_group Memory group to be used by the function. + * @param[in] weights_manager Weights manager to be used by the function. + * @param[in] os Output stage meta-data. */ void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, - MemoryGroup &memory_group, const OutputStage &os = {}); + MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); // Inherited methods overridden: void run() override; @@ -108,7 +179,7 @@ private: void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment); /** Assembly Gemm kernel */ - std::unique_ptr> _gemm_kernel_asm{ nullptr }; + std::shared_ptr> _gemm_kernel_asm{ nullptr }; /** Optimised NEON kernel */ std::unique_ptr _optimised_kernel{ nullptr }; /** Input A */ @@ -130,20 +201,25 @@ private: /** GEMM workspace */ Tensor _workspace{}; /** Pre-transpose tensor */ - Tensor _pretranspose{}; + ITensor *_pretranspose{ nullptr }; /** Prepared flag */ bool _is_prepared{ false }; /** GEMM meta-data */ GEMMInfo _gemm_info{}; + /** Weights manager */ + IWeightsManager *_weights_manager{ nullptr }; + /** Weights transform object */ + FallbackTransform _weights_transform{}; }; template void Fallback::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::GemmArgs args, const GEMMInfo &gemm_info, - MemoryGroup &memory_group, const OutputStage &os) + MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) { arm_gemm::GemmConfig gemm_cfg; const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args, os); + _weights_manager = weights_manager; if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED) { gemm_cfg.filter = gemm_kernel_info.name; @@ -190,7 +266,16 @@ void Fallback::configure(const ITensor *a, c // Forcing 128-byte alignment (required by 32-bit kernels) const unsigned int alignment = 128; const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); - _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment); + if(weights_manager && _weights_manager->are_weights_managed(b)) + { + _weights_transform.configure(B_pretranspose_size, alignment); + _pretranspose = _weights_manager->acquire(b, &_weights_transform); + } + else + { + _pretranspose = new Tensor(); + static_cast(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment); + } } } @@ -208,14 +293,28 @@ void Fallback::prepare() // Pretranspose B if required if(_gemm_kernel_asm->B_pretranspose_required()) { - _pretranspose.allocator()->allocate(); - ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr); const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); const auto in1_ptr = reinterpret_cast(_b->buffer() + _b->info()->offset_first_element_in_bytes()); const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); - _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b); - _b->mark_as_unused(); + if(_weights_manager && _weights_manager->are_weights_managed(_b)) + { + _weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm); + _weights_manager->run(_b, &_weights_transform); + + // If we didn't run the reshape function, set the pretransposed buffer + if(!_weights_transform.is_reshape_run()) + { + _weights_transform.set_pretranspose(_pretranspose); + } + } + else + { + static_cast(_pretranspose)->allocator()->allocate(); + ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr); + _gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b); + _b->mark_as_unused(); + } } _is_prepared = true; @@ -294,7 +393,7 @@ void Fallback::run() template void create_function_or_arm_gemm(std::unique_ptr &acl_function, std::unique_ptr &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr memory_manager) + std::shared_ptr memory_manager, IWeightsManager *weights_manager) { INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); const CPUInfo &ci = NEScheduler::get().cpu_info(); @@ -304,14 +403,14 @@ void create_function_or_arm_gemm(std::unique_ptr &acl_function, std:: // Try to create an ACL function: const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args); - acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager)); + acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager); // If we still don't have an ACL function: if(acl_function == nullptr) { //Fallback onto arm_gemm function if ACL doesn't support this method. auto fallback = support::cpp14::make_unique>(); - fallback->configure(a, b, c, d, args, gemm_info, memory_group); + fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager); arm_gemm = std::move(fallback); } } @@ -319,7 +418,7 @@ void create_function_or_arm_gemm(std::unique_ptr &acl_function, std:: template void create_function_or_arm_gemm_quant(std::unique_ptr &acl_function, std::unique_ptr &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, - std::shared_ptr memory_manager) + std::shared_ptr memory_manager, IWeightsManager *weights_manager) { INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); const CPUInfo &ci = NEScheduler::get().cpu_info(); @@ -339,22 +438,22 @@ void create_function_or_arm_gemm_quant(std::unique_ptr &acl_function, // Try to create an ACL function: const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args, gemm_requant_info); - acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager)); + acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager); // If we still don't have an ACL function: if(acl_function == nullptr) { // Fallback onto arm_gemm function if ACL doesn't support this method. auto fallback = support::cpp14::make_unique>(); - fallback->configure(a, b, c, d, args, gemm_info, memory_group, gemm_requant_info); + fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info); arm_gemm = std::move(fallback); } } } //namespace -NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr memory_manager) - : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager) +NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr memory_manager, IWeightsManager *weights_manager) + : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager), _weights_manager(weights_manager) { } @@ -390,27 +489,27 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const switch(a->info()->data_type()) { case DataType::F32: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: if(d->info()->data_type() == DataType::S32) { - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); } else { - create_function_or_arm_gemm_quant(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager); + create_function_or_arm_gemm_quant(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); } break; case DataType::S8: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: -- cgit v1.2.1