diff options
Diffstat (limited to 'arm_compute/core/Types.h')
-rw-r--r-- | arm_compute/core/Types.h | 108 |
1 files changed, 100 insertions, 8 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 5402e358b5..5197000bf9 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2018 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -824,13 +824,95 @@ private: const unsigned int _num_kernels; }; -/** GEMM Information class. This class stores the necessary information to compute GEMM functions */ +/** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape. + * + * The matrix A can only be reshaped through @ref CLGEMMInterleave4x4Kernel or @ref NEGEMMInterleave4x4Kernel or @ref GCGEMMInterleave4x4Kernel + * Note: Optionally just for @ref CLGEMMInterleave4x4Kernel is it possible to set mult_interleave4x4_height, the multiplication factor for the height of the 4x4 interleaved block + * + * The matrix B can only be reshaped through @ref CLGEMMTranspose1xWKernel or @ref NEGEMMTranspose1xWKernel or @ref GCGEMMTranspose1xWKernel + * Note: Optionally just for @ref CLGEMMTranspose1xWKernel is it possible to set mult_transpose1xW_width, the multiplication factor for the width of the 1xW transposed block + * + */ +class GEMMReshapeInfo final +{ +public: + /** Default constructor */ + GEMMReshapeInfo() + : _m(1), _n(1), _k(1), _mult_transpose1xW_width(1), _mult_interleave4x4_height(1) + { + } + /** Constructor + * + * @param[in] m Number of matrix A rows + * @param[in] n Number of matrix B columns + * @param[in] k Number of matrix A columns or matrix B rows + * @param[in] mult_transpose1xW_width (Optional) Multiplication factor for the width of the 1xW transposed block + * @param[in] mult_interleave4x4_height (Optional) Multiplication factor for the height of the 4x4 interleaved block + */ + GEMMReshapeInfo(int m, int n, int k, int mult_transpose1xW_width = 1, int mult_interleave4x4_height = 1) + : _m(m), _n(n), _k(k), _mult_transpose1xW_width(mult_transpose1xW_width), _mult_interleave4x4_height(mult_interleave4x4_height) + { + } + /** Number of matrix A rows + * + * @return the number of matrix A rows + */ + int m() const + { + return _m; + } + /** Number of matrix B columns + * + * @return the number of matrix B columns + */ + int n() const + { + return _n; + } + /** Number of matrix A columns or matrix B rows + * + * @return the number of matrix A columns or matrix B rows + */ + int k() const + { + return _k; + } + /** Multiplication factor for the width of the 1xW transposed block + * + * @return the multiplication factor for the width of the 1xW transposed block + */ + int mult_transpose1xW_width() const + { + return _mult_transpose1xW_width; + } + /** Multiplication factor for the height of the 4x4 interleaved block + * + * @return the multiplication factor for the height of the 4x4 interleaved block + */ + int mult_interleave4x4_height() const + { + return _mult_interleave4x4_height; + } + +private: + const int _m; + const int _n; + const int _k; + const int _mult_transpose1xW_width; + const int _mult_interleave4x4_height; +}; + +/** GEMM information class. This class stores the necessary information to compute GEMM functions + * + * This object also contains the information about how matrix A and matrix B have been reshaped + * + */ class GEMMInfo { public: /** Default constructor */ GEMMInfo() - : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false) + : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _reshape_info() { } /** Constructor @@ -838,9 +920,10 @@ public: * @param[in] is_a_reshaped True if the matrix A has been reshaped * @param[in] is_b_reshaped True if the matrix B has been reshaped * @param[in] reshape_b_only_on_first_run Reshape matrix B only for the first run + * @param[in] reshape_info (Optional) GEMM reshape information object */ - GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run) - : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run) + GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, const GEMMReshapeInfo &reshape_info = GEMMReshapeInfo()) + : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _reshape_info(reshape_info) { } /** Flag which specifies if the matrix A has been reshaped @@ -869,11 +952,20 @@ public: { return _reshape_b_only_on_first_run; }; + /** GEMMReshapeInfo object which stores the necessary information to understand how the matrix A and matrix B have been reshaped + * + * @return the GEMMReshapeInfo object + */ + const GEMMReshapeInfo &reshape_info() const + { + return _reshape_info; + } private: - const bool _is_a_reshaped; - const bool _is_b_reshaped; - const bool _reshape_b_only_on_first_run; + const bool _is_a_reshaped; + const bool _is_b_reshaped; + const bool _reshape_b_only_on_first_run; + GEMMReshapeInfo _reshape_info; }; /** IO formatting information class*/ |