diff options
Diffstat (limited to 'src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r-- | src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 190 |
1 files changed, 98 insertions, 92 deletions
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 36c1bbb1b3..0c511ff548 100644 --- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -56,14 +56,13 @@ struct Params bool indirect; }; -Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *d, const AsmGemmInfo &info) +Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); - Params p; - p.M = d->info()->tensor_shape().y(); - p.K = a->info()->tensor_shape().x(); - p.N = d->info()->tensor_shape().x(); + p.M = d->tensor_shape().y(); + p.K = a->tensor_shape().x(); + p.N = d->tensor_shape().x(); p.batches = 1; p.multis = 1; p.sections = 1; @@ -72,19 +71,19 @@ Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *d, if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect) { p.indirect = true; - p.sections = b->info()->tensor_shape()[2] * b->info()->tensor_shape()[3]; + p.sections = b->tensor_shape()[2] * b->tensor_shape()[3]; } else { - p.multis = b->info()->tensor_shape().z(); - p.batches = d->info()->tensor_shape().total_size_upper(2) / p.multis; + p.multis = b->tensor_shape().z(); + p.batches = d->tensor_shape().total_size_upper(2) / p.multis; } // Update M in case of GEMM3D for output if(info.depth_output_gemm3d != 0) { - p.M = d->info()->tensor_shape().y() * d->info()->tensor_shape().z(); - p.batches = d->info()->tensor_shape().total_size_upper(3) / p.multis; + p.M = d->tensor_shape().y() * d->tensor_shape().z(); + p.batches = d->tensor_shape().total_size_upper(3) / p.multis; } return p; @@ -205,11 +204,11 @@ public: } private: - Tensor _output{}; - int _ldb{}; - const TypeInput *_in1_ptr{}; - int _multi_stride_b{}; - size_t _B_pretranspose_size{}; + Tensor _output{}; + int _ldb{}; + const TypeInput *_in1_ptr{}; + int _multi_stride_b{}; + size_t _B_pretranspose_size{}; std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr }; }; @@ -221,8 +220,7 @@ public: /** Destructor */ ~Fallback() { - // Release memory if we have allocated the memory ourselves - if(_pretranspose && !(_weights_manager && _weights_manager->are_weights_managed(_b))) + if(_pretranspose && !(is_weight_managed())) { delete _pretranspose; } @@ -240,7 +238,7 @@ public: * @param[in] weights_manager Weights manager to be used by the function. * @param[in] os Output stage meta-data. */ - void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, + void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); @@ -262,8 +260,8 @@ public: const std::vector<int32_t> &multipliers); // Inherited methods overridden: - void run() override; - void prepare() override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; private: @@ -283,28 +281,12 @@ private: */ void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info); /** Prepare the indirect buffer */ - void prepare_indirect_buffer(); + void prepare_indirect_buffer(ITensorPack &tensors); /** Assembly Gemm kernel */ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr }; /** Optimised Arm® Neon™ kernel */ std::unique_ptr<INEKernel> _optimised_kernel{ nullptr }; - /** Input A */ - const ITensor *_a - { - nullptr - }; - /** Input B */ - const ITensor *_b - { - nullptr - }; - const ITensor *_c - { - nullptr - }; - /** Output */ - ITensor *_d{ nullptr }; /** GEMM workspace */ Tensor _workspace{}; /** Pre-transpose tensor */ @@ -328,8 +310,27 @@ private: /** Indirect buffer */ std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{}; std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{}; - std::vector<TypeInput> _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; + std::vector<TypeInput> _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + + bool is_weight_managed() + { + // TODO (COMPMID-4539): This function should do the following: + // _weights_manager && _weights_manager->are_weights_managed(_b) + // , where _b is the second Tensor that is used to be given to the configure(). + // Currently, however, weight manager is disabled to make this class stateless. + // This should be revisited in the future. + return false; + } + + void acquire_managed_weight() + { + // TODO (COMPMID-4539): This function should do the following: + // _pretranspose = _weights_manager->acquire(_b, &_weights_transform); + // , where _b is the second Tensor that is used to be given to the configure(). + // Currently, however, weight manager is disabled to make this class stateless. + _pretranspose = nullptr; + } }; template <typename TypeInput, typename TypeOutput, class OutputStage> @@ -352,14 +353,15 @@ Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vec } template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer() +void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors) { - const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(_a->buffer()); + auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); + const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer()); const int multis = 1; - const int batches = _a->info()->tensor_shape().total_size_upper(3); - const size_t stride_A = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); - const size_t batch_stride_A = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); - const size_t multi_stride_A = _a->info()->strides_in_bytes()[4] / sizeof(TypeInput); + const int batches = a->info()->tensor_shape().total_size_upper(3); + const size_t stride_A = a->info()->strides_in_bytes().y() / sizeof(TypeInput); + const size_t batch_stride_A = a->info()->strides_in_bytes()[3] / sizeof(TypeInput); + const size_t multi_stride_A = a->info()->strides_in_bytes()[4] / sizeof(TypeInput); const size_t output_hw = _cp.output_height * _cp.output_width; const int batch_size = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput); @@ -466,10 +468,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen } template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, +void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) { + ARM_COMPUTE_UNUSED(c); arm_gemm::GemmConfig gemm_cfg; _kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os); _weights_manager = weights_manager; @@ -508,10 +511,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c } _optimised_kernel = std::move(acl_gemm_wrapper); - _a = a; - _b = b; - _c = c; - _d = d; _gemm_info = gemm_info; // Check for pre-transposed support if(_gemm_kernel_asm->B_pretranspose_required()) @@ -519,10 +518,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c // Forcing 128-byte alignment (required by 32-bit kernels) const unsigned int alignment = 128; const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size(); - if(weights_manager && _weights_manager->are_weights_managed(b)) + if(is_weight_managed()) { _weights_transform.configure(B_pretranspose_size, alignment); - _pretranspose = _weights_manager->acquire(b, &_weights_transform); + acquire_managed_weight(); } else { @@ -534,32 +533,34 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c // Handle indirect GEMM convolution if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect) { - configure_indirect(a->info(), b->info(), d->info(), gemm_info); + configure_indirect(a, b, d, gemm_info); } } template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::prepare() +void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) { + auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); if(!_is_prepared) { // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. - if(_c && _c->info()->data_type() == DataType::S32) + if(c && c->info()->data_type() == DataType::S32) { - _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()), 0); + _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0); } // Pretranspose B if required if(_gemm_kernel_asm->B_pretranspose_required()) { - const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); - const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes()); - const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); + const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); - if(_weights_manager && _weights_manager->are_weights_managed(_b)) + if(is_weight_managed()) { _weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm); - _weights_manager->run(_b, &_weights_transform); + _weights_manager->run(b, &_weights_transform); // If we didn't run the reshape function, set the pretransposed buffer if(!_weights_transform.is_reshape_run()) @@ -572,13 +573,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare() static_cast<Tensor *>(_pretranspose)->allocator()->allocate(); ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr); _gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b); - _b->mark_as_unused(); + b->mark_as_unused(); } } if(_gemm_info.method == AsmConvMethod::Indirect) { - prepare_indirect_buffer(); + prepare_indirect_buffer(tensors); } _is_prepared = true; @@ -601,37 +602,42 @@ bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const } template <typename TypeInput, typename TypeOutput, class OutputStage> -void Fallback<TypeInput, TypeOutput, OutputStage>::run() +void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) { - int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); + auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0); + auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto d = tensors.get_tensor(TensorType::ACL_DST); + + int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput); int ldb = 0; - const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); + const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput); const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2; const size_t a_multi_idx = a_batch_idx + 1; const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2; const size_t d_multi_idx = d_batch_idx + 1; - int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); - const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput); + int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); + const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput); - int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); + int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); int multi_stride_b = 0; - const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput); + const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput); - auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes()); + auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes()); const TypeInput *in1_ptr = nullptr; - auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes()); + auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes()); // Check if B is pre-tranposed and de-reference if not if(!_gemm_kernel_asm->B_is_pretransposed()) { - ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); - multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); - in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes()); + ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); + in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); } - const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, _d->info()->data_type()); + const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type()); // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads if(_workspace.buffer() != nullptr) @@ -654,13 +660,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run() } // Prepare assembly kernel - prepare(); + prepare(tensors); // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C. TypeOutput *bias = nullptr; - if(_c && _c->info()->data_type() != DataType::S32) + if(c && c->info()->data_type() != DataType::S32) { - bias = reinterpret_cast<TypeOutput *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()); + bias = reinterpret_cast<TypeOutput *>(c->buffer() + c->info()->offset_first_element_in_bytes()); } if(_gemm_info.method == AsmConvMethod::Indirect) @@ -682,7 +688,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run() template <typename TypeInput, typename TypeOutput> void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, - const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info, + const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info, IWeightsManager *weights_manager) { Params p = extract_parameters(a, b, d, info); @@ -699,7 +705,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge template <typename TypeInput, typename TypeOutput> void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, - const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info, + const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info, IWeightsManager *weights_manager) { ARM_COMPUTE_UNUSED(activation); @@ -714,8 +720,8 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & // Configure requantization info const int32_t negation = info.negated_offsets ? 1 : -1; - const int32_t a_offset = -a->info()->quantization_info().uniform().offset * negation; - const int32_t b_offset = -b->info()->quantization_info().uniform().offset * negation; + const int32_t a_offset = -a->quantization_info().uniform().offset * negation; + const int32_t b_offset = -b->quantization_info().uniform().offset * negation; const GEMMLowpOutputStageInfo os_info = info.output_stage; arm_gemm::Requantize32 gemm_requant_info{}; @@ -786,18 +792,18 @@ bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo return act.type != arm_gemm::Activation::Type::None; } -void CpuGemmAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info) +void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); arm_gemm::Activation act = map_to_arm_gemm_activation(info.activation_info); //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() - if(!CpuGemmAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), info)) + if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info)) { return; } - switch(a->info()->data_type()) + switch(a->data_type()) { case DataType::F32: create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); @@ -805,7 +811,7 @@ void CpuGemmAssemblyDispatch::configure(const ITensor *a, const ITensor *b, cons #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: - if(d->info()->data_type() == DataType::S32) + if(d->data_type() == DataType::S32) { create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); } @@ -816,7 +822,7 @@ void CpuGemmAssemblyDispatch::configure(const ITensor *a, const ITensor *b, cons break; case DataType::S8: case DataType::QASYMM8_SIGNED: - if(d->info()->data_type() == DataType::S32) + if(d->data_type() == DataType::S32) { create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); } @@ -841,10 +847,10 @@ void CpuGemmAssemblyDispatch::configure(const ITensor *a, const ITensor *b, cons } } -void CpuGemmAssemblyDispatch::prepare() +void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors) { ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); - _arm_gemm->prepare(); + _arm_gemm->prepare(tensors); } bool CpuGemmAssemblyDispatch::is_configured() const @@ -852,12 +858,12 @@ bool CpuGemmAssemblyDispatch::is_configured() const return _arm_gemm != nullptr && _arm_gemm->is_configured(); } -void CpuGemmAssemblyDispatch::run() +void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) { MemoryGroupResourceScope scope_mg(_memory_group); ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); - _arm_gemm->run(); + _arm_gemm->run(tensors); } } // namespace cpu } // namespace arm_compute |