From 8e74f4488daf1b628ca718396d5fc72fea95a83d Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Thu, 1 Mar 2018 16:42:00 +0000 Subject: COMPMID-911: Allow GEMM to work with 3D tensors Change-Id: I8c4823a0d909e19e9ef548f00b9ae98c66de61dd Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123569 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- arm_compute/core/Types.h | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) (limited to 'arm_compute/core/Types.h') diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 12c4e25222..da28e131de 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1013,7 +1013,7 @@ class GEMMReshapeInfo final public: /** Default constructor */ GEMMReshapeInfo() - : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1) + : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(1) { } /** Constructor @@ -1023,9 +1023,10 @@ public: * @param[in] k Number of matrix A columns or matrix B rows * @param[in] mult_transpose1xW_width (Optional) Multiplication factor for the width of the 1xW transposed block * @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleaved block + * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel */ - GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1) - : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height) + GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 1) + : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d) { } /** Number of matrix A rows @@ -1068,6 +1069,17 @@ public: { return _mult_interleave4x4_height; } + /** Depth (third dimension) of the output tensor to be used with the GEMM3D kernel + * + * @note GEMM3D kernel is used when the output has to be reinterpret as 3D tensor. In that case: + * m = depth_output_gemm3d * output_height + * + * @return the depth of the output tensor to be used with the GEMM3D kernel + */ + int depth_output_gemm3d() const + { + return _depth_output_gemm3d; + } private: const int _m; @@ -1075,6 +1087,7 @@ private: const int _k; const int _mult_transpose1xW_width; const int _mult_interleave4x4_height; + const int _depth_output_gemm3d; }; /** GEMM information class. This class stores the necessary information to compute GEMM functions @@ -1087,7 +1100,7 @@ class GEMMInfo public: /** Default constructor */ GEMMInfo() - : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _reshape_info() + : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1) { } /** Constructor @@ -1095,10 +1108,11 @@ public: * @param[in] is_a_reshaped True if the matrix A has been reshaped * @param[in] is_b_reshaped True if the matrix B has been reshaped * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run - * @param[in] reshape_info (Optional) GEMM reshape information object + * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel + * */ - GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo()) - : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _reshape_info(reshape_info) + GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1) + : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d) { } /** Flag which specifies if the matrix A has been reshaped @@ -1127,20 +1141,20 @@ public: { return _reshape_b_only_on_first_run; }; - /** GEMMReshapeInfo object which stores the necessary information to understand how the matrix A and matrix B have been reshaped + /** Depth of the output when GEMM output is reinterpreted as 3D tensor * - * @return the GEMMReshapeInfo object + * @return the depth of the output tensor */ - const GEMMReshapeInfo &reshape_info() const + int depth_output_gemm3d() const { - return _reshape_info; - } + return _depth_output_gemm3d; + }; private: - const bool _is_a_reshaped; - const bool _is_b_reshaped; - const bool _reshape_b_only_on_first_run; - GEMMReshapeInfo _reshape_info; + const bool _is_a_reshaped; + const bool _is_b_reshaped; + const bool _reshape_b_only_on_first_run; + const int _depth_output_gemm3d; }; /** Winograd information */ -- cgit v1.2.1