diff options
Diffstat (limited to 'src/core/NEON/NEMath.inl')
-rw-r--r-- | src/core/NEON/NEMath.inl | 105 |
1 files changed, 50 insertions, 55 deletions
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl index 1cbe669373..f875917988 100644 --- a/src/core/NEON/NEMath.inl +++ b/src/core/NEON/NEMath.inl @@ -29,19 +29,16 @@ namespace arm_compute { /** Logarithm polynomial coefficients */ -const std::array<float32x4_t, 8> log_tab = -{ - { - vdupq_n_f32(-2.29561495781f), - vdupq_n_f32(-2.47071170807f), - vdupq_n_f32(-5.68692588806f), - vdupq_n_f32(-0.165253549814f), - vdupq_n_f32(5.17591238022f), - vdupq_n_f32(0.844007015228f), - vdupq_n_f32(4.58445882797f), - vdupq_n_f32(0.0141278216615f), - } -}; +const std::array<float32x4_t, 8> log_tab = {{ + vdupq_n_f32(-2.29561495781f), + vdupq_n_f32(-2.47071170807f), + vdupq_n_f32(-5.68692588806f), + vdupq_n_f32(-0.165253549814f), + vdupq_n_f32(5.17591238022f), + vdupq_n_f32(0.844007015228f), + vdupq_n_f32(4.58445882797f), + vdupq_n_f32(0.0141278216615f), +}}; /** Sin polynomial coefficients */ constexpr float te_sin_coeff2 = 0.166666666666f; // 1/(2*3) @@ -54,7 +51,7 @@ inline float32x4_t prefer_vfmaq_f32(float32x4_t a, float32x4_t b, float32x4_t c) { #if __ARM_FEATURE_FMA return vfmaq_f32(a, b, c); -#else // __ARM_FEATURE_FMA +#else // __ARM_FEATURE_FMA return vmlaq_f32(a, b, c); #endif // __ARM_FEATURE_FMA } @@ -73,13 +70,14 @@ inline float32x4_t vroundq_rte_f32(float32x4_t val) { #ifdef __aarch64__ return vrndnq_f32(val); -#else // __aarch64__ +#else // __aarch64__ static const float32x4_t CONST_HALF_FLOAT = vdupq_n_f32(0.5f); static const float32x4_t CONST_1_FLOAT = vdupq_n_f32(1.f); static const int32x4_t CONST_1_INT = vdupq_n_s32(1); const float32x4_t floor_val = vfloorq_f32(val); const float32x4_t diff = vsubq_f32(val, floor_val); - const float32x4_t fp32_upper_limit = vreinterpretq_f32_u32(vdupq_n_u32(0x4B000000)); // 0x4B000000 = (23U + 127U) << 23U + const float32x4_t fp32_upper_limit = + vreinterpretq_f32_u32(vdupq_n_u32(0x4B000000)); // 0x4B000000 = (23U + 127U) << 23U /* * 1. Select the floor value when (diff<0.5 || (diff==0.5 && floor_val%2==0). @@ -95,12 +93,13 @@ inline float32x4_t vroundq_rte_f32(float32x4_t val) * Threshold upper limit with format |S|E(8bits)| Fraction(23bits) | = (23 + 127) << 23 (assuming positive sign): Adding 127, because 127 represents the actual zero in this format. */ - float32x4_t rounded_val = vbslq_f32(vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT), - vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT), - vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT),CONST_1_INT)))), - floor_val, vaddq_f32(floor_val, CONST_1_FLOAT)); + float32x4_t rounded_val = vbslq_f32( + vorrq_u32(vcltq_f32(diff, CONST_HALF_FLOAT), + vandq_u32(vceqq_f32(diff, CONST_HALF_FLOAT), + vmvnq_u32(vtstq_s32(vandq_s32(vcvtq_s32_f32(floor_val), CONST_1_INT), CONST_1_INT)))), + floor_val, vaddq_f32(floor_val, CONST_1_FLOAT)); - float32x4_t result = vbslq_f32(vcgeq_f32(vabsq_f32(val), fp32_upper_limit), val, rounded_val); + float32x4_t result = vbslq_f32(vcgeq_f32(vabsq_f32(val), fp32_upper_limit), val, rounded_val); return result; #endif // __aarch64__ @@ -118,8 +117,8 @@ inline float32x2_t vinvsqrt_f32(float32x2_t x) inline float32x4_t vinvsqrtq_f32(float32x4_t x) { float32x4_t sqrt_reciprocal = vrsqrteq_f32(x); - sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); - sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); + sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); + sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); return sqrt_reciprocal; } @@ -152,8 +151,7 @@ inline float32x4_t vtaylor_polyq_f32(float32x4_t x, const std::array<float32x4_t return res; } -static const uint32_t exp_f32_coeff[] = -{ +static const uint32_t exp_f32_coeff[] = { 0x3f7ffff6, // x^1: 0x1.ffffecp-1f 0x3efffedb, // x^2: 0x1.fffdb6p-2f 0x3e2aaf33, // x^3: 0x1.555e66p-3f @@ -169,10 +167,12 @@ inline float32x4_t vexpq_f32(float32x4_t x) const auto c4 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[3])); const auto c5 = vreinterpretq_f32_u32(vdupq_n_u32(exp_f32_coeff[4])); - const auto shift = vreinterpretq_f32_u32(vdupq_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f - const auto inv_ln2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f - const auto neg_ln2_hi = vreinterpretq_f32_u32(vdupq_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f - const auto neg_ln2_lo = vreinterpretq_f32_u32(vdupq_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + const auto shift = vreinterpretq_f32_u32(vdupq_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = + vreinterpretq_f32_u32(vdupq_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = + vreinterpretq_f32_u32(vdupq_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f const auto inf = vdupq_n_f32(std::numeric_limits<float>::infinity()); const auto max_input = vdupq_n_f32(88.37f); // Approximately ln(2^127.5) @@ -224,9 +224,9 @@ inline float32x4_t vexpq_f32(float32x4_t x) #ifdef __aarch64__ inline float32x4_t verfq_f32(float32x4_t x) { - static const float erffdata[4] = { 0.278393f, 0.230389f, 0.000972f, 0.078108f }; + static const float erffdata[4] = {0.278393f, 0.230389f, 0.000972f, 0.078108f}; static const float32x4_t coeffdata = vld1q_f32(erffdata); - static const float32x4_t onev{ vdupq_n_f32(1.0f) }; + static const float32x4_t onev{vdupq_n_f32(1.0f)}; uint32x4_t selector = vcltzq_f32(x); @@ -287,10 +287,12 @@ inline float32x4_t vtanhq_f32(float32x4_t val) float32x4_t x = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH); // x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise - float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x)); - float32x4_t num = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x)); - float32x4_t den = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vaddq_f32(exp2x, CONST_1), vsubq_f32(CONST_1, num)); - float32x4_t tanh = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vmulq_f32(num, vinvq_f32(den)), vmulq_f32(x, den)); + float32x4_t exp2x = + vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x)); + float32x4_t num = + vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x)); + float32x4_t den = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vaddq_f32(exp2x, CONST_1), vsubq_f32(CONST_1, num)); + float32x4_t tanh = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vmulq_f32(num, vinvq_f32(den)), vmulq_f32(x, den)); return tanh; } @@ -456,30 +458,23 @@ inline float32x4x4_t convert_to_float32x4x4(const int8x16_t &in) inline void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const float32x4x3_t &in2, uint8x8x3_t &out) { - out.val[0] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[0])), - vqmovn_u32(vcvtq_u32_f32(in2.val[0])))); - out.val[1] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[1])), - vqmovn_u32(vcvtq_u32_f32(in2.val[1])))); - out.val[2] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[2])), - vqmovn_u32(vcvtq_u32_f32(in2.val[2])))); + out.val[0] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[0])), vqmovn_u32(vcvtq_u32_f32(in2.val[0])))); + out.val[1] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[1])), vqmovn_u32(vcvtq_u32_f32(in2.val[1])))); + out.val[2] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[2])), vqmovn_u32(vcvtq_u32_f32(in2.val[2])))); } inline void convert_float32x4x4_to_uint8x16(const float32x4x4_t &in, uint8x16_t &out) { - const auto low = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[0])), - vqmovn_u32(vcvtq_u32_f32(in.val[1]))); - const auto high = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[2])), - vqmovn_u32(vcvtq_u32_f32(in.val[3]))); - out = vcombine_u8(vqmovn_u16(low), vqmovn_u16(high)); + const auto low = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[0])), vqmovn_u32(vcvtq_u32_f32(in.val[1]))); + const auto high = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[2])), vqmovn_u32(vcvtq_u32_f32(in.val[3]))); + out = vcombine_u8(vqmovn_u16(low), vqmovn_u16(high)); } inline void convert_float32x4x4_to_int8x16(const float32x4x4_t &in, int8x16_t &out) { - const auto low = vcombine_s16(vqmovn_s32(vcvtq_s32_f32(in.val[0])), - vqmovn_s32(vcvtq_s32_f32(in.val[1]))); - const auto high = vcombine_s16(vqmovn_s32(vcvtq_s32_f32(in.val[2])), - vqmovn_s32(vcvtq_s32_f32(in.val[3]))); - out = vcombine_s8(vqmovn_s16(low), vqmovn_s16(high)); + const auto low = vcombine_s16(vqmovn_s32(vcvtq_s32_f32(in.val[0])), vqmovn_s32(vcvtq_s32_f32(in.val[1]))); + const auto high = vcombine_s16(vqmovn_s32(vcvtq_s32_f32(in.val[2])), vqmovn_s32(vcvtq_s32_f32(in.val[3]))); + out = vcombine_s8(vqmovn_s16(low), vqmovn_s16(high)); } template <> @@ -552,8 +547,8 @@ inline float16x4_t vinvsqrt_f16(float16x4_t x) inline float16x8_t vinvsqrtq_f16(float16x8_t x) { float16x8_t sqrt_reciprocal = vrsqrteq_f16(x); - sqrt_reciprocal = vmulq_f16(vrsqrtsq_f16(vmulq_f16(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); - sqrt_reciprocal = vmulq_f16(vrsqrtsq_f16(vmulq_f16(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); + sqrt_reciprocal = vmulq_f16(vrsqrtsq_f16(vmulq_f16(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); + sqrt_reciprocal = vmulq_f16(vrsqrtsq_f16(vmulq_f16(x, sqrt_reciprocal), sqrt_reciprocal), sqrt_reciprocal); return sqrt_reciprocal; } @@ -602,8 +597,8 @@ inline float16x4_t vtanh_rational_approx_f16(float16x4_t x16) inline float16x8_t vtanhq_f16(float16x8_t x) { // Split into high/low and use rational approximation on both parts exactly - const float16x8_t tanh = vcombine_f16(vtanh_rational_approx_f16(vget_low_f16(x)), - vtanh_rational_approx_f16(vget_high_f16(x))); + const float16x8_t tanh = + vcombine_f16(vtanh_rational_approx_f16(vget_low_f16(x)), vtanh_rational_approx_f16(vget_high_f16(x))); // tanh(x) == sign(x) to F16 precision for |x| >= 4.508, use sign after this const float16x8_t ONE = vdupq_n_f16(1.0f); |