From f7a3bf2519a0bdf38e12b3cfa9073dc4f293316d Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Fri, 15 Mar 2019 14:58:24 +0000 Subject: COMPMID-1995: Fix NEActivation Logistic Code simplified due to accuracy problem. Change-Id: Ife14656ca831655489bf43d6cf59b241d482b11e Signed-off-by: Isabella Gottardi Reviewed-on: https://review.mlplatform.org/c/861 Reviewed-by: Pablo Marquez Tested-by: Arm Jenkins --- src/core/NEON/kernels/NEActivationLayerKernel.cpp | 48 +++++++---------------- 1 file changed, 14 insertions(+), 34 deletions(-) (limited to 'src/core') diff --git a/src/core/NEON/kernels/NEActivationLayerKernel.cpp b/src/core/NEON/kernels/NEActivationLayerKernel.cpp index b67396c5a1..cf31cb841a 100644 --- a/src/core/NEON/kernels/NEActivationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEActivationLayerKernel.cpp @@ -319,6 +319,7 @@ typename std::enable_if::value, void>::type NEActivat const qasymm8_t b = sqcvt_qasymm8_f32(_act_info.b(), qi_in.scale, qi_in.offset); const qasymm8_t const_0 = sqcvt_qasymm8_f32(0.f, qi_in.scale, qi_in.offset); const qasymm8x16_t vconst_0 = vdupq_n_u8(const_0); + const auto vconst_1 = vdupq_n_f32(1.f); // Initialise scale/offset for re-quantization float s = qi_in.scale / qi_out.scale; @@ -361,41 +362,20 @@ typename std::enable_if::value, void>::type NEActivat } else if(act == ActivationFunction::LOGISTIC) { - const auto scale_in = vdupq_n_f32(qi_in.scale); - const auto off_in = vdupq_n_f32(qi_in.offset); - const auto scale_out = vdupq_n_f32(qi_out.scale); - const auto off_out = vdupq_n_f32(qi_out.offset); - const auto vconst_1 = vdupq_n_f32(1.f); - - const auto vin_low = wrapper::vgetlow(vin); - const auto vin_high = wrapper::vgethigh(vin); - uint16x8_t vin_low_u16x8 = wrapper::vmovl(vin_low); - uint16x8_t vin_high_u16x8 = wrapper::vmovl(vin_high); - // Convert uint16 vectors to uint32 vectors - uint32x4_t A_u32x4 = wrapper::vmovl(wrapper::vgetlow(vin_low_u16x8)); - uint32x4_t B_u32x4 = wrapper::vmovl(wrapper::vgethigh(vin_low_u16x8)); - uint32x4_t C_u32x4 = wrapper::vmovl(wrapper::vgetlow(vin_high_u16x8)); - uint32x4_t D_u32x4 = wrapper::vmovl(wrapper::vgethigh(vin_high_u16x8)); - // Convert uint32 vectors to float32 vectors - float32x4_t A_f32x4 = wrapper::vmul(wrapper::vsub(vcvtq_f32_u32(A_u32x4), off_in), scale_in); - float32x4_t B_f32x4 = wrapper::vmul(wrapper::vsub(vcvtq_f32_u32(B_u32x4), off_in), scale_in); - float32x4_t C_f32x4 = wrapper::vmul(wrapper::vsub(vcvtq_f32_u32(C_u32x4), off_in), scale_in); - float32x4_t D_f32x4 = wrapper::vmul(wrapper::vsub(vcvtq_f32_u32(D_u32x4), off_in), scale_in); + // De-quantize + const auto vin_deq = vdequantize(vin, qi_in); // Perform activation - A_f32x4 = wrapper::vdiv(vconst_1, wrapper::vadd(vconst_1, wrapper::vexpq(wrapper::vneg(A_f32x4)))); - B_f32x4 = wrapper::vdiv(vconst_1, wrapper::vadd(vconst_1, wrapper::vexpq(wrapper::vneg(B_f32x4)))); - C_f32x4 = wrapper::vdiv(vconst_1, wrapper::vadd(vconst_1, wrapper::vexpq(wrapper::vneg(C_f32x4)))); - D_f32x4 = wrapper::vdiv(vconst_1, wrapper::vadd(vconst_1, wrapper::vexpq(wrapper::vneg(D_f32x4)))); - // Convert float32 vectors to uint32 vectors - A_u32x4 = vcvtq_u32_f32(wrapper::vadd(wrapper::vdiv(A_f32x4, scale_out), off_out)); - B_u32x4 = vcvtq_u32_f32(wrapper::vadd(wrapper::vdiv(B_f32x4, scale_out), off_out)); - C_u32x4 = vcvtq_u32_f32(wrapper::vadd(wrapper::vdiv(C_f32x4, scale_out), off_out)); - D_u32x4 = vcvtq_u32_f32(wrapper::vadd(wrapper::vdiv(D_f32x4, scale_out), off_out)); - // Convert uint32 vectors to uint16 vectors (with saturation) - vin_low_u16x8 = wrapper::vcombine(wrapper::vqmovn(A_u32x4), wrapper::vqmovn(B_u32x4)); - vin_high_u16x8 = wrapper::vcombine(wrapper::vqmovn(C_u32x4), wrapper::vqmovn(D_u32x4)); - // convert uint16 vectors to uint8 vectors (with saturation) - tmp = wrapper::vcombine(wrapper::vqmovn(vin_low_u16x8), wrapper::vqmovn(vin_high_u16x8)); + const float32x4x4_t tmp_dep = + { + { + wrapper::vdiv(vconst_1, wrapper::vadd(vconst_1, wrapper::vexpq(wrapper::vneg(vin_deq.val[0])))), + wrapper::vdiv(vconst_1, wrapper::vadd(vconst_1, wrapper::vexpq(wrapper::vneg(vin_deq.val[1])))), + wrapper::vdiv(vconst_1, wrapper::vadd(vconst_1, wrapper::vexpq(wrapper::vneg(vin_deq.val[2])))), + wrapper::vdiv(vconst_1, wrapper::vadd(vconst_1, wrapper::vexpq(wrapper::vneg(vin_deq.val[3])))), + } + }; + // Re-quantize to new output space + tmp = vquantize(tmp_dep, qi_out); } else { -- cgit v1.2.1