diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-07-31 17:22:11 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | 40ed6d89b9af1fc3c6fa24a757982d3cd713c6bf (patch) | |
tree | 34009f143c98bcbdfdbb605b6713bf90e405e65d | |
parent | 2f1366a931a8641d0f8c4ce28dc85dfa818236ed (diff) | |
download | ComputeLibrary-40ed6d89b9af1fc3c6fa24a757982d3cd713c6bf.tar.gz |
COMPMID-1440: Access original B in gemm assembly when not pretransposed.
Change-Id: I5f2c198f7ac4d8996180e204e763ab53f5e7ea3d
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/142153
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Matteo Martincigh <matteo.martincigh@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index e4a7214c10..c0638c561e 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -208,7 +208,7 @@ template <typename TypeInput, typename TypeOutput> void Fallback<TypeInput, TypeOutput>::run() { const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); - const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + int ldb = 0; const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is @@ -220,12 +220,20 @@ void Fallback<TypeInput, TypeOutput>::run() const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); - const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + int multi_stride_b = 0; const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput); - const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes()); - const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes()); - auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes()); + const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes()); + const TypeInput *in1_ptr = nullptr; + auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes()); + + // Check if B is pre-tranposed and de-reference if not + if(!_gemm_kernel_asm->B_is_pretransposed()) + { + ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes()); + } // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads if(_workspace.buffer() != nullptr) |