From 13ec5f0a09e038f12cbe0f3b119a215934b72b42 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 2 Jan 2020 12:11:13 +0000 Subject: COMPMID-2800: Add support for QASYMM8_SIGNED in NEDepthwiseConvolutionLayer3x3Kernel Change-Id: Ia5d23ff2c9e59c80ded2fac5ca02704214f0a01a Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/2537 Comments-Addressed: Arm Jenkins Reviewed-by: Pablo Marquez Tested-by: Arm Jenkins --- .../kernels/detail/NEDirectConvolutionDetail.h | 743 +++++++++------------ 1 file changed, 320 insertions(+), 423 deletions(-) (limited to 'arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h') diff --git a/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h index c4f6ac7c66..788bb649b0 100644 --- a/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h +++ b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,8 @@ #include "arm_compute/core/AccessWindowStatic.h" #include "arm_compute/core/NEON/NEFixedPoint.h" +#include "arm_compute/core/NEON/wrapper/wrapper.h" +#include "arm_compute/core/utils/misc/Requires.h" #include @@ -55,14 +57,15 @@ inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0) return r; } -/** Loads a 3x3 matrix as a row (uint8_t). +/** Loads a 3x3 matrix as a row (uint8_t/int8_t). * - * @param[in] ptr Pointer to a uint8_t 3x3 matrix. + * @param[in] ptr Pointer to a uint8_t/int8_t 3x3 matrix. * @param[in] weights_offset (Optional) Weights quantization offset. * * @return The loaded matrix. */ -inline int32x4x3_t load_matrix_row(const uint8_t *ptr, int weights_offset = 0) +template < typename T, REQUIRES_TA(std::is_same::value || std::is_same::value) > +inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0) { const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset); @@ -145,22 +148,16 @@ inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. * @param[in] dilation_x Dilation, in elements across x. + * @param[in] stridex Stride value in elements across x. * @param[in] input_offset (Optional) Input quantization offset. * */ -template -float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - const size_t dilation_x, int input_offset = 0); - -template <> -inline float32x4x2_t convolve_3x3_dilation<1>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - const size_t dilation_x, int input_offset) +inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low, + const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, + const size_t dilation_x, unsigned int stridex, int input_offset = 0) { - ARM_COMPUTE_UNUSED(input_offset); - - const float32x4x2_t out = + ARM_COMPUTE_ERROR_ON(stridex > 3); + float32x4x2_t out = { { single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset), @@ -168,33 +165,17 @@ inline float32x4x2_t convolve_3x3_dilation<1>(const float *in_top, const float * } }; - return out; -} - -template <> -inline float32x4x2_t convolve_3x3_dilation<2>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - const size_t dilation_x, int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - - float32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3); - return out; -} - -template <> -inline float32x4x2_t convolve_3x3_dilation<3>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - const size_t dilation_x, int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); + if(stridex == 2) + { + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1); + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2); + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3); + } + else if(stridex == 3) + { + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); + } - float32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - ; - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); return out; } @@ -206,123 +187,111 @@ inline float32x4x2_t convolve_3x3_dilation<3>(const float *in_top, const float * * @param[in] m0 First row of the filter. * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. + * @param[in] stridex Stride value in elements across x. * @param[in] input_offset (Optional) Input quantization offset. * */ -template float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - int input_offset = 0); + unsigned int stridex, int input_offset = 0); -template <> -inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - int input_offset) +inline float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, + const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, + unsigned int stridex, int input_offset) { ARM_COMPUTE_UNUSED(input_offset); + ARM_COMPUTE_ERROR_ON(stridex > 3); - const float32x4x3_t vtop = + float32x4x2_t out = { { - vld1q_f32(in_top), - vld1q_f32(in_top + 4), - vld1q_f32(in_top + 8) + vdupq_n_f32(0.f), + vdupq_n_f32(0.f) } }; - const float32x4x3_t vmid = + if(stridex == 2) { - { - vld1q_f32(in_mid), - vld1q_f32(in_mid + 4), - vld1q_f32(in_mid + 8) - } - }; - const float32x4x3_t vlow = + const float32x4x2_t vtop = vld2q_f32(in_top); + const float32x4x2_t vmid = vld2q_f32(in_mid); + const float32x4x2_t vlow = vld2q_f32(in_low); + const float32x4_t vtop_end = vld1q_f32(in_top + 8); + const float32x4_t vmid_end = vld1q_f32(in_mid + 8); + const float32x4_t vlow_end = vld1q_f32(in_low + 8); + + out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]); + + out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]); + + out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]); + + out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]); + } + else { + const float32x4x3_t vtop = { - vld1q_f32(in_low), - vld1q_f32(in_low + 4), - vld1q_f32(in_low + 8) - } - }; - float32x4x2_t out = - { + { + vld1q_f32(in_top), + vld1q_f32(in_top + 4), + vld1q_f32(in_top + 8) + } + }; + const float32x4x3_t vmid = { - vmulq_f32(vtop.val[0], m0.val[0]), - vmulq_f32(vtop.val[1], m0.val[0]) - } - }; - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]); + { + vld1q_f32(in_mid), + vld1q_f32(in_mid + 4), + vld1q_f32(in_mid + 8) + } + }; + const float32x4x3_t vlow = + { + { + vld1q_f32(in_low), + vld1q_f32(in_low + 4), + vld1q_f32(in_low + 8) + } + }; + out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]); + out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]); - out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]); - return out; -} + out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]); -template <> -inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - const float32x4x2_t vtop = vld2q_f32(in_top); - const float32x4x2_t vmid = vld2q_f32(in_mid); - const float32x4x2_t vlow = vld2q_f32(in_low); - const float32x4_t vtop_end = vld1q_f32(in_top + 8); - const float32x4_t vmid_end = vld1q_f32(in_mid + 8); - const float32x4_t vlow_end = vld1q_f32(in_low + 8); + out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]); - float32x4x2_t out = - { + if(stridex == 3) { - vmulq_f32(vtop.val[0], m0.val[0]), - vdupq_n_f32(0) + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); } - }; - out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]); - - out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]); - - out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]); - - return out; -} - -template <> -inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); + } - float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); return out; } -/** Perform a 3x3 convolution for 4 consecutive elements on uint8_t when dilation.x() or dilation.y() is not 1. +/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1. * * @param[in] in_top Pointer to the first row of the input. * @param[in] in_mid Pointer to the second row of the input. @@ -334,78 +303,82 @@ inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, c * @param[in] input_offset Input quantization offset. * */ -inline int32x4_t single_convolve_3x3_dilation(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, +template < typename T, REQUIRES_TA(std::is_same::value || std::is_same::value) > +inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, size_t dilation_x, int input_offset) { - const int32x4_t v_input_offset = vdupq_n_s32(input_offset); + using VectorType = typename std::conditional::value, uint8x8x3_t, int8x8x3_t>::type; + using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t; - const uint8x8x3_t vtop = + const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{}); + + const VectorType vtop = { { - vld1_u8(in_top), - vld1_u8(in_top + dilation_x), - vld1_u8(in_top + 2 * dilation_x) + wrapper::vload(in_top), + wrapper::vload(in_top + dilation_x), + wrapper::vload(in_top + 2 * dilation_x) } }; - const uint8x8x3_t vmid = + const VectorType vmid = { { - vld1_u8(in_mid), - vld1_u8(in_mid + dilation_x), - vld1_u8(in_mid + 2 * dilation_x) + wrapper::vload(in_mid), + wrapper::vload(in_mid + dilation_x), + wrapper::vload(in_mid + 2 * dilation_x) } }; - const uint8x8x3_t vlow = + const VectorType vlow = { { - vld1_u8(in_low), - vld1_u8(in_low + dilation_x), - vld1_u8(in_low + 2 * dilation_x) + wrapper::vload(in_low), + wrapper::vload(in_low + dilation_x), + wrapper::vload(in_low + 2 * dilation_x) } }; const int32x4x3_t vtop_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[0])))), //convert from uint8x8 to uint16x8, to uint16x4(lower or bottom half) to int16x4 to int32x4 - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[1])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[2])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[2])))), } }; const int32x4x3_t vmid_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[1])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[2])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[2])))), } }; const int32x4x3_t vlow_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[1])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[2])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[2])))), } }; - int32x4_t out = vmulq_s32(vtop_s32.val[0], m0.val[0]); - out = vmlaq_s32(out, vtop_s32.val[1], m0.val[1]); - out = vmlaq_s32(out, vtop_s32.val[2], m0.val[2]); + int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]); + out = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]); + out = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]); - out = vmlaq_s32(out, vmid_s32.val[0], m1.val[0]); - out = vmlaq_s32(out, vmid_s32.val[1], m1.val[1]); - out = vmlaq_s32(out, vmid_s32.val[2], m1.val[2]); + out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]); + out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]); + out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]); - out = vmlaq_s32(out, vlow_s32.val[0], m2.val[0]); - out = vmlaq_s32(out, vlow_s32.val[1], m2.val[1]); - out = vmlaq_s32(out, vlow_s32.val[2], m2.val[2]); + out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]); + out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]); + out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]); return out; } -/** Perform a 3x3 convolution for 4 consecutive elements on uint8_t when dilation.x() or dilation.y() is not 1. +/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1. * * @param[in] in_top Pointer to the first row of the input. * @param[in] in_mid Pointer to the second row of the input. @@ -414,52 +387,37 @@ inline int32x4_t single_convolve_3x3_dilation(const uint8_t *in_top, const uint8 * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. * @param[in] dilation_x Dilation, in elements across x. + * @param[in] stridex Stride value in elements across x. * @param[in] input_offset Input quantization offset. * */ -template -int32x4x2_t convolve_3x3_dilation(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - const size_t dilation_x, int input_offset); - -template <> -inline int32x4x2_t convolve_3x3_dilation<1>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - const size_t dilation_x, int input_offset) +template < typename T, REQUIRES_TA(std::is_same::value || std::is_same::value) > +inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, + const size_t dilation_x, unsigned int stridex, int input_offset) { - const int32x4x2_t out = + ARM_COMPUTE_ERROR_ON(stridex > 3); + int32x4x2_t out = { { single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset), single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset) } }; - return out; -} - -template <> -inline int32x4x2_t convolve_3x3_dilation<2>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - const size_t dilation_x, int input_offset) -{ - int32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 0), out.val[0], 2); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 2), out.val[0], 3); - return out; -} -template <> -inline int32x4x2_t convolve_3x3_dilation<3>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - const size_t dilation_x, int input_offset) -{ - int32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 3), out.val[0], 1); + if(stridex == 2) + { + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1); + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2); + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3); + } + else if(stridex == 3) + { + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1); + } return out; } -/** Perform a convolve3x3 on uint8_t +/** Perform a convolve3x3 on 8-bit elements * * @param[in] in_top Pointer to the first row of the input. * @param[in] in_mid Pointer to the second row of the input. @@ -467,123 +425,112 @@ inline int32x4x2_t convolve_3x3_dilation<3>(const uint8_t *in_top, const uint8_t * @param[in] m0 First row of the filter. * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. - * @param[in] input_offset (Optional) Input quantization offset. + * @param[in] stridex Stride value in elements across x. + * @param[in] input_offset Input quantization offset. * */ -template -int32x4x2_t convolve_3x3(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, +template < typename T, REQUIRES_TA(std::is_same::value || std::is_same::value) > +int32x4x2_t convolve_3x3(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - int input_offset); - -template <> -inline int32x4x2_t convolve_3x3<1>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - int input_offset) + unsigned int stridex, int input_offset) { - const int32x4_t v_input_offset = vdupq_n_s32(input_offset); + ARM_COMPUTE_ERROR_ON(stridex > 3); + using VectorType = typename std::conditional::value, uint8x8x2_t, int8x8x2_t>::type; + using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t; - const uint8x8x2_t vtop = + const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{}); + + const VectorType vtop = { { - vld1_u8(in_top), - vld1_u8(in_top + 8) + wrapper::vload(in_top), + wrapper::vload(in_top + 8) } }; - const uint8x8x2_t vmid = + const VectorType vmid = { { - vld1_u8(in_mid), - vld1_u8(in_mid + 8) + wrapper::vload(in_mid), + wrapper::vload(in_mid + 8) } }; - const uint8x8x2_t vlow = + const VectorType vlow = { { - vld1_u8(in_low), - vld1_u8(in_low + 8) + wrapper::vload(in_low), + wrapper::vload(in_low + 8) } }; const int32x4x3_t vtop_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vtop.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vtop.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))), } }; const int32x4x3_t vmid_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vmid.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vmid.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))), } }; const int32x4x3_t vlow_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vlow.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vlow.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))), } }; int32x4x2_t out { { - vdupq_n_s32(0), - vdupq_n_s32(0), + wrapper::vdup_n(0, OutputTagType{}), + wrapper::vdup_n(0, OutputTagType{}), } }; // 0 - out.val[0] = vmlaq_s32(out.val[0], vtop_s32.val[0], m0.val[0]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 1), m0.val[1]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 2), m0.val[2]); + out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]); - out.val[0] = vmlaq_s32(out.val[0], vmid_s32.val[0], m1.val[0]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 1), m1.val[1]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 2), m1.val[2]); + out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]); - out.val[0] = vmlaq_s32(out.val[0], vlow_s32.val[0], m2.val[0]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 1), m2.val[1]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 2), m2.val[2]); + out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]); // 1 - out.val[1] = vmlaq_s32(out.val[1], vtop_s32.val[1], m0.val[0]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 1), m0.val[1]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 2), m0.val[2]); + out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]); - out.val[1] = vmlaq_s32(out.val[1], vmid_s32.val[1], m1.val[0]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 1), m1.val[1]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 2), m1.val[2]); + out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]); - out.val[1] = vmlaq_s32(out.val[1], vlow_s32.val[1], m2.val[0]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 1), m2.val[1]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 2), m2.val[2]); + out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]); - return out; -} - -template <> -inline int32x4x2_t convolve_3x3<2>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - int input_offset) -{ - int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 0), out.val[0], 2); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 2), out.val[0], 3); - return out; -} - -template <> -inline int32x4x2_t convolve_3x3<3>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - int input_offset) -{ - int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 3), out.val[0], 1); + if(stridex == 2) + { + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1); + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2); + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3); + } + else if(stridex == 3) + { + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1); + } return out; } @@ -731,174 +678,143 @@ inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const f * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. * @param[in] dilation_x Dilation, in elements across x. - * @param[in] input_offset (Optional)Input quantization offset. + * @param[in] stridex Stride value in elements across x. + * @param[in] input_offset (Optional) Input quantization offset. * */ -template -float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - const size_t dilation_x, int input_offset = 0); - -template <> -inline float16x8x2_t convolve_3x3_dilation<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - const size_t dilation_x, int input_offset) +inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, + const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, + const size_t dilation_x, unsigned int stridex, int input_offset = 0) { - const float16x8x2_t out = + float16x8x2_t out = { { single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset), single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset) } }; - return out; -} -template <> -inline float16x8x2_t convolve_3x3_dilation<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - const size_t dilation_x, int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - float16x8x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7); - return out; -} + if(stridex == 2) + { + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7); + } + else if(stridex == 3) + { + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3); + } -template <> -inline float16x8x2_t convolve_3x3_dilation<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - const size_t dilation_x, int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - float16x8x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3); return out; } /** Perform a convolve3x3 on float16. * - * @param[in] in_top Pointer to the first row of the input. - * @param[in] in_mid Pointer to the second row of the input. - * @param[in] in_low Pointer to the third row of the input. - * @param[in] m0 First row of the filter. - * @param[in] m1 Second row of the filter. - * @param[in] m2 Third row of the filter. + * @param[in] in_top Pointer to the first row of the input. + * @param[in] in_mid Pointer to the second row of the input. + * @param[in] in_low Pointer to the third row of the input. + * @param[in] m0 First row of the filter. + * @param[in] m1 Second row of the filter. + * @param[in] m2 Third row of the filter. + * @param[in] stridex Stride value in elements across x. + * @param[in] input_offset (Optional) Input quantization offset. * */ -template -float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - int input_offset = 0); - -template <> -inline float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - const float16x8x3_t vtop = - { - { - vld1q_f16(in_top), - vld1q_f16(in_top + 8), - vld1q_f16(in_top + 16) - } - }; - const float16x8x3_t vmid = - { - { - vld1q_f16(in_mid), - vld1q_f16(in_mid + 8), - vld1q_f16(in_mid + 16) - } - }; - const float16x8x3_t vlow = - { - { - vld1q_f16(in_low), - vld1q_f16(in_low + 8), - vld1q_f16(in_low + 16) - } - }; - float16x8x2_t out = - { - { - vmulq_f16(vtop.val[0], m0.val[0]), - vmulq_f16(vtop.val[1], m0.val[0]) - } - }; - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2])); - return out; -} - -template <> -inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - int input_offset) +inline float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, + const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, + unsigned int stridex, int input_offset = 0) { ARM_COMPUTE_UNUSED(input_offset); - const float16x8x2_t vtop = vld2q_f16(in_top); - const float16x8x2_t vmid = vld2q_f16(in_mid); - const float16x8x2_t vlow = vld2q_f16(in_low); - const float16x8_t vtop_end = vld1q_f16(in_top + 16); - const float16x8_t vmid_end = vld1q_f16(in_mid + 16); - const float16x8_t vlow_end = vld1q_f16(in_low + 16); float16x8x2_t out = { { - vmulq_f16(vtop.val[0], m0.val[0]), + vdupq_n_f16(0), vdupq_n_f16(0) } }; - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2])); + if(stridex == 2) + { + const float16x8x2_t vtop = vld2q_f16(in_top); + const float16x8x2_t vmid = vld2q_f16(in_mid); + const float16x8x2_t vlow = vld2q_f16(in_low); + const float16x8_t vtop_end = vld1q_f16(in_top + 16); + const float16x8_t vmid_end = vld1q_f16(in_mid + 16); + const float16x8_t vlow_end = vld1q_f16(in_low + 16); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2])); + out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2])); - return out; -} + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2])); + + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2])); + } + else + { + const float16x8x3_t vtop = + { + { + vld1q_f16(in_top), + vld1q_f16(in_top + 8), + vld1q_f16(in_top + 16) + } + }; + const float16x8x3_t vmid = + { + { + vld1q_f16(in_mid), + vld1q_f16(in_mid + 8), + vld1q_f16(in_mid + 16) + } + }; + const float16x8x3_t vlow = + { + { + vld1q_f16(in_low), + vld1q_f16(in_low + 8), + vld1q_f16(in_low + 16) + } + }; + out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]); + out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]); + + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2])); + + if(stridex == 3) + { + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3); + } + } -template <> -inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3); return out; } @@ -934,39 +850,20 @@ inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values) /** Get the number of elements processed on 3x3 convolution. * * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution. + * @param[in] stridex Stride value in elements across x. * * @return The number of elements processed. */ -template -int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration); - -template <> -inline int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration) -{ - return num_elems_written_per_iteration; -} - -template <> -inline int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration) -{ - return num_elems_written_per_iteration << 1; -} - -template <> -inline int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration) -{ - return num_elems_written_per_iteration * 3; -} inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex) { switch(stridex) { case 1: - return get_input_num_elems_processed<1>(num_elems_written_per_iteration); + return num_elems_written_per_iteration; case 2: - return get_input_num_elems_processed<2>(num_elems_written_per_iteration); + return num_elems_written_per_iteration << 1; case 3: - return get_input_num_elems_processed<3>(num_elems_written_per_iteration); + return num_elems_written_per_iteration * 3; default: ARM_COMPUTE_ERROR("stridex not supported"); return 0; -- cgit v1.2.1