From 37d080f2f11cfd734104b76512e1fb191486216e Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 21 Jun 2019 18:43:12 +0100 Subject: COMPMID-2378: Sanitize GEMM configuration for NEON Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Michalis Spyrou Tested-by: Arm Jenkins --- .../assembly/NEGEMMInterleavedTransformAWrapper.h | 38 ++++++++++++---------- 1 file changed, 20 insertions(+), 18 deletions(-) (limited to 'arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h') diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h index 5d6cd02398..b18d327339 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h @@ -87,20 +87,22 @@ class NEGEMMInterleavedTransformAWrapperTemplate : public NEGEMMInterleavedTrans public: /** Configure the reshape A routine. * - * @param[in] a Input matrix A. - * @param[out] transformed_a Reshaped matrix A. - * @param[in] transpose_a Also transpose A ? - * @param[in] block_walker Window representing the layout of the matrix's blocks - * @param[in] params M, N, K sizes. + * @param[in] a Input matrix A. + * @param[out] transformed_a Reshaped matrix A. + * @param[in] transpose_a Also transpose A ? + * @param[in] reinterpret_a_as_3d Re-interpret as 3D ? + * @param[in] block_walker Window representing the layout of the matrix's blocks + * @param[in] params M, N, K sizes. */ - void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms) + void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, bool reinterpret_a_as_3d, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms) { - _a = a; - _transformed_a = transformed_a; - _transpose_a = transpose_a; - _Ksize = params.K; - _Msize = params.M; - _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension + _a = a; + _transformed_a = transformed_a; + _transpose_a = transpose_a; + _reinterpret_a_as_3d = reinterpret_a_as_3d; + _Ksize = params.K; + _Msize = params.M; + _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension } // Inherited methods overridden: @@ -110,12 +112,12 @@ public: TensorAccessor a(*_a); TensorAccessor transformed_a(*_transformed_a); - if(_a->info()->data_layout() == DataLayout::NHWC) + // Handle 3d input re-interpretation + if(_reinterpret_a_as_3d) { - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize; - a.set_stride(2, nhwc_batch_stride); + Strides a_strides_as_3d = _a->info()->strides_in_bytes(); + a_strides_as_3d.remove(Window::DimZ); + a.set_strides(a_strides_as_3d); } unsigned int last_m = 0; @@ -164,8 +166,8 @@ private: unsigned int _Msize{ 0 }; unsigned int _Ksize{ 0 }; bool _transpose_a{ false }; + bool _reinterpret_a_as_3d{ false }; Window _k_multi_window{}; }; - } // namespace arm_compute #endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__ */ -- cgit v1.2.1