aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON
diff options
context:
space:
mode:
authoralankelly <me@alankelly.dev>2019-05-15 23:05:31 +0200
committerPablo Marquez <pablo.tello@arm.com>2019-05-21 12:43:13 +0000
commit1f103d3bb2191d70806c9d9c0113aa5602079828 (patch)
tree7f55bb9cb143a53f776ca7908ab21aa2e9ea5130 /arm_compute/core/NEON
parent09f24975437e2e141ba51a07055a9372b0d173a2 (diff)
downloadComputeLibrary-1f103d3bb2191d70806c9d9c0113aa5602079828.tar.gz
Optimizes stride 2 NEDirectConvolution
Change-Id: I3d0593541af2fa9c9afe23224d2150a47c2092c5 Signed-off-by: Alan Kelly <me@alankelly.dev> Reviewed-on: https://review.mlplatform.org/c/1147 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/NEON')
-rw-r--r--arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h61
1 files changed, 49 insertions, 12 deletions
diff --git a/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
index 3547d2d110..1684894a5c 100644
--- a/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
+++ b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h
@@ -282,11 +282,31 @@ inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, c
int input_offset)
{
ARM_COMPUTE_UNUSED(input_offset);
+ const float32x4x2_t vtop = vld2q_f32(in_top);
+ const float32x4x2_t vmid = vld2q_f32(in_mid);
+ const float32x4x2_t vlow = vld2q_f32(in_low);
+ const float32x4_t vtop_end = vld1q_f32(in_top + 8);
+ const float32x4_t vmid_end = vld1q_f32(in_mid + 8);
+ const float32x4_t vlow_end = vld1q_f32(in_low + 8);
+
+ float32x4x2_t out =
+ {
+ {
+ vmulq_f32(vtop.val[0], m0.val[0]),
+ vdupq_n_f32(0)
+ }
+ };
+ out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
+
+ out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
+ out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
+ out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
- float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
- out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
return out;
}
@@ -841,14 +861,31 @@ inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *i
int input_offset)
{
ARM_COMPUTE_UNUSED(input_offset);
- float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
- out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
+ const float16x8x2_t vtop = vld2q_f16(in_top);
+ const float16x8x2_t vmid = vld2q_f16(in_mid);
+ const float16x8x2_t vlow = vld2q_f16(in_low);
+ const float16x8_t vtop_end = vld1q_f16(in_top + 16);
+ const float16x8_t vmid_end = vld1q_f16(in_mid + 16);
+ const float16x8_t vlow_end = vld1q_f16(in_low + 16);
+
+ float16x8x2_t out =
+ {
+ {
+ vmulq_f16(vtop.val[0], m0.val[0]),
+ vdupq_n_f16(0)
+ }
+ };
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2]));
+
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2]));
+
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1]));
+ out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2]));
+
return out;
}