aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-21 18:43:12 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-07-05 15:30:24 +0000
commit37d080f2f11cfd734104b76512e1fb191486216e (patch)
treed5df067c826aacc0676e7e9557a54b61a9a3b7eb /src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
parent11de30da8a9f79943255ddba7bb70a66b076673b (diff)
downloadComputeLibrary-37d080f2f11cfd734104b76512e1fb191486216e.tar.gz
COMPMID-2378: Sanitize GEMM configuration for NEON
Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp')
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp18
1 files changed, 13 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
index 97c20dbd4e..ecdb5a938c 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
@@ -81,12 +81,20 @@ void NEGEMMNativeWrapperKernel<To, Tr>::run_internal(const Window &window, const
TensorAccessor<To> b(*_b);
TensorAccessor<Tr> c(*_c);
- if(_a->info()->data_layout() == DataLayout::NHWC)
+ // Handle 3d input re-interpretation
+ if(_gemm_info.reinterpret_input_as_3d())
{
- // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
- // the relevant multiple of the row stride.
- const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1);
- a.set_stride(2, nhwc_batch_stride);
+ Strides a_strides_as_3d = _a->info()->strides_in_bytes();
+ a_strides_as_3d.remove(Window::DimZ);
+ a.set_strides(a_strides_as_3d);
+ }
+
+ // Handle 3d output re-interpretation
+ if(_gemm_info.depth_output_gemm3d() != 0)
+ {
+ Strides c_strides_as_3d = _c->info()->strides_in_bytes();
+ c_strides_as_3d.remove(Window::DimZ);
+ c.set_strides(c_strides_as_3d);
}
unsigned int m_end = 0;