From 40ed6d89b9af1fc3c6fa24a757982d3cd713c6bf Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 31 Jul 2018 17:22:11 +0100 Subject: COMPMID-1440: Access original B in gemm assembly when not pretransposed. Change-Id: I5f2c198f7ac4d8996180e204e763ab53f5e7ea3d Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/142153 Tested-by: Jenkins Reviewed-by: Matteo Martincigh Reviewed-by: Anthony Barbier --- src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index e4a7214c10..c0638c561e 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -208,7 +208,7 @@ template void Fallback::run() { const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); - const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + int ldb = 0; const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is @@ -220,12 +220,20 @@ void Fallback::run() const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); - const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + int multi_stride_b = 0; const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput); - const auto in0_ptr = reinterpret_cast(_a->buffer() + _a->info()->offset_first_element_in_bytes()); - const auto in1_ptr = reinterpret_cast(_b->buffer() + _b->info()->offset_first_element_in_bytes()); - auto out_ptr = reinterpret_cast(_d->buffer() + _d->info()->offset_first_element_in_bytes()); + const auto in0_ptr = reinterpret_cast(_a->buffer() + _a->info()->offset_first_element_in_bytes()); + const TypeInput *in1_ptr = nullptr; + auto out_ptr = reinterpret_cast(_d->buffer() + _d->info()->offset_first_element_in_bytes()); + + // Check if B is pre-tranposed and de-reference if not + if(!_gemm_kernel_asm->B_is_pretransposed()) + { + ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + in1_ptr = reinterpret_cast(_b->buffer() + _b->info()->offset_first_element_in_bytes()); + } // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads if(_workspace.buffer() != nullptr) -- cgit v1.2.1