diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2020-02-07 13:46:45 +0000 |
---|---|---|
committer | Giorgio Arena <giorgio.arena@arm.com> | 2020-03-02 15:51:39 +0000 |
commit | 1856ff7ebb29e04c3549b74d7ced336111cbf05e (patch) | |
tree | c94654f0d8535930a81712bf7aadffd757c82577 /arm_compute/core | |
parent | 3c4bf0c4eab5ead756c472f17ddf008b882cc905 (diff) | |
download | ComputeLibrary-1856ff7ebb29e04c3549b74d7ced336111cbf05e.tar.gz |
COMPMID-3097 Fuse activation with fully connected layer CL
Change-Id: I447030e69b9e565f2f81529a41af8c5e7ece7ecf
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2702
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core')
-rw-r--r-- | arm_compute/core/PixelValue.h | 18 | ||||
-rw-r--r-- | arm_compute/core/Types.h | 87 |
2 files changed, 52 insertions, 53 deletions
diff --git a/arm_compute/core/PixelValue.h b/arm_compute/core/PixelValue.h index c5f6608163..31bc55098a 100644 --- a/arm_compute/core/PixelValue.h +++ b/arm_compute/core/PixelValue.h @@ -41,11 +41,11 @@ public: } /** Initialize the union with a pixel value of chosen datatype * - * @param[in] v int value. + * @param[in] v value. * @param[in] datatype DataType that @p v have to be stored * @param[in] qinfo (Optional) QuantizationInfo to apply in case of quantized data types to @p v */ - PixelValue(int64_t v, DataType datatype, QuantizationInfo qinfo = QuantizationInfo()) + PixelValue(double v, DataType datatype, QuantizationInfo qinfo = QuantizationInfo()) : PixelValue() { switch(datatype) @@ -57,13 +57,13 @@ public: value.s8 = static_cast<int8_t>(v); break; case DataType::QASYMM8: - value.u8 = quantize_qasymm8(static_cast<uint8_t>(v), qinfo); + value.u8 = quantize_qasymm8(static_cast<float>(v), qinfo); break; case DataType::QASYMM8_SIGNED: - value.s8 = quantize_qasymm8_signed(static_cast<int8_t>(v), qinfo); + value.s8 = quantize_qasymm8_signed(static_cast<float>(v), qinfo); break; case DataType::QSYMM8: - value.s8 = quantize_qsymm8(static_cast<int8_t>(v), qinfo); + value.s8 = quantize_qsymm8(static_cast<float>(v), qinfo); break; case DataType::U16: value.u16 = static_cast<uint16_t>(v); @@ -72,10 +72,10 @@ public: value.s16 = static_cast<int16_t>(v); break; case DataType::QASYMM16: - value.u16 = quantize_qasymm16(static_cast<uint16_t>(v), qinfo); + value.u16 = quantize_qasymm16(static_cast<float>(v), qinfo); break; case DataType::QSYMM16: - value.s16 = quantize_qsymm16(static_cast<int16_t>(v), qinfo); + value.s16 = quantize_qsymm16(static_cast<float>(v), qinfo); break; case DataType::U32: value.u32 = static_cast<uint32_t>(v); @@ -96,10 +96,8 @@ public: value.f32 = static_cast<float>(v); break; case DataType::F64: - value.f64 = static_cast<double>(v); - break; default: - value.s64 = v; + value.f64 = v; break; } } diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 2030b171c6..cf689d757c 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -799,39 +799,6 @@ private: DimensionRoundingType _round_type; }; -/** Fully connected layer info */ -struct FullyConnectedLayerInfo -{ - DataLayout weights_trained_layout{ DataLayout::NCHW }; /**< Layout that the weights have been trained with. */ - bool transpose_weights{ true }; /**< Transpose weights if true. */ - bool are_weights_reshaped{ false }; /**< Reshape the weights tensor if false. */ - bool retain_internal_weights{ false }; /**< Retain internal reshaped weights. */ - bool fp_mixed_precision{ false }; /**< Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. */ - - /** Sets the weights trained data layout - * - * @param[in] layout Data layout that the weights were trained with - * - * @return Updated object - */ - FullyConnectedLayerInfo &set_weights_trained_layout(DataLayout layout) - { - weights_trained_layout = layout; - return *this; - } - /** Sets the transpose weights flag - * - * @param[in] should_transpose_weights Boolean flag indicating if weights should be transposed - * - * @return Updated object - */ - FullyConnectedLayerInfo &set_transpose_weights(bool should_transpose_weights) - { - transpose_weights = should_transpose_weights; - return *this; - } -}; - /** PriorBox layer info */ class PriorBoxLayerInfo final { @@ -1674,6 +1641,40 @@ private: bool _enabled = { false }; }; +/** Fully connected layer info */ +struct FullyConnectedLayerInfo +{ + DataLayout weights_trained_layout{ DataLayout::NCHW }; /**< Layout that the weights have been trained with. */ + bool transpose_weights{ true }; /**< Transpose weights if true. */ + bool are_weights_reshaped{ false }; /**< Reshape the weights tensor if false. */ + bool retain_internal_weights{ false }; /**< Retain internal reshaped weights. */ + bool fp_mixed_precision{ false }; /**< Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. */ + ActivationLayerInfo activation_info{}; /**< Fused activation to apply after the matrix multiplication. */ + + /** Sets the weights trained data layout + * + * @param[in] layout Data layout that the weights were trained with + * + * @return Updated object + */ + FullyConnectedLayerInfo &set_weights_trained_layout(DataLayout layout) + { + weights_trained_layout = layout; + return *this; + } + /** Sets the transpose weights flag + * + * @param[in] should_transpose_weights Boolean flag indicating if weights should be transposed + * + * @return Updated object + */ + FullyConnectedLayerInfo &set_transpose_weights(bool should_transpose_weights) + { + transpose_weights = should_transpose_weights; + return *this; + } +}; + /** Normalization Layer Information class */ class NormalizationLayerInfo { @@ -1944,16 +1945,16 @@ enum class GEMMLowpOutputStageType /** GEMMLowp output stage info */ struct GEMMLowpOutputStageInfo { - GEMMLowpOutputStageType type{ GEMMLowpOutputStageType::NONE }; /**< GEMMLowp output stage type */ - int32_t gemmlowp_offset{ 0 }; /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */ - int32_t gemmlowp_multiplier{ 0 }; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */ - int32_t gemmlowp_shift{ 0 }; /**< GEMMLowp output stage shift used for quantizing to uint8 */ - int32_t gemmlowp_min_bound{ 0 }; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */ - int32_t gemmlowp_max_bound{ 0 }; /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */ - std::vector<int32_t> gemmlowp_multipliers{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */ - std::vector<int32_t> gemmlowp_shifts{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */ - bool is_quantized_per_channel{ false }; /**< GEMMLowp quantized per-channel flag */ - DataType output_data_type{ DataType::UNKNOWN }; /**< Output tensor data type to use if the output is not initialized */ + GEMMLowpOutputStageType type{ GEMMLowpOutputStageType::NONE }; /**< GEMMLowp output stage type */ + int32_t gemmlowp_offset{ 0 }; /**< GEMMLowp output stage offset used for quantizing to QASYMM8 */ + int32_t gemmlowp_multiplier{ 0 }; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */ + int32_t gemmlowp_shift{ 0 }; /**< GEMMLowp output stage shift used for quantizing to uint8 */ + int32_t gemmlowp_min_bound{ std::numeric_limits<int32_t>::lowest() }; /**< GEMMLowp min value used to saturate down the output result before converting back to QASYMM8 */ + int32_t gemmlowp_max_bound{ std::numeric_limits<int32_t>::max() }; /**< GEMMLowp max value used to saturate down the output result before converting back to QASYMM8 */ + std::vector<int32_t> gemmlowp_multipliers{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */ + std::vector<int32_t> gemmlowp_shifts{}; /**< GEMMLowp output stage multiplier used for quantizing to QASYMM8 */ + bool is_quantized_per_channel{ false }; /**< GEMMLowp quantized per-channel flag */ + DataType output_data_type{ DataType::UNKNOWN }; /**< Output tensor data type to use if the output is not initialized */ }; /** GEMM LHS (Left Hand Side) matrix information */ |