aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index e60fe80e0f..e4a7214c10 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -174,7 +174,7 @@ void Fallback<TypeInput, TypeOutput>::prepare()
if(_gemm_kernel_asm->B_pretranspose_required())
{
const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
+ const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
@@ -223,9 +223,9 @@ void Fallback<TypeInput, TypeOutput>::run()
const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
- const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer());
- const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
- auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
+ const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
+ const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
+ auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
// Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
if(_workspace.buffer() != nullptr)