diff options
Diffstat (limited to 'arm_compute/core/Types.h')
-rw-r--r-- | arm_compute/core/Types.h | 49 |
1 files changed, 40 insertions, 9 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 5e04bcd0f4..134b8e2905 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1205,6 +1205,26 @@ private: const bool _reinterpret_input_as_3d; }; +/** GEMMLowp output stage type */ +enum class GEMMLowpOutputStageType +{ + NONE, /**< No quantization to uint8 */ + QUANTIZE_DOWN, /**< Quantize to uint8 using an integer multiplication */ + QUANTIZE_DOWN_FIXEDPOINT, /**< Quantize to uint8 using a fixed point multiplication */ + QUANTIZE_DOWN_FLOAT /**< Quantize to uint8 using a floating point multiplication */ +}; + +/** GEMMLowp output stage info */ +struct GEMMLowpOutputStageInfo +{ + GEMMLowpOutputStageType type{ GEMMLowpOutputStageType::NONE }; /**< GEMMLowp output stage type */ + int gemmlowp_offset{ 0 }; /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */ + int gemmlowp_multiplier{ 0 }; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */ + int gemmlowp_shift{ 0 }; /**< GEMMLowp output stage shift used for quantizing to uint8 */ + int gemmlowp_min_bound{ 0 }; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */ + int gemmlowp_max_bound{ 0 }; /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */ +}; + /** GEMM information class. This class stores the necessary information to compute GEMM functions * * This object also contains the information about how matrix A and matrix B have been reshaped @@ -1215,7 +1235,7 @@ class GEMMInfo public: /** Default constructor */ GEMMInfo() - : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false), _retain_internal_weights(false) + : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(1), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage() { } /** Constructor @@ -1227,11 +1247,13 @@ public: * @param[in] reinterpret_input_as_3d (Optional) Reinterpret the input as 3D tensor. (i.e. this flag should be set to true when GEMM is used * to perform 1x1 convolutions with the NHWC data layout) * @param[in] retain_internal_weights (Optional) Retain the weights tensor from previous run + * @param[in] gemmlowp_output_stage (Optional) GEMMLowp Output stage info * */ - GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false) + GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 1, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false, + GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo()) : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d), - _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights) + _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage) { } /** Flag which specifies if the matrix A has been reshaped @@ -1284,14 +1306,23 @@ public: { return _retain_internal_weights; }; + /** GEMMLowp output stage + * + * @return the GEMMLowp output stage info + */ + GEMMLowpOutputStageInfo gemmlowp_output_stage() const + { + return _gemmlowp_output_stage; + }; private: - const bool _is_a_reshaped; - const bool _is_b_reshaped; - const bool _reshape_b_only_on_first_run; - const int _depth_output_gemm3d; - const bool _reinterpret_input_as_3d; - const bool _retain_internal_weights; + const bool _is_a_reshaped; + const bool _is_b_reshaped; + const bool _reshape_b_only_on_first_run; + const int _depth_output_gemm3d; + const bool _reinterpret_input_as_3d; + const bool _retain_internal_weights; + const GEMMLowpOutputStageInfo _gemmlowp_output_stage; }; /** Winograd information */ |