aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/direct_convolution_quantized.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/direct_convolution_quantized.cl')
-rw-r--r--src/core/CL/cl_kernels/direct_convolution_quantized.cl140
1 files changed, 138 insertions, 2 deletions
diff --git a/src/core/CL/cl_kernels/direct_convolution_quantized.cl b/src/core/CL/cl_kernels/direct_convolution_quantized.cl
index ed1b7cfe2a..8237fe1700 100644
--- a/src/core/CL/cl_kernels/direct_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/direct_convolution_quantized.cl
@@ -33,7 +33,113 @@
#if defined(DATA_LAYOUT_NHWC)
-#if KERNEL_SIZE == 5
+#if KERNEL_SIZE == 9
+
+#if STRIDE_X == 1
+#define CONVOLUTION1x9(acc, src_ptr, weights_ptr) CONVOLUTION1x9_STRIDE1(acc, src_ptr, weights_ptr)
+#elif STRIDE_X == 2
+#define CONVOLUTION1x9(acc, src_ptr, weights_ptr) CONVOLUTION1x9_STRIDE2(acc, src_ptr, weights_ptr)
+#else /* STRIDE_X not equals 1 or 2 */
+#error "STRIDE_X larger than 2 is not supported"
+#endif /* STRIDE_X */
+
+#define CONVOLUTION1x9_STRIDE1(acc, src_ptr, weights_ptr) \
+ ({ \
+ int8 weights_values0 = 0; \
+ int weights_value1 = 0; \
+ weights_values0.s0 = convert_int(*(weights_ptr + 0 * weights_stride_y)); \
+ weights_values0.s1 = convert_int(*(weights_ptr + 1 * weights_stride_y)); \
+ weights_values0.s2 = convert_int(*(weights_ptr + 2 * weights_stride_y)); \
+ weights_values0.s3 = convert_int(*(weights_ptr + 3 * weights_stride_y)); \
+ weights_values0.s4 = convert_int(*(weights_ptr + 4 * weights_stride_y)); \
+ weights_values0.s5 = convert_int(*(weights_ptr + 5 * weights_stride_y)); \
+ weights_values0.s6 = convert_int(*(weights_ptr + 6 * weights_stride_y)); \
+ weights_values0.s7 = convert_int(*(weights_ptr + 7 * weights_stride_y)); \
+ weights_value1 = convert_int(*(weights_ptr + 8 * weights_stride_y)); \
+ \
+ int8 src0 = 0; \
+ int8 src1 = 0; \
+ src0.s0 = convert_int(*(src_ptr + 0 * weights_stride_y)); \
+ src0.s1 = convert_int(*(src_ptr + 1 * weights_stride_y)); \
+ src0.s2 = convert_int(*(src_ptr + 2 * weights_stride_y)); \
+ src0.s3 = convert_int(*(src_ptr + 3 * weights_stride_y)); \
+ src0.s4 = convert_int(*(src_ptr + 4 * weights_stride_y)); \
+ src0.s5 = convert_int(*(src_ptr + 5 * weights_stride_y)); \
+ src0.s6 = convert_int(*(src_ptr + 6 * weights_stride_y)); \
+ src0.s7 = convert_int(*(src_ptr + 7 * weights_stride_y)); \
+ src1.s0 = convert_int(*(src_ptr + 8 * weights_stride_y)); \
+ src1.s1 = convert_int(*(src_ptr + 9 * weights_stride_y)); \
+ src1.s2 = convert_int(*(src_ptr + 10 * weights_stride_y)); \
+ src1.s3 = convert_int(*(src_ptr + 11 * weights_stride_y)); \
+ src1.s4 = convert_int(*(src_ptr + 12 * weights_stride_y)); \
+ src1.s5 = convert_int(*(src_ptr + 13 * weights_stride_y)); \
+ src1.s6 = convert_int(*(src_ptr + 14 * weights_stride_y)); \
+ src1.s7 = convert_int(*(src_ptr + 15 * weights_stride_y)); \
+ \
+ acc += src0 * (int8)weights_values0.s0; \
+ acc += (int8)(src0.s1234, src0.s567, src1.s0) * (int8)weights_values0.s1; \
+ acc += (int8)(src0.s234, src0.s567, src1.s01) * (int8)weights_values0.s2; \
+ acc += (int8)(src0.s345, src0.s67, src1.s012) * (int8)weights_values0.s3; \
+ acc += (int8)(src0.s4567, src1.s0123) * (int8)weights_values0.s4; \
+ acc += (int8)(src0.s567, src1.s0123, src1.s4) * (int8)weights_values0.s5; \
+ acc += (int8)(src0.s67, src1.s012, src1.s345) * (int8)weights_values0.s6; \
+ acc += (int8)(src0.s7, src1.s0123, src1.s456) * (int8)weights_values0.s7; \
+ acc += src1 * (int8)weights_value1; \
+ })
+
+#define CONVOLUTION1x9_STRIDE2(acc, src_ptr, weights_ptr) \
+ ({ \
+ int8 weights_values0 = 0; \
+ int weights_value1 = 0; \
+ weights_values0.s0 = convert_int(*(weights_ptr + 0 * weights_stride_y)); \
+ weights_values0.s1 = convert_int(*(weights_ptr + 1 * weights_stride_y)); \
+ weights_values0.s2 = convert_int(*(weights_ptr + 2 * weights_stride_y)); \
+ weights_values0.s3 = convert_int(*(weights_ptr + 3 * weights_stride_y)); \
+ weights_values0.s4 = convert_int(*(weights_ptr + 4 * weights_stride_y)); \
+ weights_values0.s5 = convert_int(*(weights_ptr + 5 * weights_stride_y)); \
+ weights_values0.s6 = convert_int(*(weights_ptr + 6 * weights_stride_y)); \
+ weights_values0.s7 = convert_int(*(weights_ptr + 7 * weights_stride_y)); \
+ weights_value1 = convert_int(*(weights_ptr + 8 * weights_stride_y)); \
+ \
+ int16 src0 = 0; \
+ int8 src1 = 0; \
+ src0.s0 = convert_int(*(src_ptr + 0 * weights_stride_y)); \
+ src0.s1 = convert_int(*(src_ptr + 1 * weights_stride_y)); \
+ src0.s2 = convert_int(*(src_ptr + 2 * weights_stride_y)); \
+ src0.s3 = convert_int(*(src_ptr + 3 * weights_stride_y)); \
+ src0.s4 = convert_int(*(src_ptr + 4 * weights_stride_y)); \
+ src0.s5 = convert_int(*(src_ptr + 5 * weights_stride_y)); \
+ src0.s6 = convert_int(*(src_ptr + 6 * weights_stride_y)); \
+ src0.s7 = convert_int(*(src_ptr + 7 * weights_stride_y)); \
+ src0.s8 = convert_int(*(src_ptr + 8 * weights_stride_y)); \
+ src0.s9 = convert_int(*(src_ptr + 9 * weights_stride_y)); \
+ src0.sA = convert_int(*(src_ptr + 10 * weights_stride_y)); \
+ src0.sB = convert_int(*(src_ptr + 11 * weights_stride_y)); \
+ src0.sC = convert_int(*(src_ptr + 12 * weights_stride_y)); \
+ src0.sD = convert_int(*(src_ptr + 13 * weights_stride_y)); \
+ src0.sE = convert_int(*(src_ptr + 14 * weights_stride_y)); \
+ src0.sF = convert_int(*(src_ptr + 15 * weights_stride_y)); \
+ src1.s0 = convert_int(*(src_ptr + 16 * weights_stride_y)); \
+ src1.s1 = convert_int(*(src_ptr + 17 * weights_stride_y)); \
+ src1.s2 = convert_int(*(src_ptr + 18 * weights_stride_y)); \
+ src1.s3 = convert_int(*(src_ptr + 19 * weights_stride_y)); \
+ src1.s4 = convert_int(*(src_ptr + 20 * weights_stride_y)); \
+ src1.s5 = convert_int(*(src_ptr + 21 * weights_stride_y)); \
+ src1.s6 = convert_int(*(src_ptr + 22 * weights_stride_y)); \
+ src1.s7 = convert_int(*(src_ptr + 23 * weights_stride_y)); \
+ \
+ acc += src0.s02468ACE * (int8)weights_values0.s0; \
+ acc += (int8)(src0.s1357, src0.s9BDF) * (int8)weights_values0.s1; \
+ acc += (int8)(src0.s2468, src0.sACE, src1.s0) * (int8)weights_values0.s2; \
+ acc += (int8)(src0.s3579, src0.sBDF, src1.s1) * (int8)weights_values0.s3; \
+ acc += (int8)(src0.s468A, src0.sCE, src1.s02) * (int8)weights_values0.s4; \
+ acc += (int8)(src0.s579, src0.sBDF, src1.s13) * (int8)weights_values0.s5; \
+ acc += (int8)(src0.s68A, src0.sCE, src1.s024) * (int8)weights_values0.s6; \
+ acc += (int8)(src0.s79B, src0.sDF, src1.s135) * (int8)weights_values0.s7; \
+ acc += (int8)(src0.s8AC, src0.sE, src1.s0246) * (int8)weights_value1; \
+ })
+
+#elif KERNEL_SIZE == 5
#if STRIDE_X == 1
#define CONVOLUTION1x5(acc, src_ptr, weights_ptr) CONVOLUTION1x5_STRIDE1(acc, src_ptr, weights_ptr)
@@ -331,7 +437,37 @@ __kernel void direct_convolution_quantized(
for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
{
-#if KERNEL_SIZE == 5
+#if KERNEL_SIZE == 9
+ if(y_coord < 0)
+ {
+ const int start_z = -y_coord;
+ for(int i = start_z; i < 9; ++i)
+ {
+ CONVOLUTION1x9(values0, (src_addr + i * (int)src_stride_z), (weights_addr + i * (int)weights_stride_z));
+ }
+ }
+ else if(y_coord > (SRC_HEIGHT - 9))
+ {
+ // Avoid loading rows beyond the input height
+ const int end_z = SRC_HEIGHT - y_coord;
+ for(int i = 0; i < end_z; ++i)
+ {
+ CONVOLUTION1x9(values0, (src_addr + i * (int)src_stride_z), (weights_addr + i * (int)weights_stride_z));
+ }
+ }
+ else
+ {
+ CONVOLUTION1x9(values0, src_addr, weights_addr);
+ CONVOLUTION1x9(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 5 * (int)src_stride_z), (weights_addr + 5 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 6 * (int)src_stride_z), (weights_addr + 6 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 7 * (int)src_stride_z), (weights_addr + 7 * (int)weights_stride_z));
+ CONVOLUTION1x9(values0, (src_addr + 8 * (int)src_stride_z), (weights_addr + 8 * (int)weights_stride_z));
+ }
+#elif KERNEL_SIZE == 5
#if(PAD_TOP == 1) || (PAD_BOTTM == 1)
if(y_coord < 0) // special case Z = -1 doesn't exists
{