aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/direct_convolution5x5.cl
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2018-10-03 17:11:09 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:55:19 +0000
commitd041a835041159a0a6744fc271df15e9f66167bc (patch)
tree895f01d68218cbeabf639ef027d519fa4c96d655 /src/core/CL/cl_kernels/direct_convolution5x5.cl
parentecd9d09c7c77005586250587ec8e1ddb6f224bde (diff)
downloadComputeLibrary-d041a835041159a0a6744fc271df15e9f66167bc.tar.gz
COMPMID-1610: Fixed CLDirectConvolution mismatches
Kernel size 5x5 layout NHWC. Change-Id: Ia82ff211d1c954df228962b5c2c5ad8df7112449 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/151740 Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: bsgcomp <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/direct_convolution5x5.cl')
-rw-r--r--src/core/CL/cl_kernels/direct_convolution5x5.cl62
1 files changed, 57 insertions, 5 deletions
diff --git a/src/core/CL/cl_kernels/direct_convolution5x5.cl b/src/core/CL/cl_kernels/direct_convolution5x5.cl
index 70be058854..5299409243 100644
--- a/src/core/CL/cl_kernels/direct_convolution5x5.cl
+++ b/src/core/CL/cl_kernels/direct_convolution5x5.cl
@@ -194,11 +194,11 @@ __kernel void direct_convolution5x5_nhwc(
__global uchar *src_addr = (__global uchar *)offset(&src, 0, 0) - src_stride_x * id0 + ((id2 * STRIDE_Y) - PAD_TOP) * (int)src_stride_z;
weights_addr += id0 * weights_stride_w;
- const int coordy = id2 - PAD_TOP;
+#if(PAD_TOP == 1)
+ const int coordy = id2 - PAD_TOP;
for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
{
-#if(PAD_TOP)
if(coordy < 0) // special case Z = -1 doesn't exists
{
//skip first row and load the two next ones
@@ -224,17 +224,69 @@ __kernel void direct_convolution5x5_nhwc(
CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
}
-#else //PAD_TOP > 0
+ src_addr += src_stride_x;
+ weights_addr += weights_stride_x;
+ }
+#elif(PAD_TOP == 2)
+ const int coordy = id2 * STRIDE_Y;
+ for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
+ {
+ if(coordy == 0) // special case Z = -2 doesn't exists
+ {
+ //skip first row and load the two next ones
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+ }
+ else if(coordy == 1) // special case Z = -1 doesn't exists
+ {
+ //skip first row and load the two next ones
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+ }
+ else if(coordy == (SRC_HEIGHT - 1))
+ {
+ // special case when computing the last row of the output we must read the last three rows from the input buffer (including padding) but the
+ // Z axis has no padding at all.
+ CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr);
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ }
+ else if(coordy == (SRC_HEIGHT - 2))
+ {
+ // special case when computing the last row of the output we must read the last three rows from the input buffer (including padding) but the
+ // Z axis has no padding at all.
+ CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr);
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ }
+ else
+ {
+ CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr);
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
+ CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
+ }
+ src_addr += src_stride_x;
+ weights_addr += weights_stride_x;
+ }
+
+#else /* PAD_TOP == 2 */
+ for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
+ {
CONVOLUTION1x5_NHWC(values0, src_addr, weights_addr);
CONVOLUTION1x5_NHWC(values0, (src_addr + 1 * (int)src_stride_z), (weights_addr + 1 * (int)weights_stride_z));
CONVOLUTION1x5_NHWC(values0, (src_addr + 2 * (int)src_stride_z), (weights_addr + 2 * (int)weights_stride_z));
CONVOLUTION1x5_NHWC(values0, (src_addr + 3 * (int)src_stride_z), (weights_addr + 3 * (int)weights_stride_z));
CONVOLUTION1x5_NHWC(values0, (src_addr + 4 * (int)src_stride_z), (weights_addr + 4 * (int)weights_stride_z));
-#endif // PAD_TOP > 0
-
src_addr += src_stride_x;
weights_addr += weights_stride_x;
}
+#endif /* PAD_TOP == 1 */
#ifdef HAS_BIAS
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);