diff options
Diffstat (limited to 'arm_compute/core/Types.h')
-rw-r--r-- | arm_compute/core/Types.h | 48 |
1 files changed, 36 insertions, 12 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 00370918bd..81d652dd7d 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1031,7 +1031,7 @@ class GEMMReshapeInfo final public: /** Default constructor */ GEMMReshapeInfo() - : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(1) + : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false) { } /** Constructor @@ -1042,9 +1042,12 @@ public: * @param[in] mult_transpose1xW_width (Optional) Multiplication factor for the width of the 1xW transposed block * @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleaved block * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel + * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used + * to perform 1x1 convolutions with the NHWC data layout) */ - GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 1) - : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d) + GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false) + : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height), _depth_output_gemm3d(depth_output_gemm3d), + _reinterpret_input_as_3d(reinterpret_input_as_3d) { } /** Number of matrix A rows @@ -1098,14 +1101,23 @@ public: { return _depth_output_gemm3d; } + /** Flag which specifies if the input tensor has to be reinterpreted as 3D + * + * @return True if the input tensor has to be reinterpreted as 3D tensor + */ + bool reinterpret_input_as_3d() const + { + return _reinterpret_input_as_3d; + }; private: - const int _m; - const int _n; - const int _k; - const int _mult_transpose1xW_width; - const int _mult_interleave4x4_height; - const int _depth_output_gemm3d; + const int _m; + const int _n; + const int _k; + const int _mult_transpose1xW_width; + const int _mult_interleave4x4_height; + const int _depth_output_gemm3d; + const bool _reinterpret_input_as_3d; }; /** GEMM information class. This class stores the necessary information to compute GEMM functions @@ -1118,7 +1130,7 @@ class GEMMInfo public: /** Default constructor */ GEMMInfo() - : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1) + : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false) { } /** Constructor @@ -1127,10 +1139,13 @@ public: * @param[in] is_b_reshaped True if the matrix B has been reshaped * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run * @param[in] depth_output_gemm3d (Optional) Depth (third dimension) of the output tensor to be used with the GEMM3D kernel + * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used + * to perform 1x1 convolutions with the NHWC data layout) * */ - GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1) - : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d) + GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false) + : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d), + _reinterpret_input_as_3d(reinterpret_input_as_3d) { } /** Flag which specifies if the matrix A has been reshaped @@ -1167,12 +1182,21 @@ public: { return _depth_output_gemm3d; }; + /** Flag which specifies if the input tensor has to be reinterpreted as 3D + * + * @return True if the input tensor has to be reinterpreted as 3D tensor + */ + bool reinterpret_input_as_3d() const + { + return _reinterpret_input_as_3d; + }; private: const bool _is_a_reshaped; const bool _is_b_reshaped; const bool _reshape_b_only_on_first_run; const int _depth_output_gemm3d; + const bool _reinterpret_input_as_3d; }; /** Winograd information */ |