aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-31 17:22:11 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit40ed6d89b9af1fc3c6fa24a757982d3cd713c6bf (patch)
tree34009f143c98bcbdfdbb605b6713bf90e405e65d /src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
parent2f1366a931a8641d0f8c4ce28dc85dfa818236ed (diff)
downloadComputeLibrary-40ed6d89b9af1fc3c6fa24a757982d3cd713c6bf.tar.gz
COMPMID-1440: Access original B in gemm assembly when not pretransposed.
Change-Id: I5f2c198f7ac4d8996180e204e763ab53f5e7ea3d Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/142153 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Matteo Martincigh <matteo.martincigh@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp18
1 files changed, 13 insertions, 5 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index e4a7214c10..c0638c561e 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -208,7 +208,7 @@ template <typename TypeInput, typename TypeOutput>
void Fallback<TypeInput, TypeOutput>::run()
{
const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
- const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ int ldb = 0;
const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
// In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
@@ -220,12 +220,20 @@ void Fallback<TypeInput, TypeOutput>::run()
const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
- const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ int multi_stride_b = 0;
const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
- const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
- const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
- auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
+ const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
+ const TypeInput *in1_ptr = nullptr;
+ auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
+
+ // Check if B is pre-tranposed and de-reference if not
+ if(!_gemm_kernel_asm->B_is_pretransposed())
+ {
+ ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
+ }
// Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
if(_workspace.buffer() != nullptr)