aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/depthwise_convolution_quantized.cl')
-rw-r--r--src/core/CL/cl_kernels/depthwise_convolution_quantized.cl171
1 files changed, 125 insertions, 46 deletions
diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
index 503aa7e837..8d145a038e 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
@@ -53,6 +53,8 @@
#if !(defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8))
+#if DILATION_X == 1
+
#if CONV_STRIDE_X == 1
#define GET_VALUES(first_value, left, middle, right) \
({ \
@@ -85,6 +87,46 @@
})
#endif /* CONV_STRIDE_X */
+#else /* DILATION_X == 1 */
+
+#if CONV_STRIDE_X == 1
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ left = CONVERT(vload8(0, first_value), int8); \
+ middle = CONVERT(vload8(0, first_value + DILATION_X * sizeof(uchar)), int8); \
+ right = CONVERT(vload8(0, first_value + 2 * DILATION_X * sizeof(uchar)), int8); \
+ })
+#elif CONV_STRIDE_X == 2
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ int16 temp0 = CONVERT(vload16(0, first_value), int16); \
+ left = CONVERT(temp0.s02468ace, int8); \
+ \
+ temp0 = CONVERT(vload16(0, first_value + DILATION_X * sizeof(uchar)), int16); \
+ middle = CONVERT(temp0.s02468ace, int8); \
+ \
+ temp0 = CONVERT(vload16(0, first_value + 2 * DILATION_X * sizeof(uchar)), int16); \
+ right = CONVERT(temp0.s02468ace, int8); \
+ })
+#else /* CONV_STRIDE_X */
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ int16 temp0 = CONVERT(vload16(0, first_value), int16); \
+ int8 temp1 = CONVERT(vload8(0, (first_value + 16 * sizeof(uchar))), int8); \
+ left = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8); \
+ \
+ temp0 = CONVERT(vload16(0, first_value + DILATION_X * sizeof(uchar)), int16); \
+ temp1 = CONVERT(vload8(0, (first_value + (16 + DILATION_X) * sizeof(uchar))), int8); \
+ middle = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8); \
+ \
+ temp0 = CONVERT(vload16(0, first_value + 2 * DILATION_X * sizeof(uchar)), int16); \
+ temp1 = CONVERT(vload8(0, (first_value + (16 + 2 * DILATION_X) * sizeof(uchar))), int8); \
+ right = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8); \
+ })
+
+#endif /* CONV_STRIDE_X */
+#endif /* DILATION_X==1 */
+
/** This function computes the depthwise convolution quantized.
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8
@@ -151,10 +193,10 @@ __kernel void dwc_3x3_native_qasymm8_nchw(
int8 values0 = 0;
int8 sum0 = 0;
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
int8 values1 = 0;
int8 sum1 = 0;
-#endif /* CONV_STRIDE_Y */
+#endif /* CONV_STRIDE_Y &&DILATION_Y==1 */
// Row0
int8 left, middle, right;
@@ -168,44 +210,44 @@ __kernel void dwc_3x3_native_qasymm8_nchw(
#endif /* WEIGHTS_OFFSET != 0 */
// Row1
- GET_VALUES(src.ptr + 1 * src_stride_y, left, middle, right);
+ GET_VALUES(src.ptr + DILATION_Y * src_stride_y, left, middle, right);
values0 += left * (int8)(w1.s0);
values0 += middle * (int8)(w1.s1);
values0 += right * (int8)(w1.s2);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += left * (int8)(w0.s0);
values1 += middle * (int8)(w0.s1);
values1 += right * (int8)(w0.s2);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y && DILATION_Y== 1 */
#if WEIGHTS_OFFSET != 0
int8 tmp = left + middle + right;
sum0 += tmp;
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
sum1 += tmp;
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y &&DILATION_Y== 1 */
#endif /* WEIGHTS_OFFSET != 0 */
// Row2
- GET_VALUES(src.ptr + 2 * src_stride_y, left, middle, right);
+ GET_VALUES(src.ptr + 2 * DILATION_Y * src_stride_y, left, middle, right);
values0 += left * (int8)(w2.s0);
values0 += middle * (int8)(w2.s1);
values0 += right * (int8)(w2.s2);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += left * (int8)(w1.s0);
values1 += middle * (int8)(w1.s1);
values1 += right * (int8)(w1.s2);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y &&DILATION_Y == 1 */
#if WEIGHTS_OFFSET != 0
tmp = left + middle + right;
sum0 += tmp;
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
sum1 += tmp;
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */
#endif /* WEIGHTS_OFFSET != 0 */
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
// Row3
GET_VALUES(src.ptr + 3 * src_stride_y, left, middle, right);
values1 += left * (int8)(w2.s0);
@@ -215,20 +257,20 @@ __kernel void dwc_3x3_native_qasymm8_nchw(
#if WEIGHTS_OFFSET != 0
sum1 += left + middle + right;
#endif /* WEIGHTS_OFFSET != 0 */
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y && DILATION_Y == 1 */
#if defined(HAS_BIAS)
values0 += (int8)(bias_value);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += (int8)(bias_value);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y & &DILATION_Y == 1 */
#endif //defined(HAS_BIAS)
#if WEIGHTS_OFFSET != 0
values0 += sum0 * (int8)(WEIGHTS_OFFSET);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += sum1 * (int8)(WEIGHTS_OFFSET);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */
#endif /* WEIGHTS_OFFSET != 0 */
#if INPUT_OFFSET != 0
@@ -236,16 +278,16 @@ __kernel void dwc_3x3_native_qasymm8_nchw(
ushort3 tmp_we = convert_ushort3(w0) + convert_ushort3(w1) + convert_ushort3(w2);
sum_weights += tmp_we.s0 + tmp_we.s1 + tmp_we.s2;
values0 += sum_weights * (int8)(INPUT_OFFSET);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += sum_weights * (int8)(INPUT_OFFSET);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */
#endif /* INPUT_OFFSET != 0 */
#if K_OFFSET != 0
values0 += (int8)(K_OFFSET);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += (int8)(K_OFFSET);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/
#endif /* K_OFFSET != 0 */
#if defined(REAL_MULTIPLIER)
@@ -264,7 +306,7 @@ __kernel void dwc_3x3_native_qasymm8_nchw(
res0 = min(res0, (uchar8)255);
vstore8(ACTIVATION_FUNC(res0), 0, dst.ptr);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
#if defined(REAL_MULTIPLIER)
values1 = CONVERT(round(CONVERT(values1, float8) * (float8)REAL_MULTIPLIER), int8);
@@ -281,11 +323,11 @@ __kernel void dwc_3x3_native_qasymm8_nchw(
res1 = min(res1, (uchar8)255);
vstore8(ACTIVATION_FUNC(res1), 0, dst.ptr + dst_stride_y);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/
}
#else // !(defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8))
-
+#if DILATION_X == 1
#if CONV_STRIDE_X == 1
#define GET_VALUES(first_value, left, middle, right) \
({ \
@@ -317,6 +359,43 @@ __kernel void dwc_3x3_native_qasymm8_nchw(
right = (uchar8)(temp0.s258b, temp0.se, temp1.s147); \
})
#endif /* CONV_STRIDE_X */
+#else /*DILATION_X==1*/
+
+#if CONV_STRIDE_X == 1
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ left = vload8(0, first_value); \
+ middle = vload8(0, first_value + DILATION_X * sizeof(uchar)); \
+ right = vload8(0, first_value + 2 * DILATION_X * sizeof(uchar)); \
+ })
+#elif CONV_STRIDE_X == 2
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ uchar16 temp0 = vload16(0, first_value); \
+ left = temp0.s02468ace; \
+ temp0 = vload16(0, first_value + DILATION_X * sizeof(uchar)); \
+ middle = temp0.s02468ace; \
+ temp0 = vload16(0, first_value + 2 * DILATION_X * sizeof(uchar)); \
+ right = temp0.s02468ace; \
+ })
+#else /* CONV_STRIDE_X */
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ uchar16 temp0 = vload16(0, first_value); \
+ uchar8 temp1 = vload8(0, (first_value + 16 * sizeof(uchar))); \
+ left = (uchar8)(temp0.s0369, temp0.scf, temp1.s25); \
+ \
+ temp0 = vload16(0, first_value + DILATION_X * sizeof(uchar)); \
+ temp1 = vload8(0, (first_value + (16 + DILATION_X) * sizeof(uchar))); \
+ middle = (uchar8)(temp0.s0369, temp0.scf, temp1.s25); \
+ \
+ temp0 = vload16(0, first_value + 2 * DILATION_X * sizeof(uchar)); \
+ temp1 = vload8(0, (first_value + (16 + 2 * DILATION_X) * sizeof(uchar))); \
+ right = (uchar8)(temp0.s0369, temp0.scf, temp1.s25); \
+ })
+
+#endif /* CONV_STRIDE_X */
+#endif /*DILATION_X==1*/
/** This function computes the depthwise convolution quantized using dot product when the data layout is NCHW.
*
* @param[in] src_ptr Pointer to the source tensor. Supported data types: QASYMM8
@@ -389,8 +468,8 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw(
int8 sum0 = 0;
GET_VALUES(src.ptr + 0 * src_stride_y, left0, middle0, right0);
- GET_VALUES(src.ptr + 1 * src_stride_y, left1, middle1, right1);
- GET_VALUES(src.ptr + 2 * src_stride_y, left2, middle2, right2);
+ GET_VALUES(src.ptr + DILATION_Y * src_stride_y, left1, middle1, right1);
+ GET_VALUES(src.ptr + 2 * DILATION_Y * src_stride_y, left2, middle2, right2);
#if WEIGHTS_OFFSET != 0
sum0 += convert_int8(left0) + convert_int8(middle0) + convert_int8(right0);
@@ -398,7 +477,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw(
sum0 += convert_int8(left2) + convert_int8(middle2) + convert_int8(right2);
#endif /* WEIGHTS_OFFSET != 0 */
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
// If conv_stride_y is equals to 1, we compute two output rows
uchar8 left3, middle3, right3;
@@ -412,7 +491,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw(
sum1 += convert_int8(left2) + convert_int8(middle2) + convert_int8(right2);
sum1 += convert_int8(left3) + convert_int8(middle3) + convert_int8(right3);
#endif /* WEIGHTS_OFFSET != 0 */
-#endif // CONV_STRIDE_Y == 1
+#endif // CONV_STRIDE_Y == 1 && DILATION_Y==1
ARM_DOT((uchar4)(left0.s0, middle0.s0, right0.s0, left1.s0), (uchar4)(w0.s0, w0.s1, w0.s2, w1.s0), values0.s0);
ARM_DOT((uchar4)(middle1.s0, right1.s0, left2.s0, middle2.s0), (uchar4)(w1.s1, w1.s2, w2.s0, w2.s1), values0.s0);
@@ -446,7 +525,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw(
ARM_DOT((uchar4)(middle1.s7, right1.s7, left2.s7, middle2.s7), (uchar4)(w1.s1, w1.s2, w2.s0, w2.s1), values0.s7);
values0.s7 += right2.s7 * w2.s2;
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
ARM_DOT((uchar4)(left1.s0, middle1.s0, right1.s0, left2.s0), (uchar4)(w0.s0, w0.s1, w0.s2, w1.s0), values1.s0);
ARM_DOT((uchar4)(middle2.s0, right2.s0, left3.s0, middle3.s0), (uchar4)(w1.s1, w1.s2, w2.s0, w2.s1), values1.s0);
values1.s0 += right3.s0 * w2.s2;
@@ -478,20 +557,20 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw(
ARM_DOT((uchar4)(left1.s7, middle1.s7, right1.s7, left2.s7), (uchar4)(w0.s0, w0.s1, w0.s2, w1.s0), values1.s7);
ARM_DOT((uchar4)(middle2.s7, right2.s7, left3.s7, middle3.s7), (uchar4)(w1.s1, w1.s2, w2.s0, w2.s1), values1.s7);
values1.s7 += right3.s7 * w2.s2;
-#endif // CONV_STRIDE_Y == 1
+#endif // CONV_STRIDE_Y == 1 && DILATION_Y==1
#if defined(HAS_BIAS)
values0 += (int8)(bias_value);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += (int8)(bias_value);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */
#endif //defined(HAS_BIAS)
#if WEIGHTS_OFFSET != 0
values0 += sum0 * (int8)(WEIGHTS_OFFSET);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += sum1 * (int8)(WEIGHTS_OFFSET);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1 */
#endif /* WEIGHTS_OFFSET != 0 */
#if INPUT_OFFSET != 0
@@ -499,16 +578,16 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw(
ushort3 tmp_we = convert_ushort3(w0) + convert_ushort3(w1) + convert_ushort3(w2);
sum_weights += tmp_we.s0 + tmp_we.s1 + tmp_we.s2;
values0 += sum_weights * (int8)(INPUT_OFFSET);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += sum_weights * (int8)(INPUT_OFFSET);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/
#endif /* INPUT_OFFSET != 0 */
#if K_OFFSET != 0
values0 += (int8)(K_OFFSET);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
values1 += (int8)(K_OFFSET);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/
#endif /* K_OFFSET != 0 */
#if defined(REAL_MULTIPLIER)
@@ -527,7 +606,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw(
res0 = min(res0, (uchar8)255);
vstore8(ACTIVATION_FUNC(res0), 0, dst.ptr);
-#if CONV_STRIDE_Y == 1
+#if CONV_STRIDE_Y == 1 && DILATION_Y == 1
#if defined(REAL_MULTIPLIER)
@@ -545,7 +624,7 @@ __kernel void dwc_3x3_native_qasymm8_dot8_nchw(
res1 = min(res1, (uchar8)255);
vstore8(ACTIVATION_FUNC(res1), 0, dst.ptr + dst_stride_y);
-#endif /* CONV_STRIDE_Y == 1 */
+#endif /* CONV_STRIDE_Y == 1 && DILATION_Y==1*/
}
#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
@@ -669,7 +748,7 @@ __kernel void dwc_3x3_reshaped_qasymm8_nhwc(
int z_coord = 0;
int4 offset = 0;
- int4 y_coord = ((int4)(y * CONV_STRIDE_X) + (int4)(0, 1, 2, 3)) - (int)CONV_PAD_LEFT;
+ int4 y_coord = ((int4)(y * CONV_STRIDE_X) + (int4)(0, DILATION_X * 1, DILATION_X * 2, DILATION_X * 3)) - (int)CONV_PAD_LEFT;
// Only for y = 0 we can have a negative coordinate. If so, we convert it to SRC_DIM_1
y_coord.s0 = min((uint)y_coord.s0, (uint)SRC_DIM_1);
@@ -720,16 +799,16 @@ __kernel void dwc_3x3_reshaped_qasymm8_nhwc(
// z == 1
// z_coord can be only negative for z = 0 so we do not need to clamp it
// Moreover z_coord cannot be out-of-bound for z = 1 so we do not need to clamp the offset
- z_coord = z * (int)CONV_STRIDE_Y - (int)CONV_PAD_TOP + 1;
+ z_coord = z * (int)CONV_STRIDE_Y - (int)CONV_PAD_TOP + DILATION_Y;
offset = y_offset + (int4)(z_coord * src_stride_z);
VEC_UCHAR values3 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
VEC_UCHAR values4 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);
VEC_UCHAR values5 = VLOAD(VEC_SIZE)(0, src_addr + offset.s2);
// z == 2
- // After z = 1 we can simply add src_stride_z to offset without updating z_coord
- // However offset can be out-of-bound so we need to check if it is greater than max_offset
- offset += (int4)src_stride_z;
+ // Offset can be out-of-bound so we need to check if it is greater than max_offset
+ z_coord = z * (int)CONV_STRIDE_Y - (int)CONV_PAD_TOP + DILATION_Y * 2;
+ offset = y_offset + (int4)(z_coord * src_stride_z);
offset = min(offset, (int4)max_offset);
VEC_UCHAR values6 = VLOAD(VEC_SIZE)(0, src_addr + offset.s0);
VEC_UCHAR values7 = VLOAD(VEC_SIZE)(0, src_addr + offset.s1);