From 72219330fd85b1271e714d4ba894d6d8e26340c9 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 5 Jun 2018 14:56:06 +0100 Subject: COMPMID-1145: (API) Introduce prepare() stage (NEON/CL/GLES) Change-Id: I5b46764f9c3154ec3e3b9c951cc9e6dfbcb81dfb Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/134255 Reviewed-by: Anthony Barbier Tested-by: Jenkins Reviewed-by: Pablo Tello Reviewed-by: Michele DiGiorgio --- arm_compute/runtime/NEON/AssemblyHelper.h | 52 ++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 15 deletions(-) (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h') diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h index 3aa43ec96e..c4ba1a584e 100644 --- a/arm_compute/runtime/NEON/AssemblyHelper.h +++ b/arm_compute/runtime/NEON/AssemblyHelper.h @@ -51,7 +51,7 @@ public: using TypeResult = TypeOutput; /** Default constructor. */ AssemblyKernelGlue() - : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _workspace(nullptr), _pretranspose(nullptr) + : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _workspace(nullptr), _pretranspose(nullptr), _is_prepared(false) { } /** Assembly Gemm */ @@ -76,6 +76,31 @@ public: ITensor *_workspace; /** Pre-transpose tensor */ ITensor *_pretranspose; + /** Prepared flag */ + bool _is_prepared; + + /** Runs a preparation step, usually for pre-transposing matrix b */ + void prepare() + { + // Pretranspose B if required + if(_gemm_kernel_asm->B_pretranspose_required()) + { + const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const auto in1_ptr = reinterpret_cast(_b->buffer()); + const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + + // Forcing 128-byte alignment (required by 32-bit kernels) + const unsigned int alignment = 128; + void *raw_ptr = reinterpret_cast(_pretranspose->buffer()); + size_t space = _pretranspose->info()->total_size(); + void *aligned_ptr = support::cpp11::align(alignment, _gemm_kernel_asm->get_B_pretransposed_array_size(), raw_ptr, space); + ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr); + _gemm_kernel_asm->pretranspose_B_array(aligned_ptr, in1_ptr, ldb, multi_stride_b); + _b->mark_as_unused(); + } + + _is_prepared = true; + } /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel. * The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2) @@ -102,28 +127,25 @@ public: const auto in1_ptr = reinterpret_cast(_b->buffer()); auto out_ptr = reinterpret_cast(_d->buffer()); - // Set workspace if needed + // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads if(_workspace != nullptr) { _gemm_kernel_asm->set_working_space(reinterpret_cast(_workspace->buffer())); + const unsigned int window_size = _gemm_kernel_asm->get_window_size(); + unsigned int num_threads = NEScheduler::get().num_threads(); + if(window_size < num_threads) + { + num_threads = window_size; + _gemm_kernel_asm->set_nthreads(num_threads); + } } + // Prepare assembly kernel + prepare(); + // Set gemm parameters _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d); - // Pretranspose B if required - if(_gemm_kernel_asm->B_pretranspose_required()) - { - // Forcing 128-byte alignment (required by 32-bit kernels) - const unsigned int alignment = 128; - void *raw_ptr = reinterpret_cast(_pretranspose->buffer()); - size_t space = _pretranspose->info()->total_size(); - void *aligned_ptr = support::cpp11::align(alignment, _gemm_kernel_asm->get_B_pretransposed_array_size(), raw_ptr, space); - ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr); - _gemm_kernel_asm->pretranspose_B_array(aligned_ptr, in1_ptr, ldb, multi_stride_b); - _b->mark_as_unused(); - } - // Schedule assembly kernel NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); } -- cgit v1.2.1