From 37d080f2f11cfd734104b76512e1fb191486216e Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 21 Jun 2019 18:43:12 +0100 Subject: COMPMID-2378: Sanitize GEMM configuration for NEON Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Michalis Spyrou Tested-by: Arm Jenkins --- src/runtime/NEON/functions/NEGEMM.cpp | 10 ++-- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 68 +++++++++++++--------- .../NEGEMMLowpAssemblyMatrixMultiplyCore.cpp | 2 +- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 9 ++- .../assembly/NEGEMMInterleavedWrapper.cpp | 12 ++-- 5 files changed, 56 insertions(+), 45 deletions(-) (limited to 'src/runtime/NEON/functions') diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 55bcc45d12..2f36397c8e 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -58,17 +58,19 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe _run_vector_matrix_multiplication = a->info()->dimension(1) < 2; _original_b = b; - bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run)); + bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info)); if(run_optimised) { if(MEMInfo::get_policy() == MemoryPolicy::MINIMIZE) { - _asm_glue.configure(a, b, d, alpha, beta, false); + GEMMInfo gemm_info_ntb = gemm_info; + gemm_info_ntb.set_pretranpose_B(false); + _asm_glue.configure(a, b, d, alpha, beta, gemm_info_ntb); } else { - _asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run); + _asm_glue.configure(a, b, d, alpha, beta, gemm_info); } ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured()); } @@ -176,7 +178,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso } // Check if we need to run the optimized assembly kernel - const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true)); + const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, gemm_info)); if(!run_optimised) { diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 55e067f52d..2de7d2b279 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -36,21 +36,22 @@ namespace arm_compute namespace { std::unique_ptr create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info, - const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, + const ITensor *a, const ITensor *b, ITensor *d, + float alpha, float beta, const GEMMInfo &gemm_info, std::shared_ptr memory_manager) { - //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() + // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() switch(gemm_kernel_info.method) { case arm_gemm::GemmMethod::GEMM_INTERLEAVED: { - if(!pretranspose_hint) + if(!gemm_info.pretranpose_B()) { return nullptr; } auto function = support::cpp14::make_unique(memory_manager); - function->configure(a, b, d, alpha, beta, pretranspose_hint); + function->configure(a, b, d, alpha, beta, gemm_info); return std::move(function); } #if defined(__aarch64__) @@ -59,7 +60,7 @@ std::unique_ptr create_function_all_types(const arm_gemm::KernelDescr if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos) { auto kernel = support::cpp14::make_unique>(); - kernel->configure(a, b, d, alpha, beta); + kernel->configure(a, b, d, alpha, beta, gemm_info); auto function = support::cpp14::make_unique(); function->configure(std::move(kernel)); return std::move(function); @@ -83,9 +84,11 @@ public: * @param[in] b Input tensor containing the Matrix B. * @param[out] d Output tensor to store the result of matrix multiplication. * @param[in] args Matrix multiplication information. + * @param[in] gemm_info GEMM meta-data * @param[in] memory_group Memory group to be used by the function. */ - void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, MemoryGroup &memory_group); + void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, + const GEMMInfo &gemm_info, MemoryGroup &memory_group); // Inherited methods overridden: void run() override; @@ -123,10 +126,13 @@ private: Tensor _pretranspose{}; /** Prepared flag */ bool _is_prepared{ false }; + /** GEMM meta-data */ + GEMMInfo _gemm_info{}; }; template -void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, MemoryGroup &memory_group) +void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, + const GEMMInfo &gemm_info, MemoryGroup &memory_group) { arm_gemm::GemmConfig gemm_cfg; const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args); @@ -168,6 +174,7 @@ void Fallback::configure(const ITensor *a, const ITensor _a = a; _b = b; _d = d; + _gemm_info = gemm_info; // Check for pre-transposed support if(_gemm_kernel_asm->B_pretranspose_required()) { @@ -222,17 +229,17 @@ void Fallback::run() int ldb = 0; const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC; - const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z(); + const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d() != 0 ? 3 : 2; + const size_t a_multi_idx = a_batch_idx + 1; + const size_t d_batch_idx = _gemm_info.depth_output_gemm3d() != 0 ? 3 : 2; + const size_t d_multi_idx = d_batch_idx + 1; - const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput); - const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); + const int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); + const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput); - const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); + const int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); int multi_stride_b = 0; - const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput); + const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput); const auto in0_ptr = reinterpret_cast(_a->buffer() + _a->info()->offset_first_element_in_bytes()); const TypeInput *in1_ptr = nullptr; @@ -270,24 +277,27 @@ void Fallback::run() } template -void create_function_or_arm_gemm(std::unique_ptr &acl_function, std::unique_ptr &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b, - ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr memory_manager) +void create_function_or_arm_gemm(std::unique_ptr &acl_function, + std::unique_ptr &arm_gemm, + MemoryGroup &memory_group, const ITensor *a, const ITensor *b, + ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, + std::shared_ptr memory_manager) { - INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d); + INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B()); //Try to create an ACL function: - acl_function = create_function_all_types(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager)); + acl_function = create_function_all_types(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, gemm_info, std::move(memory_manager)); //If we still don't have an ACL function: if(acl_function == nullptr) { //Fallback onto arm_gemm function if ACL doesn't support this method. auto fallback = support::cpp14::make_unique>(); - fallback->configure(a, b, d, args, memory_group); + fallback->configure(a, b, d, args, gemm_info, memory_group); arm_gemm = std::move(fallback); } } @@ -299,11 +309,11 @@ NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr m { } -Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint) +Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_UNUSED(pretranspose_hint); + ARM_COMPUTE_UNUSED(gemm_info); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); #ifndef __aarch64__ @@ -319,14 +329,14 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo return Status{}; } -void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) +void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a); ARM_COMPUTE_ERROR_ON_NULLPTR(b); ARM_COMPUTE_ERROR_ON_NULLPTR(d); //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() - if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint)) + if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info)) { return; } @@ -334,20 +344,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITens switch(a->info()->data_type()) { case DataType::F32: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager); break; case DataType::S8: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager); break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp index ede89bf558..5b70c8724c 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp @@ -59,7 +59,7 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe case DataType::QASYMM8: case DataType::U8: { - _asm_glue.configure(a, b, output, 1.f, 0.f, true); + _asm_glue.configure(a, b, output, 1.f, 0.f, GEMMInfo(false, false, true)); run_optimised = _asm_glue.is_configured(); break; } diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index d8773e37ab..f10f114287 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -87,7 +87,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, case DataType::U8: case DataType::S8: { - _asm_glue.configure(a, b, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, _reshape_b_only_on_first_run); + _asm_glue.configure(a, b, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, gemm_info); _dot_product_path = _asm_glue.is_configured(); break; } @@ -224,9 +224,8 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso TensorInfo tmp_b_info{}; TensorInfo mm_result_s32_info{}; - int32_t a_offset = a->quantization_info().uniform().offset; - int32_t b_offset = b->quantization_info().uniform().offset; - const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); + int32_t a_offset = a->quantization_info().uniform().offset; + int32_t b_offset = b->quantization_info().uniform().offset; bool fuse_output_stage = gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE; if(fuse_output_stage) @@ -235,7 +234,7 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso } // Check if we need to run the optimized assembly kernel - const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, reshape_b_only_on_first_run)); + const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, gemm_info)); if(run_optimised) { diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp index 20aa1496b6..ac809fa142 100644 --- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp +++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp @@ -339,19 +339,19 @@ void NEGEMMInterleavedWrapper::prepare() } } -void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b) +void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info) { - _params = INEGEMMWrapperKernel::extract_parameters(a, b, c); + _params = INEGEMMWrapperKernel::extract_parameters(a, b, c, gemm_info); _a = a; _b = b; _c = c; - _pretranspose_b = pretranspose_b; + _pretranspose_b = gemm_info.pretranpose_B(); const DataType input_type = a->info()->data_type(); const CPUInfo &ci = NEScheduler::get().cpu_info(); const unsigned int num_threads = NEScheduler::get().num_threads(); - const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, pretranspose_b); + const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, _pretranspose_b); ARM_COMPUTE_ERROR_ON(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMM_INTERLEAVED); // Forcing 128-byte alignment (required by 32-bit kernels) @@ -411,8 +411,8 @@ void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITe _memory_group.manage(&_transformed_a); _memory_group.manage(&_tmp_c); - _transform_a = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params); - _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, pretranspose_b, num_threads); + _transform_a = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params, gemm_info); + _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, gemm_info, num_threads); ARM_COMPUTE_ERROR_ON(_transform_a == nullptr); ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr); -- cgit v1.2.1