From b042e39060901b44e615b923b5723c04d9b42a95 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 21 Jun 2022 15:56:15 +0100 Subject: Add LUT-based leaky relu for QASYMM8 on CPU * Add LUT generation function for Leaky ReLU. * Some additional changes in the existing LUT implementation: + Bring back the NEON implementation of hard swish for 32-bit build. Library size of 64-bit build is not affected. + Add some extra #ifdef to remove unnecessary code in 32-bit build. Resolves: COMPMID-5386 Signed-off-by: Viet-Hoa Do Change-Id: I1ea49611cc922765ee741e31138c888401d33e9b Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7845 Benchmark: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins --- arm_compute/core/QuantizationInfo.h | 11 +++++++++++ arm_compute/core/Types.h | 38 ++++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) (limited to 'arm_compute/core') diff --git a/arm_compute/core/QuantizationInfo.h b/arm_compute/core/QuantizationInfo.h index 0bd0f21bc1..21d962d08b 100644 --- a/arm_compute/core/QuantizationInfo.h +++ b/arm_compute/core/QuantizationInfo.h @@ -409,6 +409,17 @@ inline qasymm8_t qasymm8_hard_swish(qasymm8_t in, return tmp; } +inline qasymm8_t qasymm8_leaky_relu(qasymm8_t in, + const UniformQuantizationInfo &qi_in, + const UniformQuantizationInfo &qi_out, + float alpha) +{ + float tmp_f = dequantize_qasymm8(in, qi_in); + tmp_f = tmp_f > 0 ? tmp_f : tmp_f * alpha; + const qasymm8_t tmp = quantize_qasymm8(tmp_f, qi_out); + return tmp; +} + /** Dequantize a value given a 8-bit symmetric quantization scheme * * @param[in] value Value to dequantize diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 4524976d6b..94fe1a07f4 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1680,6 +1680,7 @@ public: return _enabled; } +#ifdef __aarch64__ const LookupTable256 &lut() const { return _lut; @@ -1687,7 +1688,27 @@ public: void init_lut(const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out) { - qasymm8_hard_swish_populate_table(_lut, qi_in, qi_out); + if(_act == ActivationFunction::HARD_SWISH) + { + qasymm8_hard_swish_populate_table(_lut, qi_in, qi_out); + } + else if(_act == ActivationFunction::LEAKY_RELU) + { + qasymm8_leaky_relu_populate_table(_lut, qi_in, qi_out, _a); + } + } +#endif // __aarch64__ + + static inline bool is_lut_supported(ActivationFunction act_func, DataType data_type) + { +#ifdef __aarch64__ + auto supported = (data_type == DataType::QASYMM8 && (act_func == ActivationFunction::HARD_SWISH || act_func == ActivationFunction::LEAKY_RELU)); + return supported; +#else // __aarch64__ + ARM_COMPUTE_UNUSED(act_func); + ARM_COMPUTE_UNUSED(data_type); + return false; +#endif // __aarch64__ } private: @@ -1695,15 +1716,26 @@ private: float _a = {}; float _b = {}; bool _enabled = { false }; - LookupTable256 _lut = {}; - inline void qasymm8_hard_swish_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out) +#ifdef __aarch64__ + LookupTable256 _lut = {}; + + static inline void qasymm8_hard_swish_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out) { for(size_t i = 0; i < lut.size(); ++i) { lut[i] = qasymm8_hard_swish(i, qi_in, qi_out); } } + + static inline void qasymm8_leaky_relu_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out, float alpha) + { + for(size_t i = 0; i < lut.size(); ++i) + { + lut[i] = qasymm8_leaky_relu(i, qi_in, qi_out, alpha); + } + } +#endif // __aarch64__ }; /** Fully connected layer info */ -- cgit v1.2.1