aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/function_info/GEMMInfo.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/function_info/GEMMInfo.h')
-rw-r--r--arm_compute/function_info/GEMMInfo.h44
1 files changed, 30 insertions, 14 deletions
diff --git a/arm_compute/function_info/GEMMInfo.h b/arm_compute/function_info/GEMMInfo.h
index 29a57a00c2..c24762c0aa 100644
--- a/arm_compute/function_info/GEMMInfo.h
+++ b/arm_compute/function_info/GEMMInfo.h
@@ -26,6 +26,7 @@
#include "arm_compute/core/CoreTypes.h"
#include "arm_compute/function_info/ActivationLayerInfo.h"
+
#include <vector>
namespace arm_compute
@@ -43,17 +44,22 @@ enum class GEMMLowpOutputStageType
/** GEMMLowp output stage info */
struct GEMMLowpOutputStageInfo
{
- GEMMLowpOutputStageType type{ GEMMLowpOutputStageType::NONE }; /**< GEMMLowp output stage type */
- int32_t gemmlowp_offset{ 0 }; /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */
- int32_t gemmlowp_multiplier{ 0 }; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
- int32_t gemmlowp_shift{ 0 }; /**< GEMMLowp output stage shift used for quantizing to uint8 */
- int32_t gemmlowp_min_bound{ std::numeric_limits<int32_t>::lowest() }; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */
- int32_t gemmlowp_max_bound{ std::numeric_limits<int32_t>::max() }; /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */
- std::vector<int32_t> gemmlowp_multipliers{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
- std::vector<int32_t> gemmlowp_shifts{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
- float gemmlowp_real_multiplier{ 0 }; /**< GEMMLowp output stage real multiplier used for quantizing to QASYMM8 */
- bool is_quantized_per_channel{ false }; /**< GEMMLowp quantized per-channel flag */
- DataType output_data_type{ DataType::UNKNOWN }; /**< Output tensor data type to use if the output is not initialized */
+ GEMMLowpOutputStageType type{GEMMLowpOutputStageType::NONE}; /**< GEMMLowp output stage type */
+ int32_t gemmlowp_offset{0}; /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */
+ int32_t gemmlowp_multiplier{0}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
+ int32_t gemmlowp_shift{0}; /**< GEMMLowp output stage shift used for quantizing to uint8 */
+ int32_t gemmlowp_min_bound{
+ std::numeric_limits<int32_t>::
+ lowest()}; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */
+ int32_t gemmlowp_max_bound{
+ std::numeric_limits<int32_t>::
+ max()}; /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */
+ std::vector<int32_t> gemmlowp_multipliers{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
+ std::vector<int32_t> gemmlowp_shifts{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */
+ float gemmlowp_real_multiplier{0}; /**< GEMMLowp output stage real multiplier used for quantizing to QASYMM8 */
+ bool is_quantized_per_channel{false}; /**< GEMMLowp quantized per-channel flag */
+ DataType output_data_type{
+ DataType::UNKNOWN}; /**< Output tensor data type to use if the output is not initialized */
};
/** GEMM information class. This class stores the necessary information to compute GEMM functions
*
@@ -100,9 +106,19 @@ public:
* @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat.
* @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED.
*/
- GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
- GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false,
- const ActivationLayerInfo &activation_info = ActivationLayerInfo(), bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept
+ GEMMInfo(bool is_a_reshaped,
+ bool is_b_reshaped,
+ bool reshape_b_only_on_first_run,
+ int depth_output_gemm3d = 0,
+ bool reinterpret_input_as_3d = false,
+ bool retain_internal_weights = false,
+ GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(),
+ bool fp_mixed_precision = false,
+ bool fast_math = false,
+ bool broadcast_bias = false,
+ const ActivationLayerInfo &activation_info = ActivationLayerInfo(),
+ bool fixed_format = false,
+ arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept
: _is_a_reshaped(is_a_reshaped),
_is_b_reshaped(is_b_reshaped),
_reshape_b_only_on_first_run(reshape_b_only_on_first_run),