diff options
Diffstat (limited to 'src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp index 97c20dbd4e..ecdb5a938c 100644 --- a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp +++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp @@ -81,12 +81,20 @@ void NEGEMMNativeWrapperKernel<To, Tr>::run_internal(const Window &window, const TensorAccessor<To> b(*_b); TensorAccessor<Tr> c(*_c); - if(_a->info()->data_layout() == DataLayout::NHWC) + // Handle 3d input re-interpretation + if(_gemm_info.reinterpret_input_as_3d()) { - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1); - a.set_stride(2, nhwc_batch_stride); + Strides a_strides_as_3d = _a->info()->strides_in_bytes(); + a_strides_as_3d.remove(Window::DimZ); + a.set_strides(a_strides_as_3d); + } + + // Handle 3d output re-interpretation + if(_gemm_info.depth_output_gemm3d() != 0) + { + Strides c_strides_as_3d = _c->info()->strides_in_bytes(); + c_strides_as_3d.remove(Window::DimZ); + c.set_strides(c_strides_as_3d); } unsigned int m_end = 0; |