diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-06-21 18:43:12 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-07-05 15:30:24 +0000 |
commit | 37d080f2f11cfd734104b76512e1fb191486216e (patch) | |
tree | d5df067c826aacc0676e7e9557a54b61a9a3b7eb /arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h | |
parent | 11de30da8a9f79943255ddba7bb70a66b076673b (diff) | |
download | ComputeLibrary-37d080f2f11cfd734104b76512e1fb191486216e.tar.gz |
COMPMID-2378: Sanitize GEMM configuration for NEON
Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1418
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h | 38 |
1 files changed, 20 insertions, 18 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h index 5d6cd02398..b18d327339 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h @@ -87,20 +87,22 @@ class NEGEMMInterleavedTransformAWrapperTemplate : public NEGEMMInterleavedTrans public: /** Configure the reshape A routine. * - * @param[in] a Input matrix A. - * @param[out] transformed_a Reshaped matrix A. - * @param[in] transpose_a Also transpose A ? - * @param[in] block_walker Window representing the layout of the matrix's blocks - * @param[in] params M, N, K sizes. + * @param[in] a Input matrix A. + * @param[out] transformed_a Reshaped matrix A. + * @param[in] transpose_a Also transpose A ? + * @param[in] reinterpret_a_as_3d Re-interpret as 3D ? + * @param[in] block_walker Window representing the layout of the matrix's blocks + * @param[in] params M, N, K sizes. */ - void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms) + void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, bool reinterpret_a_as_3d, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms) { - _a = a; - _transformed_a = transformed_a; - _transpose_a = transpose_a; - _Ksize = params.K; - _Msize = params.M; - _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension + _a = a; + _transformed_a = transformed_a; + _transpose_a = transpose_a; + _reinterpret_a_as_3d = reinterpret_a_as_3d; + _Ksize = params.K; + _Msize = params.M; + _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension } // Inherited methods overridden: @@ -110,12 +112,12 @@ public: TensorAccessor<typename strategy::operand_type> a(*_a); TensorAccessor<typename strategy::operand_type> transformed_a(*_transformed_a); - if(_a->info()->data_layout() == DataLayout::NHWC) + // Handle 3d input re-interpretation + if(_reinterpret_a_as_3d) { - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize; - a.set_stride(2, nhwc_batch_stride); + Strides a_strides_as_3d = _a->info()->strides_in_bytes(); + a_strides_as_3d.remove(Window::DimZ); + a.set_strides(a_strides_as_3d); } unsigned int last_m = 0; @@ -164,8 +166,8 @@ private: unsigned int _Msize{ 0 }; unsigned int _Ksize{ 0 }; bool _transpose_a{ false }; + bool _reinterpret_a_as_3d{ false }; Window _k_multi_window{}; }; - } // namespace arm_compute #endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__ */ |