diff options
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index e4a7214c10..c0638c561e 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -208,7 +208,7 @@ template <typename TypeInput, typename TypeOutput> void Fallback<TypeInput, TypeOutput>::run() { const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput); - const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + int ldb = 0; const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is @@ -220,12 +220,20 @@ void Fallback<TypeInput, TypeOutput>::run() const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); - const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + int multi_stride_b = 0; const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput); - const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes()); - const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes()); - auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes()); + const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes()); + const TypeInput *in1_ptr = nullptr; + auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes()); + + // Check if B is pre-tranposed and de-reference if not + if(!_gemm_kernel_asm->B_is_pretransposed()) + { + ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); + multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput); + in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes()); + } // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads if(_workspace.buffer() != nullptr) |