From e73686ac797be2d19cd9bed26d690e1431e3d848 Mon Sep 17 00:00:00 2001 From: Usama Arif Date: Mon, 8 Apr 2019 17:30:48 +0100 Subject: COMPMID-2047: Add support for dilation in CLDepthwiseConvolution. Change-Id: I3106aa34bd168985a56791613d95072756be6e9b Signed-off-by: Usama Arif Reviewed-on: https://review.mlplatform.org/c/958 Comments-Addressed: Arm Jenkins Reviewed-by: Pablo Marquez Tested-by: Arm Jenkins --- .../cl_kernels/depthwise_convolution_quantized.cl | 171 +++++++++++++++------ 1 file changed, 125 insertions(+), 46 deletions(-) (limited to 'src/core/CL/cl_kernels/depthwise_convolution_quantized.cl') diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl index 503aa7e837..8d145a038e 100644 --- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl +++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl @@ -53,6 +53,8 @@ #if !(defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)) +#if DILATION_X == 1 + #if CONV_STRIDE_X == 1 #define GET_VALUES(first_value, left, middle, right) \ ({ \ @@ -85,6 +87,46 @@ }) #endif /* CONV_STRIDE_X */ +#else /* DILATION_X == 1 */ + +#if CONV_STRIDE_X == 1 +#define GET_VALUES(first_value, left, middle, right) \ + ({ \ + left = CONVERT(vload8(0, first_value), int8); \ + middle = CONVERT(vload8(0, first_value + DILATION_X * sizeof(uchar)), int8); \ + right = CONVERT(vload8(0, first_value + 2 * DILATION_X * sizeof(uchar)), int8); \ + }) +#elif CONV_STRIDE_X == 2 +#define GET_VALUES(first_value, left, middle, right) \ + ({ \ + int16 temp0 = CONVERT(vload16(0, first_value), int16); \ + left = CONVERT(temp0.s02468ace, int8); \ + \ + temp0 = CONVERT(vload16(0, first_value + DILATION_X * sizeof(uchar)), int16); \ + middle = CONVERT(temp0.s02468ace, int8); \ + \ + temp0 = CONVERT(vload16(0, first_value + 2 * DILATION_X * sizeof(uchar)), int16); \ + right = CONVERT(temp0.s02468ace, int8); \ + }) +#else /* CONV_STRIDE_X */ +#define GET_VALUES(first_value, left, middle, right) \ + ({ \ + int16 temp0 = CONVERT(vload16(0, first_value), int16); \ + int8 temp1 = CONVERT(vload8(0, (first_value + 16 * sizeof(uchar))), int8); \ + left = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8); \ + \ + temp0 = CONVERT(vload16(0, first_value + DILATION_X * sizeof(uchar)), int16); \ + temp1 = CONVERT(vload8(0, (first_value + (16 + DILATION_X) * sizeof(uchar))), int8); \ + middle = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8); \ + \ + temp0 = CONVERT(vload16(0, first_value + 2 * DILATION_X * sizeof(uchar)), int16); \ + temp1 = CONVERT(vload8(0, (first_value + (16 + 2 * DILATION_X) * sizeof(uchar))), int8); \ + right = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8); \ + }) + +#endif /* CONV_STRIDE_X */ +#endif /* DILATION_X==1 */ + /** This function computes the depthwise convolution quantized. * * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8 @@ -151,10 +193,10 @@ __kernel void dwc_3x3_native_qasymm8_nchw( int8 values0 = 0; int8 sum0 = 0; -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 int8 values1 = 0; int8 sum1 = 0; -#endif /* CONV_STRIDE_Y */ +#endif /* CONV_STRIDE_Y &&DILATION_Y==1 */ // Row0 int8 left, middle, right; @@ -168,44 +210,44 @@ __kernel void dwc_3x3_native_qasymm8_nchw( #endif /* WEIGHTS_OFFSET != 0 */ // Row1 - GET_VALUES(src.ptr + 1 * src_stride_y, left, middle, right); + GET_VALUES(src.ptr + DILATION_Y * src_stride_y, left, middle, right); values0 += left * (int8)(w1.s0); values0 += middle * (int8)(w1.s1); values0 += right * (int8)(w1.s2); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += left * (int8)(w0.s0); values1 += middle * (int8)(w0.s1); values1 += right * (int8)(w0.s2); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y && DILATION_Y== 1 */ #if WEIGHTS_OFFSET != 0 int8 tmp = left + middle + right; sum0 += tmp; -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 sum1 += tmp; -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y &&DILATION_Y== 1 */ #endif /* WEIGHTS_OFFSET != 0 */ // Row2 - GET_VALUES(src.ptr + 2 * src_stride_y, left, middle, right); + GET_VALUES(src.ptr + 2 * DILATION_Y * src_stride_y, left, middle, right); values0 += left * (int8)(w2.s0); values0 += middle * (int8)(w2.s1); values0 += right * (int8)(w2.s2); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += left * (int8)(w1.s0); values1 += middle * (int8)(w1.s1); values1 += right * (int8)(w1.s2); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y &&DILATION_Y == 1 */ #if WEIGHTS_OFFSET != 0 tmp = left + middle + right; sum0 += tmp; -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 sum1 += tmp; -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */ #endif /* WEIGHTS_OFFSET != 0 */ -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 // Row3 GET_VALUES(src.ptr + 3 * src_stride_y, left, middle, right); values1 += left * (int8)(w2.s0); @@ -215,20 +257,20 @@ __kernel void dwc_3x3_native_qasymm8_nchw( #if WEIGHTS_OFFSET != 0 sum1 += left + middle + right; #endif /* WEIGHTS_OFFSET != 0 */ -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y && DILATION_Y == 1 */ #if defined(HAS_BIAS) values0 += (int8)(bias_value); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += (int8)(bias_value); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y & &DILATION_Y == 1 */ #endif //defined(HAS_BIAS) #if WEIGHTS_OFFSET != 0 values0 += sum0 * (int8)(WEIGHTS_OFFSET); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += sum1 * (int8)(WEIGHTS_OFFSET); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */ #endif /* WEIGHTS_OFFSET != 0 */ #if INPUT_OFFSET != 0 @@ -236,16 +278,16 @@ __kernel void dwc_3x3_native_qasymm8_nchw( ushort3 tmp_we = convert_ushort3(w0) + convert_ushort3(w1) + convert_ushort3(w2); sum_weights += tmp_we.s0 + tmp_we.s1 + tmp_we.s2; values0 += sum_weights * (int8)(INPUT_OFFSET); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += sum_weights * (int8)(INPUT_OFFSET); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */ #endif /* INPUT_OFFSET != 0 */ #if K_OFFSET != 0 values0 += (int8)(K_OFFSET); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += (int8)(K_OFFSET); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/ #endif /* K_OFFSET != 0 */ #if defined(REAL_MULTIPLIER) @@ -264,7 +306,7 @@ __kernel void dwc_3x3_native_qasymm8_nchw( res0 = min(res0, (uchar8)255); vstore8(ACTIVATION_FUNC(res0), 0, dst.ptr); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 #if defined(REAL_MULTIPLIER) values1 = CONVERT(round(CONVERT(values1, float8) * (float8)REAL_MULTIPLIER), int8); @@ -281,11 +323,11 @@ __kernel void dwc_3x3_native_qasymm8_nchw( res1 = min(res1, (uchar8)255); vstore8(ACTIVATION_FUNC(res1), 0, dst.ptr + dst_stride_y); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/ } #else // !(defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)) - +#if DILATION_X == 1 #if CONV_STRIDE_X == 1 #define GET_VALUES(first_value, left, middle, right) \ ({ \ @@ -317,6 +359,43 @@ __kernel void dwc_3x3_native_qasymm8_nchw( right = (uchar8)(temp0.s258b, temp0.se, temp1.s147); \ }) #endif /* CONV_STRIDE_X */ +#else /*DILATION_X==1*/ + +#if CONV_STRIDE_X == 1 +#define GET_VALUES(first_value, left, middle, right) \ + ({ \ + left = vload8(0, first_value); \ + middle = vload8(0, first_value + DILATION_X * sizeof(uchar)); \ + right = vload8(0, first_value + 2 * DILATION_X * sizeof(uchar)); \ + }) +#elif CONV_STRIDE_X == 2 +#define GET_VALUES(first_value, left, middle, right) \ + ({ \ + uchar16 temp0 = vload16(0, first_value); \ + left = temp0.s02468ace; \ + temp0 = vload16(0, first_value + DILATION_X * sizeof(uchar)); \ + middle = temp0.s02468ace; \ + temp0 = vload16(0, first_value + 2 * DILATION_X * sizeof(uchar)); \ + right = temp0.s02468ace; \ + }) +#else /* CONV_STRIDE_X */ +#define GET_VALUES(first_value, left, middle, right) \ + ({ \ + uchar16 temp0 = vload16(0, first_value); \ + uchar8 temp1 = vload8(0, (first_value + 16 * sizeof(uchar))); \ + left = (uchar8)(temp0.s0369, temp0.scf, temp1.s25); \ + \ + temp0 = vload16(0, first_value + DILATION_X * sizeof(uchar)); \ + temp1 = vload8(0, (first_value + (16 + DILATION_X) * sizeof(uchar))); \ + middle = (uchar8)(temp0.s0369, temp0.scf, temp1.s25); \ + \ + temp0 = vload16(0, first_value + 2 * DILATION_X * sizeof(uchar)); \ + temp1 = vload8(0, (first_value + (16 + 2 * DILATION_X) * sizeof(uchar))); \ + right = (uchar8)(temp0.s0369, temp0.scf, temp1.s25); \ + }) + +#endif /* CONV_STRIDE_X */ +#endif /*DILATION_X==1*/ /** This function computes the depthwise convolution quantized using dot product when the data layout is NCHW. * * @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8 @@ -389,8 +468,8 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw( int8 sum0 = 0; GET_VALUES(src.ptr + 0 * src_stride_y, left0, middle0, right0); - GET_VALUES(src.ptr + 1 * src_stride_y, left1, middle1, right1); - GET_VALUES(src.ptr + 2 * src_stride_y, left2, middle2, right2); + GET_VALUES(src.ptr + DILATION_Y * src_stride_y, left1, middle1, right1); + GET_VALUES(src.ptr + 2 * DILATION_Y * src_stride_y, left2, middle2, right2); #if WEIGHTS_OFFSET != 0 sum0 += convert_int8(left0) + convert_int8(middle0) + convert_int8(right0); @@ -398,7 +477,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw( sum0 += convert_int8(left2) + convert_int8(middle2) + convert_int8(right2); #endif /* WEIGHTS_OFFSET != 0 */ -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 // If conv_stride_y is equals to 1, we compute two output rows uchar8 left3, middle3, right3; @@ -412,7 +491,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw( sum1 += convert_int8(left2) + convert_int8(middle2) + convert_int8(right2); sum1 += convert_int8(left3) + convert_int8(middle3) + convert_int8(right3); #endif /* WEIGHTS_OFFSET != 0 */ -#endif // CONV_STRIDE_Y == 1 +#endif // CONV_STRIDE_Y == 1 && DILATION_Y==1 ARM_DOT((uchar4)(left0.s0, middle0.s0, right0.s0, left1.s0), (uchar4)(w0.s0, w0.s1, w0.s2, w1.s0), values0.s0); ARM_DOT((uchar4)(middle1.s0, right1.s0, left2.s0, middle2.s0), (uchar4)(w1.s1, w1.s2, w2.s0, w2.s1), values0.s0); @@ -446,7 +525,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw( ARM_DOT((uchar4)(middle1.s7, right1.s7, left2.s7, middle2.s7), (uchar4)(w1.s1, w1.s2, w2.s0, w2.s1), values0.s7); values0.s7 += right2.s7 * w2.s2; -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 ARM_DOT((uchar4)(left1.s0, middle1.s0, right1.s0, left2.s0), (uchar4)(w0.s0, w0.s1, w0.s2, w1.s0), values1.s0); ARM_DOT((uchar4)(middle2.s0, right2.s0, left3.s0, middle3.s0), (uchar4)(w1.s1, w1.s2, w2.s0, w2.s1), values1.s0); values1.s0 += right3.s0 * w2.s2; @@ -478,20 +557,20 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw( ARM_DOT((uchar4)(left1.s7, middle1.s7, right1.s7, left2.s7), (uchar4)(w0.s0, w0.s1, w0.s2, w1.s0), values1.s7); ARM_DOT((uchar4)(middle2.s7, right2.s7, left3.s7, middle3.s7), (uchar4)(w1.s1, w1.s2, w2.s0, w2.s1), values1.s7); values1.s7 += right3.s7 * w2.s2; -#endif // CONV_STRIDE_Y == 1 +#endif // CONV_STRIDE_Y == 1 && DILATION_Y==1 #if defined(HAS_BIAS) values0 += (int8)(bias_value); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += (int8)(bias_value); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */ #endif //defined(HAS_BIAS) #if WEIGHTS_OFFSET != 0 values0 += sum0 * (int8)(WEIGHTS_OFFSET); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += sum1 * (int8)(WEIGHTS_OFFSET); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */ #endif /* WEIGHTS_OFFSET != 0 */ #if INPUT_OFFSET != 0 @@ -499,16 +578,16 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw( ushort3 tmp_we = convert_ushort3(w0) + convert_ushort3(w1) + convert_ushort3(w2); sum_weights += tmp_we.s0 + tmp_we.s1 + tmp_we.s2; values0 += sum_weights * (int8)(INPUT_OFFSET); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += sum_weights * (int8)(INPUT_OFFSET); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/ #endif /* INPUT_OFFSET != 0 */ #if K_OFFSET != 0 values0 += (int8)(K_OFFSET); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 values1 += (int8)(K_OFFSET); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/ #endif /* K_OFFSET != 0 */ #if defined(REAL_MULTIPLIER) @@ -527,7 +606,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw( res0 = min(res0, (uchar8)255); vstore8(ACTIVATION_FUNC(res0), 0, dst.ptr); -#if CONV_STRIDE_Y == 1 +#if CONV_STRIDE_Y == 1 && DILATION_Y == 1 #if defined(REAL_MULTIPLIER) @@ -545,7 +624,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw( res1 = min(res1, (uchar8)255); vstore8(ACTIVATION_FUNC(res1), 0, dst.ptr + dst_stride_y); -#endif /* CONV_STRIDE_Y == 1 */ +#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/ } #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) @@ -669,7 +748,7 @@ __kernel void dwc_3x3_reshaped_qasymm8_nhwc( int z_coord = 0; int4 offset = 0; - int4 y_coord = ((int4)(y * CONV_STRIDE_X) + (int4)(0, 1, 2, 3)) - (int)CONV_PAD_LEFT; + int4 y_coord = ((int4)(y * CONV_STRIDE_X) + (int4)(0, DILATION_X * 1, DILATION_X * 2, DILATION_X * 3)) - (int)CONV_PAD_LEFT; // Only for y = 0 we can have a negative coordinate. If so, we convert it to SRC_DIM_1 y_coord.s0 = min((uint)y_coord.s0, (uint)SRC_DIM_1); @@ -720,16 +799,16 @@ __kernel void dwc_3x3_reshaped_qasymm8_nhwc( // z == 1 // z_coord can be only negative for z = 0 so we do not need to clamp it // Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset - z_coord = z * (int)CONV_STRIDE_Y - (int)CONV_PAD_TOP + 1; + z_coord = z * (int)CONV_STRIDE_Y - (int)CONV_PAD_TOP + DILATION_Y; offset = y_offset + (int4)(z_coord * src_stride_z); VEC_UCHAR values3 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0); VEC_UCHAR values4 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1); VEC_UCHAR values5 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2); // z == 2 - // After z = 1 we can simply add src_stride_z to offset without updating z_coord - // However offset can be out-of-bound so we need to check if it is greater than max_offset - offset += (int4)src_stride_z; + // Offset can be out-of-bound so we need to check if it is greater than max_offset + z_coord = z * (int)CONV_STRIDE_Y - (int)CONV_PAD_TOP + DILATION_Y * 2; + offset = y_offset + (int4)(z_coord * src_stride_z); offset = min(offset, (int4)max_offset); VEC_UCHAR values6 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0); VEC_UCHAR values7 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1); -- cgit v1.2.1