diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-06-21 18:43:12 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-07-05 15:30:24 +0000 |
commit | 37d080f2f11cfd734104b76512e1fb191486216e (patch) | |
tree | d5df067c826aacc0676e7e9557a54b61a9a3b7eb /src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp | |
parent | 11de30da8a9f79943255ddba7bb70a66b076673b (diff) | |
download | ComputeLibrary-37d080f2f11cfd734104b76512e1fb191486216e.tar.gz |
COMPMID-2378: Sanitize GEMM configuration for NEON
Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1418
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp index 97c20dbd4e..ecdb5a938c 100644 --- a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp +++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp @@ -81,12 +81,20 @@ void NEGEMMNativeWrapperKernel<To, Tr>::run_internal(const Window &window, const TensorAccessor<To> b(*_b); TensorAccessor<Tr> c(*_c); - if(_a->info()->data_layout() == DataLayout::NHWC) + // Handle 3d input re-interpretation + if(_gemm_info.reinterpret_input_as_3d()) { - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1); - a.set_stride(2, nhwc_batch_stride); + Strides a_strides_as_3d = _a->info()->strides_in_bytes(); + a_strides_as_3d.remove(Window::DimZ); + a.set_strides(a_strides_as_3d); + } + + // Handle 3d output re-interpretation + if(_gemm_info.depth_output_gemm3d() != 0) + { + Strides c_strides_as_3d = _c->info()->strides_in_bytes(); + c_strides_as_3d.remove(Window::DimZ); + c.set_strides(c_strides_as_3d); } unsigned int m_end = 0; |