diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-06-05 14:56:06 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:53:09 +0000 |
commit | 72219330fd85b1271e714d4ba894d6d8e26340c9 (patch) | |
tree | 9ae0510087a1ca77b1695252a8621de3f2ab98af /arm_compute/runtime/NEON/AssemblyHelper.h | |
parent | c42f28d45e9b990276d54880d2cee9c9ee675a41 (diff) | |
download | ComputeLibrary-72219330fd85b1271e714d4ba894d6d8e26340c9.tar.gz |
COMPMID-1145: (API) Introduce prepare() stage (NEON/CL/GLES)
Change-Id: I5b46764f9c3154ec3e3b9c951cc9e6dfbcb81dfb
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/134255
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h')
-rw-r--r-- | arm_compute/runtime/NEON/AssemblyHelper.h | 52 |
1 files changed, 37 insertions, 15 deletions
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h index 3aa43ec96e..c4ba1a584e 100644 --- a/arm_compute/runtime/NEON/AssemblyHelper.h +++ b/arm_compute/runtime/NEON/AssemblyHelper.h @@ -51,7 +51,7 @@ public: using TypeResult = TypeOutput; /** Default constructor. */ AssemblyKernelGlue() - : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _workspace(nullptr), _pretranspose(nullptr) + : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _workspace(nullptr), _pretranspose(nullptr), _is_prepared(false) { } /** Assembly Gemm */ @@ -76,6 +76,31 @@ public: ITensor *_workspace; /** Pre-transpose tensor */ ITensor *_pretranspose; + /** Prepared flag */ + bool _is_prepared; + + /** Runs a preparation step, usually for pre-transposing matrix b */ + void prepare() + { + // Pretranspose B if required + if(_gemm_kernel_asm->B_pretranspose_required()) + { + const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer()); + const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + + // Forcing 128-byte alignment (required by 32-bit kernels) + const unsigned int alignment = 128; + void *raw_ptr = reinterpret_cast<void *>(_pretranspose->buffer()); + size_t space = _pretranspose->info()->total_size(); + void *aligned_ptr = support::cpp11::align(alignment, _gemm_kernel_asm->get_B_pretransposed_array_size(), raw_ptr, space); + ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr); + _gemm_kernel_asm->pretranspose_B_array(aligned_ptr, in1_ptr, ldb, multi_stride_b); + _b->mark_as_unused(); + } + + _is_prepared = true; + } /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel. * The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2) @@ -102,28 +127,25 @@ public: const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer()); auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer()); - // Set workspace if needed + // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads if(_workspace != nullptr) { _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace->buffer())); + const unsigned int window_size = _gemm_kernel_asm->get_window_size(); + unsigned int num_threads = NEScheduler::get().num_threads(); + if(window_size < num_threads) + { + num_threads = window_size; + _gemm_kernel_asm->set_nthreads(num_threads); + } } + // Prepare assembly kernel + prepare(); + // Set gemm parameters _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d); - // Pretranspose B if required - if(_gemm_kernel_asm->B_pretranspose_required()) - { - // Forcing 128-byte alignment (required by 32-bit kernels) - const unsigned int alignment = 128; - void *raw_ptr = reinterpret_cast<void *>(_pretranspose->buffer()); - size_t space = _pretranspose->info()->total_size(); - void *aligned_ptr = support::cpp11::align(alignment, _gemm_kernel_asm->get_B_pretransposed_array_size(), raw_ptr, space); - ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr); - _gemm_kernel_asm->pretranspose_B_array(aligned_ptr, in1_ptr, ldb, multi_stride_b); - _b->mark_as_unused(); - } - // Schedule assembly kernel NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX); } |