diff options
Diffstat (limited to 'src/core/CL/cl_kernels/direct_convolution5x5.cl')
-rw-r--r-- | src/core/CL/cl_kernels/direct_convolution5x5.cl | 62 |
1 files changed, 57 insertions, 5 deletions
diff --git a/src/core/CL/cl_kernels/direct_convolution5x5.cl b/src/core/CL/cl_kernels/direct_convolution5x5.cl index 70be058854..5299409243 100644 --- a/src/core/CL/cl_kernels/direct_convolution5x5.cl +++ b/src/core/CL/cl_kernels/direct_convolution5x5.cl @@ -194,11 +194,11 @@ __kernel void direct_convolution5x5_nhwc( __global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - src_stride_x * id0 + ((id2 * STRIDE_Y) - PAD_TOP) * (int)src_stride_z; weights_addr += id0 * weights_stride_w; - const int coordy = id2 - PAD_TOP; +#if(PAD_TOP == 1) + const int coordy = id2 - PAD_TOP; for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d) { -#if(PAD_TOP) if(coordy < 0) // special case Z = -1 doesn't exists { //skip first row and load the two next ones @@ -224,17 +224,69 @@ __kernel void direct_convolution5x5_nhwc( CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z)); CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z)); } -#else //PAD_TOP > 0 + src_addr += src_stride_x; + weights_addr += weights_stride_x; + } +#elif(PAD_TOP == 2) + const int coordy = id2 * STRIDE_Y; + for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d) + { + if(coordy == 0) // special case Z = -2 doesn't exists + { + //skip first row and load the two next ones + CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z)); + } + else if(coordy == 1) // special case Z = -1 doesn't exists + { + //skip first row and load the two next ones + CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z)); + } + else if(coordy == (SRC_HEIGHT - 1)) + { + // special case when computing the last row of the output we must read the last three rows from the input buffer (including padding) but the + // Z axis has no padding at all. + CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr); + CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z)); + } + else if(coordy == (SRC_HEIGHT - 2)) + { + // special case when computing the last row of the output we must read the last three rows from the input buffer (including padding) but the + // Z axis has no padding at all. + CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr); + CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z)); + } + else + { + CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr); + CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z)); + CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z)); + } + src_addr += src_stride_x; + weights_addr += weights_stride_x; + } + +#else /* PAD_TOP == 2 */ + for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d) + { CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr); CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z)); CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z)); CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z)); CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z)); -#endif // PAD_TOP > 0 - src_addr += src_stride_x; weights_addr += weights_stride_x; } +#endif /* PAD_TOP == 1 */ #ifdef HAS_BIAS Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases); |