aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels')
-rw-r--r--src/core/CL/cl_kernels/direct_convolution_quantized.cl (renamed from src/core/CL/cl_kernels/direct_convolution_1x1_3x3_5x5_quantized.cl)63
1 files changed, 58 insertions, 5 deletions
diff --git a/src/core/CL/cl_kernels/direct_convolution_1x1_3x3_5x5_quantized.cl b/src/core/CL/cl_kernels/direct_convolution_quantized.cl
index 5ad9afb23c..1182428cd5 100644
--- a/src/core/CL/cl_kernels/direct_convolution_1x1_3x3_5x5_quantized.cl
+++ b/src/core/CL/cl_kernels/direct_convolution_quantized.cl
@@ -27,7 +27,50 @@
#if defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
-#if KERNEL_SIZE == 5
+#if KERNEL_SIZE == 9
+
+#if STRIDE_X == 1
+#define CONVOLUTION1x9(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x9_STRIDE1(acc, src_row_ptr, weights_row_ptr)
+#elif STRIDE_X == 2
+#define CONVOLUTION1x9(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x9_STRIDE2(acc, src_row_ptr, weights_row_ptr)
+#else /* STRIDE_X not equals 1 or 2 */
+#error "STRIDE_X larger than 2 is not supported"
+#endif /* STRIDE_X */
+
+#define CONVOLUTION1x9_STRIDE1(acc, src_row_ptr, weights_row_ptr) \
+ ({ \
+ int8 weights_values0 = convert_int8(vload8(0, weights_row_ptr)); \
+ int weights_value1 = convert_int(*(weights_row_ptr + 8)); \
+ int16 src0 = convert_int16(vload16(0, src_row_ptr)); \
+ acc += (src0.lo + input_offset) * ((int8)weights_values0.s0 + weight_offset); \
+ acc += ((int8)(src0.s1234, src0.s5678) + input_offset) * ((int8)weights_values0.s1 + weight_offset); \
+ acc += ((int8)(src0.s2345, src0.s6789) + input_offset) * ((int8)weights_values0.s2 + weight_offset); \
+ acc += ((int8)(src0.s3456, src0.s789A) + input_offset) * ((int8)weights_values0.s3 + weight_offset); \
+ acc += ((int8)(src0.s4567, src0.s89AB) + input_offset) * ((int8)weights_values0.s4 + weight_offset); \
+ acc += ((int8)(src0.s5678, src0.s9ABC) + input_offset) * ((int8)weights_values0.s5 + weight_offset); \
+ acc += ((int8)(src0.s6789, src0.sABCD) + input_offset) * ((int8)weights_values0.s6 + weight_offset); \
+ acc += ((int8)(src0.s789A, src0.sBCDE) + input_offset) * ((int8)weights_values0.s7 + weight_offset); \
+ acc += ((int8)(src0.s89AB, src0.sCDEF) + input_offset) * ((int8)weights_value1 + weight_offset); \
+ })
+
+#define CONVOLUTION1x9_STRIDE2(acc, src_row_ptr, weights_row_ptr) \
+ ({ \
+ int8 weights_values0 = convert_int8(vload8(0, weights_row_ptr)); \
+ int weights_value1 = convert_int(*(weights_row_ptr + 8)); \
+ int16 src0 = convert_int16(vload16(0, src_row_ptr)); \
+ int8 src1 = convert_int8(vload8(0, src_row_ptr + 16)); \
+ acc += (src0.even + input_offset) * ((int8)weights_values0.s0 + weight_offset); \
+ acc += ((int8)(src0.s1357, src0.s9BDF) + input_offset) * ((int8)weights_values0.s1 + weight_offset); \
+ acc += ((int8)(src0.s2468, src0.sACE, src1.s0) + input_offset) * ((int8)weights_values0.s2 + weight_offset); \
+ acc += ((int8)(src0.s3579, src0.sBDF, src1.s1) + input_offset) * ((int8)weights_values0.s3 + weight_offset); \
+ acc += ((int8)(src0.s468A, src0.sCE, src1.s02) + input_offset) * ((int8)weights_values0.s4 + weight_offset); \
+ acc += ((int8)(src0.s579B, src0.sDF, src1.s13) + input_offset) * ((int8)weights_values0.s5 + weight_offset); \
+ acc += ((int8)(src0.s68AC, src0.sE, src1.s024) + input_offset) * ((int8)weights_values0.s6 + weight_offset); \
+ acc += ((int8)(src0.s79BD, src0.sF, src1.s135) + input_offset) * ((int8)weights_values0.s7 + weight_offset); \
+ acc += ((int8)(src0.s8ACE, src1.s0246) + input_offset) * ((int8)weights_value1 + weight_offset); \
+ })
+
+#elif KERNEL_SIZE == 5
#if STRIDE_X == 1
#define CONVOLUTION1x5(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x5_STRIDE1(acc, src_row_ptr, weights_row_ptr)
@@ -142,8 +185,8 @@ inline uchar8 extract_input_stride3(__global const uchar *input_pixel)
return (uchar8)(temp1.s0369, temp2.s0369);
}
-#else /* KERNEL_SIZE not equals 1, 3 or 5 */
-#error "Only kernel sizes 1, 3 and 5 are supported"
+#else /* KERNEL_SIZE not equals 1, 3 , 5, 9 */
+#error "Only kernel sizes 1, 3, 5 and 9 are supported"
#endif /* KERNEL_SIZE */
/** This kernel performs a direct convolution to convolve the low three dimensions.
@@ -187,7 +230,7 @@ inline uchar8 extract_input_stride3(__global const uchar *input_pixel)
* @param[in] output_multiplier Output integer multiplier quantization parameter
* @param[in] output_shift Output integer shift quantization parameter
*/
-__kernel void direct_convolution_1x1_3x3_5x5_quantized(
+__kernel void direct_convolution_quantized(
TENSOR3D_DECLARATION(src),
TENSOR3D_DECLARATION(dst),
TENSOR3D_DECLARATION(weights),
@@ -215,7 +258,17 @@ __kernel void direct_convolution_1x1_3x3_5x5_quantized(
for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
{
-#if KERNEL_SIZE == 5
+#if KERNEL_SIZE == 9
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 0 * src_stride_y), (__global uchar *)(weights_addr + 0 * weights_stride_y));
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 1 * src_stride_y), (__global uchar *)(weights_addr + 1 * weights_stride_y));
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 2 * src_stride_y), (__global uchar *)(weights_addr + 2 * weights_stride_y));
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 3 * src_stride_y), (__global uchar *)(weights_addr + 3 * weights_stride_y));
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 4 * src_stride_y), (__global uchar *)(weights_addr + 4 * weights_stride_y));
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 5 * src_stride_y), (__global uchar *)(weights_addr + 5 * weights_stride_y));
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 6 * src_stride_y), (__global uchar *)(weights_addr + 6 * weights_stride_y));
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 7 * src_stride_y), (__global uchar *)(weights_addr + 7 * weights_stride_y));
+ CONVOLUTION1x9(pixels0, (__global uchar *)(src_addr + 8 * src_stride_y), (__global uchar *)(weights_addr + 8 * weights_stride_y));
+#elif KERNEL_SIZE == 5
CONVOLUTION1x5(pixels0, (__global uchar *)src_addr, (__global uchar *)weights_addr);
CONVOLUTION1x5(pixels0, (__global uchar *)(src_addr + 1 * src_stride_y), (__global uchar *)(weights_addr + 1 * weights_stride_y));
CONVOLUTION1x5(pixels0, (__global uchar *)(src_addr + 2 * src_stride_y), (__global uchar *)(weights_addr + 2 * weights_stride_y));