aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/NEAsymm.h
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2019-10-29 10:58:13 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2019-12-20 14:05:24 +0000
commitf29d1b7d8bf2d1619554eb3443556b44d4aa1a4c (patch)
tree0a427f7fda2131f39e055f27b97f0a612aff990c /arm_compute/core/NEON/NEAsymm.h
parent748a7c81245ae81d04607b3a762cf65cd39026f2 (diff)
downloadComputeLibrary-f29d1b7d8bf2d1619554eb3443556b44d4aa1a4c.tar.gz
COMPMID-2608: Enable quantization with multiplier greater than 1 on NEON
Change-Id: Ib2b0c9ac88fc2b645f478c9981f71ee28f2c77fd Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/2425 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/NEON/NEAsymm.h')
-rw-r--r--arm_compute/core/NEON/NEAsymm.h160
1 files changed, 120 insertions, 40 deletions
diff --git a/arm_compute/core/NEON/NEAsymm.h b/arm_compute/core/NEON/NEAsymm.h
index 67adcef9b1..c09a7d9028 100644
--- a/arm_compute/core/NEON/NEAsymm.h
+++ b/arm_compute/core/NEON/NEAsymm.h
@@ -88,17 +88,32 @@ uint8x16_t finalize_quantization(int32x4x4_t &in_s32,
{
const static int32x4_t zero_s32 = vdupq_n_s32(0);
- // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
- in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
- in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
- in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
- in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
-
- // Round to the nearest division by a power-of-two using result_shift_s32
- in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
- in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
- in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift);
- in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift);
+ if(result_shift < 0)
+ {
+ in_s32.val[0] = vmulq_n_s32(in_s32.val[0], (1 << (-result_shift)));
+ in_s32.val[1] = vmulq_n_s32(in_s32.val[1], (1 << (-result_shift)));
+ in_s32.val[2] = vmulq_n_s32(in_s32.val[2], (1 << (-result_shift)));
+ in_s32.val[3] = vmulq_n_s32(in_s32.val[3], (1 << (-result_shift)));
+
+ in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+ in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+ in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
+ in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+ }
+ else
+ {
+ // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+ in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+ in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+ in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
+ in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+
+ // Round to the nearest division by a power-of-two using result_shift_s32
+ in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
+ in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
+ in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift);
+ in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift);
+ }
// Add the offset terms
in_s32.val[0] = vaddq_s32(in_s32.val[0], result_offset_after_shift_s32);
@@ -154,17 +169,32 @@ int8x16_t finalize_quantization(int32x4x4_t &in_s32,
int8x16_t min_s8,
int8x16_t max_s8)
{
- // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
- in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
- in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
- in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
- in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
-
- // Round to the nearest division by a power-of-two using result_shift_s32
- in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
- in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
- in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift);
- in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift);
+ if(result_shift < 0)
+ {
+ in_s32.val[0] = vmulq_n_s32(in_s32.val[0], (1 << (-result_shift)));
+ in_s32.val[1] = vmulq_n_s32(in_s32.val[1], (1 << (-result_shift)));
+ in_s32.val[2] = vmulq_n_s32(in_s32.val[2], (1 << (-result_shift)));
+ in_s32.val[3] = vmulq_n_s32(in_s32.val[3], (1 << (-result_shift)));
+
+ in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+ in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+ in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
+ in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+ }
+ else
+ {
+ // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+ in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+ in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+ in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
+ in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+
+ // Round to the nearest division by a power-of-two using result_shift_s32
+ in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
+ in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
+ in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift);
+ in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift);
+ }
// Add the offset terms
in_s32.val[0] = vaddq_s32(in_s32.val[0], result_offset_after_shift_s32);
@@ -214,17 +244,54 @@ inline int8x16_t finalize_quantization_symm(int32x4x4_t &in_s32,
const int8x16_t &min_s8,
const int8x16_t &max_s8)
{
- // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
- in_s32.val[0] = vqrdmulhq_s32(in_s32.val[0], result_fixedpoint_multiplier.val[0]);
- in_s32.val[1] = vqrdmulhq_s32(in_s32.val[1], result_fixedpoint_multiplier.val[1]);
- in_s32.val[2] = vqrdmulhq_s32(in_s32.val[2], result_fixedpoint_multiplier.val[2]);
- in_s32.val[3] = vqrdmulhq_s32(in_s32.val[3], result_fixedpoint_multiplier.val[3]);
+ const static int32x4_t one_s32 = vdupq_n_s32(1);
+ // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+ int32x4x4_t res_shift_gt0 =
+ {
+ vqrdmulhq_s32(in_s32.val[0], result_fixedpoint_multiplier.val[0]),
+ vqrdmulhq_s32(in_s32.val[1], result_fixedpoint_multiplier.val[1]),
+ vqrdmulhq_s32(in_s32.val[2], result_fixedpoint_multiplier.val[2]),
+ vqrdmulhq_s32(in_s32.val[3], result_fixedpoint_multiplier.val[3]),
+ };
// Round to the nearest division by a power-of-two using result_shift_s32
- in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift.val[0]);
- in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift.val[1]);
- in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift.val[2]);
- in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift.val[3]);
+ res_shift_gt0.val[0] = rounding_divide_by_pow2(res_shift_gt0.val[0], result_shift.val[0]);
+ res_shift_gt0.val[1] = rounding_divide_by_pow2(res_shift_gt0.val[1], result_shift.val[1]);
+ res_shift_gt0.val[2] = rounding_divide_by_pow2(res_shift_gt0.val[2], result_shift.val[2]);
+ res_shift_gt0.val[3] = rounding_divide_by_pow2(res_shift_gt0.val[3], result_shift.val[3]);
+
+ int32x4x4_t res_shift_lt0 =
+ {
+ vmulq_s32(in_s32.val[0], vshlq_s32(one_s32, vnegq_s32(result_shift.val[0]))),
+ vmulq_s32(in_s32.val[1], vshlq_s32(one_s32, vnegq_s32(result_shift.val[1]))),
+ vmulq_s32(in_s32.val[2], vshlq_s32(one_s32, vnegq_s32(result_shift.val[2]))),
+ vmulq_s32(in_s32.val[3], vshlq_s32(one_s32, vnegq_s32(result_shift.val[3]))),
+ };
+ res_shift_lt0.val[0] = vqrdmulhq_s32(res_shift_lt0.val[0], result_fixedpoint_multiplier.val[0]);
+ res_shift_lt0.val[1] = vqrdmulhq_s32(res_shift_lt0.val[1], result_fixedpoint_multiplier.val[1]);
+ res_shift_lt0.val[2] = vqrdmulhq_s32(res_shift_lt0.val[2], result_fixedpoint_multiplier.val[2]);
+ res_shift_lt0.val[3] = vqrdmulhq_s32(res_shift_lt0.val[3], result_fixedpoint_multiplier.val[3]);
+
+ // Select result depending on shift value
+ const uint32x4x4_t mask_lt0 =
+ {
+#ifdef __aarch64__
+ vcltzq_s32(result_shift.val[0]),
+ vcltzq_s32(result_shift.val[1]),
+ vcltzq_s32(result_shift.val[2]),
+ vcltzq_s32(result_shift.val[3]),
+#else //__aarch64__
+ vcltq_s32(result_shift.val[0], vdupq_n_s32(0)),
+ vcltq_s32(result_shift.val[1], vdupq_n_s32(0)),
+ vcltq_s32(result_shift.val[2], vdupq_n_s32(0)),
+ vcltq_s32(result_shift.val[3], vdupq_n_s32(0)),
+#endif //__aarch64__
+ };
+
+ in_s32.val[0] = vbslq_s32(mask_lt0.val[0], res_shift_lt0.val[0], res_shift_gt0.val[0]);
+ in_s32.val[1] = vbslq_s32(mask_lt0.val[1], res_shift_lt0.val[1], res_shift_gt0.val[1]);
+ in_s32.val[2] = vbslq_s32(mask_lt0.val[2], res_shift_lt0.val[2], res_shift_gt0.val[2]);
+ in_s32.val[3] = vbslq_s32(mask_lt0.val[3], res_shift_lt0.val[3], res_shift_gt0.val[3]);
// Add the offset terms
in_s32.val[0] = vaddq_s32(in_s32.val[0], result_offset_after_shift_s32);
@@ -273,11 +340,17 @@ inline uint8_t finalize_quantization(int32_t in_value, int result_fixedpoint_mul
{
int32x4_t in_s32 = vdupq_n_s32(in_value);
- // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
- in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
-
- // Shift value by result_shift_s32
- in_value = rounding_divide_by_pow2(in_value, result_shift);
+ if(result_shift < 0)
+ {
+ in_value = vgetq_lane_s32(vqrdmulhq_n_s32(vmulq_n_s32(in_s32, (1 << (-result_shift))), result_fixedpoint_multiplier), 0);
+ }
+ else
+ {
+ // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+ in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
+ // Shift value by result_shift_s32
+ in_value = rounding_divide_by_pow2(in_value, result_shift);
+ }
// Add the offset term
in_value += result_offset_after_shift_s32;
@@ -312,11 +385,18 @@ inline int8_t finalize_quantization(int32_t in_value, int result_fixedpoint_mult
{
int32x4_t in_s32 = vdupq_n_s32(in_value);
- // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
- in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
+ if(result_shift < 0)
+ {
+ in_value = vgetq_lane_s32(vqrdmulhq_n_s32(vmulq_n_s32(in_s32, (1 << (-result_shift))), result_fixedpoint_multiplier), 0);
+ }
+ else
+ {
+ // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+ in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
- // Shift value by result_shift_s32
- in_value = rounding_divide_by_pow2(in_value, result_shift);
+ // Shift value by result_shift_s32
+ in_value = rounding_divide_by_pow2(in_value, result_shift);
+ }
// Add the offset term
in_value += result_offset_after_shift_s32;