From 13ec5f0a09e038f12cbe0f3b119a215934b72b42 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 2 Jan 2020 12:11:13 +0000 Subject: COMPMID-2800: Add support for QASYMM8_SIGNED in NEDepthwiseConvolutionLayer3x3Kernel Change-Id: Ia5d23ff2c9e59c80ded2fac5ca02704214f0a01a Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/2537 Comments-Addressed: Arm Jenkins Reviewed-by: Pablo Marquez Tested-by: Arm Jenkins --- .../kernels/NEDepthwiseConvolutionLayer3x3Kernel.h | 6 +- .../kernels/detail/NEDirectConvolutionDetail.h | 743 +++++++++------------ arm_compute/core/NEON/wrapper/intrinsics/ext.h | 45 ++ .../core/NEON/wrapper/intrinsics/intrinsics.h | 3 +- .../core/NEON/wrapper/intrinsics/reinterpret.h | 24 +- arm_compute/core/NEON/wrapper/intrinsics/setlane.h | 4 +- 6 files changed, 387 insertions(+), 438 deletions(-) create mode 100644 arm_compute/core/NEON/wrapper/intrinsics/ext.h (limited to 'arm_compute') diff --git a/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h b/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h index efde38b47a..227ddb4743 100644 --- a/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h +++ b/arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,7 +53,7 @@ public: * * @note Supported data layouts: NCHW and NHWC * - * @param[in] input Source tensor. DataType supported: QASYMM8/F16/F32. + * @param[in] input Source tensor. DataType supported: QASYMM8/QASYMM8_SIGNED/F16/F32. * @param[in] weights Weights tensor. This is a 3D tensor with dimensions [3, 3, IFM] for NCHW or [IFM, 3, 3] if NHWC data layout. Data type supported: Same as @p input. * @param[out] output Destination tensor. Data type supported: Same as @p input. * @param[in] conv_info Padding and stride information to use for the convolution. @@ -66,7 +66,7 @@ public: * * @note Supported data layouts: NCHW and NHWC * - * @param[in] input Source tensor info. DataType supported: QASYMM8/F16/F32. + * @param[in] input Source tensor info. DataType supported: QASYMM8/QASYMM8_SIGNED/F16/F32. * @param[in] weights Weights tensor info. This is a 3D tensor with dimensions [3, 3, IFM] for NCHW or [IFM, 3, 3] if NHWC data layout. Data type supported: Same as @p input. * @param[in] output Destination tensor info. Data type supported: Same as @p input. * @param[in] conv_info Padding and stride information to use for the convolution. diff --git a/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h index c4f6ac7c66..788bb649b0 100644 --- a/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h +++ b/arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,8 @@ #include "arm_compute/core/AccessWindowStatic.h" #include "arm_compute/core/NEON/NEFixedPoint.h" +#include "arm_compute/core/NEON/wrapper/wrapper.h" +#include "arm_compute/core/utils/misc/Requires.h" #include @@ -55,14 +57,15 @@ inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0) return r; } -/** Loads a 3x3 matrix as a row (uint8_t). +/** Loads a 3x3 matrix as a row (uint8_t/int8_t). * - * @param[in] ptr Pointer to a uint8_t 3x3 matrix. + * @param[in] ptr Pointer to a uint8_t/int8_t 3x3 matrix. * @param[in] weights_offset (Optional) Weights quantization offset. * * @return The loaded matrix. */ -inline int32x4x3_t load_matrix_row(const uint8_t *ptr, int weights_offset = 0) +template < typename T, REQUIRES_TA(std::is_same::value || std::is_same::value) > +inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0) { const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset); @@ -145,22 +148,16 @@ inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. * @param[in] dilation_x Dilation, in elements across x. + * @param[in] stridex Stride value in elements across x. * @param[in] input_offset (Optional) Input quantization offset. * */ -template -float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - const size_t dilation_x, int input_offset = 0); - -template <> -inline float32x4x2_t convolve_3x3_dilation<1>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - const size_t dilation_x, int input_offset) +inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low, + const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, + const size_t dilation_x, unsigned int stridex, int input_offset = 0) { - ARM_COMPUTE_UNUSED(input_offset); - - const float32x4x2_t out = + ARM_COMPUTE_ERROR_ON(stridex > 3); + float32x4x2_t out = { { single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset), @@ -168,33 +165,17 @@ inline float32x4x2_t convolve_3x3_dilation<1>(const float *in_top, const float * } }; - return out; -} - -template <> -inline float32x4x2_t convolve_3x3_dilation<2>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - const size_t dilation_x, int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - - float32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3); - return out; -} - -template <> -inline float32x4x2_t convolve_3x3_dilation<3>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - const size_t dilation_x, int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); + if(stridex == 2) + { + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1); + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2); + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3); + } + else if(stridex == 3) + { + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); + } - float32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - ; - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); return out; } @@ -206,123 +187,111 @@ inline float32x4x2_t convolve_3x3_dilation<3>(const float *in_top, const float * * @param[in] m0 First row of the filter. * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. + * @param[in] stridex Stride value in elements across x. * @param[in] input_offset (Optional) Input quantization offset. * */ -template float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - int input_offset = 0); + unsigned int stridex, int input_offset = 0); -template <> -inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - int input_offset) +inline float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, + const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, + unsigned int stridex, int input_offset) { ARM_COMPUTE_UNUSED(input_offset); + ARM_COMPUTE_ERROR_ON(stridex > 3); - const float32x4x3_t vtop = + float32x4x2_t out = { { - vld1q_f32(in_top), - vld1q_f32(in_top + 4), - vld1q_f32(in_top + 8) + vdupq_n_f32(0.f), + vdupq_n_f32(0.f) } }; - const float32x4x3_t vmid = + if(stridex == 2) { - { - vld1q_f32(in_mid), - vld1q_f32(in_mid + 4), - vld1q_f32(in_mid + 8) - } - }; - const float32x4x3_t vlow = + const float32x4x2_t vtop = vld2q_f32(in_top); + const float32x4x2_t vmid = vld2q_f32(in_mid); + const float32x4x2_t vlow = vld2q_f32(in_low); + const float32x4_t vtop_end = vld1q_f32(in_top + 8); + const float32x4_t vmid_end = vld1q_f32(in_mid + 8); + const float32x4_t vlow_end = vld1q_f32(in_low + 8); + + out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]); + + out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]); + + out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]); + + out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]); + } + else { + const float32x4x3_t vtop = { - vld1q_f32(in_low), - vld1q_f32(in_low + 4), - vld1q_f32(in_low + 8) - } - }; - float32x4x2_t out = - { + { + vld1q_f32(in_top), + vld1q_f32(in_top + 4), + vld1q_f32(in_top + 8) + } + }; + const float32x4x3_t vmid = { - vmulq_f32(vtop.val[0], m0.val[0]), - vmulq_f32(vtop.val[1], m0.val[0]) - } - }; - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]); + { + vld1q_f32(in_mid), + vld1q_f32(in_mid + 4), + vld1q_f32(in_mid + 8) + } + }; + const float32x4x3_t vlow = + { + { + vld1q_f32(in_low), + vld1q_f32(in_low + 4), + vld1q_f32(in_low + 8) + } + }; + out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]); + out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]); - out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]); + out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]); + out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]); - out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]); - out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]); - return out; -} + out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]); -template <> -inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - const float32x4x2_t vtop = vld2q_f32(in_top); - const float32x4x2_t vmid = vld2q_f32(in_mid); - const float32x4x2_t vlow = vld2q_f32(in_low); - const float32x4_t vtop_end = vld1q_f32(in_top + 8); - const float32x4_t vmid_end = vld1q_f32(in_mid + 8); - const float32x4_t vlow_end = vld1q_f32(in_low + 8); + out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]); + out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]); - float32x4x2_t out = - { + if(stridex == 3) { - vmulq_f32(vtop.val[0], m0.val[0]), - vdupq_n_f32(0) + out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); } - }; - out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]); - - out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]); - - out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]); - out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]); - out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]); - - return out; -} - -template <> -inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, - const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, - int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); + } - float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset); - out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1); return out; } -/** Perform a 3x3 convolution for 4 consecutive elements on uint8_t when dilation.x() or dilation.y() is not 1. +/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1. * * @param[in] in_top Pointer to the first row of the input. * @param[in] in_mid Pointer to the second row of the input. @@ -334,78 +303,82 @@ inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, c * @param[in] input_offset Input quantization offset. * */ -inline int32x4_t single_convolve_3x3_dilation(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, +template < typename T, REQUIRES_TA(std::is_same::value || std::is_same::value) > +inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, size_t dilation_x, int input_offset) { - const int32x4_t v_input_offset = vdupq_n_s32(input_offset); + using VectorType = typename std::conditional::value, uint8x8x3_t, int8x8x3_t>::type; + using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t; - const uint8x8x3_t vtop = + const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{}); + + const VectorType vtop = { { - vld1_u8(in_top), - vld1_u8(in_top + dilation_x), - vld1_u8(in_top + 2 * dilation_x) + wrapper::vload(in_top), + wrapper::vload(in_top + dilation_x), + wrapper::vload(in_top + 2 * dilation_x) } }; - const uint8x8x3_t vmid = + const VectorType vmid = { { - vld1_u8(in_mid), - vld1_u8(in_mid + dilation_x), - vld1_u8(in_mid + 2 * dilation_x) + wrapper::vload(in_mid), + wrapper::vload(in_mid + dilation_x), + wrapper::vload(in_mid + 2 * dilation_x) } }; - const uint8x8x3_t vlow = + const VectorType vlow = { { - vld1_u8(in_low), - vld1_u8(in_low + dilation_x), - vld1_u8(in_low + 2 * dilation_x) + wrapper::vload(in_low), + wrapper::vload(in_low + dilation_x), + wrapper::vload(in_low + 2 * dilation_x) } }; const int32x4x3_t vtop_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[0])))), //convert from uint8x8 to uint16x8, to uint16x4(lower or bottom half) to int16x4 to int32x4 - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[1])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[2])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[2])))), } }; const int32x4x3_t vmid_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[1])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[2])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[2])))), } }; const int32x4x3_t vlow_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[1])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[2])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[2])))), } }; - int32x4_t out = vmulq_s32(vtop_s32.val[0], m0.val[0]); - out = vmlaq_s32(out, vtop_s32.val[1], m0.val[1]); - out = vmlaq_s32(out, vtop_s32.val[2], m0.val[2]); + int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]); + out = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]); + out = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]); - out = vmlaq_s32(out, vmid_s32.val[0], m1.val[0]); - out = vmlaq_s32(out, vmid_s32.val[1], m1.val[1]); - out = vmlaq_s32(out, vmid_s32.val[2], m1.val[2]); + out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]); + out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]); + out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]); - out = vmlaq_s32(out, vlow_s32.val[0], m2.val[0]); - out = vmlaq_s32(out, vlow_s32.val[1], m2.val[1]); - out = vmlaq_s32(out, vlow_s32.val[2], m2.val[2]); + out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]); + out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]); + out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]); return out; } -/** Perform a 3x3 convolution for 4 consecutive elements on uint8_t when dilation.x() or dilation.y() is not 1. +/** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1. * * @param[in] in_top Pointer to the first row of the input. * @param[in] in_mid Pointer to the second row of the input. @@ -414,52 +387,37 @@ inline int32x4_t single_convolve_3x3_dilation(const uint8_t *in_top, const uint8 * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. * @param[in] dilation_x Dilation, in elements across x. + * @param[in] stridex Stride value in elements across x. * @param[in] input_offset Input quantization offset. * */ -template -int32x4x2_t convolve_3x3_dilation(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - const size_t dilation_x, int input_offset); - -template <> -inline int32x4x2_t convolve_3x3_dilation<1>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - const size_t dilation_x, int input_offset) +template < typename T, REQUIRES_TA(std::is_same::value || std::is_same::value) > +inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, + const size_t dilation_x, unsigned int stridex, int input_offset) { - const int32x4x2_t out = + ARM_COMPUTE_ERROR_ON(stridex > 3); + int32x4x2_t out = { { single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset), single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset) } }; - return out; -} - -template <> -inline int32x4x2_t convolve_3x3_dilation<2>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - const size_t dilation_x, int input_offset) -{ - int32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 0), out.val[0], 2); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 2), out.val[0], 3); - return out; -} -template <> -inline int32x4x2_t convolve_3x3_dilation<3>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - const size_t dilation_x, int input_offset) -{ - int32x4x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 3), out.val[0], 1); + if(stridex == 2) + { + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1); + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2); + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3); + } + else if(stridex == 3) + { + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1); + } return out; } -/** Perform a convolve3x3 on uint8_t +/** Perform a convolve3x3 on 8-bit elements * * @param[in] in_top Pointer to the first row of the input. * @param[in] in_mid Pointer to the second row of the input. @@ -467,123 +425,112 @@ inline int32x4x2_t convolve_3x3_dilation<3>(const uint8_t *in_top, const uint8_t * @param[in] m0 First row of the filter. * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. - * @param[in] input_offset (Optional) Input quantization offset. + * @param[in] stridex Stride value in elements across x. + * @param[in] input_offset Input quantization offset. * */ -template -int32x4x2_t convolve_3x3(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, +template < typename T, REQUIRES_TA(std::is_same::value || std::is_same::value) > +int32x4x2_t convolve_3x3(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - int input_offset); - -template <> -inline int32x4x2_t convolve_3x3<1>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - int input_offset) + unsigned int stridex, int input_offset) { - const int32x4_t v_input_offset = vdupq_n_s32(input_offset); + ARM_COMPUTE_ERROR_ON(stridex > 3); + using VectorType = typename std::conditional::value, uint8x8x2_t, int8x8x2_t>::type; + using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t; - const uint8x8x2_t vtop = + const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{}); + + const VectorType vtop = { { - vld1_u8(in_top), - vld1_u8(in_top + 8) + wrapper::vload(in_top), + wrapper::vload(in_top + 8) } }; - const uint8x8x2_t vmid = + const VectorType vmid = { { - vld1_u8(in_mid), - vld1_u8(in_mid + 8) + wrapper::vload(in_mid), + wrapper::vload(in_mid + 8) } }; - const uint8x8x2_t vlow = + const VectorType vlow = { { - vld1_u8(in_low), - vld1_u8(in_low + 8) + wrapper::vload(in_low), + wrapper::vload(in_low + 8) } }; const int32x4x3_t vtop_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vtop.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vtop.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vtop.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))), } }; const int32x4x3_t vmid_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vmid.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vmid.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vmid.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))), } }; const int32x4x3_t vlow_s32 = { { - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vlow.val[0])))), - vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vlow.val[1])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vlow.val[0])))), + wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))), } }; int32x4x2_t out { { - vdupq_n_s32(0), - vdupq_n_s32(0), + wrapper::vdup_n(0, OutputTagType{}), + wrapper::vdup_n(0, OutputTagType{}), } }; // 0 - out.val[0] = vmlaq_s32(out.val[0], vtop_s32.val[0], m0.val[0]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 1), m0.val[1]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vtop_s32.val[0], vtop_s32.val[1], 2), m0.val[2]); + out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]); - out.val[0] = vmlaq_s32(out.val[0], vmid_s32.val[0], m1.val[0]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 1), m1.val[1]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vmid_s32.val[0], vmid_s32.val[1], 2), m1.val[2]); + out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]); - out.val[0] = vmlaq_s32(out.val[0], vlow_s32.val[0], m2.val[0]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 1), m2.val[1]); - out.val[0] = vmlaq_s32(out.val[0], vextq_s32(vlow_s32.val[0], vlow_s32.val[1], 2), m2.val[2]); + out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]); + out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]); // 1 - out.val[1] = vmlaq_s32(out.val[1], vtop_s32.val[1], m0.val[0]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 1), m0.val[1]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vtop_s32.val[1], vtop_s32.val[2], 2), m0.val[2]); + out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]); - out.val[1] = vmlaq_s32(out.val[1], vmid_s32.val[1], m1.val[0]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 1), m1.val[1]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vmid_s32.val[1], vmid_s32.val[2], 2), m1.val[2]); + out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]); - out.val[1] = vmlaq_s32(out.val[1], vlow_s32.val[1], m2.val[0]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 1), m2.val[1]); - out.val[1] = vmlaq_s32(out.val[1], vextq_s32(vlow_s32.val[1], vlow_s32.val[2], 2), m2.val[2]); + out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]); + out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]); - return out; -} - -template <> -inline int32x4x2_t convolve_3x3<2>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - int input_offset) -{ - int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 0), out.val[0], 2); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[1], 2), out.val[0], 3); - return out; -} - -template <> -inline int32x4x2_t convolve_3x3<3>(const uint8_t *in_top, const uint8_t *in_mid, const uint8_t *in_low, - const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2, - int input_offset) -{ - int32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, input_offset); - out.val[0] = vsetq_lane_s32(vgetq_lane_s32(out.val[0], 3), out.val[0], 1); + if(stridex == 2) + { + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1); + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2); + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3); + } + else if(stridex == 3) + { + out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1); + } return out; } @@ -731,174 +678,143 @@ inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const f * @param[in] m1 Second row of the filter. * @param[in] m2 Third row of the filter. * @param[in] dilation_x Dilation, in elements across x. - * @param[in] input_offset (Optional)Input quantization offset. + * @param[in] stridex Stride value in elements across x. + * @param[in] input_offset (Optional) Input quantization offset. * */ -template -float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - const size_t dilation_x, int input_offset = 0); - -template <> -inline float16x8x2_t convolve_3x3_dilation<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - const size_t dilation_x, int input_offset) +inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, + const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, + const size_t dilation_x, unsigned int stridex, int input_offset = 0) { - const float16x8x2_t out = + float16x8x2_t out = { { single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset), single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset) } }; - return out; -} -template <> -inline float16x8x2_t convolve_3x3_dilation<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - const size_t dilation_x, int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - float16x8x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7); - return out; -} + if(stridex == 2) + { + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7); + } + else if(stridex == 3) + { + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3); + } -template <> -inline float16x8x2_t convolve_3x3_dilation<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - const size_t dilation_x, int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - float16x8x2_t out = convolve_3x3_dilation<1>(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3); return out; } /** Perform a convolve3x3 on float16. * - * @param[in] in_top Pointer to the first row of the input. - * @param[in] in_mid Pointer to the second row of the input. - * @param[in] in_low Pointer to the third row of the input. - * @param[in] m0 First row of the filter. - * @param[in] m1 Second row of the filter. - * @param[in] m2 Third row of the filter. + * @param[in] in_top Pointer to the first row of the input. + * @param[in] in_mid Pointer to the second row of the input. + * @param[in] in_low Pointer to the third row of the input. + * @param[in] m0 First row of the filter. + * @param[in] m1 Second row of the filter. + * @param[in] m2 Third row of the filter. + * @param[in] stridex Stride value in elements across x. + * @param[in] input_offset (Optional) Input quantization offset. * */ -template -float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - int input_offset = 0); - -template <> -inline float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - const float16x8x3_t vtop = - { - { - vld1q_f16(in_top), - vld1q_f16(in_top + 8), - vld1q_f16(in_top + 16) - } - }; - const float16x8x3_t vmid = - { - { - vld1q_f16(in_mid), - vld1q_f16(in_mid + 8), - vld1q_f16(in_mid + 16) - } - }; - const float16x8x3_t vlow = - { - { - vld1q_f16(in_low), - vld1q_f16(in_low + 8), - vld1q_f16(in_low + 16) - } - }; - float16x8x2_t out = - { - { - vmulq_f16(vtop.val[0], m0.val[0]), - vmulq_f16(vtop.val[1], m0.val[0]) - } - }; - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1])); - out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2])); - return out; -} - -template <> -inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - int input_offset) +inline float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, + const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, + unsigned int stridex, int input_offset = 0) { ARM_COMPUTE_UNUSED(input_offset); - const float16x8x2_t vtop = vld2q_f16(in_top); - const float16x8x2_t vmid = vld2q_f16(in_mid); - const float16x8x2_t vlow = vld2q_f16(in_low); - const float16x8_t vtop_end = vld1q_f16(in_top + 16); - const float16x8_t vmid_end = vld1q_f16(in_mid + 16); - const float16x8_t vlow_end = vld1q_f16(in_low + 16); float16x8x2_t out = { { - vmulq_f16(vtop.val[0], m0.val[0]), + vdupq_n_f16(0), vdupq_n_f16(0) } }; - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2])); + if(stridex == 2) + { + const float16x8x2_t vtop = vld2q_f16(in_top); + const float16x8x2_t vmid = vld2q_f16(in_mid); + const float16x8x2_t vlow = vld2q_f16(in_low); + const float16x8_t vtop_end = vld1q_f16(in_top + 16); + const float16x8_t vmid_end = vld1q_f16(in_mid + 16); + const float16x8_t vlow_end = vld1q_f16(in_low + 16); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2])); + out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1])); - out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2])); - return out; -} + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2])); + + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2])); + } + else + { + const float16x8x3_t vtop = + { + { + vld1q_f16(in_top), + vld1q_f16(in_top + 8), + vld1q_f16(in_top + 16) + } + }; + const float16x8x3_t vmid = + { + { + vld1q_f16(in_mid), + vld1q_f16(in_mid + 8), + vld1q_f16(in_mid + 16) + } + }; + const float16x8x3_t vlow = + { + { + vld1q_f16(in_low), + vld1q_f16(in_low + 8), + vld1q_f16(in_low + 16) + } + }; + out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]); + out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]); + + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1])); + out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1])); + out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2])); + + if(stridex == 3) + { + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2); + out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3); + } + } -template <> -inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, - const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2, - int input_offset) -{ - ARM_COMPUTE_UNUSED(input_offset); - float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2); - out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3); return out; } @@ -934,39 +850,20 @@ inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values) /** Get the number of elements processed on 3x3 convolution. * * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution. + * @param[in] stridex Stride value in elements across x. * * @return The number of elements processed. */ -template -int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration); - -template <> -inline int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration) -{ - return num_elems_written_per_iteration; -} - -template <> -inline int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration) -{ - return num_elems_written_per_iteration << 1; -} - -template <> -inline int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration) -{ - return num_elems_written_per_iteration * 3; -} inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex) { switch(stridex) { case 1: - return get_input_num_elems_processed<1>(num_elems_written_per_iteration); + return num_elems_written_per_iteration; case 2: - return get_input_num_elems_processed<2>(num_elems_written_per_iteration); + return num_elems_written_per_iteration << 1; case 3: - return get_input_num_elems_processed<3>(num_elems_written_per_iteration); + return num_elems_written_per_iteration * 3; default: ARM_COMPUTE_ERROR("stridex not supported"); return 0; diff --git a/arm_compute/core/NEON/wrapper/intrinsics/ext.h b/arm_compute/core/NEON/wrapper/intrinsics/ext.h new file mode 100644 index 0000000000..70bc91aaa6 --- /dev/null +++ b/arm_compute/core/NEON/wrapper/intrinsics/ext.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2020 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_WRAPPER_EXT_H +#define ARM_COMPUTE_WRAPPER_EXT_H + +#include + +namespace arm_compute +{ +namespace wrapper +{ +#define VEXT_IMPL(vtype, prefix, postfix, size) \ + inline vtype vext_##size(vtype value_a, vtype value_b) \ + { \ + return prefix##_##postfix(value_a, value_b, size); \ + } + +VEXT_IMPL(int32x4_t, vextq, s32, 1) +VEXT_IMPL(int32x4_t, vextq, s32, 2) + +#undef VEXT_IMPL +} // namespace wrapper +} // namespace arm_compute +#endif /* ARM_COMPUTE_WRAPPER_EXT_H */ diff --git a/arm_compute/core/NEON/wrapper/intrinsics/intrinsics.h b/arm_compute/core/NEON/wrapper/intrinsics/intrinsics.h index f119642b83..3d674757e8 100644 --- a/arm_compute/core/NEON/wrapper/intrinsics/intrinsics.h +++ b/arm_compute/core/NEON/wrapper/intrinsics/intrinsics.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -37,6 +37,7 @@ #include "arm_compute/core/NEON/wrapper/intrinsics/dup_n.h" #include "arm_compute/core/NEON/wrapper/intrinsics/eor.h" #include "arm_compute/core/NEON/wrapper/intrinsics/exp.h" +#include "arm_compute/core/NEON/wrapper/intrinsics/ext.h" #include "arm_compute/core/NEON/wrapper/intrinsics/gethigh.h" #include "arm_compute/core/NEON/wrapper/intrinsics/getlane.h" #include "arm_compute/core/NEON/wrapper/intrinsics/getlow.h" diff --git a/arm_compute/core/NEON/wrapper/intrinsics/reinterpret.h b/arm_compute/core/NEON/wrapper/intrinsics/reinterpret.h index 0cff237b14..579da344a7 100644 --- a/arm_compute/core/NEON/wrapper/intrinsics/reinterpret.h +++ b/arm_compute/core/NEON/wrapper/intrinsics/reinterpret.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -30,14 +30,20 @@ namespace arm_compute { namespace wrapper { -inline int32x4_t vreinterpret_s32(const uint32x4_t &val) -{ - return vreinterpretq_s32_u32(val); -} -inline int32x4_t vreinterpret_s32(const int32x4_t &val) -{ - return val; -} +#define VREINTERPRET_IMPL(ptype, vtype, prefix, postfix1, postfix2) \ + inline ptype vreinterpret(const vtype &a) \ + { \ + return prefix##_##postfix1##_##postfix2(a); \ + } \ + \ + inline ptype vreinterpret(const ptype &a) \ + { \ + return a; \ + } + +VREINTERPRET_IMPL(int16x4_t, uint16x4_t, vreinterpret, s16, u16) + +VREINTERPRET_IMPL(int32x4_t, uint32x4_t, vreinterpretq, s32, u32) } // namespace wrapper } // namespace arm_compute #endif /* ARM_COMPUTE_WRAPPER_REINTERPRET_H */ diff --git a/arm_compute/core/NEON/wrapper/intrinsics/setlane.h b/arm_compute/core/NEON/wrapper/intrinsics/setlane.h index 86a95b8bad..6332f3025e 100644 --- a/arm_compute/core/NEON/wrapper/intrinsics/setlane.h +++ b/arm_compute/core/NEON/wrapper/intrinsics/setlane.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -205,4 +205,4 @@ VSETQLANE_IMPL_8(float16x8_t, float16_t, float16x8_t, f16) #undef VSETQLANE_IMPL_4 } // namespace wrapper } // namespace arm_compute -#endif /* ARM_COMPUTE_WRAPPER_AET_LANE_H */ +#endif /* ARM_COMPUTE_WRAPPER_SET_LANE_H */ -- cgit v1.2.1